diff --git a/src/lib.rs b/src/lib.rs index b2cb4f6..6b215ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -540,10 +540,7 @@ impl Options { // 5. If method is not a case-sensitive match for any of the values in list of methods // do not set any additional headers and terminate this set of steps. - let response = response.allowed_methods( - &method, - self.allowed_methods.clone(), - )?; + let response = response.allowed_methods(&method, &self.allowed_methods)?; // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the // values in list of headers do not set any additional headers and terminate this set of @@ -672,6 +669,7 @@ impl Options { /// - `Access-Control-Allow-Methods` /// - `Access-Control-Allow-Headers` /// - `Vary` +#[derive(Debug)] pub struct Response { responder: R, allow_origin: Option>, @@ -706,8 +704,9 @@ impl<'r, R: Responder<'r>> Response { } /// Consumes the `Response` and return an altered response with origin set to "*" - fn any(self) -> Self { - self.origin("*", false) + fn any(mut self) -> Self { + self.allow_origin = Some(AllOrSome::All); + self } /// Consumes the responder and based on the provided list of allowed origins, @@ -770,8 +769,8 @@ impl<'r, R: Responder<'r>> Response { /// Consumes the CORS, set allow_methods to /// passed methods and returns changed CORS - fn methods(mut self, methods: HashSet) -> Self { - self.allow_methods = methods; + fn methods(mut self, methods: &HashSet) -> Self { + self.allow_methods = methods.clone(); self } @@ -780,7 +779,7 @@ impl<'r, R: Responder<'r>> Response { fn allowed_methods( self, method: &AccessControlRequestMethod, - allowed_methods: HashSet, + allowed_methods: &HashSet, ) -> Result { let &AccessControlRequestMethod(ref request_method) = method; if !allowed_methods.iter().any(|m| m == request_method) { @@ -788,7 +787,7 @@ impl<'r, R: Responder<'r>> Response { } // TODO: Subset to route? Or just the method requested for? - Ok(self.methods(allowed_methods)) + Ok(self.methods(&allowed_methods)) } /// Consumes the CORS, set allow_headers to @@ -871,7 +870,6 @@ impl<'r, R: Responder<'r>> Response { builder.raw_header("Access-Control-Allow-Headers", headers); } - if !self.allow_methods.is_empty() { let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect(); let methods = methods.join(", "); @@ -951,7 +949,6 @@ mod tests { use rocket; use rocket::local::Client; use rocket::http::Method; - use rocket::http::Status; use super::*; @@ -961,6 +958,8 @@ mod tests { Client::new(rocket).expect("valid rocket instance") } + // The following tests check that CORS Request headers are parsed correctly + #[test] fn origin_header_conversion() { let url = "https://foo.bar.xyz"; @@ -1051,26 +1050,443 @@ mod tests { ); } - #[get("/any")] - fn any() -> Response<&'static str> { - Response::new("Hello, world!").any() - } + // The following tests check `Response`'s validation #[test] - fn response_any_origin_smoke_test() { - let rocket = rocket::ignite().mount("/", routes![any]); - let client = not_err!(Client::new(rocket)); + fn response_allows_all_origin_with_wildcard() { + let url = "https://www.example.com"; + let origin = Origin::from_str(url).unwrap(); + let allowed_origins = AllOrSome::All; + let send_wildcard = true; - let req = client.get("/any"); - let mut response = req.dispatch(); + let response = Response::new(()); + let response = not_err!(response.allowed_origin( + &origin, + &allowed_origins, + send_wildcard, + )); - assert_eq!(Status::Ok, response.status()); - let body_str = response.body().and_then(|body| body.into_string()); - let values: Vec<_> = response + assert_matches!(response.allow_origin, Some(AllOrSome::All)); + assert_eq!(response.vary_origin, false); + + // Build response and check built response header + let expected_header = vec!["*"]; + let response = response.build(); + let actual_header: Vec<_> = response .headers() .get("Access-Control-Allow-Origin") .collect(); - assert_eq!(values, vec!["*"]); - assert_eq!(body_str, Some("Hello, world!".to_string())); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_allows_all_origin_with_echoed_domain() { + let url = "https://www.example.com"; + let origin = Origin::from_str(url).unwrap(); + let allowed_origins = AllOrSome::All; + let send_wildcard = false; + + let response = Response::new(()); + let response = not_err!(response.allowed_origin( + &origin, + &allowed_origins, + send_wildcard, + )); + + let actual_origin = assert_matches!( + response.allow_origin, + Some(AllOrSome::Some(ref origin)), + origin + ); + assert_eq!(url, actual_origin); + assert_eq!(response.vary_origin, true); + + // Build response and check built response header + let expected_header = vec![url]; + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_allows_origin() { + let url = "https://www.example.com"; + let origin = Origin::from_str(url).unwrap(); + let (allowed_origins, failed_origins) = + AllOrSome::new_from_str_list(&["https://www.example.com"]); + assert!(failed_origins.is_empty()); + let send_wildcard = false; + + let response = Response::new(()); + let response = not_err!(response.allowed_origin( + &origin, + &allowed_origins, + send_wildcard, + )); + + let actual_origin = assert_matches!( + response.allow_origin, + Some(AllOrSome::Some(ref origin)), + origin + ); + + assert_eq!(url, actual_origin); + assert_eq!(response.vary_origin, false); + + // Build response and check built response header + let expected_header = vec![url]; + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + #[should_panic(expected = "OriginNotAllowed")] + fn response_rejects_invalid_origin() { + let url = "https://www.acme.com"; + let origin = Origin::from_str(url).unwrap(); + let (allowed_origins, failed_origins) = + AllOrSome::new_from_str_list(&["https://www.example.com"]); + assert!(failed_origins.is_empty()); + let send_wildcard = false; + + let response = Response::new(()); + let _ = response + .allowed_origin(&origin, &allowed_origins, send_wildcard) + .unwrap(); + } + + #[test] + #[should_panic(expected = "CredentialsWithWildcardOrigin")] + fn response_credentials_does_not_allow_wildcard_with_all_origins() { + let response = Response::new(()); + let response = response.any(); + + let _ = response.credentials(true).unwrap(); + } + + #[test] + fn response_credentials_allows_specific_origins() { + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + + let response = response.credentials(true).expect( + "to allow specific origins", + ); + assert_eq!(response.allow_credentials, true); + + // Build response and check built response header + let expected_header = vec!["true"]; + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Credentials") + .collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_sets_exposed_headers_correctly() { + let headers = vec!["Bar", "Baz", "Foo"]; + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let response = response.exposed_headers(&headers); + + // Build response and check built response header + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Expose-Headers") + .collect(); + + assert_eq!(1, actual_header.len()); + let mut actual_headers: Vec = actual_header[0] + .split(',') + .map(|header| header.trim().to_string()) + .collect(); + actual_headers.sort(); + assert_eq!(headers, actual_headers); + } + + #[test] + fn response_sets_max_age_correctly() { + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + + let response = response.max_age(Some(42)); + + // Build response and check built response header + let expected_header = vec!["42"]; + let response = response.build(); + let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_does_not_set_max_age_when_none() { + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + + let response = response.max_age(None); + + // Build response and check built response header + let response = response.build(); + assert!(response + .headers() + .get("Access-Control-Max-Age") + .next().is_none()) + } + + /// When all headers are allowed, tests that the requested headers are echoed back + #[test] + fn response_allowed_headers_echoes_back_requested_headers() { + let allowed_headers = AllOrSome::All; + let requested_headers = vec!["Bar", "Foo"]; + + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let response = response + .allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &allowed_headers, + ) + .expect("to not fail"); + + // Build response and check built response header + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Headers") + .collect(); + + assert_eq!(1, actual_header.len()); + let mut actual_headers: Vec = actual_header[0] + .split(',') + .map(|header| header.trim().to_string()) + .collect(); + actual_headers.sort(); + assert_eq!(requested_headers, actual_headers); + } + + #[test] + fn response_allowed_methods_sets_headers_properly() { + let allowed_methods = vec![ + Method::Get, + Method::Head, + Method::Post, + ].into_iter() + .collect(); + + let method = "GET"; + + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let response = response + .allowed_methods( + &FromStr::from_str(method).expect("not to fail"), + &allowed_methods, + ) + .expect("not to fail"); + + // Build response and check built response header + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Methods") + .collect(); + + assert_eq!(1, actual_header.len()); + let mut actual_headers: Vec = actual_header[0] + .split(',') + .map(|header| header.trim().to_string()) + .collect(); + actual_headers.sort(); + let mut expected_headers: Vec<_> = allowed_methods.iter().map(|m| m.as_str()).collect(); + expected_headers.sort(); + assert_eq!(expected_headers, actual_headers); + } + + #[test] + #[should_panic(expected = "MethodNotAllowed")] + fn response_allowed_method_errors_on_disallowed_method() { + let allowed_methods = vec![ + Method::Get, + Method::Head, + Method::Post, + ].into_iter() + .collect(); + + let method = "DELETE"; + + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let _ = response + .allowed_methods( + &FromStr::from_str(method).expect("not to fail"), + &allowed_methods, + ) + .unwrap(); + } + + /// `Response::allowed_headers` should check that headers are allowed, and only + /// echoes back the list that is actually requested for and not the whole list + #[test] + fn response_allowed_headers_validates_and_echoes_requested_headers() { + let allowed_headers = vec!["Bar", "Baz", "Foo"]; + let requested_headers = vec!["Bar", "Foo"]; + + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let response = response + .allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &AllOrSome::Some( + allowed_headers + .iter() + .map(|s| FromStr::from_str(*s).unwrap()) + .collect(), + ), + ) + .expect("to not fail"); + + // Build response and check built response header + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Headers") + .collect(); + + assert_eq!(1, actual_header.len()); + let mut actual_headers: Vec = actual_header[0] + .split(',') + .map(|header| header.trim().to_string()) + .collect(); + actual_headers.sort(); + assert_eq!(requested_headers, actual_headers); + } + + #[test] + #[should_panic(expected = "HeadersNotAllowed")] + fn response_allowed_headers_errors_on_non_subset() { + let allowed_headers = vec!["Bar", "Baz", "Foo"]; + let requested_headers = vec!["Bar", "Foo", "Unknown"]; + + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let _ = response + .allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &AllOrSome::Some( + allowed_headers + .iter() + .map(|s| FromStr::from_str(*s).unwrap()) + .collect(), + ), + ) + .unwrap(); + + } + + #[test] + fn response_does_not_build_if_origin_is_not_set() { + let response = Response::new(()); + let response = response.build(); + + let headers: Vec<_> = response.headers().iter().collect(); + assert_eq!(headers.len(), 0); + } + + // Note: Correct operation of Response::build is tested in the tests above for each of the + // individual headers + + #[test] + fn response_merges_correctly() { + use std::io::Cursor; + use rocket::http::Status; + + let wrapped = response::Response::build() + .status(Status::ImATeapot) + .raw_header("X-Teapot-Make", "Rocket") + .sized_body(Cursor::new("Brewing the best coffee!")) + .finalize(); + + let response = Response::new(()); + let response = response.origin("https://www.acme.com", false); + + let mut response = Response::::merge(wrapped, response.build()); + assert_eq!(response.status(), Status::ImATeapot); + assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string())); + + // Check CORS header + let expected_header = vec!["https://www.acme.com"]; + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(expected_header, actual_header); + + // Check other header + let expected_header = vec!["Rocket"]; + let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_does_not_merge_existing_cors() { + let wrapped = response::Response::build() + .raw_header("Access-Control-Allow-Origin", "https://www.example.com") + .finalize(); + + let response = Response::new(()); + let response = response.origin("https://www.acme.com", false); + + let response = Response::<()>::merge(wrapped, response.build()); + let expected_header = vec!["https://www.example.com"]; + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_finalize_smoke_test() { + use std::io::Cursor; + use rocket::http::Status; + + let wrapped = response::Response::build() + .status(Status::ImATeapot) + .raw_header("X-Teapot-Make", "Rocket") + .sized_body(Cursor::new("Brewing the best coffee!")) + .finalize(); + + let response = Response::new(wrapped); + let response = response.origin("https://www.acme.com", false); + + let client = make_client(); + let request = client.get("/"); + let mut response = response.finalize(request.inner()).expect("not to fail"); + + assert_eq!(response.status(), Status::ImATeapot); + assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string())); + + // Check CORS header + let expected_header = vec!["https://www.acme.com"]; + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(expected_header, actual_header); + + // Check other header + let expected_header = vec!["Rocket"]; + let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect(); + assert_eq!(expected_header, actual_header); } } diff --git a/tests/headers.rs b/tests/headers.rs index d56f9e2..a81e167 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -28,7 +28,7 @@ fn request_headers( format!("{}\n{}\n{}", origin, method, headers.join(", ")) } -/// Tests that all the headers are parsed correcly in a HTTP request +/// Tests that all the request headers are parsed correcly in a HTTP request #[test] fn request_headers_round_trip_smoke_test() { let rocket = rocket::ignite().mount("/", routes![request_headers]);