From 7dbc22b523fcf09361be61653d5d6b326e001bd1 Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Sat, 15 Jul 2017 01:38:13 +0800 Subject: [PATCH] Delay CORS checks and response until `Responder::respond_to` is invoked (#6) * Delay checking of CORS to just before responding * Lifetime issues * Use State::inner() * Fix lifetime issues * Bump Rocket * Document 'static limitation And link to https://github.com/SergioBenitez/Rocket/pull/345 * Remove extraneous comments --- Cargo.toml | 4 +- README.md | 2 +- src/lib.rs | 498 +++++++++++++++++++++++++++--------------------- tests/routes.rs | 17 +- 4 files changed, 286 insertions(+), 235 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1ff5259..362c11b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ travis-ci = { repository = "lawliet89/rocket_cors" } [dependencies] log = "0.3" -rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "aa51fe0" } +rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "51a465f2cc88d537079133bcdfec37d029070dcd" } serde = "1.0" serde_derive = "1.0" unicase="1.4" @@ -29,5 +29,5 @@ version_check = "0.1" [dev-dependencies] hyper = "0.10" -rocket_codegen = { git = "https://github.com/SergioBenitez/Rocket", rev = "aa51fe0" } +rocket_codegen = { git = "https://github.com/SergioBenitez/Rocket", rev = "51a465f2cc88d537079133bcdfec37d029070dcd" } serde_json = "1.0" diff --git a/README.md b/README.md index 1ef30df..425e6ac 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ In particular, `rocket_cors` is currently targetted for `nightly-2017-07-13`. Rocket > 0.3 is needed. At this moment, `0.3` is not released, and this crate will not be published to Crates.io until Rocket 0.3 is released to Crates.io. -We currently tie this crate to revision [aa51fe0](https://github.com/SergioBenitez/Rocket/tree/aa51fe0) of Rocket. +We currently tie this crate to revision [51a465f2cc88d537079133bcdfec37d029070dcd](https://github.com/SergioBenitez/Rocket/tree/51a465f2cc88d537079133bcdfec37d029070dcd) of Rocket. ## Installation diff --git a/src/lib.rs b/src/lib.rs index 6b215ed..c227ec5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,7 @@ //! to Crates.io until Rocket 0.3 is released to Crates.io. //! //! We currently tie this crate to revision -//! [aa51fe0](https://github.com/SergioBenitez/Rocket/tree/aa51fe0) of Rocket. +//! [51a465f2cc88d537079133bcdfec37d029070dcd](https://github.com/SergioBenitez/Rocket/tree/51a465f2cc88d537079133bcdfec37d029070dcd) of Rocket. //! //! ## Installation //! @@ -118,20 +118,27 @@ extern crate hyper; use std::collections::{HashSet, HashMap}; use std::error; use std::fmt; +use std::marker::PhantomData; use std::ops::Deref; use std::str::FromStr; -use rocket::request::{self, Request, FromRequest}; -use rocket::response::{self, Responder}; +use rocket::{Outcome, State}; use rocket::http::{Method, Status}; -use rocket::Outcome; +use rocket::request::{self, Request, FromRequest}; +use rocket::response; use unicase::UniCase; #[cfg(test)] #[macro_use] mod test_macros; -/// CORS related error +/// Errors during operations +/// +/// This enum implements `rocket::response::Responder` which will return an appropriate status code +/// while printing out the error in the console. +/// Because these errors are usually the result of an error while trying to respond to a CORS +/// request, CORS headers cannot be added to the response and your applications requesting CORS +/// will not be able to see the status code. #[derive(Debug)] pub enum Error { /// The HTTP request header `Origin` is required but was not provided @@ -201,7 +208,7 @@ impl fmt::Display for Error { } } -impl<'r> Responder<'r> for Error { +impl<'r> response::Responder<'r> for Error { fn respond_to(self, _: &Request) -> Result, Status> { error_!("CORS Error: {:?}", self); Err(match self { @@ -259,7 +266,6 @@ impl<'a, 'r> FromRequest<'a, 'r> for Url { } } - /// The `Origin` request header used in CORS /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) @@ -374,7 +380,11 @@ impl AllOrSome> { } } -/// Configuration options to for building CORS preflight or actual responses. +/// Responder and Fairing for CORS +/// +/// This struct can be used as Fairing for Rocket, or as an ad-hoc responder for any CORS requests. +/// You create a new copy of this struct by defining the configurations in the fields below. +/// This struct can also be deserialized by serde. /// /// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this /// struct. The default for each field is described in the docuementation for the field. @@ -473,7 +483,30 @@ impl Default for Options { } } +/// Ad-hoc per route CORS response to requests +/// +/// Note: If you use this, the lifetime parameter `'r` of your `rocket:::response::Responder<'r>` +/// CANNOT be `'static`. This is because the code generated by Rocket will implicitly try to +/// to restrain the `Request` object passed to the route to `&'static Request`, and it is not +/// possible to have such a reference. +/// See [this PR on Rocket](https://github.com/SergioBenitez/Rocket/pull/345). +pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>( + options: State<'a, Options>, + responder: R, +) -> Responder<'a, 'r, R> { + options.inner().respond(responder) +} + impl Options { + /// Wrap any `Rocket::Response` and respond with CORS headers. + /// This is only used for ad-hoc route CORS response + fn respond<'a, 'r: 'a, R: response::Responder<'r>>( + &'a self, + responder: R, + ) -> Responder<'a, 'r, R> { + Responder::new(responder, self) + } + fn default_allowed_methods() -> HashSet { vec![ Method::Get, @@ -486,40 +519,138 @@ impl Options { ].into_iter() .collect() } +} + +/// A CORS Responder which will inspect the incoming requests and respond accoridingly. +/// +/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set, +/// this responder will leave the response untouched. +/// This allows for chaining of several CORS responders. +/// +/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any +/// existing headers defined: +/// +/// - `Access-Control-Allow-Origin` +/// - `Access-Control-Expose-Headers` +/// - `Access-Control-Max-Age` +/// - `Access-Control-Allow-Credentials` +/// - `Access-Control-Allow-Methods` +/// - `Access-Control-Allow-Headers` +/// - `Vary` +#[derive(Debug)] +pub struct Responder<'a, 'r: 'a, R> { + responder: R, + options: &'a Options, + marker: PhantomData>, +} + +impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> { + fn new(responder: R, options: &'a Options) -> Self { + Self { + responder, + options, + marker: PhantomData, + } + } + + /// Respond to a request + fn respond(self, request: &Request) -> response::Result<'r> { + match self.build_cors_response(request) { + Ok(response) => response, + Err(e) => response::Responder::respond_to(e, request), + } + } + + /// Build a CORS response and merge with an existing `rocket::Response` for the request + fn build_cors_response(self, request: &Request) -> Result, Error> { + let original_response = match self.responder.respond_to(request) { + Ok(response) => response, + Err(status) => return Ok(Err(status)), // TODO: Handle this? + }; + + // Existing CORS response? + if Self::has_allow_origin(&original_response) { + return Ok(Ok(original_response)); + } + + // 1. If the Origin header is not present terminate this set of steps. + // The request is outside the scope of this specification. + let origin = Self::origin(request)?; + let origin = match origin { + None => { + // Not a CORS request + return Ok(Ok(original_response)); + } + Some(origin) => origin, + }; + + // Check if the request verb is an OPTION or something else + let cors_response = match request.method() { + Method::Options => { + let method = Self::request_method(request)?; + let headers = Self::request_headers(request)?; + Self::preflight(&self.options, origin, method, headers) + } + _ => Self::actual_request(&self.options, origin), + }?; + + Ok(Ok(cors_response.build(original_response))) + } + + /// Gets the `Origin` request header from the request + fn origin(request: &Request) -> Result, Error> { + match Origin::from_request(request) { + Outcome::Forward(()) => Ok(None), + Outcome::Success(origin) => Ok(Some(origin)), + Outcome::Failure((_, err)) => Err(err), + } + } + + /// Gets the `Access-Control-Request-Method` request header from the request + fn request_method(request: &Request) -> Result, Error> { + match AccessControlRequestMethod::from_request(request) { + Outcome::Forward(()) => Ok(None), + Outcome::Success(method) => Ok(Some(method)), + Outcome::Failure((_, err)) => Err(err), + } + } + + /// Gets the `Access-Control-Request-Headers` request header from the request + fn request_headers(request: &Request) -> Result, Error> { + match AccessControlRequestHeaders::from_request(request) { + Outcome::Forward(()) => Ok(None), + Outcome::Success(geaders) => Ok(Some(geaders)), + Outcome::Failure((_, err)) => Err(err), + } + } + + /// Checks if an existing Response already has the header `Access-Control-Allow-Origin` + fn has_allow_origin(response: &response::Response<'r>) -> bool { + response.headers().get("Access-Control-Allow-Origin").next() != None + } /// Construct a preflight response based on the options. Will return an `Err` /// if any of the preflight checks fail. /// /// This implementation references the /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests). - pub fn preflight<'r, R: Responder<'r>>( - &self, - responder: R, - origin: Option, + fn preflight( + options: &Options, + origin: Origin, method: Option, headers: Option, - ) -> Result, Error> { + ) -> Result { - let response = Response::new(responder); + let response = Response::new(); // Note: All header parse failures are dealt with in the `FromRequest` trait implementation - // 1. If the Origin header is not present terminate this set of steps. - // The request is outside the scope of this specification. - let origin = match origin { - None => { - // Not a CORS request - return Ok(response); - } - Some(origin) => origin, - }; - // 2. If the value of the Origin header is not a case-sensitive match for any of the values // in list of origins do not set any additional headers and terminate this set of steps. let response = response.allowed_origin( &origin, - &self.allowed_origins, - self.send_wildcard, + &options.allowed_origins, + options.send_wildcard, )?; // 3. Let `method` be the value as result of parsing the Access-Control-Request-Method @@ -540,13 +671,13 @@ impl Options { // 5. If method is not a case-sensitive match for any of the values in list of methods // do not set any additional headers and terminate this set of steps. - let response = response.allowed_methods(&method, &self.allowed_methods)?; + let response = response.allowed_methods(&method, &options.allowed_methods)?; // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the // values in list of headers do not set any additional headers and terminate this set of // steps. let response = if let Some(headers) = headers { - response.allowed_headers(&headers, &self.allowed_headers)? + response.allowed_headers(&headers, &options.allowed_headers)? } else { response }; @@ -559,12 +690,12 @@ impl Options { // with either the value of the Origin header or the string "*" as value. // Note: The string "*" cannot be used for a resource that supports credentials. - let response = response.credentials(self.allow_credentials)?; + let response = response.credentials(options.allow_credentials)?; // 8. Optionally add a single Access-Control-Max-Age header // with as value the amount of seconds the user agent is allowed to cache the result of the // request. - let response = response.max_age(self.max_age); + let response = response.max_age(options.max_age); // 9. If method is a simple method this step may be skipped. // Add one or more Access-Control-Allow-Methods headers consisting of @@ -591,36 +722,22 @@ impl Options { Ok(response) } - /// Respond to a request based on the settings. + /// Respond to an actual request based on the settings. /// If the `Origin` is not provided, then this request was not made by a browser and there is no /// CORS enforcement. - pub fn respond<'r, R: Responder<'r>>( - &self, - responder: R, - origin: Option, - ) -> Result, Error> { - let response = Response::new(responder); + fn actual_request(options: &Options, origin: Origin) -> Result { + let response = Response::new(); // Note: All header parse failures are dealt with in the `FromRequest` trait implementation - // 1. If the Origin header is not present terminate this set of steps. - // The request is outside the scope of this specification. - let origin = match origin { - None => { - // Not a CORS request - return Ok(response); - } - Some(origin) => origin, - }; - // 2. If the value of the Origin header is not a case-sensitive match for any of the values // in list of origins, do not set any additional headers and terminate this set of steps. // Always matching is acceptable since the list of origins can be unbounded. let response = response.allowed_origin( &origin, - &self.allowed_origins, - self.send_wildcard, + &options.allowed_origins, + options.send_wildcard, )?; // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, @@ -631,7 +748,7 @@ impl Options { // with either the value of the Origin header or the string "*" as value. // Note: The string "*" cannot be used for a resource that supports credentials. - let response = response.credentials(self.allow_credentials)?; + let response = response.credentials(options.allow_credentials)?; // 4. If the list of exposed headers is not empty add one or more // Access-Control-Expose-Headers headers, with as values the header field names given in @@ -641,7 +758,8 @@ impl Options { // and url is a case-sensitive match for the URL of the resource. let response = response.exposed_headers( - self.expose_headers + options + .expose_headers .iter() .map(|s| &**s) .collect::>() @@ -651,16 +769,14 @@ impl Options { } } -/// A CORS Response which wraps another struct which implements `Responder`. You will typically -/// use [`Options`] instead to verify and build the response instead of this directly. -/// See module level documentation for usage examples. -/// -/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set, -/// this responder will leave the response untouched. -/// This allows for chaining of several CORS responders. -/// -/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any -/// existing headers defined: +impl<'a, 'r: 'a, R: response::Responder<'r>> response::Responder<'r> for Responder<'a, 'r, R> { + fn respond_to(self, request: &Request) -> response::Result<'r> { + self.respond(request) + } +} + + +/// A CORS Response which provides the following CORS headers: /// /// - `Access-Control-Allow-Origin` /// - `Access-Control-Expose-Headers` @@ -670,8 +786,7 @@ impl Options { /// - `Access-Control-Allow-Headers` /// - `Vary` #[derive(Debug)] -pub struct Response { - responder: R, +struct Response { allow_origin: Option>, allow_methods: HashSet, allow_headers: HeaderFieldNamesSet, @@ -681,14 +796,13 @@ pub struct Response { vary_origin: bool, } -impl<'r, R: Responder<'r>> Response { +impl Response { /// Consumes the responder and return an empty `Response` - fn new(responder: R) -> Self { + fn new() -> Self { Self { allow_origin: None, allow_headers: HashSet::new(), allow_methods: HashSet::new(), - responder, allow_credentials: false, expose_headers: HashSet::new(), max_age: None, @@ -826,15 +940,18 @@ impl<'r, R: Responder<'r>> Response { ) } - /// Builds a `rocket::Response` from this struct containing only the CORS headers. + /// Builds a `rocket::Response` from this struct based off some base `rocket::Response` + /// + /// This will overwrite any existing CORS headers #[allow(unused_results)] - fn build(&self) -> response::Response<'r> { - let mut builder = response::Response::build(); + fn build<'r>(&self, base: response::Response<'r>) -> response::Response<'r> { + let mut response = response::Response::build_from(base).finalize(); + // TODO: We should be able to remove this let origin = match self.allow_origin { None => { // This is not a CORS response - return builder.finalize(); + return response; } Some(ref origin) => origin, }; @@ -844,10 +961,12 @@ impl<'r, R: Responder<'r>> Response { AllOrSome::Some(ref origin) => origin.to_string(), }; - builder.raw_header("Access-Control-Allow-Origin", origin); + response.set_raw_header("Access-Control-Allow-Origin", origin); if self.allow_credentials { - builder.raw_header("Access-Control-Allow-Credentials", "true"); + response.set_raw_header("Access-Control-Allow-Credentials", "true"); + } else { + response.remove_header("Access-Control-Allow-Credentials"); } if !self.expose_headers.is_empty() { @@ -857,7 +976,9 @@ impl<'r, R: Responder<'r>> Response { .collect(); let headers = headers.join(", "); - builder.raw_header("Access-Control-Expose-Headers", headers); + response.set_raw_header("Access-Control-Expose-Headers", headers); + } else { + response.remove_header("Access-Control-Expose-Headers"); } if !self.allow_headers.is_empty() { @@ -867,76 +988,34 @@ impl<'r, R: Responder<'r>> Response { .collect(); let headers = headers.join(", "); - builder.raw_header("Access-Control-Allow-Headers", headers); + response.set_raw_header("Access-Control-Allow-Headers", headers); + } else { + response.remove_header("Access-Control-Allow-Headers"); } if !self.allow_methods.is_empty() { let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect(); let methods = methods.join(", "); - builder.raw_header("Access-Control-Allow-Methods", methods); + response.set_raw_header("Access-Control-Allow-Methods", methods); + } else { + response.remove_header("Access-Control-Allow-Methods"); } if self.max_age.is_some() { let max_age = self.max_age.unwrap(); - builder.raw_header("Access-Control-Max-Age", max_age.to_string()); + response.set_raw_header("Access-Control-Max-Age", max_age.to_string()); + } else { + response.remove_header("Access-Control-Max-Age"); } if self.vary_origin { - builder.raw_header("Vary", "Origin"); + response.set_raw_header("Vary", "Origin"); + } else { + response.remove_header("Vary"); } - builder.finalize() - } - - /// Merge a `wrapped` Response with a `cors` response - /// - /// If the `wrapped` response has the `Access-Control-Allow-Origin` header already defined, - /// it will be left untouched. This allows for chaining of several CORS responders. - /// - /// Otherwise, the merging will be done according to the rules of `rocket::Response::merge`. - fn merge( - mut wrapped: response::Response<'r>, - cors: response::Response<'r>, - ) -> response::Response<'r> { - - let existing_cors = { - wrapped.headers().get("Access-Control-Allow-Origin").next() == None - }; - - if existing_cors { - wrapped.merge(cors); - } - - wrapped - } - - /// Finalize the Response by merging the CORS header with the wrapped `Responder - /// - /// If the original response has the `Access-Control-Allow-Origin` header already defined, - /// it will be left untouched.This allows for chaining of several CORS responders. - /// - /// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any - /// existing headers defined: - /// - /// - `Access-Control-Allow-Origin` - /// - `Access-Control-Expose-Headers` - /// - `Access-Control-Max-Age` - /// - `Access-Control-Allow-Credentials` - /// - `Access-Control-Allow-Methods` - /// - `Access-Control-Allow-Headers` - /// - `Vary` - fn finalize(self, request: &Request) -> response::Result<'r> { - let cors_response = self.build(); - let original_response = self.responder.respond_to(request)?; - - Ok(Self::merge(original_response, cors_response)) - } -} - -impl<'r, R: Responder<'r>> Responder<'r> for Response { - fn respond_to(self, request: &Request) -> response::Result<'r> { - self.finalize(request) + response } } @@ -1059,7 +1138,7 @@ mod tests { let allowed_origins = AllOrSome::All; let send_wildcard = true; - let response = Response::new(()); + let response = Response::new(); let response = not_err!(response.allowed_origin( &origin, &allowed_origins, @@ -1071,7 +1150,7 @@ mod tests { // Build response and check built response header let expected_header = vec!["*"]; - let response = response.build(); + let response = response.build(response::Response::new()); let actual_header: Vec<_> = response .headers() .get("Access-Control-Allow-Origin") @@ -1086,7 +1165,7 @@ mod tests { let allowed_origins = AllOrSome::All; let send_wildcard = false; - let response = Response::new(()); + let response = Response::new(); let response = not_err!(response.allowed_origin( &origin, &allowed_origins, @@ -1103,7 +1182,7 @@ mod tests { // Build response and check built response header let expected_header = vec![url]; - let response = response.build(); + let response = response.build(response::Response::new()); let actual_header: Vec<_> = response .headers() .get("Access-Control-Allow-Origin") @@ -1120,7 +1199,7 @@ mod tests { assert!(failed_origins.is_empty()); let send_wildcard = false; - let response = Response::new(()); + let response = Response::new(); let response = not_err!(response.allowed_origin( &origin, &allowed_origins, @@ -1138,7 +1217,7 @@ mod tests { // Build response and check built response header let expected_header = vec![url]; - let response = response.build(); + let response = response.build(response::Response::new()); let actual_header: Vec<_> = response .headers() .get("Access-Control-Allow-Origin") @@ -1156,7 +1235,7 @@ mod tests { assert!(failed_origins.is_empty()); let send_wildcard = false; - let response = Response::new(()); + let response = Response::new(); let _ = response .allowed_origin(&origin, &allowed_origins, send_wildcard) .unwrap(); @@ -1165,7 +1244,7 @@ mod tests { #[test] #[should_panic(expected = "CredentialsWithWildcardOrigin")] fn response_credentials_does_not_allow_wildcard_with_all_origins() { - let response = Response::new(()); + let response = Response::new(); let response = response.any(); let _ = response.credentials(true).unwrap(); @@ -1173,7 +1252,7 @@ mod tests { #[test] fn response_credentials_allows_specific_origins() { - let response = Response::new(()); + let response = Response::new(); let response = response.origin("https://www.example.com", false); let response = response.credentials(true).expect( @@ -1183,7 +1262,7 @@ mod tests { // Build response and check built response header let expected_header = vec!["true"]; - let response = response.build(); + let response = response.build(response::Response::new()); let actual_header: Vec<_> = response .headers() .get("Access-Control-Allow-Credentials") @@ -1194,12 +1273,12 @@ mod tests { #[test] fn response_sets_exposed_headers_correctly() { let headers = vec!["Bar", "Baz", "Foo"]; - let response = Response::new(()); + let response = Response::new(); let response = response.origin("https://www.example.com", false); let response = response.exposed_headers(&headers); // Build response and check built response header - let response = response.build(); + let response = response.build(response::Response::new()); let actual_header: Vec<_> = response .headers() .get("Access-Control-Expose-Headers") @@ -1216,27 +1295,27 @@ mod tests { #[test] fn response_sets_max_age_correctly() { - let response = Response::new(()); + let response = Response::new(); let response = response.origin("https://www.example.com", false); let response = response.max_age(Some(42)); // Build response and check built response header let expected_header = vec!["42"]; - let response = response.build(); + let response = response.build(response::Response::new()); let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect(); assert_eq!(expected_header, actual_header); } #[test] fn response_does_not_set_max_age_when_none() { - let response = Response::new(()); + let response = Response::new(); let response = response.origin("https://www.example.com", false); let response = response.max_age(None); // Build response and check built response header - let response = response.build(); + let response = response.build(response::Response::new()); assert!(response .headers() .get("Access-Control-Max-Age") @@ -1249,7 +1328,7 @@ mod tests { let allowed_headers = AllOrSome::All; let requested_headers = vec!["Bar", "Foo"]; - let response = Response::new(()); + let response = Response::new(); let response = response.origin("https://www.example.com", false); let response = response .allowed_headers( @@ -1259,7 +1338,7 @@ mod tests { .expect("to not fail"); // Build response and check built response header - let response = response.build(); + let response = response.build(response::Response::new()); let actual_header: Vec<_> = response .headers() .get("Access-Control-Allow-Headers") @@ -1285,7 +1364,7 @@ mod tests { let method = "GET"; - let response = Response::new(()); + let response = Response::new(); let response = response.origin("https://www.example.com", false); let response = response .allowed_methods( @@ -1295,7 +1374,7 @@ mod tests { .expect("not to fail"); // Build response and check built response header - let response = response.build(); + let response = response.build(response::Response::new()); let actual_header: Vec<_> = response .headers() .get("Access-Control-Allow-Methods") @@ -1324,7 +1403,7 @@ mod tests { let method = "DELETE"; - let response = Response::new(()); + let response = Response::new(); let response = response.origin("https://www.example.com", false); let _ = response .allowed_methods( @@ -1341,7 +1420,7 @@ mod tests { let allowed_headers = vec!["Bar", "Baz", "Foo"]; let requested_headers = vec!["Bar", "Foo"]; - let response = Response::new(()); + let response = Response::new(); let response = response.origin("https://www.example.com", false); let response = response .allowed_headers( @@ -1356,7 +1435,7 @@ mod tests { .expect("to not fail"); // Build response and check built response header - let response = response.build(); + let response = response.build(response::Response::new()); let actual_header: Vec<_> = response .headers() .get("Access-Control-Allow-Headers") @@ -1377,7 +1456,7 @@ mod tests { let allowed_headers = vec!["Bar", "Baz", "Foo"]; let requested_headers = vec!["Bar", "Foo", "Unknown"]; - let response = Response::new(()); + let response = Response::new(); let response = response.origin("https://www.example.com", false); let _ = response .allowed_headers( @@ -1395,98 +1474,77 @@ mod tests { #[test] fn response_does_not_build_if_origin_is_not_set() { - let response = Response::new(()); - let response = response.build(); + let response = Response::new(); + let response = response.build(response::Response::new()); let headers: Vec<_> = response.headers().iter().collect(); assert_eq!(headers.len(), 0); } - // Note: Correct operation of Response::build is tested in the tests above for each of the - // individual headers - #[test] - fn response_merges_correctly() { + fn response_build_removes_existing_cors_headers_and_keeps_others() { use std::io::Cursor; - use rocket::http::Status; - let wrapped = response::Response::build() + let original = response::Response::build() .status(Status::ImATeapot) .raw_header("X-Teapot-Make", "Rocket") + .raw_header("Access-Control-Max-Age", "42") .sized_body(Cursor::new("Brewing the best coffee!")) .finalize(); - let response = Response::new(()); - let response = response.origin("https://www.acme.com", false); - - let mut response = Response::::merge(wrapped, response.build()); - assert_eq!(response.status(), Status::ImATeapot); - assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string())); - + let response = Response::new(); + let response = response.origin("https://www.example.com", false); + let response = response.build(original); // Check CORS header - let expected_header = vec!["https://www.acme.com"]; - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Origin") - .collect(); - assert_eq!(expected_header, actual_header); - - // Check other header - let expected_header = vec!["Rocket"]; - let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect(); - assert_eq!(expected_header, actual_header); - } - - #[test] - fn response_does_not_merge_existing_cors() { - let wrapped = response::Response::build() - .raw_header("Access-Control-Allow-Origin", "https://www.example.com") - .finalize(); - - let response = Response::new(()); - let response = response.origin("https://www.acme.com", false); - - let response = Response::<()>::merge(wrapped, response.build()); let expected_header = vec!["https://www.example.com"]; let actual_header: Vec<_> = response .headers() .get("Access-Control-Allow-Origin") .collect(); assert_eq!(expected_header, actual_header); - } - - #[test] - fn response_finalize_smoke_test() { - use std::io::Cursor; - use rocket::http::Status; - - let wrapped = response::Response::build() - .status(Status::ImATeapot) - .raw_header("X-Teapot-Make", "Rocket") - .sized_body(Cursor::new("Brewing the best coffee!")) - .finalize(); - - let response = Response::new(wrapped); - let response = response.origin("https://www.acme.com", false); - - let client = make_client(); - let request = client.get("/"); - let mut response = response.finalize(request.inner()).expect("not to fail"); - - assert_eq!(response.status(), Status::ImATeapot); - assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string())); - - // Check CORS header - let expected_header = vec!["https://www.acme.com"]; - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Origin") - .collect(); - assert_eq!(expected_header, actual_header); // Check other header let expected_header = vec!["Rocket"]; let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect(); assert_eq!(expected_header, actual_header); + + // Check that `Access-Control-Max-Age` is removed + assert!(response.headers().get("Access-Control-Max-Age").next().is_none()); + + } + + // The following tests check that preflight checks are done properly + + // fn make_cors_options() -> Options { + // let (allowed_origins, failed_origins) = + // AllOrSome::new_from_str_list(&["https://www.acme.com"]); + // assert!(failed_origins.is_empty()); + + // Options { + // allowed_origins: allowed_origins, + // allowed_methods: [Method::Get].iter().cloned().collect(), + // allowed_headers: AllOrSome::Some( + // ["Authorization"] + // .into_iter() + // .map(|s| s.to_string().into()) + // .collect(), + // ), + // allow_credentials: true, + // ..Default::default() + // } + // } + + // /// Tests that non CORS preflight are let through without modification + // #[test] + // fn preflight_missing_origins_are_let_through() { + // let options = make_cors_options(); + // let client = make_client(); + // let request = client.get("/"); + + // let response = options.preflight((), None, None, None).expect("not to fail"); + + // let headers: Vec<_> = response.headers().iter().collect(); + // assert_eq!(headers.len(), 0); + // } } diff --git a/tests/routes.rs b/tests/routes.rs index 1617572..310a925 100644 --- a/tests/routes.rs +++ b/tests/routes.rs @@ -15,21 +15,13 @@ use rocket::local::Client; use rocket_cors::*; #[options("/")] -fn cors_options( - origin: Option, - method: Option, - headers: Option, - options: State, -) -> Result, Error> { - options.preflight((), origin, method, headers) +fn cors_options(options: State) -> Responder<&str> { + rocket_cors::respond(options, "") } #[get("/")] -fn cors( - origin: Option, - options: State, -) -> Result, Error> { - options.respond("Hello CORS", origin) +fn cors(options: State) -> Responder<&str> { + rocket_cors::respond(options, "Hello CORS") } fn make_cors_options() -> Options { @@ -146,6 +138,7 @@ fn cors_get_check() { let req = client.get("/").header(origin_header).header(authorization); let mut response = req.dispatch(); + println!("{:?}", response); assert_eq!(response.status(), Status::Ok); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string()));