diff --git a/src/fairing.rs b/src/fairing.rs new file mode 100644 index 0000000..f2bb393 --- /dev/null +++ b/src/fairing.rs @@ -0,0 +1,107 @@ +//! Fairing implementation +use rocket::{self, Request, Outcome}; +use rocket::http::{self, Status}; + +use {Cors, Error, validate, preflight_response, actual_request_response, origin, request_headers}; + +/// Route for Fairing error handling +pub(crate) fn fairing_error_route<'r>( + request: &'r Request, + _: rocket::Data, +) -> rocket::handler::Outcome<'r> { + let status = request.get_param::(0).unwrap_or_else(|e| { + error_!("Fairing Error Handling Route error: {:?}", e); + 500 + }); + let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError); + Outcome::Failure(status) +} + +/// Create a new `Route` for Fairing handling +fn fairing_route() -> rocket::Route { + rocket::Route::new(http::Method::Get, "/", fairing_error_route) +} + +/// Modifies a `Request` to route to Fairing error handler +fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Request) { + request.set_method(http::Method::Get); + request.set_uri(format!("{}/{}", options.fairing_route_base, status)); +} + +fn on_response_wrapper( + options: &Cors, + request: &Request, + response: &mut rocket::Response, +) -> Result<(), Error> { + let origin = match origin(request)? { + None => { + // Not a CORS request + return Ok(()); + } + Some(origin) => origin, + }; + + let cors_response = if request.method() == http::Method::Options { + let headers = request_headers(request)?; + preflight_response(options, origin, headers) + } else { + actual_request_response(options, origin) + }; + + cors_response.merge(response); + + // If this was an OPTIONS request and no route can be found, we should turn this + // into a HTTP 204 with no content body. + // This allows the user to not have to specify an OPTIONS route for everything. + // + // TODO: Is there anyway we can make this smarter? Only modify status codes for + // requests where an actual route exist? + if request.method() == http::Method::Options && request.method() == http::Method::Options && + request.route().is_none() + { + response.set_status(Status::NoContent); + let _ = response.take_body(); + } + Ok(()) +} + +impl rocket::fairing::Fairing for Cors { + fn info(&self) -> rocket::fairing::Info { + rocket::fairing::Info { + name: "CORS", + kind: rocket::fairing::Kind::Attach | rocket::fairing::Kind::Request | + rocket::fairing::Kind::Response, + } + } + + fn on_attach(&self, rocket: rocket::Rocket) -> Result { + match self.validate() { + Ok(()) => { + Ok(rocket.mount(&self.fairing_route_base, vec![fairing_route()])) + } + Err(e) => { + error_!("Error attaching CORS fairing: {}", e); + Err(rocket) + } + } + } + + fn on_request(&self, request: &mut Request, _: &rocket::Data) { + // Build and merge CORS response + // Type annotation is for sanity check + let cors_response = validate(self, request); + if let Err(ref err) = cors_response { + error_!("CORS Error: {}", err); + let status = err.status(); + route_to_fairing_error_handler(self, status.code, request); + } + } + + fn on_response(&self, request: &Request, response: &mut rocket::Response) { + if let Err(err) = on_response_wrapper(self, request, response) { + error_!("Fairings on_response error: {}\nMost likely a bug", err); + response.set_status(Status::InternalServerError); + let _ = response.take_body(); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 3fc236e..4b294df 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -115,6 +115,7 @@ extern crate serde_json; #[cfg(test)] #[macro_use] mod test_macros; +mod fairing; pub mod headers; @@ -130,7 +131,6 @@ use std::str::FromStr; use rocket::{Outcome, State}; use rocket::http::{self, Status}; -use rocket::fairing; use rocket::request::{Request, FromRequest}; use rocket::response; use serde::{Serialize, Deserialize}; @@ -374,7 +374,7 @@ impl<'de> Deserialize<'de> for Method { /// /// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this /// struct. The default for each field is described in the docuementation for the field. -#[derive(Eq, PartialEq, Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)] pub struct Cors { /// Origins that are allowed to make requests. /// Will be verified against the `Origin` request header. @@ -506,7 +506,7 @@ impl Cors { /// `Guard` type. This is useful if you want an even more ad-hoc based approach to respond to /// CORS by using a `Cors` that is not in Rocket's managed state. pub fn guard<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result, Error> { - let response = build_cors_response(self, request)?; + let response = validate_and_build(self, request)?; Ok(Guard::new(response)) } @@ -520,88 +520,6 @@ impl Cors { Ok(()) } - - /// Create a new `Route` for Fairing handling - fn fairing_route(&self) -> rocket::Route { - rocket::Route::new(http::Method::Get, "/", fairing_error_route) - } - - /// Modifies a `Request` to route to Fairing error handler - fn route_to_fairing_error_handler(&self, status: u16, request: &mut Request) { - request.set_method(http::Method::Get); - request.set_uri(format!("{}/{}", self.fairing_route_base, status)); - } -} - -impl fairing::Fairing for Cors { - fn info(&self) -> fairing::Info { - fairing::Info { - name: "CORS", - kind: fairing::Kind::Attach | fairing::Kind::Request | fairing::Kind::Response, - } - } - - fn on_attach(&self, rocket: rocket::Rocket) -> Result { - match self.validate() { - Ok(()) => { - Ok(rocket.mount(&self.fairing_route_base, vec![self.fairing_route()])) - } - Err(e) => { - error_!("Error attaching CORS fairing: {}", e); - Err(rocket) - } - } - } - - fn on_request(&self, request: &mut Request, _: &rocket::Data) { - // Build and merge CORS response - match build_cors_response(self, request) { - Err(err) => { - error_!("CORS Error: {}", err); - let status = err.status(); - self.route_to_fairing_error_handler(status.code, request); - } - Ok(cors_response) => { - // TODO: How to pass response downstream? - let _ = cors_response; - } - }; - } - - fn on_response(&self, request: &Request, response: &mut rocket::Response) { - // Build and merge CORS response - match build_cors_response(self, request) { - Err(_) => { - // We have dealt with this already - } - Ok(cors_response) => { - cors_response.merge(response); - - // If this was an OPTIONS request and no route can be found, we should turn this - // into a HTTP 204 with no content body. - // This allows the user to not have to specify an OPTIONS route for everything. - // - // TODO: Is there anyway we can make this smarter? Only modify status codes for - // requests where an actual route exist? - if request.method() == http::Method::Options && request.route().is_none() { - response.set_status(Status::NoContent); - let _ = response.take_body(); - } - } - }; - - - } -} - -/// Route for Fairing error handling -fn fairing_error_route<'r>(request: &'r Request, _: rocket::Data) -> rocket::handler::Outcome<'r> { - let status = request.get_param::(0).unwrap_or_else(|e| { - error_!("Fairing Error Handling Route error: {:?}", e); - 500 - }); - let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError); - Outcome::Failure(status) } /// A CORS Response which provides the following CORS headers: @@ -775,11 +693,11 @@ impl Response { } /// Validate and create a new CORS Response from a request and settings - pub fn build_cors_response<'a, 'r>( + pub fn validate_and_build<'a, 'r>( options: &'a Cors, request: &'a Request<'r>, ) -> Result { - build_cors_response(options, request) + validate_and_build(options, request) } } @@ -829,7 +747,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> { } }; - match Response::build_cors_response(&options, request) { + match Response::validate_and_build(&options, request) { Ok(response) => Outcome::Success(Self::new(response)), Err(error) => Outcome::Failure((error.status(), error)), } @@ -878,35 +796,60 @@ impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R } } -/// Validates a request for CORS and returns a CORS Response -fn build_cors_response(options: &Cors, request: &Request) -> Result { - // Existing CORS response? - // if has_allow_origin(response) { - // return Ok(()); - // } +/// Result of CORS validation. +/// +/// The variants hold enough information to build a response to the validation result +enum ValidationResult { + /// Not a CORS request + None, + /// Successful preflight request + Preflight { + origin: Origin, + headers: Option, + }, + /// Successful actual request + Request { origin: Origin }, +} +/// Validates a request for CORS and returns a CORS Response +fn validate_and_build(options: &Cors, request: &Request) -> Result { + let result = validate(options, request)?; + + Ok(match result { + ValidationResult::None => Response::new(), + ValidationResult::Preflight { origin, headers } => { + preflight_response(options, origin, headers) + } + ValidationResult::Request { origin } => actual_request_response(options, origin), + }) +} + +/// Validate a CORS request +fn validate(options: &Cors, request: &Request) -> Result { // 1. If the Origin header is not present terminate this set of steps. // The request is outside the scope of this specification. let origin = origin(request)?; let origin = match origin { None => { // Not a CORS request - return Ok(Response::new()); + return Ok(ValidationResult::None); } Some(origin) => origin, }; // Check if the request verb is an OPTION or something else - let cors_response = match request.method() { + match request.method() { http::Method::Options => { let method = request_method(request)?; let headers = request_headers(request)?; - preflight(options, origin, method, headers) + preflight_validate(options, &origin, &method, &headers)?; + Ok(ValidationResult::Preflight { origin, headers }) } - _ => actual_request(options, origin), - }?; - - Ok(cors_response) + _ => { + actual_request_validate(options, &origin)?; + Ok(ValidationResult::Request { origin }) + } + } } /// Consumes the responder and based on the provided list of allowed origins, @@ -987,35 +930,24 @@ fn request_headers(request: &Request) -> Result, - headers: Option, -) -> Result { - options.validate()?; - let response = Response::new(); + origin: &Origin, + method: &Option, + headers: &Option, +) -> Result<(), Error> { + + options.validate()?; // Fast-forward check for #7 // Note: All header parse failures are dealt with in the `FromRequest` trait implementation // 2. If the value of the Origin header is not a case-sensitive match for any of the values // in list of origins do not set any additional headers and terminate this set of steps. validate_origin(&origin, &options.allowed_origins)?; - let response = match options.allowed_origins { - AllOrSome::All => { - if options.send_wildcard { - response.any() - } else { - response.origin(origin.as_str(), true) - } - } - AllOrSome::Some(_) => response.origin(origin.as_str(), false), - }; // 3. Let `method` be the value as result of parsing the Access-Control-Request-Method // header. @@ -1023,7 +955,7 @@ fn preflight( // do not set any additional headers and terminate this set of steps. // The request is outside the scope of this specification. - let method = method.ok_or_else(|| Error::MissingRequestMethod)?; + let method = method.as_ref().ok_or_else(|| Error::MissingRequestMethod)?; // 4. Let header field-names be the values as result of parsing the // Access-Control-Request-Headers headers. @@ -1035,26 +967,29 @@ fn preflight( // 5. If method is not a case-sensitive match for any of the values in list of methods // do not set any additional headers and terminate this set of steps. - validate_allowed_method(&method, &options.allowed_methods)?; - let response = response.methods(&options.allowed_methods); + validate_allowed_method(method, &options.allowed_methods)?; // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the // values in list of headers do not set any additional headers and terminate this set of // steps. - let response = if let Some(ref headers) = headers { + if let &Some(ref headers) = headers { validate_allowed_headers(headers, &options.allowed_headers)?; - let &AccessControlRequestHeaders(ref headers) = headers; - response.headers( - headers - .iter() - .map(|s| &**s.deref()) - .collect::>() - .as_slice(), - ) - } else { - response - }; + } + + Ok(()) +} + +/// Build a response for pre-flight checks +/// +/// This implementation references the +/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests). +fn preflight_response( + options: &Cors, + origin: Origin, + headers: Option, +) -> Response { + let response = Response::new(); // 7. If the resource supports credentials add a single Access-Control-Allow-Origin header, // with the value of the Origin header as value, and add a @@ -1065,6 +1000,16 @@ fn preflight( // Note: The string "*" cannot be used for a resource that supports credentials. // Validation has been done in options.validate + let response = match options.allowed_origins { + AllOrSome::All => { + if options.send_wildcard { + response.any() + } else { + response.origin(origin.as_str(), true) + } + } + AllOrSome::Some(_) => response.origin(origin.as_str(), false), + }; let response = response.credentials(options.allow_credentials); // 8. Optionally add a single Access-Control-Max-Age header @@ -1080,7 +1025,7 @@ fn preflight( // simply returning the method indicated by Access-Control-Request-Method // (if supported) can be enough. - // Done above + let response = response.methods(&options.allowed_methods); // 10. If each of the header field-names is a simple header and none is Content-Type, // this step may be skipped. @@ -1092,17 +1037,29 @@ fn preflight( // Since the list of headers can be unbounded, simply returning supported headers // from Access-Control-Allow-Headers can be enough. - // Done above -- we do not do anything special with simple headers + // We do not do anything special with simple headers + let response = if let Some(ref headers) = headers { + let &AccessControlRequestHeaders(ref headers) = headers; + response.headers( + headers + .iter() + .map(|s| &**s.deref()) + .collect::>() + .as_slice(), + ) + } else { + response + }; - Ok(response) + response } -/// Respond to an actual request based on the settings. -/// If the `Origin` is not provided, then this request was not made by a browser and there is no -/// CORS enforcement. -fn actual_request(options: &Cors, origin: Origin) -> Result { +/// Do checks for an actual request +/// +/// This implementation references the +/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests). +fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> { options.validate()?; - let response = Response::new(); // Note: All header parse failures are dealt with in the `FromRequest` trait implementation @@ -1111,6 +1068,27 @@ fn actual_request(options: &Cors, origin: Origin) -> Result { // Always matching is acceptable since the list of origins can be unbounded. validate_origin(&origin, &options.allowed_origins)?; + + Ok(()) +} + +/// Build the response for an actual request +/// +/// This implementation references the +/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests). +fn actual_request_response(options: &Cors, origin: Origin) -> Response { + let response = Response::new(); + + // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, + // with the value of the Origin header as value, and add a + // single Access-Control-Allow-Credentials header with the case-sensitive string "true" as + // value. + // Otherwise, add a single Access-Control-Allow-Origin header, + // with either the value of the Origin header or the string "*" as value. + // Note: The string "*" cannot be used for a resource that supports credentials. + + // Validation has been done in options.validate + let response = match options.allowed_origins { AllOrSome::All => { if options.send_wildcard { @@ -1122,15 +1100,6 @@ fn actual_request(options: &Cors, origin: Origin) -> Result { AllOrSome::Some(_) => response.origin(origin.as_str(), false), }; - // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, - // with the value of the Origin header as value, and add a - // single Access-Control-Allow-Credentials header with the case-sensitive string "true" as - // value. - // Otherwise, add a single Access-Control-Allow-Origin header, - // with either the value of the Origin header or the string "*" as value. - // Note: The string "*" cannot be used for a resource that supports credentials. - - // Validation has been done in options.validate let response = response.credentials(options.allow_credentials); // 4. If the list of exposed headers is not empty add one or more @@ -1148,7 +1117,8 @@ fn actual_request(options: &Cors, origin: Origin) -> Result { .collect::>() .as_slice(), ); - Ok(response) + + response } #[cfg(test)] diff --git a/tests/fairings.rs b/tests/fairing.rs similarity index 98% rename from tests/fairings.rs rename to tests/fairing.rs index e223bef..7518ee1 100644 --- a/tests/fairings.rs +++ b/tests/fairing.rs @@ -164,6 +164,7 @@ fn cors_options_bad_origin() { assert_eq!(response.status(), Status::Forbidden); } +/// Unlike the "ad-hoc" mode, this should return 404 because we don't have such a route #[test] fn cors_options_missing_origin() { let client = Client::new(rocket()).unwrap(); @@ -180,7 +181,7 @@ fn cors_options_missing_origin() { ); let response = req.dispatch(); - assert!(response.status().class().is_success()); + assert_eq!(response.status(), Status::NotFound); } #[test]