From 41f5ac11d8003dddb72d9649a7a22fea0530b87f Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Mon, 17 Jul 2017 19:19:51 +0800 Subject: [PATCH] Refactor to separate out validation from response building step 1 --- src/fairing.rs | 64 ++++++++++------ src/lib.rs | 190 ++++++++++++++++++++++++++++++----------------- tests/fairing.rs | 3 +- 3 files changed, 166 insertions(+), 91 deletions(-) diff --git a/src/fairing.rs b/src/fairing.rs index 5e85d87..f2bb393 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -2,7 +2,7 @@ use rocket::{self, Request, Outcome}; use rocket::http::{self, Status}; -use {Cors, build_cors_response}; +use {Cors, Error, validate, preflight_response, actual_request_response, origin, request_headers}; /// Route for Fairing error handling pub(crate) fn fairing_error_route<'r>( @@ -28,6 +28,43 @@ fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Req request.set_uri(format!("{}/{}", options.fairing_route_base, status)); } +fn on_response_wrapper( + options: &Cors, + request: &Request, + response: &mut rocket::Response, +) -> Result<(), Error> { + let origin = match origin(request)? { + None => { + // Not a CORS request + return Ok(()); + } + Some(origin) => origin, + }; + + let cors_response = if request.method() == http::Method::Options { + let headers = request_headers(request)?; + preflight_response(options, origin, headers) + } else { + actual_request_response(options, origin) + }; + + cors_response.merge(response); + + // If this was an OPTIONS request and no route can be found, we should turn this + // into a HTTP 204 with no content body. + // This allows the user to not have to specify an OPTIONS route for everything. + // + // TODO: Is there anyway we can make this smarter? Only modify status codes for + // requests where an actual route exist? + if request.method() == http::Method::Options && request.method() == http::Method::Options && + request.route().is_none() + { + response.set_status(Status::NoContent); + let _ = response.take_body(); + } + Ok(()) +} + impl rocket::fairing::Fairing for Cors { fn info(&self) -> rocket::fairing::Info { rocket::fairing::Info { @@ -52,7 +89,7 @@ impl rocket::fairing::Fairing for Cors { fn on_request(&self, request: &mut Request, _: &rocket::Data) { // Build and merge CORS response // Type annotation is for sanity check - let cors_response = build_cors_response(self, request); + let cors_response = validate(self, request); if let Err(ref err) = cors_response { error_!("CORS Error: {}", err); let status = err.status(); @@ -61,25 +98,10 @@ impl rocket::fairing::Fairing for Cors { } fn on_response(&self, request: &Request, response: &mut rocket::Response) { - // Rebuild the response - match build_cors_response(self, request) { - Err(_) => { - // We have dealt with this already - } - Ok(cors_response) => { - cors_response.merge(response); - - // If this was an OPTIONS request and no route can be found, we should turn this - // into a HTTP 204 with no content body. - // This allows the user to not have to specify an OPTIONS route for everything. - // - // TODO: Is there anyway we can make this smarter? Only modify status codes for - // requests where an actual route exist? - if request.method() == http::Method::Options && request.route().is_none() { - response.set_status(Status::NoContent); - let _ = response.take_body(); - } - } + if let Err(err) = on_response_wrapper(self, request, response) { + error_!("Fairings on_response error: {}\nMost likely a bug", err); + response.set_status(Status::InternalServerError); + let _ = response.take_body(); } } } diff --git a/src/lib.rs b/src/lib.rs index ec1e799..4b294df 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -506,7 +506,7 @@ impl Cors { /// `Guard` type. This is useful if you want an even more ad-hoc based approach to respond to /// CORS by using a `Cors` that is not in Rocket's managed state. pub fn guard<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result, Error> { - let response = build_cors_response(self, request)?; + let response = validate_and_build(self, request)?; Ok(Guard::new(response)) } @@ -693,11 +693,11 @@ impl Response { } /// Validate and create a new CORS Response from a request and settings - pub fn build_cors_response<'a, 'r>( + pub fn validate_and_build<'a, 'r>( options: &'a Cors, request: &'a Request<'r>, ) -> Result { - build_cors_response(options, request) + validate_and_build(options, request) } } @@ -747,7 +747,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> { } }; - match Response::build_cors_response(&options, request) { + match Response::validate_and_build(&options, request) { Ok(response) => Outcome::Success(Self::new(response)), Err(error) => Outcome::Failure((error.status(), error)), } @@ -796,35 +796,60 @@ impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R } } -/// Validates a request for CORS and returns a CORS Response -fn build_cors_response(options: &Cors, request: &Request) -> Result { - // Existing CORS response? - // if has_allow_origin(response) { - // return Ok(()); - // } +/// Result of CORS validation. +/// +/// The variants hold enough information to build a response to the validation result +enum ValidationResult { + /// Not a CORS request + None, + /// Successful preflight request + Preflight { + origin: Origin, + headers: Option, + }, + /// Successful actual request + Request { origin: Origin }, +} +/// Validates a request for CORS and returns a CORS Response +fn validate_and_build(options: &Cors, request: &Request) -> Result { + let result = validate(options, request)?; + + Ok(match result { + ValidationResult::None => Response::new(), + ValidationResult::Preflight { origin, headers } => { + preflight_response(options, origin, headers) + } + ValidationResult::Request { origin } => actual_request_response(options, origin), + }) +} + +/// Validate a CORS request +fn validate(options: &Cors, request: &Request) -> Result { // 1. If the Origin header is not present terminate this set of steps. // The request is outside the scope of this specification. let origin = origin(request)?; let origin = match origin { None => { // Not a CORS request - return Ok(Response::new()); + return Ok(ValidationResult::None); } Some(origin) => origin, }; // Check if the request verb is an OPTION or something else - let cors_response = match request.method() { + match request.method() { http::Method::Options => { let method = request_method(request)?; let headers = request_headers(request)?; - preflight(options, origin, method, headers) + preflight_validate(options, &origin, &method, &headers)?; + Ok(ValidationResult::Preflight { origin, headers }) } - _ => actual_request(options, origin), - }?; - - Ok(cors_response) + _ => { + actual_request_validate(options, &origin)?; + Ok(ValidationResult::Request { origin }) + } + } } /// Consumes the responder and based on the provided list of allowed origins, @@ -905,35 +930,24 @@ fn request_headers(request: &Request) -> Result, - headers: Option, -) -> Result { - options.validate()?; - let response = Response::new(); + origin: &Origin, + method: &Option, + headers: &Option, +) -> Result<(), Error> { + + options.validate()?; // Fast-forward check for #7 // Note: All header parse failures are dealt with in the `FromRequest` trait implementation // 2. If the value of the Origin header is not a case-sensitive match for any of the values // in list of origins do not set any additional headers and terminate this set of steps. validate_origin(&origin, &options.allowed_origins)?; - let response = match options.allowed_origins { - AllOrSome::All => { - if options.send_wildcard { - response.any() - } else { - response.origin(origin.as_str(), true) - } - } - AllOrSome::Some(_) => response.origin(origin.as_str(), false), - }; // 3. Let `method` be the value as result of parsing the Access-Control-Request-Method // header. @@ -941,7 +955,7 @@ fn preflight( // do not set any additional headers and terminate this set of steps. // The request is outside the scope of this specification. - let method = method.ok_or_else(|| Error::MissingRequestMethod)?; + let method = method.as_ref().ok_or_else(|| Error::MissingRequestMethod)?; // 4. Let header field-names be the values as result of parsing the // Access-Control-Request-Headers headers. @@ -953,26 +967,29 @@ fn preflight( // 5. If method is not a case-sensitive match for any of the values in list of methods // do not set any additional headers and terminate this set of steps. - validate_allowed_method(&method, &options.allowed_methods)?; - let response = response.methods(&options.allowed_methods); + validate_allowed_method(method, &options.allowed_methods)?; // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the // values in list of headers do not set any additional headers and terminate this set of // steps. - let response = if let Some(ref headers) = headers { + if let &Some(ref headers) = headers { validate_allowed_headers(headers, &options.allowed_headers)?; - let &AccessControlRequestHeaders(ref headers) = headers; - response.headers( - headers - .iter() - .map(|s| &**s.deref()) - .collect::>() - .as_slice(), - ) - } else { - response - }; + } + + Ok(()) +} + +/// Build a response for pre-flight checks +/// +/// This implementation references the +/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests). +fn preflight_response( + options: &Cors, + origin: Origin, + headers: Option, +) -> Response { + let response = Response::new(); // 7. If the resource supports credentials add a single Access-Control-Allow-Origin header, // with the value of the Origin header as value, and add a @@ -983,6 +1000,16 @@ fn preflight( // Note: The string "*" cannot be used for a resource that supports credentials. // Validation has been done in options.validate + let response = match options.allowed_origins { + AllOrSome::All => { + if options.send_wildcard { + response.any() + } else { + response.origin(origin.as_str(), true) + } + } + AllOrSome::Some(_) => response.origin(origin.as_str(), false), + }; let response = response.credentials(options.allow_credentials); // 8. Optionally add a single Access-Control-Max-Age header @@ -998,7 +1025,7 @@ fn preflight( // simply returning the method indicated by Access-Control-Request-Method // (if supported) can be enough. - // Done above + let response = response.methods(&options.allowed_methods); // 10. If each of the header field-names is a simple header and none is Content-Type, // this step may be skipped. @@ -1010,17 +1037,29 @@ fn preflight( // Since the list of headers can be unbounded, simply returning supported headers // from Access-Control-Allow-Headers can be enough. - // Done above -- we do not do anything special with simple headers + // We do not do anything special with simple headers + let response = if let Some(ref headers) = headers { + let &AccessControlRequestHeaders(ref headers) = headers; + response.headers( + headers + .iter() + .map(|s| &**s.deref()) + .collect::>() + .as_slice(), + ) + } else { + response + }; - Ok(response) + response } -/// Respond to an actual request based on the settings. -/// If the `Origin` is not provided, then this request was not made by a browser and there is no -/// CORS enforcement. -fn actual_request(options: &Cors, origin: Origin) -> Result { +/// Do checks for an actual request +/// +/// This implementation references the +/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests). +fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> { options.validate()?; - let response = Response::new(); // Note: All header parse failures are dealt with in the `FromRequest` trait implementation @@ -1029,6 +1068,27 @@ fn actual_request(options: &Cors, origin: Origin) -> Result { // Always matching is acceptable since the list of origins can be unbounded. validate_origin(&origin, &options.allowed_origins)?; + + Ok(()) +} + +/// Build the response for an actual request +/// +/// This implementation references the +/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests). +fn actual_request_response(options: &Cors, origin: Origin) -> Response { + let response = Response::new(); + + // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, + // with the value of the Origin header as value, and add a + // single Access-Control-Allow-Credentials header with the case-sensitive string "true" as + // value. + // Otherwise, add a single Access-Control-Allow-Origin header, + // with either the value of the Origin header or the string "*" as value. + // Note: The string "*" cannot be used for a resource that supports credentials. + + // Validation has been done in options.validate + let response = match options.allowed_origins { AllOrSome::All => { if options.send_wildcard { @@ -1040,15 +1100,6 @@ fn actual_request(options: &Cors, origin: Origin) -> Result { AllOrSome::Some(_) => response.origin(origin.as_str(), false), }; - // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, - // with the value of the Origin header as value, and add a - // single Access-Control-Allow-Credentials header with the case-sensitive string "true" as - // value. - // Otherwise, add a single Access-Control-Allow-Origin header, - // with either the value of the Origin header or the string "*" as value. - // Note: The string "*" cannot be used for a resource that supports credentials. - - // Validation has been done in options.validate let response = response.credentials(options.allow_credentials); // 4. If the list of exposed headers is not empty add one or more @@ -1066,7 +1117,8 @@ fn actual_request(options: &Cors, origin: Origin) -> Result { .collect::>() .as_slice(), ); - Ok(response) + + response } #[cfg(test)] diff --git a/tests/fairing.rs b/tests/fairing.rs index 8bde6ce..d1f3c3c 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -172,6 +172,7 @@ fn cors_options_bad_origin() { assert_eq!(response.status(), Status::Forbidden); } +/// Unlike the "ad-hoc" mode, this should return 404 because we don't have such a route #[test] fn cors_options_missing_origin() { let client = Client::new(rocket()).unwrap(); @@ -188,7 +189,7 @@ fn cors_options_missing_origin() { ); let response = req.dispatch(); - assert!(response.status().class().is_success()); + assert_eq!(response.status(), Status::NotFound); } #[test]