From 29952e182d7c794a2b6f2f1fd488618280c70289 Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Fri, 14 Jul 2017 13:29:54 +0800 Subject: [PATCH 1/3] Refactor response building to preserve existing CORS headers --- src/lib.rs | 92 +++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 15 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5c37c17..515430f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -654,6 +654,21 @@ impl Options { /// A CORS Response which wraps another struct which implements `Responder`. You will typically /// use [`Options`] instead to verify and build the response instead of this directly. /// See module level documentation for usage examples. +/// +/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set, +/// this responder will leave the response untouched. +/// This allows for chaining of several CORS responders. +/// +/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any +/// existing headers defined: +/// +/// - `Access-Control-Allow-Origin` +/// - `Access-Control-Expose-Headers` +/// - `Access-Control-Max-Age` +/// - `Access-Control-Allow-Credentials` +/// - `Access-Control-Allow-Methods` +/// - `Access-Control-Allow-Headers` +/// - `Vary` pub struct Response { responder: R, allow_origin: Option>, @@ -808,26 +823,23 @@ impl<'r, R: Responder<'r>> Response { ), ) } -} -impl<'r, R: Responder<'r>> Responder<'r> for Response { + /// Builds a `rocket::Response` from this struct containing only the CORS headers. #[allow(unused_results)] - fn respond_to(self, request: &Request) -> response::Result<'r> { - use std::borrow::Cow; - - let mut builder = response::Response::build_from(self.responder.respond_to(request)?); + fn build(&self) -> response::Response<'r> { + let mut builder = response::Response::build(); let origin = match self.allow_origin { None => { // This is not a CORS response - return Ok(builder.finalize()); + return builder.finalize(); } - Some(origin) => origin, + Some(ref origin) => origin, }; - let origin: Cow = match origin { - AllOrSome::All => Into::into("*"), - AllOrSome::Some(origin) => Into::into(origin), + let origin = match *origin { + AllOrSome::All => "*".to_string(), + AllOrSome::Some(ref origin) => origin.to_string(), }; builder.raw_header("Access-Control-Allow-Origin", origin); @@ -838,7 +850,7 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response { if !self.expose_headers.is_empty() { let headers: Vec = self.expose_headers - .into_iter() + .iter() .map(|s| s.deref().to_string()) .collect(); let headers = headers.join(", "); @@ -848,7 +860,7 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response { if !self.allow_headers.is_empty() { let headers: Vec = self.allow_headers - .into_iter() + .iter() .map(|s| s.deref().to_string()) .collect(); let headers = headers.join(", "); @@ -858,7 +870,7 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response { if !self.allow_methods.is_empty() { - let methods: Vec<_> = self.allow_methods.into_iter().map(|m| m.as_str()).collect(); + let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect(); let methods = methods.join(", "); builder.raw_header("Access-Control-Allow-Methods", methods); @@ -873,7 +885,57 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response { builder.raw_header("Vary", "Origin"); } - Ok(builder.finalize()) + builder.finalize() + } + + /// Merge a `wrapped` Response with a `cors` response + /// + /// If the `wrapped` response has the `Access-Control-Allow-Origin` header already defined, + /// it will be left untouched. This allows for chaining of several CORS responders. + /// + /// Otherwise, the merging will be done according to the rules of `rocket::Response::merge`. + fn merge( + mut wrapped: response::Response<'r>, + cors: response::Response<'r>, + ) -> response::Response<'r> { + + let existing_cors = { + wrapped.headers().get("Access-Control-Allow-Origin").next() == None + }; + + if existing_cors { + wrapped.merge(cors); + } + + wrapped + } + + /// Finalize the Response by merging the CORS header with the wrapped `Responder + /// + /// If the original response has the `Access-Control-Allow-Origin` header already defined, + /// it will be left untouched.This allows for chaining of several CORS responders. + /// + /// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any + /// existing headers defined: + /// + /// - `Access-Control-Allow-Origin` + /// - `Access-Control-Expose-Headers` + /// - `Access-Control-Max-Age` + /// - `Access-Control-Allow-Credentials` + /// - `Access-Control-Allow-Methods` + /// - `Access-Control-Allow-Headers` + /// - `Vary` + fn finalize(self, request: &Request) -> response::Result<'r> { + let cors_response = self.build(); + let original_response = self.responder.respond_to(request)?; + + Ok(Self::merge(original_response, cors_response)) + } +} + +impl<'r, R: Responder<'r>> Responder<'r> for Response { + fn respond_to(self, request: &Request) -> response::Result<'r> { + self.finalize(request) } } From ca096ceb281ce3f4282b7fa38e0c00c843e6cc3a Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Fri, 14 Jul 2017 13:54:34 +0800 Subject: [PATCH 2/3] Extract headers integration tests --- src/lib.rs | 106 +++++++++++++++++++++++++-------------------- src/test_macros.rs | 7 --- tests/headers.rs | 63 +++++++++++++++++++++++++++ tests/routes.rs | 2 +- 4 files changed, 124 insertions(+), 54 deletions(-) create mode 100644 tests/headers.rs diff --git a/src/lib.rs b/src/lib.rs index 515430f..b2cb4f6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -261,6 +261,9 @@ impl<'a, 'r> FromRequest<'a, 'r> for Url { /// The `Origin` request header used in CORS +/// +/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) +/// to ensure that Origins are passed in correctly. pub type Origin = Url; /// The `Access-Control-Request-Method` request header @@ -948,22 +951,45 @@ mod tests { use rocket; use rocket::local::Client; use rocket::http::Method; - use rocket::http::{Header, Status}; + use rocket::http::Status; use super::*; + /// Make a client with no routes for unit testing + fn make_client() -> Client { + let rocket = rocket::ignite(); + Client::new(rocket).expect("valid rocket instance") + } + #[test] fn origin_header_conversion() { let url = "https://foo.bar.xyz"; - let _ = not_err!(Origin::from_str(url)); + let parsed = not_err!(Origin::from_str(url)); + let expected = not_err!(Url::from_str(url)); + assert_eq!(parsed, expected); let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used - let _ = not_err!(Origin::from_str(url)); + let parsed = not_err!(Origin::from_str(url)); + let expected = not_err!(Url::from_str(url)); + assert_eq!(parsed, expected); let url = "invalid_url"; let _ = is_err!(Origin::from_str(url)); } + #[test] + fn origin_header_parsing() { + let client = make_client(); + let mut request = client.get("/"); + + let origin = hyper::header::Origin::new("https", "www.example.com", None); + request.add_header(origin); + + let outcome: request::Outcome = FromRequest::from_request(request.inner()); + let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); + assert_eq!("https://www.example.com/", parsed_header.as_str()); + } + #[test] fn request_method_conversion() { let method = "POST"; @@ -978,6 +1004,20 @@ mod tests { let _ = is_err!(AccessControlRequestMethod::from_str(method)); } + #[test] + fn request_method_parsing() { + let client = make_client(); + let mut request = client.get("/"); + let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get); + request.add_header(method); + let outcome: request::Outcome = + FromRequest::from_request(request.inner()); + + let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); + let AccessControlRequestMethod(parsed_method) = parsed_header; + assert_eq!("GET", parsed_method.as_str()); + } + #[test] fn request_headers_conversion() { let headers = ["foo", "bar", "baz"]; @@ -988,53 +1028,27 @@ mod tests { assert_eq!(actual_headers, expected_headers); } - #[get("/request_headers")] - #[allow(needless_pass_by_value)] - fn request_headers( - origin: Origin, - method: AccessControlRequestMethod, - headers: AccessControlRequestHeaders, - ) -> String { - let AccessControlRequestMethod(method) = method; - let AccessControlRequestHeaders(headers) = headers; - let mut headers = headers - .iter() - .map(|s| s.deref().to_string()) - .collect::>(); - headers.sort(); - format!("{}\n{}\n{}", origin, method, headers.join(", ")) - } - - /// Tests that all the headers are parsed correcly in a HTTP request #[test] - fn request_headers_round_trip_smoke_test() { - let rocket = rocket::ignite().mount("/", routes![request_headers]); - let client = not_err!(Client::new(rocket)); - - let origin_header = Header::from(not_err!( - hyper::header::Origin::from_str("https://foo.bar.xyz") - )); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = hyper::header::AccessControlRequestHeaders(vec![ + fn request_headers_parsing() { + let client = make_client(); + let mut request = client.get("/"); + let headers = hyper::header::AccessControlRequestHeaders(vec![ FromStr::from_str("accept-language").unwrap(), - FromStr::from_str("X-Ping").unwrap(), + FromStr::from_str("date").unwrap(), ]); - let request_headers = Header::from(request_headers); - let req = client - .get("/request_headers") - .header(origin_header) - .header(method_header) - .header(request_headers); - let mut response = req.dispatch(); + request.add_header(headers); + let outcome: request::Outcome = + FromRequest::from_request(request.inner()); - assert_eq!(Status::Ok, response.status()); - let body_str = not_none!(response.body().and_then(|body| body.into_string())); - let expected_body = r#"https://foo.bar.xyz/ -GET -X-Ping, accept-language"#; - assert_eq!(expected_body, body_str); + let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); + let AccessControlRequestHeaders(parsed_headers) = parsed_header; + let mut parsed_headers: Vec = + parsed_headers.iter().map(|s| s.to_string()).collect(); + parsed_headers.sort(); + assert_eq!( + vec!["accept-language".to_string(), "date".to_string()], + parsed_headers + ); } #[get("/any")] diff --git a/src/test_macros.rs b/src/test_macros.rs index ff19c9f..43c3b87 100644 --- a/src/test_macros.rs +++ b/src/test_macros.rs @@ -12,13 +12,6 @@ macro_rules! is_err { }) } -macro_rules! not_none { - ($e:expr) => (match $e { - Some(e) => e, - None => panic!("{} failed with None", stringify!($e)), - }) -} - macro_rules! assert_matches { ($e: expr, $p: pat) => (assert_matches!($e, $p, ())); ($e: expr, $p: pat, $f: expr) => (match $e { diff --git a/tests/headers.rs b/tests/headers.rs new file mode 100644 index 0000000..d56f9e2 --- /dev/null +++ b/tests/headers.rs @@ -0,0 +1,63 @@ +//! This crate tests that all the request headers are parsed correctly in the round trip +#![feature(plugin, custom_derive)] +#![plugin(rocket_codegen)] +extern crate hyper; +extern crate rocket; +extern crate rocket_cors; + +use std::ops::Deref; +use std::str::FromStr; + +use rocket::local::Client; +use rocket::http::{Header, Status}; +use rocket_cors::*; + +#[get("/request_headers")] +fn request_headers( + origin: Origin, + method: AccessControlRequestMethod, + headers: AccessControlRequestHeaders, +) -> String { + let AccessControlRequestMethod(method) = method; + let AccessControlRequestHeaders(headers) = headers; + let mut headers = headers + .iter() + .map(|s| s.deref().to_string()) + .collect::>(); + headers.sort(); + format!("{}\n{}\n{}", origin, method, headers.join(", ")) +} + +/// Tests that all the headers are parsed correcly in a HTTP request +#[test] +fn request_headers_round_trip_smoke_test() { + let rocket = rocket::ignite().mount("/", routes![request_headers]); + let client = Client::new(rocket).expect("A valid Rocket client"); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://foo.bar.xyz").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders(vec![ + FromStr::from_str("accept-language").unwrap(), + FromStr::from_str("X-Ping").unwrap(), + ]); + let request_headers = Header::from(request_headers); + let req = client + .get("/request_headers") + .header(origin_header) + .header(method_header) + .header(request_headers); + let mut response = req.dispatch(); + + assert_eq!(Status::Ok, response.status()); + let body_str = response.body().and_then(|body| body.into_string()).expect( + "Non-empty body", + ); + let expected_body = r#"https://foo.bar.xyz/ +GET +X-Ping, accept-language"#; + assert_eq!(expected_body, body_str); +} diff --git a/tests/routes.rs b/tests/routes.rs index 6b2b346..1617572 100644 --- a/tests/routes.rs +++ b/tests/routes.rs @@ -1,4 +1,4 @@ -//! This crate tests using rocket_cors using the "classic"" per-route handling +//! This crate tests using rocket_cors using the "classic" per-route handling #![feature(plugin, custom_derive)] #![plugin(rocket_codegen)] From f1391281cd2463700ef1688b6e9f637b2f0c7f5f Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Fri, 14 Jul 2017 15:35:44 +0800 Subject: [PATCH 3/3] Response unit tests --- src/lib.rs | 468 ++++++++++++++++++++++++++++++++++++++++++++--- tests/headers.rs | 2 +- 2 files changed, 443 insertions(+), 27 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b2cb4f6..6b215ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -540,10 +540,7 @@ impl Options { // 5. If method is not a case-sensitive match for any of the values in list of methods // do not set any additional headers and terminate this set of steps. - let response = response.allowed_methods( - &method, - self.allowed_methods.clone(), - )?; + let response = response.allowed_methods(&method, &self.allowed_methods)?; // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the // values in list of headers do not set any additional headers and terminate this set of @@ -672,6 +669,7 @@ impl Options { /// - `Access-Control-Allow-Methods` /// - `Access-Control-Allow-Headers` /// - `Vary` +#[derive(Debug)] pub struct Response { responder: R, allow_origin: Option>, @@ -706,8 +704,9 @@ impl<'r, R: Responder<'r>> Response { } /// Consumes the `Response` and return an altered response with origin set to "*" - fn any(self) -> Self { - self.origin("*", false) + fn any(mut self) -> Self { + self.allow_origin = Some(AllOrSome::All); + self } /// Consumes the responder and based on the provided list of allowed origins, @@ -770,8 +769,8 @@ impl<'r, R: Responder<'r>> Response { /// Consumes the CORS, set allow_methods to /// passed methods and returns changed CORS - fn methods(mut self, methods: HashSet) -> Self { - self.allow_methods = methods; + fn methods(mut self, methods: &HashSet) -> Self { + self.allow_methods = methods.clone(); self } @@ -780,7 +779,7 @@ impl<'r, R: Responder<'r>> Response { fn allowed_methods( self, method: &AccessControlRequestMethod, - allowed_methods: HashSet, + allowed_methods: &HashSet, ) -> Result { let &AccessControlRequestMethod(ref request_method) = method; if !allowed_methods.iter().any(|m| m == request_method) { @@ -788,7 +787,7 @@ impl<'r, R: Responder<'r>> Response { } // TODO: Subset to route? Or just the method requested for? - Ok(self.methods(allowed_methods)) + Ok(self.methods(&allowed_methods)) } /// Consumes the CORS, set allow_headers to @@ -871,7 +870,6 @@ impl<'r, R: Responder<'r>> Response { builder.raw_header("Access-Control-Allow-Headers", headers); } - if !self.allow_methods.is_empty() { let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect(); let methods = methods.join(", "); @@ -951,7 +949,6 @@ mod tests { use rocket; use rocket::local::Client; use rocket::http::Method; - use rocket::http::Status; use super::*; @@ -961,6 +958,8 @@ mod tests { Client::new(rocket).expect("valid rocket instance") } + // The following tests check that CORS Request headers are parsed correctly + #[test] fn origin_header_conversion() { let url = "https://foo.bar.xyz"; @@ -1051,26 +1050,443 @@ mod tests { ); } - #[get("/any")] - fn any() -> Response<&'static str> { - Response::new("Hello, world!").any() - } + // The following tests check `Response`'s validation #[test] - fn response_any_origin_smoke_test() { - let rocket = rocket::ignite().mount("/", routes![any]); - let client = not_err!(Client::new(rocket)); + fn response_allows_all_origin_with_wildcard() { + let url = "https://www.example.com"; + let origin = Origin::from_str(url).unwrap(); + let allowed_origins = AllOrSome::All; + let send_wildcard = true; - let req = client.get("/any"); - let mut response = req.dispatch(); + let response = Response::new(()); + let response = not_err!(response.allowed_origin( + &origin, + &allowed_origins, + send_wildcard, + )); - assert_eq!(Status::Ok, response.status()); - let body_str = response.body().and_then(|body| body.into_string()); - let values: Vec<_> = response + assert_matches!(response.allow_origin, Some(AllOrSome::All)); + assert_eq!(response.vary_origin, false); + + // Build response and check built response header + let expected_header = vec!["*"]; + let response = response.build(); + let actual_header: Vec<_> = response .headers() .get("Access-Control-Allow-Origin") .collect(); - assert_eq!(values, vec!["*"]); - assert_eq!(body_str, Some("Hello, world!".to_string())); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_allows_all_origin_with_echoed_domain() { + let url = "https://www.example.com"; + let origin = Origin::from_str(url).unwrap(); + let allowed_origins = AllOrSome::All; + let send_wildcard = false; + + let response = Response::new(()); + let response = not_err!(response.allowed_origin( + &origin, + &allowed_origins, + send_wildcard, + )); + + let actual_origin = assert_matches!( + response.allow_origin, + Some(AllOrSome::Some(ref origin)), + origin + ); + assert_eq!(url, actual_origin); + assert_eq!(response.vary_origin, true); + + // Build response and check built response header + let expected_header = vec![url]; + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_allows_origin() { + let url = "https://www.example.com"; + let origin = Origin::from_str(url).unwrap(); + let (allowed_origins, failed_origins) = + AllOrSome::new_from_str_list(&["https://www.example.com"]); + assert!(failed_origins.is_empty()); + let send_wildcard = false; + + let response = Response::new(()); + let response = not_err!(response.allowed_origin( + &origin, + &allowed_origins, + send_wildcard, + )); + + let actual_origin = assert_matches!( + response.allow_origin, + Some(AllOrSome::Some(ref origin)), + origin + ); + + assert_eq!(url, actual_origin); + assert_eq!(response.vary_origin, false); + + // Build response and check built response header + let expected_header = vec![url]; + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + #[should_panic(expected = "OriginNotAllowed")] + fn response_rejects_invalid_origin() { + let url = "https://www.acme.com"; + let origin = Origin::from_str(url).unwrap(); + let (allowed_origins, failed_origins) = + AllOrSome::new_from_str_list(&["https://www.example.com"]); + assert!(failed_origins.is_empty()); + let send_wildcard = false; + + let response = Response::new(()); + let _ = response + .allowed_origin(&origin, &allowed_origins, send_wildcard) + .unwrap(); + } + + #[test] + #[should_panic(expected = "CredentialsWithWildcardOrigin")] + fn response_credentials_does_not_allow_wildcard_with_all_origins() { + let response = Response::new(()); + let response = response.any(); + + let _ = response.credentials(true).unwrap(); + } + + #[test] + fn response_credentials_allows_specific_origins() { + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + + let response = response.credentials(true).expect( + "to allow specific origins", + ); + assert_eq!(response.allow_credentials, true); + + // Build response and check built response header + let expected_header = vec!["true"]; + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Credentials") + .collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_sets_exposed_headers_correctly() { + let headers = vec!["Bar", "Baz", "Foo"]; + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let response = response.exposed_headers(&headers); + + // Build response and check built response header + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Expose-Headers") + .collect(); + + assert_eq!(1, actual_header.len()); + let mut actual_headers: Vec = actual_header[0] + .split(',') + .map(|header| header.trim().to_string()) + .collect(); + actual_headers.sort(); + assert_eq!(headers, actual_headers); + } + + #[test] + fn response_sets_max_age_correctly() { + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + + let response = response.max_age(Some(42)); + + // Build response and check built response header + let expected_header = vec!["42"]; + let response = response.build(); + let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_does_not_set_max_age_when_none() { + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + + let response = response.max_age(None); + + // Build response and check built response header + let response = response.build(); + assert!(response + .headers() + .get("Access-Control-Max-Age") + .next().is_none()) + } + + /// When all headers are allowed, tests that the requested headers are echoed back + #[test] + fn response_allowed_headers_echoes_back_requested_headers() { + let allowed_headers = AllOrSome::All; + let requested_headers = vec!["Bar", "Foo"]; + + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let response = response + .allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &allowed_headers, + ) + .expect("to not fail"); + + // Build response and check built response header + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Headers") + .collect(); + + assert_eq!(1, actual_header.len()); + let mut actual_headers: Vec = actual_header[0] + .split(',') + .map(|header| header.trim().to_string()) + .collect(); + actual_headers.sort(); + assert_eq!(requested_headers, actual_headers); + } + + #[test] + fn response_allowed_methods_sets_headers_properly() { + let allowed_methods = vec![ + Method::Get, + Method::Head, + Method::Post, + ].into_iter() + .collect(); + + let method = "GET"; + + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let response = response + .allowed_methods( + &FromStr::from_str(method).expect("not to fail"), + &allowed_methods, + ) + .expect("not to fail"); + + // Build response and check built response header + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Methods") + .collect(); + + assert_eq!(1, actual_header.len()); + let mut actual_headers: Vec = actual_header[0] + .split(',') + .map(|header| header.trim().to_string()) + .collect(); + actual_headers.sort(); + let mut expected_headers: Vec<_> = allowed_methods.iter().map(|m| m.as_str()).collect(); + expected_headers.sort(); + assert_eq!(expected_headers, actual_headers); + } + + #[test] + #[should_panic(expected = "MethodNotAllowed")] + fn response_allowed_method_errors_on_disallowed_method() { + let allowed_methods = vec![ + Method::Get, + Method::Head, + Method::Post, + ].into_iter() + .collect(); + + let method = "DELETE"; + + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let _ = response + .allowed_methods( + &FromStr::from_str(method).expect("not to fail"), + &allowed_methods, + ) + .unwrap(); + } + + /// `Response::allowed_headers` should check that headers are allowed, and only + /// echoes back the list that is actually requested for and not the whole list + #[test] + fn response_allowed_headers_validates_and_echoes_requested_headers() { + let allowed_headers = vec!["Bar", "Baz", "Foo"]; + let requested_headers = vec!["Bar", "Foo"]; + + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let response = response + .allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &AllOrSome::Some( + allowed_headers + .iter() + .map(|s| FromStr::from_str(*s).unwrap()) + .collect(), + ), + ) + .expect("to not fail"); + + // Build response and check built response header + let response = response.build(); + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Headers") + .collect(); + + assert_eq!(1, actual_header.len()); + let mut actual_headers: Vec = actual_header[0] + .split(',') + .map(|header| header.trim().to_string()) + .collect(); + actual_headers.sort(); + assert_eq!(requested_headers, actual_headers); + } + + #[test] + #[should_panic(expected = "HeadersNotAllowed")] + fn response_allowed_headers_errors_on_non_subset() { + let allowed_headers = vec!["Bar", "Baz", "Foo"]; + let requested_headers = vec!["Bar", "Foo", "Unknown"]; + + let response = Response::new(()); + let response = response.origin("https://www.example.com", false); + let _ = response + .allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &AllOrSome::Some( + allowed_headers + .iter() + .map(|s| FromStr::from_str(*s).unwrap()) + .collect(), + ), + ) + .unwrap(); + + } + + #[test] + fn response_does_not_build_if_origin_is_not_set() { + let response = Response::new(()); + let response = response.build(); + + let headers: Vec<_> = response.headers().iter().collect(); + assert_eq!(headers.len(), 0); + } + + // Note: Correct operation of Response::build is tested in the tests above for each of the + // individual headers + + #[test] + fn response_merges_correctly() { + use std::io::Cursor; + use rocket::http::Status; + + let wrapped = response::Response::build() + .status(Status::ImATeapot) + .raw_header("X-Teapot-Make", "Rocket") + .sized_body(Cursor::new("Brewing the best coffee!")) + .finalize(); + + let response = Response::new(()); + let response = response.origin("https://www.acme.com", false); + + let mut response = Response::::merge(wrapped, response.build()); + assert_eq!(response.status(), Status::ImATeapot); + assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string())); + + // Check CORS header + let expected_header = vec!["https://www.acme.com"]; + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(expected_header, actual_header); + + // Check other header + let expected_header = vec!["Rocket"]; + let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_does_not_merge_existing_cors() { + let wrapped = response::Response::build() + .raw_header("Access-Control-Allow-Origin", "https://www.example.com") + .finalize(); + + let response = Response::new(()); + let response = response.origin("https://www.acme.com", false); + + let response = Response::<()>::merge(wrapped, response.build()); + let expected_header = vec!["https://www.example.com"]; + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_finalize_smoke_test() { + use std::io::Cursor; + use rocket::http::Status; + + let wrapped = response::Response::build() + .status(Status::ImATeapot) + .raw_header("X-Teapot-Make", "Rocket") + .sized_body(Cursor::new("Brewing the best coffee!")) + .finalize(); + + let response = Response::new(wrapped); + let response = response.origin("https://www.acme.com", false); + + let client = make_client(); + let request = client.get("/"); + let mut response = response.finalize(request.inner()).expect("not to fail"); + + assert_eq!(response.status(), Status::ImATeapot); + assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string())); + + // Check CORS header + let expected_header = vec!["https://www.acme.com"]; + let actual_header: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(expected_header, actual_header); + + // Check other header + let expected_header = vec!["Rocket"]; + let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect(); + assert_eq!(expected_header, actual_header); } } diff --git a/tests/headers.rs b/tests/headers.rs index d56f9e2..a81e167 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -28,7 +28,7 @@ fn request_headers( format!("{}\n{}\n{}", origin, method, headers.join(", ")) } -/// Tests that all the headers are parsed correcly in a HTTP request +/// Tests that all the request headers are parsed correcly in a HTTP request #[test] fn request_headers_round_trip_smoke_test() { let rocket = rocket::ignite().mount("/", routes![request_headers]);