diff --git a/src/lib.rs b/src/lib.rs index 3e48528..d359bf5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,6 +124,7 @@ use std::ops::Deref; use std::str::FromStr; use rocket::{Outcome, State}; +use rocket::fairing; use rocket::http::{Method, Status}; use rocket::request::{Request, FromRequest}; use rocket::response; @@ -160,6 +161,22 @@ pub enum Error { /// /// This is a misconfiguration. Check the docuemntation for `Cors`. CredentialsWithWildcardOrigin, + /// A CORS Request Guard was used, but no CORS Options was available in Rocket's state + /// + /// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state. + MissingCorsInRocketState, +} + +impl Error { + fn status(&self) -> Status { + match *self { + Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed | + Error::HeadersNotAllowed => Status::Forbidden, + Error::CredentialsWithWildcardOrigin | + Error::MissingCorsInRocketState => Status::InternalServerError, + _ => Status::BadRequest, + } + } } impl error::Error for Error { @@ -185,7 +202,9 @@ impl error::Error for Error { "Credentials are allowed, but the Origin is set to \"*\". \ This is not allowed by W3C" } - + Error::MissingCorsInRocketState => { + "A CORS Request Guard was used, but no CORS Options was available in Rocket's state" + } } } @@ -209,20 +228,15 @@ impl fmt::Display for Error { impl<'r> response::Responder<'r> for Error { fn respond_to(self, _: &Request) -> Result, Status> { - error_!("CORS Error: {:?}", self); - Err(match self { - Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed | - Error::HeadersNotAllowed => Status::Forbidden, - Error::CredentialsWithWildcardOrigin => Status::InternalServerError, - _ => Status::BadRequest, - }) + error_!("CORS Error: {}", self); + Err(self.status()) } } /// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// /// `Default` is implemented for this enum and is `All`. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(untagged)] pub enum AllOrSome { /// Everything is allowed. Usually equivalent to the "*" value. @@ -237,6 +251,21 @@ impl Default for AllOrSome { } } +impl AllOrSome { + /// Returns whether this is an `All` variant + pub fn is_all(&self) -> bool { + match *self { + AllOrSome::All => true, + AllOrSome::Some(_) => false, + } + } + + /// Returns whether this is a `Some` variant + pub fn is_some(&self) -> bool { + !self.is_all() + } +} + impl AllOrSome> { /// New `AllOrSome` from a list of URL strings. /// Returns a tuple where the first element is the struct `AllOrSome`, @@ -258,9 +287,9 @@ impl AllOrSome> { } } -/// Responder generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS +/// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS /// -/// This struct can be used as Fairing for Rocket, or as an ad-hoc responder for any CORS requests. +/// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. /// /// You create a new copy of this struct by defining the configurations in the fields below. /// This struct can also be deserialized by serde. @@ -348,6 +377,13 @@ pub struct Cors { /// Defaults to `false`. // #[serde(default)] pub send_wildcard: bool, + /// When used as Fairing, Cors will need to redirect failed CORS checks to a custom route to + /// be mounted by the fairing. Specify the base the route so that it doesn't clash with any + /// of your existing routes. + /// + /// Defaults to "/cors" + // #[serde(default = "Cors::default_fairing_route_base")] + pub fairing_route_base: String, } impl Default for Cors { @@ -360,34 +396,12 @@ impl Default for Cors { expose_headers: Default::default(), max_age: Default::default(), send_wildcard: Default::default(), + fairing_route_base: Self::default_fairing_route_base(), } } } -/// Ad-hoc per route CORS response to requests -/// -/// Note: If you use this, the lifetime parameter `'r` of your `rocket:::response::Responder<'r>` -/// CANNOT be `'static`. This is because the code generated by Rocket will implicitly try to -/// to restrain the `Request` object passed to the route to `&'static Request`, and it is not -/// possible to have such a reference. -/// See [this PR on Rocket](https://github.com/SergioBenitez/Rocket/pull/345). -pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>( - options: State<'a, Cors>, - responder: R, -) -> Responder<'a, 'r, R> { - options.inner().respond(responder) -} - impl Cors { - /// Wrap any `Rocket::Response` and respond with CORS headers. - /// This is only used for ad-hoc route CORS response - fn respond<'a, 'r: 'a, R: response::Responder<'r>>( - &'a self, - responder: R, - ) -> Responder<'a, 'r, R> { - Responder::new(responder, self) - } - fn default_allowed_methods() -> HashSet { vec![ Method::Get, @@ -400,264 +414,115 @@ impl Cors { ].into_iter() .collect() } -} -/// A CORS [Responder](https://rocket.rs/guide/responses/#responder) -/// which will inspect the incoming requests and respond accordingly. -/// -/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set, -/// this responder will leave the response untouched. -/// This allows for chaining of several CORS responders. -/// -/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any -/// existing headers defined: -/// -/// - `Access-Control-Allow-Origin` -/// - `Access-Control-Expose-Headers` -/// - `Access-Control-Max-Age` -/// - `Access-Control-Allow-Credentials` -/// - `Access-Control-Allow-Methods` -/// - `Access-Control-Allow-Headers` -/// - `Vary` -#[derive(Debug)] -pub struct Responder<'a, 'r: 'a, R> { - responder: R, - options: &'a Cors, - marker: PhantomData>, -} - -impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> { - fn new(responder: R, options: &'a Cors) -> Self { - Self { - responder, - options, - marker: PhantomData, - } + fn default_fairing_route_base() -> String { + "/cors".to_string() } - /// Respond to a request - fn respond(self, request: &Request) -> response::Result<'r> { - match self.build_cors_response(request) { - Ok(response) => response, - Err(e) => response::Responder::respond_to(e, request), - } - } - - /// Build a CORS response and merge with an existing `rocket::Response` for the request - fn build_cors_response(self, request: &Request) -> Result, Error> { - let original_response = match self.responder.respond_to(request) { - Ok(response) => response, - Err(status) => return Ok(Err(status)), // TODO: Handle this? - }; - - // Existing CORS response? - if Self::has_allow_origin(&original_response) { - return Ok(Ok(original_response)); - } - - // 1. If the Origin header is not present terminate this set of steps. - // The request is outside the scope of this specification. - let origin = Self::origin(request)?; - let origin = match origin { - None => { - // Not a CORS request - return Ok(Ok(original_response)); - } - Some(origin) => origin, - }; - - // Check if the request verb is an OPTION or something else - let cors_response = match request.method() { - Method::Options => { - let method = Self::request_method(request)?; - let headers = Self::request_headers(request)?; - Self::preflight(&self.options, origin, method, headers) - } - _ => Self::actual_request(&self.options, origin), - }?; - - Ok(Ok(cors_response.build(original_response))) - } - - /// Gets the `Origin` request header from the request - fn origin(request: &Request) -> Result, Error> { - match Origin::from_request(request) { - Outcome::Forward(()) => Ok(None), - Outcome::Success(origin) => Ok(Some(origin)), - Outcome::Failure((_, err)) => Err(err), - } - } - - /// Gets the `Access-Control-Request-Method` request header from the request - fn request_method(request: &Request) -> Result, Error> { - match AccessControlRequestMethod::from_request(request) { - Outcome::Forward(()) => Ok(None), - Outcome::Success(method) => Ok(Some(method)), - Outcome::Failure((_, err)) => Err(err), - } - } - - /// Gets the `Access-Control-Request-Headers` request header from the request - fn request_headers(request: &Request) -> Result, Error> { - match AccessControlRequestHeaders::from_request(request) { - Outcome::Forward(()) => Ok(None), - Outcome::Success(geaders) => Ok(Some(geaders)), - Outcome::Failure((_, err)) => Err(err), - } - } - - /// Checks if an existing Response already has the header `Access-Control-Allow-Origin` - fn has_allow_origin(response: &response::Response<'r>) -> bool { - response.headers().get("Access-Control-Allow-Origin").next() != None - } - - /// Construct a preflight response based on the options. Will return an `Err` - /// if any of the preflight checks fail. + /// Build a CORS `Guard` to an incoming request. /// - /// This implementation references the - /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests). - fn preflight( - options: &Cors, - origin: Origin, - method: Option, - headers: Option, - ) -> Result { + /// You will usually not have to use this function but simply place a route argument for the + /// `Guard` type. This is useful if you want an even more ad-hoc based approach to respond to + /// CORS by using a `Cors` that is not in Rocket's managed state. + pub fn guard<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result, Error> { + let response = build_cors_response(self, request)?; + Ok(Guard::new(response)) + } - let response = Response::new(); + /// Validates if any of the settings are disallowed or incorrect + /// + /// This is run during initial Fairing attachment + pub fn validate(&self) -> Result<(), Error> { + if self.allowed_origins.is_all() && self.send_wildcard && self.allow_credentials { + Err(Error::CredentialsWithWildcardOrigin)?; + } - // Note: All header parse failures are dealt with in the `FromRequest` trait implementation + Ok(()) + } - // 2. If the value of the Origin header is not a case-sensitive match for any of the values - // in list of origins do not set any additional headers and terminate this set of steps. - let response = response.allowed_origin( - &origin, - &options.allowed_origins, - options.send_wildcard, - )?; + /// Create a new `Route` for Fairing handling + fn fairing_route(&self) -> rocket::Route { + rocket::Route::new(Method::Get, "/", fairing_error_route) + } - // 3. Let `method` be the value as result of parsing the Access-Control-Request-Method - // header. - // If there is no Access-Control-Request-Method header or if parsing failed, - // do not set any additional headers and terminate this set of steps. - // The request is outside the scope of this specification. + /// Modifies a `Request` to route to Fairing error handler + fn route_to_fairing_error_handler(&self, status: u16, request: &mut Request) { + request.set_method(Method::Get); + request.set_uri(format!("{}/{}", self.fairing_route_base, status)); + } +} - let method = method.ok_or_else(|| Error::MissingRequestMethod)?; +impl fairing::Fairing for Cors { + fn info(&self) -> fairing::Info { + fairing::Info { + name: "CORS", + kind: fairing::Kind::Attach | fairing::Kind::Request | fairing::Kind::Response, + } + } - // 4. Let header field-names be the values as result of parsing the - // Access-Control-Request-Headers headers. - // If there are no Access-Control-Request-Headers headers - // let header field-names be the empty list. - // If parsing failed do not set any additional headers and terminate this set of steps. - // The request is outside the scope of this specification. + fn on_attach(&self, rocket: rocket::Rocket) -> Result { + match self.validate() { + Ok(()) => { + Ok(rocket.mount(&self.fairing_route_base, vec![self.fairing_route()])) + } + Err(e) => { + error_!("Error attaching CORS fairing: {}", e); + Err(rocket) + } + } + } - // 5. If method is not a case-sensitive match for any of the values in list of methods - // do not set any additional headers and terminate this set of steps. + fn on_request(&self, request: &mut Request, _: &rocket::Data) { + // Build and merge CORS response + match build_cors_response(self, request) { + Err(err) => { + error_!("CORS Error: {}", err); + let status = err.status(); + self.route_to_fairing_error_handler(status.code, request); + } + Ok(cors_response) => { + // TODO: How to pass response downstream? + let _ = cors_response; + } + }; + } - let response = response.allowed_methods(&method, &options.allowed_methods)?; + fn on_response(&self, request: &Request, response: &mut rocket::Response) { + // Build and merge CORS response + match build_cors_response(self, request) { + Err(_) => { + // We have dealt with this already + } + Ok(cors_response) => { + cors_response.merge(response); - // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the - // values in list of headers do not set any additional headers and terminate this set of - // steps. - let response = if let Some(headers) = headers { - response.allowed_headers(&headers, &options.allowed_headers)? - } else { - response + // If this was an OPTIONS request and no route can be found, we should turn this + // into a HTTP 204 with no content body. + // This allows the user to not have to specify an OPTIONS route for everything. + // + // TODO: Is there anyway we can make this smarter? Only modify status codes for + // requests where an actual route exist? + if request.method() == Method::Options && request.route().is_none() { + response.set_status(Status::NoContent); + let _ = response.take_body(); + } + } }; - // 7. If the resource supports credentials add a single Access-Control-Allow-Origin header, - // with the value of the Origin header as value, and add a - // single Access-Control-Allow-Credentials header with the case-sensitive string "true" as - // value. - // Otherwise, add a single Access-Control-Allow-Origin header, - // with either the value of the Origin header or the string "*" as value. - // Note: The string "*" cannot be used for a resource that supports credentials. - let response = response.credentials(options.allow_credentials)?; - - // 8. Optionally add a single Access-Control-Max-Age header - // with as value the amount of seconds the user agent is allowed to cache the result of the - // request. - let response = response.max_age(options.max_age); - - // 9. If method is a simple method this step may be skipped. - // Add one or more Access-Control-Allow-Methods headers consisting of - // (a subset of) the list of methods. - // If a method is a simple method it does not need to be listed, but this is not prohibited. - // Since the list of methods can be unbounded, - // simply returning the method indicated by Access-Control-Request-Method - // (if supported) can be enough. - - // Done above - - // 10. If each of the header field-names is a simple header and none is Content-Type, - // this step may be skipped. - // Add one or more Access-Control-Allow-Headers headers consisting of (a subset of) - // the list of headers. - // If a header field name is a simple header and is not Content-Type, - // it is not required to be listed. Content-Type is to be listed as only a - // subset of its values makes it qualify as simple header. - // Since the list of headers can be unbounded, simply returning supported headers - // from Access-Control-Allow-Headers can be enough. - - // Done above -- we do not do anything special with simple headers - - Ok(response) - } - - /// Respond to an actual request based on the settings. - /// If the `Origin` is not provided, then this request was not made by a browser and there is no - /// CORS enforcement. - fn actual_request(options: &Cors, origin: Origin) -> Result { - let response = Response::new(); - - // Note: All header parse failures are dealt with in the `FromRequest` trait implementation - - // 2. If the value of the Origin header is not a case-sensitive match for any of the values - // in list of origins, do not set any additional headers and terminate this set of steps. - // Always matching is acceptable since the list of origins can be unbounded. - - let response = response.allowed_origin( - &origin, - &options.allowed_origins, - options.send_wildcard, - )?; - - // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, - // with the value of the Origin header as value, and add a - // single Access-Control-Allow-Credentials header with the case-sensitive string "true" as - // value. - // Otherwise, add a single Access-Control-Allow-Origin header, - // with either the value of the Origin header or the string "*" as value. - // Note: The string "*" cannot be used for a resource that supports credentials. - - let response = response.credentials(options.allow_credentials)?; - - // 4. If the list of exposed headers is not empty add one or more - // Access-Control-Expose-Headers headers, with as values the header field names given in - // the list of exposed headers. - // By not adding the appropriate headers resource can also clear the preflight result cache - // of all entries where origin is a case-sensitive match for the value of the Origin header - // and url is a case-sensitive match for the URL of the resource. - - let response = response.exposed_headers( - options - .expose_headers - .iter() - .map(|s| &**s) - .collect::>() - .as_slice(), - ); - Ok(response) } } -impl<'a, 'r: 'a, R: response::Responder<'r>> response::Responder<'r> for Responder<'a, 'r, R> { - fn respond_to(self, request: &Request) -> response::Result<'r> { - self.respond(request) - } +/// Route for Fairing error handling +fn fairing_error_route<'r>(request: &'r Request, _: rocket::Data) -> rocket::handler::Outcome<'r> { + let status = request.get_param::(0).unwrap_or_else(|e| { + error_!("Fairing Error Handling Route error: {:?}", e); + 500 + }); + let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError); + Outcome::Failure(status) } - /// A CORS Response which provides the following CORS headers: /// /// - `Access-Control-Allow-Origin` @@ -667,7 +532,7 @@ impl<'a, 'r: 'a, R: response::Responder<'r>> response::Responder<'r> for Respond /// - `Access-Control-Allow-Methods` /// - `Access-Control-Allow-Headers` /// - `Vary` -#[derive(Debug)] +#[derive(Eq, PartialEq, Debug)] struct Response { allow_origin: Option>, allow_methods: HashSet, @@ -679,7 +544,7 @@ struct Response { } impl Response { - /// Consumes the responder and return an empty `Response` + /// Create an empty `Response` fn new() -> Self { Self { allow_origin: None, @@ -705,48 +570,10 @@ impl Response { self } - /// Consumes the responder and based on the provided list of allowed origins, - /// check if the requested origin is allowed. - /// Useful for pre-flight and during requests - fn allowed_origin( - self, - origin: &Origin, - allowed_origins: &AllOrSome>, - send_wildcard: bool, - ) -> Result { - let origin = origin.origin().unicode_serialization(); - match *allowed_origins { - // Always matching is acceptable since the list of origins can be unbounded. - AllOrSome::All => { - if send_wildcard { - Ok(self.any()) - } else { - Ok(self.origin(&origin, true)) - } - } - AllOrSome::Some(ref allowed_origins) => { - let allowed_origins: HashSet<_> = allowed_origins - .iter() - .map(|o| o.origin().unicode_serialization()) - .collect(); - let _ = allowed_origins.get(&origin).ok_or_else( - || Error::OriginNotAllowed, - )?; - Ok(self.origin(&origin, false)) - } - } - } - - /// Consumes the Response and validate whether credentials can be allowed - fn credentials(mut self, value: bool) -> Result { - if value { - if let Some(AllOrSome::All) = self.allow_origin { - Err(Error::CredentialsWithWildcardOrigin)?; - } - } - + /// Consumes the Response and set credentials + fn credentials(mut self, value: bool) -> Self { self.allow_credentials = value; - Ok(self) + self } /// Consumes the CORS, set expose_headers to @@ -770,22 +597,6 @@ impl Response { self } - /// Consumes the CORS, check if requested method is allowed. - /// Useful for pre-flight checks - fn allowed_methods( - self, - method: &AccessControlRequestMethod, - allowed_methods: &HashSet, - ) -> Result { - let &AccessControlRequestMethod(ref request_method) = method; - if !allowed_methods.iter().any(|m| m == request_method) { - Err(Error::MethodNotAllowed)? - } - - // TODO: Subset to route? Or just the method requested for? - Ok(self.methods(&allowed_methods)) - } - /// Consumes the CORS, set allow_headers to /// passed headers and returns changed CORS fn headers(mut self, headers: &[&str]) -> Self { @@ -793,47 +604,31 @@ impl Response { self } - /// Consumes the CORS, check if requested headers are allowed. - /// Useful for pre-flight checks - fn allowed_headers( - self, - headers: &AccessControlRequestHeaders, - allowed_headers: &AllOrSome>, - ) -> Result { - let &AccessControlRequestHeaders(ref headers) = headers; - - match *allowed_headers { - AllOrSome::All => {} - AllOrSome::Some(ref allowed_headers) => { - if !headers.is_empty() && !headers.is_subset(allowed_headers) { - Err(Error::HeadersNotAllowed)? - } - } - }; - - Ok( - self.headers( - headers - .iter() - .map(|s| &**s.deref()) - .collect::>() - .as_slice(), - ), - ) + /// Consumes the `Response` and return a `Responder` that wraps a + /// provided `rocket:response::Responder` with CORS headers + pub fn responder<'r, R: response::Responder<'r>>(self, responder: R) -> Responder<'r, R> { + Responder::new(responder, self) } - /// Builds a `rocket::Response` from this struct based off some base `rocket::Response` + /// Merge a `rocket::Response` with this CORS response. This is usually used in the final step + /// of a route to return a value for the route. /// /// This will overwrite any existing CORS headers - #[allow(unused_results)] - fn build<'r>(&self, base: response::Response<'r>) -> response::Response<'r> { + pub fn response<'r>(&self, base: response::Response<'r>) -> response::Response<'r> { let mut response = response::Response::build_from(base).finalize(); + self.merge(&mut response); + response + } + /// Merge CORS headers with an existing `rocket::Response`. + /// + /// This will overwrite any existing CORS headers + fn merge(&self, response: &mut response::Response) { // TODO: We should be able to remove this let origin = match self.allow_origin { None => { // This is not a CORS response - return response; + return; } Some(ref origin) => origin, }; @@ -896,9 +691,383 @@ impl Response { } else { response.remove_header("Vary"); } - - response } + + /// Validate and create a new CORS Response from a request and settings + pub fn build_cors_response<'a, 'r>( + options: &'a Cors, + request: &'a Request<'r>, + ) -> Result { + build_cors_response(options, request) + } +} + + +/// A [request guard](https://rocket.rs/guide/requests/#request-guards) to check CORS headers +/// before a route is run. Will not execute the route if checks fail +/// +// In essence, this is just a wrapper around `Response` with a `'r` borrowed lifetime so users +// don't have to keep specifying the lifetimes in their routes +pub struct Guard<'r> { + response: Response, + marker: PhantomData<&'r Response>, +} + +impl<'r> Guard<'r> { + fn new(response: Response) -> Self { + Self { + response, + marker: PhantomData, + } + } + + /// Consumes the Guard and return a `Responder` that wraps a + /// provided `rocket:response::Responder` with CORS headers + pub fn responder>(self, responder: R) -> Responder<'r, R> { + self.response.responder(responder) + } + + /// Merge a `rocket::Response` with this CORS Guard. This is usually used in the final step + /// of a route to return a value for the route. + /// + /// This will overwrite any existing CORS headers + pub fn response(&self, base: response::Response<'r>) -> response::Response<'r> { + self.response.response(base) + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> { + type Error = Error; + + fn from_request(request: &'a Request<'r>) -> rocket::request::Outcome { + let options = match request.guard::>() { + Outcome::Success(options) => options, + _ => { + let error = Error::MissingCorsInRocketState; + return Outcome::Failure((error.status(), error)); + } + }; + + match Response::build_cors_response(&options, request) { + Ok(response) => Outcome::Success(Self::new(response)), + Err(error) => Outcome::Failure((error.status(), error)), + } + } +} + +/// A [`Responder`](https://rocket.rs/guide/responses/#responder) which will simply wraps another +/// `Responder` with CORS headers. +/// +/// The following CORS headers will be overwritten: +/// +/// - `Access-Control-Allow-Origin` +/// - `Access-Control-Expose-Headers` +/// - `Access-Control-Max-Age` +/// - `Access-Control-Allow-Credentials` +/// - `Access-Control-Allow-Methods` +/// - `Access-Control-Allow-Headers` +/// - `Vary` +#[derive(Debug)] +pub struct Responder<'r, R> { + responder: R, + cors_response: Response, + marker: PhantomData>, +} + +impl<'r, R: response::Responder<'r>> Responder<'r, R> { + fn new(responder: R, cors_response: Response) -> Self { + Self { + responder, + cors_response, + marker: PhantomData, + } + } + + /// Respond to a request + fn respond(self, request: &Request) -> response::Result<'r> { + let mut response = self.responder.respond_to(request)?; // handle status errors? + self.cors_response.merge(&mut response); + Ok(response) + } +} + +impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R> { + fn respond_to(self, request: &Request) -> response::Result<'r> { + self.respond(request) + } +} + +/// Validates a request for CORS and returns a CORS Response +fn build_cors_response(options: &Cors, request: &Request) -> Result { + // Existing CORS response? + // if has_allow_origin(response) { + // return Ok(()); + // } + + // 1. If the Origin header is not present terminate this set of steps. + // The request is outside the scope of this specification. + let origin = origin(request)?; + let origin = match origin { + None => { + // Not a CORS request + return Ok(Response::new()); + } + Some(origin) => origin, + }; + + // Check if the request verb is an OPTION or something else + let cors_response = match request.method() { + Method::Options => { + let method = request_method(request)?; + let headers = request_headers(request)?; + preflight(options, origin, method, headers) + } + _ => actual_request(options, origin), + }?; + + Ok(cors_response) +} + +/// Consumes the responder and based on the provided list of allowed origins, +/// check if the requested origin is allowed. +/// Useful for pre-flight and during requests +fn validate_origin( + origin: &Origin, + allowed_origins: &AllOrSome>, +) -> Result<(), Error> { + match *allowed_origins { + // Always matching is acceptable since the list of origins can be unbounded. + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_origins) => { + allowed_origins + .get(origin) + .and_then(|_| Some(())) + .ok_or_else(|| Error::OriginNotAllowed) + } + } +} + +/// Validate allowed methods +fn validate_allowed_method( + method: &AccessControlRequestMethod, + allowed_methods: &HashSet, +) -> Result<(), Error> { + let &AccessControlRequestMethod(ref request_method) = method; + if !allowed_methods.iter().any(|m| m == request_method) { + Err(Error::MethodNotAllowed)? + } + + // TODO: Subset to route? Or just the method requested for? + Ok(()) +} + +/// Validate allowed headers +fn validate_allowed_headers( + headers: &AccessControlRequestHeaders, + allowed_headers: &AllOrSome>, +) -> Result<(), Error> { + let &AccessControlRequestHeaders(ref headers) = headers; + + match *allowed_headers { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_headers) => { + if !headers.is_empty() && !headers.is_subset(allowed_headers) { + Err(Error::HeadersNotAllowed)? + } + Ok(()) + } + } +} + +/// Gets the `Origin` request header from the request +fn origin(request: &Request) -> Result, Error> { + match Origin::from_request(request) { + Outcome::Forward(()) => Ok(None), + Outcome::Success(origin) => Ok(Some(origin)), + Outcome::Failure((_, err)) => Err(err), + } +} + +/// Gets the `Access-Control-Request-Method` request header from the request +fn request_method(request: &Request) -> Result, Error> { + match AccessControlRequestMethod::from_request(request) { + Outcome::Forward(()) => Ok(None), + Outcome::Success(method) => Ok(Some(method)), + Outcome::Failure((_, err)) => Err(err), + } +} + +/// Gets the `Access-Control-Request-Headers` request header from the request +fn request_headers(request: &Request) -> Result, Error> { + match AccessControlRequestHeaders::from_request(request) { + Outcome::Forward(()) => Ok(None), + Outcome::Success(geaders) => Ok(Some(geaders)), + Outcome::Failure((_, err)) => Err(err), + } +} + +/// Construct a preflight response based on the options. Will return an `Err` +/// if any of the preflight checks fail. +/// +/// This implementation references the +/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests). +fn preflight( + options: &Cors, + origin: Origin, + method: Option, + headers: Option, +) -> Result { + options.validate()?; + let response = Response::new(); + + // Note: All header parse failures are dealt with in the `FromRequest` trait implementation + + // 2. If the value of the Origin header is not a case-sensitive match for any of the values + // in list of origins do not set any additional headers and terminate this set of steps. + validate_origin(&origin, &options.allowed_origins)?; + let response = match options.allowed_origins { + AllOrSome::All => { + if options.send_wildcard { + response.any() + } else { + response.origin(origin.as_str(), true) + } + } + AllOrSome::Some(_) => response.origin(origin.as_str(), false), + }; + + // 3. Let `method` be the value as result of parsing the Access-Control-Request-Method + // header. + // If there is no Access-Control-Request-Method header or if parsing failed, + // do not set any additional headers and terminate this set of steps. + // The request is outside the scope of this specification. + + let method = method.ok_or_else(|| Error::MissingRequestMethod)?; + + // 4. Let header field-names be the values as result of parsing the + // Access-Control-Request-Headers headers. + // If there are no Access-Control-Request-Headers headers + // let header field-names be the empty list. + // If parsing failed do not set any additional headers and terminate this set of steps. + // The request is outside the scope of this specification. + + // 5. If method is not a case-sensitive match for any of the values in list of methods + // do not set any additional headers and terminate this set of steps. + + validate_allowed_method(&method, &options.allowed_methods)?; + let response = response.methods(&options.allowed_methods); + + // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the + // values in list of headers do not set any additional headers and terminate this set of + // steps. + + let response = if let Some(ref headers) = headers { + validate_allowed_headers(headers, &options.allowed_headers)?; + let &AccessControlRequestHeaders(ref headers) = headers; + response.headers( + headers + .iter() + .map(|s| &**s.deref()) + .collect::>() + .as_slice(), + ) + } else { + response + }; + + // 7. If the resource supports credentials add a single Access-Control-Allow-Origin header, + // with the value of the Origin header as value, and add a + // single Access-Control-Allow-Credentials header with the case-sensitive string "true" as + // value. + // Otherwise, add a single Access-Control-Allow-Origin header, + // with either the value of the Origin header or the string "*" as value. + // Note: The string "*" cannot be used for a resource that supports credentials. + + // Validation has been done in options.validate + let response = response.credentials(options.allow_credentials); + + // 8. Optionally add a single Access-Control-Max-Age header + // with as value the amount of seconds the user agent is allowed to cache the result of the + // request. + let response = response.max_age(options.max_age); + + // 9. If method is a simple method this step may be skipped. + // Add one or more Access-Control-Allow-Methods headers consisting of + // (a subset of) the list of methods. + // If a method is a simple method it does not need to be listed, but this is not prohibited. + // Since the list of methods can be unbounded, + // simply returning the method indicated by Access-Control-Request-Method + // (if supported) can be enough. + + // Done above + + // 10. If each of the header field-names is a simple header and none is Content-Type, + // this step may be skipped. + // Add one or more Access-Control-Allow-Headers headers consisting of (a subset of) + // the list of headers. + // If a header field name is a simple header and is not Content-Type, + // it is not required to be listed. Content-Type is to be listed as only a + // subset of its values makes it qualify as simple header. + // Since the list of headers can be unbounded, simply returning supported headers + // from Access-Control-Allow-Headers can be enough. + + // Done above -- we do not do anything special with simple headers + + Ok(response) +} + +/// Respond to an actual request based on the settings. +/// If the `Origin` is not provided, then this request was not made by a browser and there is no +/// CORS enforcement. +fn actual_request(options: &Cors, origin: Origin) -> Result { + options.validate()?; + let response = Response::new(); + + // Note: All header parse failures are dealt with in the `FromRequest` trait implementation + + // 2. If the value of the Origin header is not a case-sensitive match for any of the values + // in list of origins, do not set any additional headers and terminate this set of steps. + // Always matching is acceptable since the list of origins can be unbounded. + + validate_origin(&origin, &options.allowed_origins)?; + let response = match options.allowed_origins { + AllOrSome::All => { + if options.send_wildcard { + response.any() + } else { + response.origin(origin.as_str(), true) + } + } + AllOrSome::Some(_) => response.origin(origin.as_str(), false), + }; + + // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, + // with the value of the Origin header as value, and add a + // single Access-Control-Allow-Credentials header with the case-sensitive string "true" as + // value. + // Otherwise, add a single Access-Control-Allow-Origin header, + // with either the value of the Origin header or the string "*" as value. + // Note: The string "*" cannot be used for a resource that supports credentials. + + // Validation has been done in options.validate + let response = response.credentials(options.allow_credentials); + + // 4. If the list of exposed headers is not empty add one or more + // Access-Control-Expose-Headers headers, with as values the header field names given in + // the list of exposed headers. + // By not adding the appropriate headers resource can also clear the preflight result cache + // of all entries where origin is a case-sensitive match for the value of the Origin header + // and url is a case-sensitive match for the URL of the resource. + + let response = response.exposed_headers( + options + .expose_headers + .iter() + .map(|s| &**s) + .collect::>() + .as_slice(), + ); + Ok(response) } #[cfg(test)] @@ -908,65 +1077,50 @@ mod tests { use rocket::http::Method; use super::*; - // The following tests check `Response`'s validation + fn make_cors_options() -> Cors { + let (allowed_origins, failed_origins) = + AllOrSome::new_from_str_list(&["https://www.acme.com"]); + assert!(failed_origins.is_empty()); - #[test] - fn response_allows_all_origin_with_wildcard() { - let url = "https://www.example.com"; - let origin = Origin::from_str(url).unwrap(); - let allowed_origins = AllOrSome::All; - let send_wildcard = true; - - let response = Response::new(); - let response = not_err!(response.allowed_origin( - &origin, - &allowed_origins, - send_wildcard, - )); - - assert_matches!(response.allow_origin, Some(AllOrSome::All)); - assert_eq!(response.vary_origin, false); - - // Build response and check built response header - let expected_header = vec!["*"]; - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Origin") - .collect(); - assert_eq!(expected_header, actual_header); + Cors { + allowed_origins: allowed_origins, + allowed_methods: [Method::Get].iter().cloned().collect(), + allowed_headers: AllOrSome::Some( + ["Authorization"] + .into_iter() + .map(|s| s.to_string().into()) + .collect(), + ), + allow_credentials: true, + ..Default::default() + } } #[test] - fn response_allows_all_origin_with_echoed_domain() { + fn cors_is_validated() { + assert!(make_cors_options().validate().is_ok()) + } + + #[test] + #[should_panic(expected = "CredentialsWithWildcardOrigin")] + fn cors_validates_illegal_allow_credentials() { + let mut cors = make_cors_options(); + cors.allow_credentials = true; + cors.allowed_origins = AllOrSome::All; + cors.send_wildcard = true; + + cors.validate().unwrap(); + } + + // The following tests check validation + + #[test] + fn validate_origin_allows_all_origins() { let url = "https://www.example.com"; let origin = Origin::from_str(url).unwrap(); let allowed_origins = AllOrSome::All; - let send_wildcard = false; - let response = Response::new(); - let response = not_err!(response.allowed_origin( - &origin, - &allowed_origins, - send_wildcard, - )); - - let actual_origin = assert_matches!( - response.allow_origin, - Some(AllOrSome::Some(ref origin)), - origin - ); - assert_eq!(url, actual_origin); - assert_eq!(response.vary_origin, true); - - // Build response and check built response header - let expected_header = vec![url]; - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Origin") - .collect(); - assert_eq!(expected_header, actual_header); + not_err!(validate_origin(&origin, &allowed_origins)); } #[test] @@ -976,32 +1130,8 @@ mod tests { let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.example.com"]); assert!(failed_origins.is_empty()); - let send_wildcard = false; - let response = Response::new(); - let response = not_err!(response.allowed_origin( - &origin, - &allowed_origins, - send_wildcard, - )); - - let actual_origin = assert_matches!( - response.allow_origin, - Some(AllOrSome::Some(ref origin)), - origin - ); - - assert_eq!(url, actual_origin); - assert_eq!(response.vary_origin, false); - - // Build response and check built response header - let expected_header = vec![url]; - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Origin") - .collect(); - assert_eq!(expected_header, actual_header); + not_err!(validate_origin(&origin, &allowed_origins)); } #[test] @@ -1012,41 +1142,8 @@ mod tests { let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.example.com"]); assert!(failed_origins.is_empty()); - let send_wildcard = false; - let response = Response::new(); - let _ = response - .allowed_origin(&origin, &allowed_origins, send_wildcard) - .unwrap(); - } - - #[test] - #[should_panic(expected = "CredentialsWithWildcardOrigin")] - fn response_credentials_does_not_allow_wildcard_with_all_origins() { - let response = Response::new(); - let response = response.any(); - - let _ = response.credentials(true).unwrap(); - } - - #[test] - fn response_credentials_allows_specific_origins() { - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - - let response = response.credentials(true).expect( - "to allow specific origins", - ); - assert_eq!(response.allow_credentials, true); - - // Build response and check built response header - let expected_header = vec!["true"]; - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Credentials") - .collect(); - assert_eq!(expected_header, actual_header); + validate_origin(&origin, &allowed_origins).unwrap(); } #[test] @@ -1057,7 +1154,7 @@ mod tests { let response = response.exposed_headers(&headers); // Build response and check built response header - let response = response.build(response::Response::new()); + let response = response.response(response::Response::new()); let actual_header: Vec<_> = response .headers() .get("Access-Control-Expose-Headers") @@ -1081,7 +1178,7 @@ mod tests { // Build response and check built response header let expected_header = vec!["42"]; - let response = response.build(response::Response::new()); + let response = response.response(response::Response::new()); let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect(); assert_eq!(expected_header, actual_header); } @@ -1094,167 +1191,96 @@ mod tests { let response = response.max_age(None); // Build response and check built response header - let response = response.build(response::Response::new()); - assert!(response - .headers() - .get("Access-Control-Max-Age") - .next().is_none()) - } - - /// When all headers are allowed, tests that the requested headers are echoed back - #[test] - fn response_allowed_headers_echoes_back_requested_headers() { - let allowed_headers = AllOrSome::All; - let requested_headers = vec!["Bar", "Foo"]; - - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - let response = response - .allowed_headers( - &FromStr::from_str(&requested_headers.join(",")).unwrap(), - &allowed_headers, - ) - .expect("to not fail"); - - // Build response and check built response header - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Headers") - .collect(); - - assert_eq!(1, actual_header.len()); - let mut actual_headers: Vec = actual_header[0] - .split(',') - .map(|header| header.trim().to_string()) - .collect(); - actual_headers.sort(); - assert_eq!(requested_headers, actual_headers); + let response = response.response(response::Response::new()); + assert!( + response + .headers() + .get("Access-Control-Max-Age") + .next() + .is_none() + ) } #[test] - fn response_allowed_methods_sets_headers_properly() { - let allowed_methods = vec![ - Method::Get, - Method::Head, - Method::Post, - ].into_iter() + fn allowed_methods_validated_correctly() { + let allowed_methods = vec![Method::Get, Method::Head, Method::Post] + .into_iter() .collect(); let method = "GET"; - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - let response = response - .allowed_methods( - &FromStr::from_str(method).expect("not to fail"), - &allowed_methods, - ) - .expect("not to fail"); - - // Build response and check built response header - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Methods") - .collect(); - - assert_eq!(1, actual_header.len()); - let mut actual_headers: Vec = actual_header[0] - .split(',') - .map(|header| header.trim().to_string()) - .collect(); - actual_headers.sort(); - let mut expected_headers: Vec<_> = allowed_methods.iter().map(|m| m.as_str()).collect(); - expected_headers.sort(); - assert_eq!(expected_headers, actual_headers); + not_err!(validate_allowed_method( + &FromStr::from_str(method).expect("not to fail"), + &allowed_methods, + )); } #[test] #[should_panic(expected = "MethodNotAllowed")] - fn response_allowed_method_errors_on_disallowed_method() { - let allowed_methods = vec![ - Method::Get, - Method::Head, - Method::Post, - ].into_iter() + fn allowed_methods_errors_on_disallowed_method() { + let allowed_methods = vec![Method::Get, Method::Head, Method::Post] + .into_iter() .collect(); let method = "DELETE"; - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - let _ = response - .allowed_methods( - &FromStr::from_str(method).expect("not to fail"), - &allowed_methods, - ) - .unwrap(); + validate_allowed_method( + &FromStr::from_str(method).expect("not to fail"), + &allowed_methods, + ).unwrap() + } + + #[test] + fn all_allowed_headers_are_validated_correctly() { + let allowed_headers = AllOrSome::All; + let requested_headers = vec!["Bar", "Foo"]; + + not_err!(validate_allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &allowed_headers, + )); } /// `Response::allowed_headers` should check that headers are allowed, and only /// echoes back the list that is actually requested for and not the whole list #[test] - fn response_allowed_headers_validates_and_echoes_requested_headers() { + fn allowed_headers_are_validated_correctly() { let allowed_headers = vec!["Bar", "Baz", "Foo"]; let requested_headers = vec!["Bar", "Foo"]; - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - let response = response - .allowed_headers( - &FromStr::from_str(&requested_headers.join(",")).unwrap(), - &AllOrSome::Some( - allowed_headers - .iter() - .map(|s| FromStr::from_str(*s).unwrap()) - .collect(), - ), - ) - .expect("to not fail"); - - // Build response and check built response header - let response = response.build(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Headers") - .collect(); - - assert_eq!(1, actual_header.len()); - let mut actual_headers: Vec = actual_header[0] - .split(',') - .map(|header| header.trim().to_string()) - .collect(); - actual_headers.sort(); - assert_eq!(requested_headers, actual_headers); + not_err!(validate_allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &AllOrSome::Some( + allowed_headers + .iter() + .map(|s| FromStr::from_str(*s).unwrap()) + .collect(), + ), + )); } #[test] #[should_panic(expected = "HeadersNotAllowed")] - fn response_allowed_headers_errors_on_non_subset() { + fn allowed_headers_errors_on_non_subset() { let allowed_headers = vec!["Bar", "Baz", "Foo"]; let requested_headers = vec!["Bar", "Foo", "Unknown"]; - let response = Response::new(); - let response = response.origin("https://www.example.com", false); - let _ = response - .allowed_headers( - &FromStr::from_str(&requested_headers.join(",")).unwrap(), - &AllOrSome::Some( - allowed_headers - .iter() - .map(|s| FromStr::from_str(*s).unwrap()) - .collect(), - ), - ) - .unwrap(); + validate_allowed_headers( + &FromStr::from_str(&requested_headers.join(",")).unwrap(), + &AllOrSome::Some( + allowed_headers + .iter() + .map(|s| FromStr::from_str(*s).unwrap()) + .collect(), + ), + ).unwrap(); } #[test] fn response_does_not_build_if_origin_is_not_set() { let response = Response::new(); - let response = response.build(response::Response::new()); + let response = response.response(response::Response::new()); let headers: Vec<_> = response.headers().iter().collect(); assert_eq!(headers.len(), 0); @@ -1273,7 +1299,7 @@ mod tests { let response = Response::new(); let response = response.origin("https://www.example.com", false); - let response = response.build(original); + let response = response.response(original); // Check CORS header let expected_header = vec!["https://www.example.com"]; let actual_header: Vec<_> = response @@ -1288,42 +1314,19 @@ mod tests { assert_eq!(expected_header, actual_header); // Check that `Access-Control-Max-Age` is removed - assert!(response.headers().get("Access-Control-Max-Age").next().is_none()); + assert!( + response + .headers() + .get("Access-Control-Max-Age") + .next() + .is_none() + ); } - // The following tests check that preflight checks are done properly + // TODO: Preflight tests + // TODO: Actual requests tests - // fn make_cors_options() -> Cors { - // let (allowed_origins, failed_origins) = - // AllOrSome::new_from_str_list(&["https://www.acme.com"]); - // assert!(failed_origins.is_empty()); - - // Cors { - // allowed_origins: allowed_origins, - // allowed_methods: [Method::Get].iter().cloned().collect(), - // allowed_headers: AllOrSome::Some( - // ["Authorization"] - // .into_iter() - // .map(|s| s.to_string().into()) - // .collect(), - // ), - // allow_credentials: true, - // ..Default::default() - // } - // } - - // /// Tests that non CORS preflight are let through without modification - // #[test] - // fn preflight_missing_origins_are_let_through() { - // let options = make_cors_options(); - // let client = make_client(); - // let request = client.get("/"); - - // let response = options.preflight((), None, None, None).expect("not to fail"); - - // let headers: Vec<_> = response.headers().iter().collect(); - // assert_eq!(headers.len(), 0); - // } + // Origin all (wildcard + echoed with Vary). Origin Echo } diff --git a/tests/routes.rs b/tests/ad_hoc.rs similarity index 81% rename from tests/routes.rs rename to tests/ad_hoc.rs index dc421f9..ac49944 100644 --- a/tests/routes.rs +++ b/tests/ad_hoc.rs @@ -1,37 +1,68 @@ -//! This crate tests using rocket_cors using the "classic" per-route handling +//! This crate tests using rocket_cors using the "classic" ad-hoc per-route handling #![feature(plugin, custom_derive)] #![plugin(rocket_codegen)] extern crate hyper; extern crate rocket; -extern crate rocket_cors; +extern crate rocket_cors as cors; use std::str::FromStr; -use rocket::State; +use rocket::{Response, State}; use rocket::http::Method; use rocket::http::{Header, Status}; use rocket::local::Client; -use rocket_cors::*; #[options("/")] -fn cors_options(options: State) -> Responder<&str> { - rocket_cors::respond(options, "") +fn cors_options(cors: cors::Guard) -> cors::Responder<&str> { + cors.responder("") } #[get("/")] -fn cors(options: State) -> Responder<&str> { - rocket_cors::respond(options, "Hello CORS") +fn cors(cors: cors::Guard) -> cors::Responder<&str> { + cors.responder("Hello CORS") } -fn make_cors_options() -> Cors { - let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); +// The following routes tests that the routes can be compiled with ad-hoc CORS Response/Responders + +/// Using a `Response` instead of a `Responder` +#[allow(unmounted_route)] +#[get("/")] +fn response(cors: cors::Guard) -> Response { + cors.response(Response::new()) +} + +/// `Responder` with String +#[allow(unmounted_route)] +#[get("/")] +fn responder_string(cors: cors::Guard) -> cors::Responder { + cors.responder("Hello CORS".to_string()) +} + +/// `Responder` with 'static () +#[allow(unmounted_route)] +#[get("/")] +fn responder_unit(cors: cors::Guard) -> cors::Responder<()> { + cors.responder(()) +} + +struct SomeState; +/// Borrow `SomeState` from Rocket +#[allow(unmounted_route)] +#[get("/")] +fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Responder<'r, &'r str> { + cors.responder("hmm") +} + +fn make_cors_options() -> cors::Cors { + let (allowed_origins, failed_origins) = + cors::AllOrSome::new_from_str_list(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); - Cors { + cors::Cors { allowed_origins: allowed_origins, allowed_methods: [Method::Get].iter().cloned().collect(), - allowed_headers: AllOrSome::Some( + allowed_headers: cors::AllOrSome::Some( ["Authorization"] .into_iter() .map(|s| s.to_string().into()) @@ -44,12 +75,13 @@ fn make_cors_options() -> Cors { #[test] fn smoke_test() { - let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); + let (allowed_origins, failed_origins) = + cors::AllOrSome::new_from_str_list(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); - let cors_options = rocket_cors::Cors { + let cors_options = cors::Cors { allowed_origins: allowed_origins, allowed_methods: [Method::Get].iter().cloned().collect(), - allowed_headers: AllOrSome::Some( + allowed_headers: cors::AllOrSome::Some( ["Authorization"] .iter() .map(|s| s.to_string().into()) @@ -81,7 +113,7 @@ fn smoke_test() { .header(request_headers); let response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); // "Actual" request let origin_header = Header::from( @@ -91,7 +123,7 @@ fn smoke_test() { let req = client.get("/").header(origin_header).header(authorization); let mut response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); @@ -121,7 +153,7 @@ fn cors_options_check() { .header(request_headers); let response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); } #[test] @@ -139,7 +171,7 @@ fn cors_get_check() { let mut response = req.dispatch(); println!("{:?}", response); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); } @@ -156,7 +188,7 @@ fn cors_get_no_origin() { let req = client.get("/").header(authorization); let mut response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); } @@ -207,7 +239,7 @@ fn cors_options_missing_origin() { ); let response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); } #[test] diff --git a/tests/fairings.rs b/tests/fairings.rs new file mode 100644 index 0000000..4fdfd68 --- /dev/null +++ b/tests/fairings.rs @@ -0,0 +1,240 @@ +//! This crate tests using rocket_cors using Fairings + +#![feature(plugin, custom_derive)] +#![plugin(rocket_codegen)] +extern crate hyper; +extern crate rocket; +extern crate rocket_cors; + +use std::str::FromStr; + +use rocket::http::Method; +use rocket::http::{Header, Status}; +use rocket::local::Client; +use rocket_cors::*; + +#[get("/")] +fn cors<'a>() -> &'a str { + "Hello CORS" +} + +fn make_cors_options() -> Cors { + let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); + assert!(failed_origins.is_empty()); + + Cors { + allowed_origins: allowed_origins, + allowed_methods: [Method::Get].iter().cloned().collect(), + allowed_headers: AllOrSome::Some( + ["Authorization"] + .into_iter() + .map(|s| s.to_string().into()) + .collect(), + ), + allow_credentials: true, + ..Default::default() + } +} + +fn rocket() -> rocket::Rocket { + rocket::ignite().mount("/", routes![cors]).attach( + make_cors_options(), + ) +} + +#[test] +fn smoke_test() { + let client = Client::new(rocket()).unwrap(); + + // `Options` pre-flight checks + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert!(response.status().class().is_success()); + + // "Actual" request + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let mut response = req.dispatch(); + assert!(response.status().class().is_success()); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); + +} + +#[test] +fn cors_options_check() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert!(response.status().class().is_success()); +} + +#[test] +fn cors_get_check() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let mut response = req.dispatch(); + println!("{:?}", response); + assert!(response.status().class().is_success()); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); +} + +/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) +#[test] +fn cors_get_no_origin() { + let client = Client::new(rocket()).unwrap(); + + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(authorization); + + let mut response = req.dispatch(); + assert!(response.status().class().is_success()); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); +} + +#[test] +fn cors_options_bad_origin() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +} + +#[test] +fn cors_options_missing_origin() { + let client = Client::new(rocket()).unwrap(); + + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client.options("/").header(method_header).header( + request_headers, + ); + + let response = req.dispatch(); + assert!(response.status().class().is_success()); +} + +#[test] +fn cors_options_bad_request_method() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Post, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +} + +#[test] +fn cors_options_bad_request_header() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = + hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +} + +#[test] +fn cors_get_bad_origin() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +} diff --git a/tests/headers.rs b/tests/headers.rs index 9be98da..544d823 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -9,7 +9,7 @@ use std::ops::Deref; use std::str::FromStr; use rocket::local::Client; -use rocket::http::{Header, Status}; +use rocket::http::Header; use rocket_cors::headers::*; #[get("/request_headers")] @@ -52,7 +52,7 @@ fn request_headers_round_trip_smoke_test() { .header(request_headers); let mut response = req.dispatch(); - assert_eq!(Status::Ok, response.status()); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()).expect( "Non-empty body", );