diff --git a/README.md b/README.md index 6bd7aa4..1ef30df 100644 --- a/README.md +++ b/README.md @@ -34,11 +34,15 @@ We currently tie this crate to revision [aa51fe0](https://github.com/SergioBenit To use the latest `master` branch, for example: ```toml -biscuit = { git = "https://github.com/lawliet89/rocket_cors", branch = "master" } +rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", branch = "master" } ``` + +## Reference + +- [W3C CORS Recommendation](https://www.w3.org/TR/cors/#resource-processing-model) diff --git a/src/lib.rs b/src/lib.rs index ffa812a..2bc4b78 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,84 +1,50 @@ -//! Cross-origin resource sharing (CORS) for Rocket.rs applications +//! [![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors) +//! [![Dependency Status](https://dependencyci.com/github/lawliet89/rocket_cors/badge)](https://dependencyci.com/github/lawliet89/rocket_cors) +//! [![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors) +//! +//! //! -//! Rocket (as of v0.2) does not have middleware support. Support for it is (supposedly) -//! on the way. In the mean time, we adopt an -//! [example implementation](https://github.com/SergioBenitez/Rocket/pull/141) to nest -//! `Responders` to acheive the same effect in the short run. +//! - Documentation: stable | [master branch](https://lawliet89.github.io/rocket_cors) //! -//! # Examples +//! Cross-origin resource sharing (CORS) for [Rocket](https://rocket.rs/) applications +//! +//! ## Requirements +//! +//! - Nightly Rust +//! - Rocket > 0.3 +//! +//! ### Nightly Rust +//! +//! Rocket requires nightly Rust. You should probably install Rust with +//! [rustup](https://www.rustup.rs/), then override the code directory to use nightly instead of +//! stable. See +//! [installation instructions](https://rocket.rs/guide/getting-started/#installing-rust). +//! +//! In particular, `rocket_cors` is currently targetted for `nightly-2017-07-13`. +//! +//! ### Rocket > 0.3 +//! +//! Rocket > 0.3 is needed. At this moment, `0.3` is not released, and this crate will not be +//! published +//! to Crates.io until Rocket 0.3 is released to Crates.io. +//! +//! We currently tie this crate to revision +//! [aa51fe0](https://github.com/SergioBenitez/Rocket/tree/aa51fe0) of Rocket. +//! +//! ## Installation +//! +//! +//! +//! To use the latest `master` branch, for example: +//! +//! ```toml +//! rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", branch = "master" } //! ``` -//! #![feature(plugin, custom_derive)] -//! #![plugin(rocket_codegen)] -//! extern crate hyper; -//! extern crate rocket; -//! extern crate rocket_cors; //! -//! use std::str::FromStr; -//! -//! use rocket::State; -//! use rocket::http::Method::*; -//! use rocket::http::{Header, Status}; -//! use rocket::local::Client; -//! use rocket_cors::*; -//! -//! #[options("/")] -//! fn cors_options(origin: Option, -//! method: AccessControlRequestMethod, -//! headers: AccessControlRequestHeaders, -//! options: State) -//! -> Result, Error> { -//! options.preflight(origin, &method, Some(&headers)) -//! } -//! -//! #[get("/")] -//! fn cors(origin: Option, options: State) -//! -> Result, Error> -//! { -//! options.respond("Hello CORS", origin) -//! } -//! -//! # fn main() { -//! let (allowed_origins, failed_origins) = -//! AllowedOrigins::new_from_str_list(&["https://www.acme.com"]); -//! assert!(failed_origins.is_empty()); -//! let cors_options = rocket_cors::Options { -//! allowed_origins: allowed_origins, -//! allowed_methods: [Get].iter().cloned().collect(), -//! allowed_headers: ["Authorization"].iter().map(|s| s.to_string().into()).collect(), -//! allow_credentials: true, -//! ..Default::default() -//! }; -//! let rocket = rocket::ignite().mount("/", routes![cors, cors_options]).manage(cors_options); -//! let client = Client::new(rocket).unwrap(); -//! -//! // `Options` pre-flight checks -//! let origin_header = -//! Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); -//! let method_header = -//! Header::from(hyper::header::AccessControlRequestMethod(hyper::method::Method::Get)); -//! let request_headers = -//! hyper::header::AccessControlRequestHeaders( -//! vec![FromStr::from_str("Authorization").unwrap()]); -//! let request_headers = Header::from(request_headers); -//! let req = -//! client.options("/").header(origin_header).header(method_header).header(request_headers); -//! -//! let response = req.dispatch(); -//! assert_eq!(response.status(), Status::Ok); -//! -//! // "Actual" request -//! let origin_header = -//! Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); -//! let authorization = Header::new("Authorization", "let me in"); -//! let req = client.get("/").header(origin_header).header(authorization); -//! -//! let mut response = req.dispatch(); -//! assert_eq!(response.status(), Status::Ok); -//! let body_str = response.body().and_then(|body| body.into_string()); -//! assert_eq!(body_str, Some("Hello CORS".to_string())); -//! # } -//! ``` - #![allow( legacy_directory_ownership, @@ -183,6 +149,10 @@ pub enum Error { MethodNotAllowed, /// One or more headers requested are not allowed HeadersNotAllowed, + /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C + /// + /// This is a misconfiguration. Check the docuemntation for `Options`. + CredentialsWithWildcardOrigin, } impl error::Error for Error { @@ -204,6 +174,11 @@ impl error::Error for Error { Error::OriginNotAllowed => "Origin is not allowed to request", Error::MethodNotAllowed => "Method is not allowed", Error::HeadersNotAllowed => "Headers are not allowed", + Error::CredentialsWithWildcardOrigin => { + "Credentials are allowed, but the Origin is set to \"*\". \ + This is not allowed by W3C" + } + } } @@ -231,6 +206,7 @@ impl<'r> Responder<'r> for Error { Err(match self { Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed | Error::HeadersNotAllowed => Status::Forbidden, + Error::CredentialsWithWildcardOrigin => Status::InternalServerError, _ => Status::BadRequest, }) } @@ -308,12 +284,13 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod { Err(e) => Outcome::Failure((Status::BadRequest, Error::BadRequestMethod(e))), } } - None => Outcome::Failure((Status::BadRequest, Error::MissingRequestMethod)), + None => Outcome::Forward(()), } } } -type HeaderFieldNamesSet = HashSet>; +type HeaderFieldName = UniCase; +type HeaderFieldNamesSet = HashSet; /// The `Access-Control-Request-Headers` request header #[derive(Debug)] @@ -350,82 +327,30 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders { } } } - None => Outcome::Failure((Status::BadRequest, Error::MissingRequestHeaders)), + None => Outcome::Forward(()), } } } -/// Origins that are allowed to issue CORS request. This is needed for browser -/// access to the authentication server, but tools like `curl` -/// do not obey nor enforce the CORS convention. -/// -/// This enum (de)serialized as an [untagged](https://serde.rs/enum-representations.html) -/// enum variant. -/// -/// # Examples -/// ## Allow all origins -/// ```json -/// { "allowed_origins": null } -/// ``` -/// ``` -/// extern crate rocket_cors; -/// #[macro_use] -/// extern crate serde_derive; -/// extern crate serde_json; -/// -/// use rocket_cors::*; -/// -/// # fn main() { -/// #[derive(Serialize, Deserialize)] -/// struct Test { -/// allowed_origins: AllowedOrigins -/// } -/// -/// let json = r#"{ "allowed_origins": null }"#; -/// let deserialized: Test = serde_json::from_str(json).unwrap(); -/// # } -/// ``` -/// ## Allow specific origins -/// -/// ```json -/// { "allowed_origins": ["http://127.0.0.1:8000/","https://foobar.com/"] } -/// ``` -/// -/// ``` -/// extern crate rocket_cors; -/// #[macro_use] -/// extern crate serde_derive; -/// extern crate serde_json; -/// -/// use rocket_cors::*; -/// -/// # fn main() { -/// #[derive(Serialize, Deserialize)] -/// struct Test { -/// allowed_origins: AllowedOrigins -/// } -/// -/// let json = r#"{ "allowed_origins": ["http://127.0.0.1:8000/","https://foobar.com/"] }"#; -/// let deserialized: Test = serde_json::from_str(json).unwrap(); -/// # } +/// An enum signifying that some of type T is allowed, or `All` (everything is allowed). #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(untagged)] -pub enum AllowedOrigins { - /// All origins are allowed. Equivalent to the "*" value. +pub enum AllOrSome { + /// Everything is allowed. Usually equivalent to the "*" value. All, - /// Only origins listed are allowed. - Some(HashSet), + /// Only some of `T` is allowed + Some(T), } -impl Default for AllowedOrigins { +impl Default for AllOrSome { fn default() -> Self { - AllowedOrigins::All + AllOrSome::All } } -impl AllowedOrigins { - /// New `AllowedOrigins` from a list of URL strings. - /// Returns a tuple where the first element is the struct `AllowedOrigins`, +impl AllOrSome> { + /// New `AllOrSome` from a list of URL strings. + /// Returns a tuple where the first element is the struct `AllOrSome`, /// and the second element /// is a map of strings which failed to parse into URLs and their associated parse errors. pub fn new_from_str_list(urls: &[&str]) -> (Self, HashMap) { @@ -440,59 +365,234 @@ impl AllowedOrigins { let ok_set = ok_set.into_iter().map(|(_, r)| r.unwrap()).collect(); - (AllowedOrigins::Some(ok_set), error_map) + (AllOrSome::Some(ok_set), error_map) } } -/// Options to aid in the building of a CORS response during pre-flight or after. -/// See module level documentation for usage examples. -#[derive(Clone, Debug, Default)] +/// Configuration options to for building CORS preflight or actual responses. +/// +/// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this +/// struct. The default for each field is described in the docuementation for the field. +#[derive(Clone, Debug)] pub struct Options { /// Origins that are allowed to make requests. /// Will be verified against the `Origin` request header. - pub allowed_origins: AllowedOrigins, - /// Methods that the clients are allowed to request in. - /// Will be verified against the `Access-Control-Request-Method` request header - /// during pre-flight only. + /// + /// When `All` is set, and `send_wildcard` is set, "*" will be sent in + /// the `Access-Control-Allow-Origin` response header. Otherwise, the client's `Origin` request + /// header will be echoed back in the `Access-Control-Allow-Origin` response header. + /// + /// When `Some` is set, the client's `Origin` request header will be checked in a + /// case-sensitive manner. + /// + /// This is the `list of origins` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// + /// This field defaults to `All`. + /// # Examples + /// ## Allow all origins + /// ```json + /// { "allowed_origins": null } + /// + /// ## Allow specific origins + /// + /// ```json + /// { "allowed_origins": ["http://127.0.0.1:8000/","https://foobar.com/"] } + /// ``` + // #[serde(default)] + pub allowed_origins: AllOrSome>, + /// The list of methods which the allowed origins are allowed to access for + /// non-simple requests. + /// + /// This is the `list of methods` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` + // #[serde(default = "Options::default_allowed_methods")] pub allowed_methods: HashSet, - /// Headers that the clients are allowed to request in. - /// Will be verified against the `Access-Control-Request-Headers` request header - /// during pre-flight only. - pub allowed_headers: HeaderFieldNamesSet, - /// The `Access-Control-Allow-Credentials` response header + /// The list of header field names which can be used when this resource is accessed by allowed + /// origins. + /// + /// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers` + /// will be echoed back in the `Access-Control-Allow-Headers` header. + /// + /// This is the `list of headers` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// Defaults to `All`. + pub allowed_headers: AllOrSome>, + /// Allows users to make authenticated requests. + /// If true, injects the `Access-Control-Allow-Credentials` header in responses. + /// This allows cookies and credentials to be submitted across domains. + /// + /// This **CANNOT** be used in conjunction with `allowed_origins` set to `All` and + /// `send_wildcard` set to `true`. Depending on the mode of usage, this will either result + /// in an `Error::CredentialsWithWildcardOrigin` error during Rocket launch or runtime. + /// + /// Defaults to `false`. pub allow_credentials: bool, - /// The `Access-Control-Expose-Headers` responde header + /// The list of headers which are safe to expose to the API of a CORS API specification. + /// This corresponds to the `Access-Control-Expose-Headers` responde header. + /// + /// This is the `list of exposed headers` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// This defaults to an empty set. pub expose_headers: HashSet, - /// The `Access-Control-Max-Age` response header + /// The maximum time for which this CORS request maybe cached. This value is set as the + /// `Access-Control-Max-Age` header. + /// + /// This defaults to `None` (unset). pub max_age: Option, + /// If true, and the `allowed_origins` parameter is `All`, a wildcard + /// `Access-Control-Allow-Origin` response header is sent, rather than the request’s + /// `Origin` header. + /// + /// This is the `supports credentials flag` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// This **CANNOT** be used in conjunction with `allowed_origins` set to `All` and + /// `allow_credentials` set to `true`. Depending on the mode of usage, this will either result + /// in an `Error::CredentialsWithWildcardOrigin` error during Rocket launch or runtime. + /// + /// Defaults to `false`. + pub send_wildcard: bool, +} + +impl Default for Options { + fn default() -> Self { + Self { + allowed_origins: Default::default(), + allowed_methods: Self::default_allowed_methods(), + allowed_headers: Default::default(), + allow_credentials: Default::default(), + expose_headers: Default::default(), + max_age: Default::default(), + send_wildcard: Default::default(), + } + } } impl Options { + fn default_allowed_methods() -> HashSet { + vec![ + Method::Get, + Method::Head, + Method::Post, + Method::Options, + Method::Put, + Method::Patch, + Method::Delete, + ].into_iter() + .collect() + } + /// Construct a preflight response based on the options. Will return an `Err` - /// if any of the preflight checks - /// fail. - pub fn preflight( + /// if any of the preflight checks fail. + /// + /// This implementation references the + /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests). + pub fn preflight<'r, R: Responder<'r>>( &self, + responder: R, origin: Option, - method: &AccessControlRequestMethod, - headers: Option<&AccessControlRequestHeaders>, - ) -> Result, Error> { + method: Option, + headers: Option, + ) -> Result, Error> { + let response = Response::new(responder); - match origin { - None => Err(Error::MissingOrigin), - Some(origin) => { - let response = Response::<()>::allowed_origin((), &origin, &self.allowed_origins)? - .allowed_methods(method, self.allowed_methods.clone())?; + // Note: All header parse failures are dealt with in the `FromRequest` trait implementation - match headers { - Some(headers) => { - self.append(response.allowed_headers(headers, &self.allowed_headers)) - } - None => Ok(response), - } + // 1. If the Origin header is not present terminate this set of steps. + // The request is outside the scope of this specification. + let origin = match origin { + None => { + // Not a CORS request + return Ok(response); } - } + Some(origin) => origin, + }; + + // 2. If the value of the Origin header is not a case-sensitive match for any of the values + // in list of origins do not set any additional headers and terminate this set of steps. + let response = response.allowed_origin( + &origin, + &self.allowed_origins, + self.send_wildcard, + )?; + + // 3. Let `method` be the value as result of parsing the Access-Control-Request-Method + // header. + // If there is no Access-Control-Request-Method header or if parsing failed, + // do not set any additional headers and terminate this set of steps. + // The request is outside the scope of this specification. + + let method = method.ok_or_else(|| Error::MissingRequestMethod)?; + + // 4. Let header field-names be the values as result of parsing the + // Access-Control-Request-Headers headers. + // If there are no Access-Control-Request-Headers headers + // let header field-names be the empty list. + // If parsing failed do not set any additional headers and terminate this set of steps. + // The request is outside the scope of this specification. + + // 5. If method is not a case-sensitive match for any of the values in list of methods + // do not set any additional headers and terminate this set of steps. + + let response = response.allowed_methods( + &method, + self.allowed_methods.clone(), + )?; + + // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the + // values in list of headers do not set any additional headers and terminate this set of + // steps. + let response = if let Some(headers) = headers { + response.allowed_headers(&headers, &self.allowed_headers)? + } else { + response + }; + + // 7. If the resource supports credentials add a single Access-Control-Allow-Origin header, + // with the value of the Origin header as value, and add a + // single Access-Control-Allow-Credentials header with the case-sensitive string "true" as + // value. + // Otherwise, add a single Access-Control-Allow-Origin header, + // with either the value of the Origin header or the string "*" as value. + // Note: The string "*" cannot be used for a resource that supports credentials. + + let response = response.credentials(self.allow_credentials)?; + + // 8. Optionally add a single Access-Control-Max-Age header + // with as value the amount of seconds the user agent is allowed to cache the result of the + // request. + let response = response.max_age(self.max_age); + + // 9. If method is a simple method this step may be skipped. + // Add one or more Access-Control-Allow-Methods headers consisting of + // (a subset of) the list of methods. + // If a method is a simple method it does not need to be listed, but this is not prohibited. + // Since the list of methods can be unbounded, + // simply returning the method indicated by Access-Control-Request-Method + // (if supported) can be enough. + + // Done above + + // 10. If each of the header field-names is a simple header and none is Content-Type, + // this step may be skipped. + // Add one or more Access-Control-Allow-Headers headers consisting of (a subset of) + // the list of headers. + // If a header field name is a simple header and is not Content-Type, + // it is not required to be listed. Content-Type is to be listed as only a + // subset of its values makes it qualify as simple header. + // Since the list of headers can be unbounded, simply returning supported headers + // from Access-Control-Allow-Headers can be enough. + + // Done above -- we do not do anything special with simple headers + + Ok(response) } /// Respond to a request based on the settings. @@ -503,34 +603,55 @@ impl Options { responder: R, origin: Option, ) -> Result, Error> { - match origin { - None => Ok(Response::::any(responder)), - Some(origin) => { - self.append(Response::::allowed_origin( - responder, - &origin, - &self.allowed_origins, - )) - } - } - } + let response = Response::new(responder); - fn append<'r, R: Responder<'r>>( - &self, - response: Result, Error>, - ) -> Result, Error> { - Ok( - response? - .credentials(self.allow_credentials) - .exposed_headers( - self.expose_headers - .iter() - .map(|s| &**s) - .collect::>() - .as_slice(), - ) - .max_age(self.max_age), - ) + // Note: All header parse failures are dealt with in the `FromRequest` trait implementation + + // 1. If the Origin header is not present terminate this set of steps. + // The request is outside the scope of this specification. + let origin = match origin { + None => { + // Not a CORS request + return Ok(response); + } + Some(origin) => origin, + }; + + // 2. If the value of the Origin header is not a case-sensitive match for any of the values + // in list of origins, do not set any additional headers and terminate this set of steps. + // Always matching is acceptable since the list of origins can be unbounded. + + let response = response.allowed_origin( + &origin, + &self.allowed_origins, + self.send_wildcard, + )?; + + // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, + // with the value of the Origin header as value, and add a + // single Access-Control-Allow-Credentials header with the case-sensitive string "true" as + // value. + // Otherwise, add a single Access-Control-Allow-Origin header, + // with either the value of the Origin header or the string "*" as value. + // Note: The string "*" cannot be used for a resource that supports credentials. + + let response = response.credentials(self.allow_credentials)?; + + // 4. If the list of exposed headers is not empty add one or more + // Access-Control-Expose-Headers headers, with as values the header field names given in + // the list of exposed headers. + // By not adding the appropriate headers resource can also clear the preflight result cache + // of all entries where origin is a case-sensitive match for the value of the Origin header + // and url is a case-sensitive match for the URL of the resource. + + let response = response.exposed_headers( + self.expose_headers + .iter() + .map(|s| &**s) + .collect::>() + .as_slice(), + ); + Ok(response) } } @@ -539,40 +660,62 @@ impl Options { /// See module level documentation for usage examples. pub struct Response { responder: R, - allow_origin: String, + allow_origin: Option>, allow_methods: HashSet, allow_headers: HeaderFieldNamesSet, allow_credentials: bool, expose_headers: HeaderFieldNamesSet, max_age: Option, + vary_origin: bool, } impl<'r, R: Responder<'r>> Response { - /// Consumes the responder and origin and returns basic CORS - fn origin(responder: R, origin: &str) -> Self { + /// Consumes the responder and return an empty `Response` + fn new(responder: R) -> Self { Self { - allow_origin: origin.to_string(), + allow_origin: None, allow_headers: HashSet::new(), allow_methods: HashSet::new(), - responder: responder, + responder, allow_credentials: false, expose_headers: HashSet::new(), max_age: None, + vary_origin: false, } } + + /// Consumes the `Response` and return an altered response with origin and `vary_origin` set + fn origin(mut self, origin: &str, vary_origin: bool) -> Self { + self.allow_origin = Some(AllOrSome::Some(origin.to_string())); + self.vary_origin = vary_origin; + self + } + + /// Consumes the `Response` and return an altered response with origin set to "*" + fn any(self) -> Self { + self.origin("*", false) + } + /// Consumes the responder and based on the provided list of allowed origins, /// check if the requested origin is allowed. /// Useful for pre-flight and during requests - pub fn allowed_origin( - responder: R, + fn allowed_origin( + self, origin: &Origin, - allowed_origins: &AllowedOrigins, + allowed_origins: &AllOrSome>, + send_wildcard: bool, ) -> Result { + let origin = origin.origin().unicode_serialization(); match *allowed_origins { - AllowedOrigins::All => Ok(Self::any(responder)), - AllowedOrigins::Some(ref allowed_origins) => { - let origin = origin.origin().unicode_serialization(); - + // Always matching is acceptable since the list of origins can be unbounded. + AllOrSome::All => { + if send_wildcard { + Ok(self.any()) + } else { + Ok(self.origin(&origin, true)) + } + } + AllOrSome::Some(ref allowed_origins) => { let allowed_origins: HashSet<_> = allowed_origins .iter() .map(|o| o.origin().unicode_serialization()) @@ -580,33 +723,33 @@ impl<'r, R: Responder<'r>> Response { let _ = allowed_origins.get(&origin).ok_or_else( || Error::OriginNotAllowed, )?; - Ok(Self::origin(responder, &origin)) + Ok(self.origin(&origin, false)) } } } - /// Consumes responder and returns CORS with any origin - pub fn any(responder: R) -> Self { - Self::origin(responder, "*") - } + /// Consumes the Response and validate whether credentials can be allowed + fn credentials(mut self, value: bool) -> Result { + if value { + if let Some(AllOrSome::All) = self.allow_origin { + Err(Error::CredentialsWithWildcardOrigin)?; + } + } - /// Consumes the CORS, set allow_credentials to - /// new value and returns changed CORS - pub fn credentials(mut self, value: bool) -> Self { self.allow_credentials = value; - self + Ok(self) } /// Consumes the CORS, set expose_headers to /// passed headers and returns changed CORS - pub fn exposed_headers(mut self, headers: &[&str]) -> Self { + fn exposed_headers(mut self, headers: &[&str]) -> Self { self.expose_headers = headers.into_iter().map(|s| s.to_string().into()).collect(); self } /// Consumes the CORS, set max_age to /// passed value and returns changed CORS - pub fn max_age(mut self, value: Option) -> Self { + fn max_age(mut self, value: Option) -> Self { self.max_age = value; self } @@ -620,7 +763,7 @@ impl<'r, R: Responder<'r>> Response { /// Consumes the CORS, check if requested method is allowed. /// Useful for pre-flight checks - pub fn allowed_methods( + fn allowed_methods( self, method: &AccessControlRequestMethod, allowed_methods: HashSet, @@ -629,6 +772,8 @@ impl<'r, R: Responder<'r>> Response { if !allowed_methods.iter().any(|m| m == request_method) { Err(Error::MethodNotAllowed)? } + + // TODO: Subset to route? Or just the method requested for? Ok(self.methods(allowed_methods)) } @@ -639,20 +784,27 @@ impl<'r, R: Responder<'r>> Response { self } - /// Consumes the CORS, check if requested headersa are allowed. + /// Consumes the CORS, check if requested headers are allowed. /// Useful for pre-flight checks - pub fn allowed_headers( + fn allowed_headers( self, headers: &AccessControlRequestHeaders, - allowed_headers: &HeaderFieldNamesSet, + allowed_headers: &AllOrSome>, ) -> Result { let &AccessControlRequestHeaders(ref headers) = headers; - if !headers.is_empty() && !headers.is_subset(allowed_headers) { - Err(Error::HeadersNotAllowed)? - } + + match *allowed_headers { + AllOrSome::All => {} + AllOrSome::Some(ref allowed_headers) => { + if !headers.is_empty() && !headers.is_subset(allowed_headers) { + Err(Error::HeadersNotAllowed)? + } + } + }; + Ok( self.headers( - allowed_headers + headers .iter() .map(|s| &**s.deref()) .collect::>() @@ -663,15 +815,29 @@ impl<'r, R: Responder<'r>> Response { } impl<'r, R: Responder<'r>> Responder<'r> for Response { + #[allow(unused_results)] fn respond_to(self, request: &Request) -> response::Result<'r> { - let mut response = response::Response::build_from(self.responder.respond_to(request)?) - .raw_header("Access-Control-Allow-Origin", self.allow_origin) - .finalize(); + use std::borrow::Cow; + + let mut builder = response::Response::build_from(self.responder.respond_to(request)?); + + let origin = match self.allow_origin { + None => { + // This is not a CORS response + return Ok(builder.finalize()); + } + Some(origin) => origin, + }; + + let origin: Cow = match origin { + AllOrSome::All => Into::into("*"), + AllOrSome::Some(origin) => Into::into(origin), + }; + + builder.raw_header("Access-Control-Allow-Origin", origin); if self.allow_credentials { - response.set_raw_header("Access-Control-Allow-Credentials", "true"); - } else { - response.set_raw_header("Access-Control-Allow-Credentials", "false"); + builder.raw_header("Access-Control-Allow-Credentials", "true"); } if !self.expose_headers.is_empty() { @@ -681,7 +847,7 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response { .collect(); let headers = headers.join(", "); - response.set_raw_header("Access-Control-Expose-Headers", headers); + builder.raw_header("Access-Control-Expose-Headers", headers); } if !self.allow_headers.is_empty() { @@ -691,7 +857,7 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response { .collect(); let headers = headers.join(", "); - response.set_raw_header("Access-Control-Allow-Headers", headers); + builder.raw_header("Access-Control-Allow-Headers", headers); } @@ -699,15 +865,19 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response { let methods: Vec<_> = self.allow_methods.into_iter().map(|m| m.as_str()).collect(); let methods = methods.join(", "); - response.set_raw_header("Access-Control-Allow-Methods", methods); + builder.raw_header("Access-Control-Allow-Methods", methods); } if self.max_age.is_some() { let max_age = self.max_age.unwrap(); - response.set_raw_header("Access-Control-Max-Age", max_age.to_string()); + builder.raw_header("Access-Control-Max-Age", max_age.to_string()); } - Ok(response) + if self.vary_origin { + builder.raw_header("Vary", "Origin"); + } + + Ok(builder.finalize()) } } @@ -813,7 +983,7 @@ X-Ping, accept-language"#; #[get("/any")] #[cfg_attr(feature = "clippy_lints", allow(needless_pass_by_value))] fn any() -> Response<&'static str> { - Response::any("Hello, world!") + Response::new("Hello, world!").any() } #[test] @@ -838,11 +1008,11 @@ X-Ping, accept-language"#; #[allow(needless_pass_by_value)] fn cors_options( origin: Option, - method: AccessControlRequestMethod, - headers: AccessControlRequestHeaders, + method: Option, + headers: Option, options: State, ) -> Result, Error> { - options.preflight(origin, &method, Some(&headers)) + options.preflight((), origin, method, headers) } #[get("/")] @@ -856,16 +1026,18 @@ X-Ping, accept-language"#; fn make_cors_options() -> Options { let (allowed_origins, failed_origins) = - AllowedOrigins::new_from_str_list(&["https://www.acme.com"]); + AllOrSome::new_from_str_list(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); Options { allowed_origins: allowed_origins, allowed_methods: [Method::Get].iter().cloned().collect(), - allowed_headers: ["Authorization"] - .iter() - .map(|s| s.to_string().into()) - .collect(), + allowed_headers: AllOrSome::Some( + ["Authorization"] + .into_iter() + .map(|s| s.to_string().into()) + .collect(), + ), allow_credentials: true, ..Default::default() } @@ -980,7 +1152,7 @@ X-Ping, accept-language"#; ); let response = req.dispatch(); - assert_eq!(response.status(), Status::Forbidden); + assert_eq!(response.status(), Status::Ok); } #[test]