Refactor to separate out validation from response building step 1

This commit is contained in:
Yong Wen Chua 2017-07-17 19:19:51 +08:00
parent 05b969e735
commit 41f5ac11d8
3 changed files with 166 additions and 91 deletions

View File

@ -2,7 +2,7 @@
use rocket::{self, Request, Outcome}; use rocket::{self, Request, Outcome};
use rocket::http::{self, Status}; use rocket::http::{self, Status};
use {Cors, build_cors_response}; use {Cors, Error, validate, preflight_response, actual_request_response, origin, request_headers};
/// Route for Fairing error handling /// Route for Fairing error handling
pub(crate) fn fairing_error_route<'r>( pub(crate) fn fairing_error_route<'r>(
@ -28,6 +28,43 @@ fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Req
request.set_uri(format!("{}/{}", options.fairing_route_base, status)); request.set_uri(format!("{}/{}", options.fairing_route_base, status));
} }
fn on_response_wrapper(
options: &Cors,
request: &Request,
response: &mut rocket::Response,
) -> Result<(), Error> {
let origin = match origin(request)? {
None => {
// Not a CORS request
return Ok(());
}
Some(origin) => origin,
};
let cors_response = if request.method() == http::Method::Options {
let headers = request_headers(request)?;
preflight_response(options, origin, headers)
} else {
actual_request_response(options, origin)
};
cors_response.merge(response);
// If this was an OPTIONS request and no route can be found, we should turn this
// into a HTTP 204 with no content body.
// This allows the user to not have to specify an OPTIONS route for everything.
//
// TODO: Is there anyway we can make this smarter? Only modify status codes for
// requests where an actual route exist?
if request.method() == http::Method::Options && request.method() == http::Method::Options &&
request.route().is_none()
{
response.set_status(Status::NoContent);
let _ = response.take_body();
}
Ok(())
}
impl rocket::fairing::Fairing for Cors { impl rocket::fairing::Fairing for Cors {
fn info(&self) -> rocket::fairing::Info { fn info(&self) -> rocket::fairing::Info {
rocket::fairing::Info { rocket::fairing::Info {
@ -52,7 +89,7 @@ impl rocket::fairing::Fairing for Cors {
fn on_request(&self, request: &mut Request, _: &rocket::Data) { fn on_request(&self, request: &mut Request, _: &rocket::Data) {
// Build and merge CORS response // Build and merge CORS response
// Type annotation is for sanity check // Type annotation is for sanity check
let cors_response = build_cors_response(self, request); let cors_response = validate(self, request);
if let Err(ref err) = cors_response { if let Err(ref err) = cors_response {
error_!("CORS Error: {}", err); error_!("CORS Error: {}", err);
let status = err.status(); let status = err.status();
@ -61,25 +98,10 @@ impl rocket::fairing::Fairing for Cors {
} }
fn on_response(&self, request: &Request, response: &mut rocket::Response) { fn on_response(&self, request: &Request, response: &mut rocket::Response) {
// Rebuild the response if let Err(err) = on_response_wrapper(self, request, response) {
match build_cors_response(self, request) { error_!("Fairings on_response error: {}\nMost likely a bug", err);
Err(_) => { response.set_status(Status::InternalServerError);
// We have dealt with this already
}
Ok(cors_response) => {
cors_response.merge(response);
// If this was an OPTIONS request and no route can be found, we should turn this
// into a HTTP 204 with no content body.
// This allows the user to not have to specify an OPTIONS route for everything.
//
// TODO: Is there anyway we can make this smarter? Only modify status codes for
// requests where an actual route exist?
if request.method() == http::Method::Options && request.route().is_none() {
response.set_status(Status::NoContent);
let _ = response.take_body(); let _ = response.take_body();
} }
} }
} }
}
}

View File

@ -506,7 +506,7 @@ impl Cors {
/// `Guard` type. This is useful if you want an even more ad-hoc based approach to respond to /// `Guard` type. This is useful if you want an even more ad-hoc based approach to respond to
/// CORS by using a `Cors` that is not in Rocket's managed state. /// CORS by using a `Cors` that is not in Rocket's managed state.
pub fn guard<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result<Guard<'r>, Error> { pub fn guard<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result<Guard<'r>, Error> {
let response = build_cors_response(self, request)?; let response = validate_and_build(self, request)?;
Ok(Guard::new(response)) Ok(Guard::new(response))
} }
@ -693,11 +693,11 @@ impl Response {
} }
/// Validate and create a new CORS Response from a request and settings /// Validate and create a new CORS Response from a request and settings
pub fn build_cors_response<'a, 'r>( pub fn validate_and_build<'a, 'r>(
options: &'a Cors, options: &'a Cors,
request: &'a Request<'r>, request: &'a Request<'r>,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
build_cors_response(options, request) validate_and_build(options, request)
} }
} }
@ -747,7 +747,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> {
} }
}; };
match Response::build_cors_response(&options, request) { match Response::validate_and_build(&options, request) {
Ok(response) => Outcome::Success(Self::new(response)), Ok(response) => Outcome::Success(Self::new(response)),
Err(error) => Outcome::Failure((error.status(), error)), Err(error) => Outcome::Failure((error.status(), error)),
} }
@ -796,35 +796,60 @@ impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R
} }
} }
/// Validates a request for CORS and returns a CORS Response /// Result of CORS validation.
fn build_cors_response(options: &Cors, request: &Request) -> Result<Response, Error> { ///
// Existing CORS response? /// The variants hold enough information to build a response to the validation result
// if has_allow_origin(response) { enum ValidationResult {
// return Ok(()); /// Not a CORS request
// } None,
/// Successful preflight request
Preflight {
origin: Origin,
headers: Option<AccessControlRequestHeaders>,
},
/// Successful actual request
Request { origin: Origin },
}
/// Validates a request for CORS and returns a CORS Response
fn validate_and_build(options: &Cors, request: &Request) -> Result<Response, Error> {
let result = validate(options, request)?;
Ok(match result {
ValidationResult::None => Response::new(),
ValidationResult::Preflight { origin, headers } => {
preflight_response(options, origin, headers)
}
ValidationResult::Request { origin } => actual_request_response(options, origin),
})
}
/// Validate a CORS request
fn validate(options: &Cors, request: &Request) -> Result<ValidationResult, Error> {
// 1. If the Origin header is not present terminate this set of steps. // 1. If the Origin header is not present terminate this set of steps.
// The request is outside the scope of this specification. // The request is outside the scope of this specification.
let origin = origin(request)?; let origin = origin(request)?;
let origin = match origin { let origin = match origin {
None => { None => {
// Not a CORS request // Not a CORS request
return Ok(Response::new()); return Ok(ValidationResult::None);
} }
Some(origin) => origin, Some(origin) => origin,
}; };
// Check if the request verb is an OPTION or something else // Check if the request verb is an OPTION or something else
let cors_response = match request.method() { match request.method() {
http::Method::Options => { http::Method::Options => {
let method = request_method(request)?; let method = request_method(request)?;
let headers = request_headers(request)?; let headers = request_headers(request)?;
preflight(options, origin, method, headers) preflight_validate(options, &origin, &method, &headers)?;
Ok(ValidationResult::Preflight { origin, headers })
}
_ => {
actual_request_validate(options, &origin)?;
Ok(ValidationResult::Request { origin })
}
} }
_ => actual_request(options, origin),
}?;
Ok(cors_response)
} }
/// Consumes the responder and based on the provided list of allowed origins, /// Consumes the responder and based on the provided list of allowed origins,
@ -905,35 +930,24 @@ fn request_headers(request: &Request) -> Result<Option<AccessControlRequestHeade
} }
} }
/// Construct a preflight response based on the options. Will return an `Err` /// Do pre-flight validation checks
/// if any of the preflight checks fail.
/// ///
/// This implementation references the /// This implementation references the
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests). /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests).
fn preflight( fn preflight_validate(
options: &Cors, options: &Cors,
origin: Origin, origin: &Origin,
method: Option<AccessControlRequestMethod>, method: &Option<AccessControlRequestMethod>,
headers: Option<AccessControlRequestHeaders>, headers: &Option<AccessControlRequestHeaders>,
) -> Result<Response, Error> { ) -> Result<(), Error> {
options.validate()?;
let response = Response::new(); options.validate()?; // Fast-forward check for #7
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation // Note: All header parse failures are dealt with in the `FromRequest` trait implementation
// 2. If the value of the Origin header is not a case-sensitive match for any of the values // 2. If the value of the Origin header is not a case-sensitive match for any of the values
// in list of origins do not set any additional headers and terminate this set of steps. // in list of origins do not set any additional headers and terminate this set of steps.
validate_origin(&origin, &options.allowed_origins)?; validate_origin(&origin, &options.allowed_origins)?;
let response = match options.allowed_origins {
AllOrSome::All => {
if options.send_wildcard {
response.any()
} else {
response.origin(origin.as_str(), true)
}
}
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
};
// 3. Let `method` be the value as result of parsing the Access-Control-Request-Method // 3. Let `method` be the value as result of parsing the Access-Control-Request-Method
// header. // header.
@ -941,7 +955,7 @@ fn preflight(
// do not set any additional headers and terminate this set of steps. // do not set any additional headers and terminate this set of steps.
// The request is outside the scope of this specification. // The request is outside the scope of this specification.
let method = method.ok_or_else(|| Error::MissingRequestMethod)?; let method = method.as_ref().ok_or_else(|| Error::MissingRequestMethod)?;
// 4. Let header field-names be the values as result of parsing the // 4. Let header field-names be the values as result of parsing the
// Access-Control-Request-Headers headers. // Access-Control-Request-Headers headers.
@ -953,26 +967,29 @@ fn preflight(
// 5. If method is not a case-sensitive match for any of the values in list of methods // 5. If method is not a case-sensitive match for any of the values in list of methods
// do not set any additional headers and terminate this set of steps. // do not set any additional headers and terminate this set of steps.
validate_allowed_method(&method, &options.allowed_methods)?; validate_allowed_method(method, &options.allowed_methods)?;
let response = response.methods(&options.allowed_methods);
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
// values in list of headers do not set any additional headers and terminate this set of // values in list of headers do not set any additional headers and terminate this set of
// steps. // steps.
let response = if let Some(ref headers) = headers { if let &Some(ref headers) = headers {
validate_allowed_headers(headers, &options.allowed_headers)?; validate_allowed_headers(headers, &options.allowed_headers)?;
let &AccessControlRequestHeaders(ref headers) = headers; }
response.headers(
headers Ok(())
.iter() }
.map(|s| &**s.deref())
.collect::<Vec<&str>>() /// Build a response for pre-flight checks
.as_slice(), ///
) /// This implementation references the
} else { /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests).
response fn preflight_response(
}; options: &Cors,
origin: Origin,
headers: Option<AccessControlRequestHeaders>,
) -> Response {
let response = Response::new();
// 7. If the resource supports credentials add a single Access-Control-Allow-Origin header, // 7. If the resource supports credentials add a single Access-Control-Allow-Origin header,
// with the value of the Origin header as value, and add a // with the value of the Origin header as value, and add a
@ -983,6 +1000,16 @@ fn preflight(
// Note: The string "*" cannot be used for a resource that supports credentials. // Note: The string "*" cannot be used for a resource that supports credentials.
// Validation has been done in options.validate // Validation has been done in options.validate
let response = match options.allowed_origins {
AllOrSome::All => {
if options.send_wildcard {
response.any()
} else {
response.origin(origin.as_str(), true)
}
}
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
};
let response = response.credentials(options.allow_credentials); let response = response.credentials(options.allow_credentials);
// 8. Optionally add a single Access-Control-Max-Age header // 8. Optionally add a single Access-Control-Max-Age header
@ -998,7 +1025,7 @@ fn preflight(
// simply returning the method indicated by Access-Control-Request-Method // simply returning the method indicated by Access-Control-Request-Method
// (if supported) can be enough. // (if supported) can be enough.
// Done above let response = response.methods(&options.allowed_methods);
// 10. If each of the header field-names is a simple header and none is Content-Type, // 10. If each of the header field-names is a simple header and none is Content-Type,
// this step may be skipped. // this step may be skipped.
@ -1010,17 +1037,29 @@ fn preflight(
// Since the list of headers can be unbounded, simply returning supported headers // Since the list of headers can be unbounded, simply returning supported headers
// from Access-Control-Allow-Headers can be enough. // from Access-Control-Allow-Headers can be enough.
// Done above -- we do not do anything special with simple headers // We do not do anything special with simple headers
let response = if let Some(ref headers) = headers {
let &AccessControlRequestHeaders(ref headers) = headers;
response.headers(
headers
.iter()
.map(|s| &**s.deref())
.collect::<Vec<&str>>()
.as_slice(),
)
} else {
response
};
Ok(response) response
} }
/// Respond to an actual request based on the settings. /// Do checks for an actual request
/// If the `Origin` is not provided, then this request was not made by a browser and there is no ///
/// CORS enforcement. /// This implementation references the
fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> { /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests).
fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> {
options.validate()?; options.validate()?;
let response = Response::new();
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation // Note: All header parse failures are dealt with in the `FromRequest` trait implementation
@ -1029,6 +1068,27 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
// Always matching is acceptable since the list of origins can be unbounded. // Always matching is acceptable since the list of origins can be unbounded.
validate_origin(&origin, &options.allowed_origins)?; validate_origin(&origin, &options.allowed_origins)?;
Ok(())
}
/// Build the response for an actual request
///
/// This implementation references the
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests).
fn actual_request_response(options: &Cors, origin: Origin) -> Response {
let response = Response::new();
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
// with the value of the Origin header as value, and add a
// single Access-Control-Allow-Credentials header with the case-sensitive string "true" as
// value.
// Otherwise, add a single Access-Control-Allow-Origin header,
// with either the value of the Origin header or the string "*" as value.
// Note: The string "*" cannot be used for a resource that supports credentials.
// Validation has been done in options.validate
let response = match options.allowed_origins { let response = match options.allowed_origins {
AllOrSome::All => { AllOrSome::All => {
if options.send_wildcard { if options.send_wildcard {
@ -1040,15 +1100,6 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
AllOrSome::Some(_) => response.origin(origin.as_str(), false), AllOrSome::Some(_) => response.origin(origin.as_str(), false),
}; };
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
// with the value of the Origin header as value, and add a
// single Access-Control-Allow-Credentials header with the case-sensitive string "true" as
// value.
// Otherwise, add a single Access-Control-Allow-Origin header,
// with either the value of the Origin header or the string "*" as value.
// Note: The string "*" cannot be used for a resource that supports credentials.
// Validation has been done in options.validate
let response = response.credentials(options.allow_credentials); let response = response.credentials(options.allow_credentials);
// 4. If the list of exposed headers is not empty add one or more // 4. If the list of exposed headers is not empty add one or more
@ -1066,7 +1117,8 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
.collect::<Vec<&str>>() .collect::<Vec<&str>>()
.as_slice(), .as_slice(),
); );
Ok(response)
response
} }
#[cfg(test)] #[cfg(test)]

View File

@ -172,6 +172,7 @@ fn cors_options_bad_origin() {
assert_eq!(response.status(), Status::Forbidden); assert_eq!(response.status(), Status::Forbidden);
} }
/// Unlike the "ad-hoc" mode, this should return 404 because we don't have such a route
#[test] #[test]
fn cors_options_missing_origin() { fn cors_options_missing_origin() {
let client = Client::new(rocket()).unwrap(); let client = Client::new(rocket()).unwrap();
@ -188,7 +189,7 @@ fn cors_options_missing_origin() {
); );
let response = req.dispatch(); let response = req.dispatch();
assert!(response.status().class().is_success()); assert_eq!(response.status(), Status::NotFound);
} }
#[test] #[test]