diff --git a/src/lib.rs b/src/lib.rs index fb29fb4..d98845f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -437,7 +437,7 @@ impl fairing::Fairing for Cors { use rocket::response::Responder; // Build and merge CORS response - match build_cors_response(self, request, response) { + match build_cors_response(self, request) { Err(err) => { // CORS error -- overwrite the original response let error_response = match err.respond_to(request) { @@ -451,7 +451,9 @@ impl fairing::Fairing for Cors { }; response.merge(error_response) } - Ok(()) => { + Ok(cors_response) => { + cors_response.merge(response); + // If this was an OPTIONS request and no route can be found, we should turn this // into a HTTP 204 with no content body. // This allows the user to not have to specify an OPTIONS route for everything. @@ -508,10 +510,14 @@ impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> { fn respond(self, request: &Request) -> response::Result<'r> { let mut response = self.responder.respond_to(request)?; // handle status errors? - match build_cors_response(self.options, request, &mut response) { - Ok(()) => Ok(response), + match build_cors_response(self.options, request) { + Ok(cors_response) => { + cors_response.merge(&mut response); + Ok(response) + }, Err(e) => response::Responder::respond_to(e, request), } + } } @@ -569,48 +575,10 @@ impl Response { self } - /// Consumes the responder and based on the provided list of allowed origins, - /// check if the requested origin is allowed. - /// Useful for pre-flight and during requests - fn allowed_origin( - self, - origin: &Origin, - allowed_origins: &AllOrSome>, - send_wildcard: bool, - ) -> Result { - let origin = origin.origin().unicode_serialization(); - match *allowed_origins { - // Always matching is acceptable since the list of origins can be unbounded. - AllOrSome::All => { - if send_wildcard { - Ok(self.any()) - } else { - Ok(self.origin(&origin, true)) - } - } - AllOrSome::Some(ref allowed_origins) => { - let allowed_origins: HashSet<_> = allowed_origins - .iter() - .map(|o| o.origin().unicode_serialization()) - .collect(); - let _ = allowed_origins.get(&origin).ok_or_else( - || Error::OriginNotAllowed, - )?; - Ok(self.origin(&origin, false)) - } - } - } - - /// Consumes the Response and validate whether credentials can be allowed - fn credentials(mut self, value: bool) -> Result { - if value { - if let Some(AllOrSome::All) = self.allow_origin { - Err(Error::CredentialsWithWildcardOrigin)?; - } - } - + /// Consumes the Response and set credentials + fn credentials(mut self, value: bool) -> Self { self.allow_credentials = value; - Ok(self) + self } /// Consumes the CORS, set expose_headers to @@ -634,22 +602,6 @@ impl Response { self } - /// Consumes the CORS, check if requested method is allowed. - /// Useful for pre-flight checks - fn allowed_methods( - self, - method: &AccessControlRequestMethod, - allowed_methods: &HashSet, - ) -> Result { - let &AccessControlRequestMethod(ref request_method) = method; - if !allowed_methods.iter().any(|m| m == request_method) { - Err(Error::MethodNotAllowed)? - } - - // TODO: Subset to route? Or just the method requested for? - Ok(self.methods(&allowed_methods)) - } - /// Consumes the CORS, set allow_headers to /// passed headers and returns changed CORS fn headers(mut self, headers: &[&str]) -> Self { @@ -657,35 +609,6 @@ impl Response { self } - /// Consumes the CORS, check if requested headers are allowed. - /// Useful for pre-flight checks - fn allowed_headers( - self, - headers: &AccessControlRequestHeaders, - allowed_headers: &AllOrSome>, - ) -> Result { - let &AccessControlRequestHeaders(ref headers) = headers; - - match *allowed_headers { - AllOrSome::All => {} - AllOrSome::Some(ref allowed_headers) => { - if !headers.is_empty() && !headers.is_subset(allowed_headers) { - Err(Error::HeadersNotAllowed)? - } - } - }; - - Ok( - self.headers( - headers - .iter() - .map(|s| &**s.deref()) - .collect::>() - .as_slice(), - ), - ) - } - /// Builds a `rocket::Response` from this struct based off some base `rocket::Response` /// /// This will overwrite any existing CORS headers @@ -788,12 +711,11 @@ pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>( fn build_cors_response( options: &Cors, request: &Request, - response: &mut rocket::Response, -) -> Result<(), Error> { +) -> Result { // Existing CORS response? - if has_allow_origin(response) { - return Ok(()); - } + // if has_allow_origin(response) { + // return Ok(()); + // } // 1. If the Origin header is not present terminate this set of steps. // The request is outside the scope of this specification. @@ -801,7 +723,7 @@ fn build_cors_response( let origin = match origin { None => { // Not a CORS request - return Ok(()); + return Ok(Response::new()); } Some(origin) => origin, }; @@ -816,10 +738,60 @@ fn build_cors_response( _ => actual_request(options, origin), }?; - cors_response.merge(response); + Ok(cors_response) +} + +/// Consumes the responder and based on the provided list of allowed origins, +/// check if the requested origin is allowed. +/// Useful for pre-flight and during requests +fn validate_origin( + origin: &Origin, + allowed_origins: &AllOrSome>, +) -> Result<(), Error> { + match *allowed_origins { + // Always matching is acceptable since the list of origins can be unbounded. + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_origins) => { + allowed_origins + .get(origin) + .and_then(|_| Some(())) + .ok_or_else(|| Error::OriginNotAllowed) + } + } +} + +/// Validate allowed methods +fn validate_allowed_method( + method: &AccessControlRequestMethod, + allowed_methods: &HashSet, +) -> Result<(), Error> { + let &AccessControlRequestMethod(ref request_method) = method; + if !allowed_methods.iter().any(|m| m == request_method) { + Err(Error::MethodNotAllowed)? + } + + // TODO: Subset to route? Or just the method requested for? Ok(()) } +/// Validate allowed headers +fn validate_allowed_headers( + headers: &AccessControlRequestHeaders, + allowed_headers: &AllOrSome>, +) -> Result<(), Error> { + let &AccessControlRequestHeaders(ref headers) = headers; + + match *allowed_headers { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_headers) => { + if !headers.is_empty() && !headers.is_subset(allowed_headers) { + Err(Error::HeadersNotAllowed)? + } + Ok(()) + } + } +} + /// Gets the `Origin` request header from the request fn origin(request: &Request) -> Result, Error> { match Origin::from_request(request) { @@ -863,18 +835,24 @@ fn preflight( method: Option, headers: Option, ) -> Result { - + options.validate()?; let response = Response::new(); // Note: All header parse failures are dealt with in the `FromRequest` trait implementation // 2. If the value of the Origin header is not a case-sensitive match for any of the values // in list of origins do not set any additional headers and terminate this set of steps. - let response = response.allowed_origin( - &origin, - &options.allowed_origins, - options.send_wildcard, - )?; + validate_origin(&origin, &options.allowed_origins)?; + let response = match options.allowed_origins { + AllOrSome::All => { + if options.send_wildcard { + response.any() + } else { + response.origin(origin.as_str(), true) + } + } + AllOrSome::Some(_) => response.origin(origin.as_str(), false), + }; // 3. Let `method` be the value as result of parsing the Access-Control-Request-Method // header. @@ -894,13 +872,23 @@ fn preflight( // 5. If method is not a case-sensitive match for any of the values in list of methods // do not set any additional headers and terminate this set of steps. - let response = response.allowed_methods(&method, &options.allowed_methods)?; + validate_allowed_method(&method, &options.allowed_methods)?; + let response = response.methods(&options.allowed_methods); // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the // values in list of headers do not set any additional headers and terminate this set of // steps. - let response = if let Some(headers) = headers { - response.allowed_headers(&headers, &options.allowed_headers)? + + let response = if let Some(ref headers) = headers { + validate_allowed_headers(headers, &options.allowed_headers)?; + let &AccessControlRequestHeaders(ref headers) = headers; + response.headers( + headers + .iter() + .map(|s| &**s.deref()) + .collect::>() + .as_slice(), + ) } else { response }; @@ -913,7 +901,8 @@ fn preflight( // with either the value of the Origin header or the string "*" as value. // Note: The string "*" cannot be used for a resource that supports credentials. - let response = response.credentials(options.allow_credentials)?; + // Validation has been done in options.validate + let response = response.credentials(options.allow_credentials); // 8. Optionally add a single Access-Control-Max-Age header // with as value the amount of seconds the user agent is allowed to cache the result of the @@ -949,6 +938,7 @@ fn preflight( /// If the `Origin` is not provided, then this request was not made by a browser and there is no /// CORS enforcement. fn actual_request(options: &Cors, origin: Origin) -> Result { + options.validate()?; let response = Response::new(); // Note: All header parse failures are dealt with in the `FromRequest` trait implementation @@ -957,11 +947,17 @@ fn actual_request(options: &Cors, origin: Origin) -> Result { // in list of origins, do not set any additional headers and terminate this set of steps. // Always matching is acceptable since the list of origins can be unbounded. - let response = response.allowed_origin( - &origin, - &options.allowed_origins, - options.send_wildcard, - )?; + validate_origin(&origin, &options.allowed_origins)?; + let response = match options.allowed_origins { + AllOrSome::All => { + if options.send_wildcard { + response.any() + } else { + response.origin(origin.as_str(), true) + } + } + AllOrSome::Some(_) => response.origin(origin.as_str(), false), + }; // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, // with the value of the Origin header as value, and add a @@ -971,7 +967,8 @@ fn actual_request(options: &Cors, origin: Origin) -> Result { // with either the value of the Origin header or the string "*" as value. // Note: The string "*" cannot be used for a resource that supports credentials. - let response = response.credentials(options.allow_credentials)?; + // Validation has been done in options.validate + let response = response.credentials(options.allow_credentials); // 4. If the list of exposed headers is not empty add one or more // Access-Control-Expose-Headers headers, with as values the header field names given in @@ -1033,65 +1030,15 @@ mod tests { cors.validate().unwrap(); } - // The following tests check `Response`'s validation + // The following tests check validation #[test] - fn response_allows_all_origin_with_wildcard() { + fn validate_origin_allows_all_origins() { let url = "https://www.example.com"; let origin = Origin::from_str(url).unwrap(); let allowed_origins = AllOrSome::All; - let send_wildcard = true; - let response = Response::new(); - let response = not_err!(response.allowed_origin( - &origin, - &allowed_origins, - send_wildcard, - )); - - assert_matches!(response.allow_origin, Some(AllOrSome::All)); - assert_eq!(response.vary_origin, false); - - // Build response and check built response header - let expected_header = vec!["*"]; - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Origin") - .collect(); - assert_eq!(expected_header, actual_header); - } - - #[test] - fn response_allows_all_origin_with_echoed_domain() { - let url = "https://www.example.com"; - let origin = Origin::from_str(url).unwrap(); - let allowed_origins = AllOrSome::All; - let send_wildcard = false; - - let response = Response::new(); - let response = not_err!(response.allowed_origin( - &origin, - &allowed_origins, - send_wildcard, - )); - - let actual_origin = assert_matches!( - response.allow_origin, - Some(AllOrSome::Some(ref origin)), - origin - ); - assert_eq!(url, actual_origin); - assert_eq!(response.vary_origin, true); - - // Build response and check built response header - let expected_header = vec![url]; - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Origin") - .collect(); - assert_eq!(expected_header, actual_header); + not_err!(validate_origin(&origin, &allowed_origins)); } #[test] @@ -1101,32 +1048,8 @@ mod tests { let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.example.com"]); assert!(failed_origins.is_empty()); - let send_wildcard = false; - let response = Response::new(); - let response = not_err!(response.allowed_origin( - &origin, - &allowed_origins, - send_wildcard, - )); - - let actual_origin = assert_matches!( - response.allow_origin, - Some(AllOrSome::Some(ref origin)), - origin - ); - - assert_eq!(url, actual_origin); - assert_eq!(response.vary_origin, false); - - // Build response and check built response header - let expected_header = vec![url]; - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Origin") - .collect(); - assert_eq!(expected_header, actual_header); + not_err!(validate_origin(&origin, &allowed_origins)); } #[test] @@ -1137,41 +1060,8 @@ mod tests { let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.example.com"]); assert!(failed_origins.is_empty()); - let send_wildcard = false; - let response = Response::new(); - let _ = response - .allowed_origin(&origin, &allowed_origins, send_wildcard) - .unwrap(); - } - - #[test] - #[should_panic(expected = "CredentialsWithWildcardOrigin")] - fn response_credentials_does_not_allow_wildcard_with_all_origins() { - let response = Response::new(); - let response = response.any(); - - let _ = response.credentials(true).unwrap(); - } - - #[test] - fn response_credentials_allows_specific_origins() { - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - - let response = response.credentials(true).expect( - "to allow specific origins", - ); - assert_eq!(response.allow_credentials, true); - - // Build response and check built response header - let expected_header = vec!["true"]; - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Credentials") - .collect(); - assert_eq!(expected_header, actual_header); + validate_origin(&origin, &allowed_origins).unwrap(); } #[test] @@ -1220,159 +1110,88 @@ mod tests { // Build response and check built response header let response = response.build(response::Response::new()); - assert!(response - .headers() - .get("Access-Control-Max-Age") - .next().is_none()) - } - - /// When all headers are allowed, tests that the requested headers are echoed back - #[test] - fn response_allowed_headers_echoes_back_requested_headers() { - let allowed_headers = AllOrSome::All; - let requested_headers = vec!["Bar", "Foo"]; - - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - let response = response - .allowed_headers( - &FromStr::from_str(&requested_headers.join(",")).unwrap(), - &allowed_headers, - ) - .expect("to not fail"); - - // Build response and check built response header - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Headers") - .collect(); - - assert_eq!(1, actual_header.len()); - let mut actual_headers: Vec = actual_header[0] - .split(',') - .map(|header| header.trim().to_string()) - .collect(); - actual_headers.sort(); - assert_eq!(requested_headers, actual_headers); + assert!( + response + .headers() + .get("Access-Control-Max-Age") + .next() + .is_none() + ) } #[test] - fn response_allowed_methods_sets_headers_properly() { - let allowed_methods = vec![ - Method::Get, - Method::Head, - Method::Post, - ].into_iter() + fn allowed_methods_validated_correctly() { + let allowed_methods = vec![Method::Get, Method::Head, Method::Post] + .into_iter() .collect(); let method = "GET"; - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - let response = response - .allowed_methods( - &FromStr::from_str(method).expect("not to fail"), - &allowed_methods, - ) - .expect("not to fail"); - - // Build response and check built response header - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Methods") - .collect(); - - assert_eq!(1, actual_header.len()); - let mut actual_headers: Vec = actual_header[0] - .split(',') - .map(|header| header.trim().to_string()) - .collect(); - actual_headers.sort(); - let mut expected_headers: Vec<_> = allowed_methods.iter().map(|m| m.as_str()).collect(); - expected_headers.sort(); - assert_eq!(expected_headers, actual_headers); + not_err!(validate_allowed_method( + &FromStr::from_str(method).expect("not to fail"), + &allowed_methods, + )); } #[test] #[should_panic(expected = "MethodNotAllowed")] - fn response_allowed_method_errors_on_disallowed_method() { - let allowed_methods = vec![ - Method::Get, - Method::Head, - Method::Post, - ].into_iter() + fn allowed_methods_errors_on_disallowed_method() { + let allowed_methods = vec![Method::Get, Method::Head, Method::Post] + .into_iter() .collect(); let method = "DELETE"; - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - let _ = response - .allowed_methods( - &FromStr::from_str(method).expect("not to fail"), - &allowed_methods, - ) - .unwrap(); + validate_allowed_method( + &FromStr::from_str(method).expect("not to fail"), + &allowed_methods, + ).unwrap() + } + + #[test] + fn all_allowed_headers_are_validated_correctly() { + let allowed_headers = AllOrSome::All; + let requested_headers = vec!["Bar", "Foo"]; + + not_err!(validate_allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &allowed_headers, + )); } /// `Response::allowed_headers` should check that headers are allowed, and only /// echoes back the list that is actually requested for and not the whole list #[test] - fn response_allowed_headers_validates_and_echoes_requested_headers() { + fn allowed_headers_are_validated_correctly() { let allowed_headers = vec!["Bar", "Baz", "Foo"]; let requested_headers = vec!["Bar", "Foo"]; - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - let response = response - .allowed_headers( - &FromStr::from_str(&requested_headers.join(",")).unwrap(), - &AllOrSome::Some( - allowed_headers - .iter() - .map(|s| FromStr::from_str(*s).unwrap()) - .collect(), - ), - ) - .expect("to not fail"); - - // Build response and check built response header - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Headers") - .collect(); - - assert_eq!(1, actual_header.len()); - let mut actual_headers: Vec = actual_header[0] - .split(',') - .map(|header| header.trim().to_string()) - .collect(); - actual_headers.sort(); - assert_eq!(requested_headers, actual_headers); + not_err!(validate_allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &AllOrSome::Some( + allowed_headers + .iter() + .map(|s| FromStr::from_str(*s).unwrap()) + .collect(), + ), + )); } #[test] #[should_panic(expected = "HeadersNotAllowed")] - fn response_allowed_headers_errors_on_non_subset() { + fn allowed_headers_errors_on_non_subset() { let allowed_headers = vec!["Bar", "Baz", "Foo"]; let requested_headers = vec!["Bar", "Foo", "Unknown"]; - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - let _ = response - .allowed_headers( - &FromStr::from_str(&requested_headers.join(",")).unwrap(), - &AllOrSome::Some( - allowed_headers - .iter() - .map(|s| FromStr::from_str(*s).unwrap()) - .collect(), - ), - ) - .unwrap(); + validate_allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &AllOrSome::Some( + allowed_headers + .iter() + .map(|s| FromStr::from_str(*s).unwrap()) + .collect(), + ), + ).unwrap(); } @@ -1413,42 +1232,19 @@ mod tests { assert_eq!(expected_header, actual_header); // Check that `Access-Control-Max-Age` is removed - assert!(response.headers().get("Access-Control-Max-Age").next().is_none()); + assert!( + response + .headers() + .get("Access-Control-Max-Age") + .next() + .is_none() + ); } - // The following tests check that preflight checks are done properly + // TODO: Preflight tests + // TODO: Actual requests tests - // fn make_cors_options() -> Cors { - // let (allowed_origins, failed_origins) = - // AllOrSome::new_from_str_list(&["https://www.acme.com"]); - // assert!(failed_origins.is_empty()); - - // Cors { - // allowed_origins: allowed_origins, - // allowed_methods: [Method::Get].iter().cloned().collect(), - // allowed_headers: AllOrSome::Some( - // ["Authorization"] - // .into_iter() - // .map(|s| s.to_string().into()) - // .collect(), - // ), - // allow_credentials: true, - // ..Default::default() - // } - // } - - // /// Tests that non CORS preflight are let through without modification - // #[test] - // fn preflight_missing_origins_are_let_through() { - // let options = make_cors_options(); - // let client = make_client(); - // let request = client.get("/"); - - // let response = options.preflight((), None, None, None).expect("not to fail"); - - // let headers: Vec<_> = response.headers().iter().collect(); - // assert_eq!(headers.len(), 0); - // } + // Origin all (wildcard + echoed with Vary). Origin Echo } diff --git a/tests/headers.rs b/tests/headers.rs index ae5212a..544d823 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -9,7 +9,7 @@ use std::ops::Deref; use std::str::FromStr; use rocket::local::Client; -use rocket::http::{Header, Status}; +use rocket::http::Header; use rocket_cors::headers::*; #[get("/request_headers")]