Refactor to separate out validation from response building step 1
This commit is contained in:
parent
05b969e735
commit
41f5ac11d8
|
@ -2,7 +2,7 @@
|
|||
use rocket::{self, Request, Outcome};
|
||||
use rocket::http::{self, Status};
|
||||
|
||||
use {Cors, build_cors_response};
|
||||
use {Cors, Error, validate, preflight_response, actual_request_response, origin, request_headers};
|
||||
|
||||
/// Route for Fairing error handling
|
||||
pub(crate) fn fairing_error_route<'r>(
|
||||
|
@ -28,6 +28,43 @@ fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Req
|
|||
request.set_uri(format!("{}/{}", options.fairing_route_base, status));
|
||||
}
|
||||
|
||||
fn on_response_wrapper(
|
||||
options: &Cors,
|
||||
request: &Request,
|
||||
response: &mut rocket::Response,
|
||||
) -> Result<(), Error> {
|
||||
let origin = match origin(request)? {
|
||||
None => {
|
||||
// Not a CORS request
|
||||
return Ok(());
|
||||
}
|
||||
Some(origin) => origin,
|
||||
};
|
||||
|
||||
let cors_response = if request.method() == http::Method::Options {
|
||||
let headers = request_headers(request)?;
|
||||
preflight_response(options, origin, headers)
|
||||
} else {
|
||||
actual_request_response(options, origin)
|
||||
};
|
||||
|
||||
cors_response.merge(response);
|
||||
|
||||
// If this was an OPTIONS request and no route can be found, we should turn this
|
||||
// into a HTTP 204 with no content body.
|
||||
// This allows the user to not have to specify an OPTIONS route for everything.
|
||||
//
|
||||
// TODO: Is there anyway we can make this smarter? Only modify status codes for
|
||||
// requests where an actual route exist?
|
||||
if request.method() == http::Method::Options && request.method() == http::Method::Options &&
|
||||
request.route().is_none()
|
||||
{
|
||||
response.set_status(Status::NoContent);
|
||||
let _ = response.take_body();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl rocket::fairing::Fairing for Cors {
|
||||
fn info(&self) -> rocket::fairing::Info {
|
||||
rocket::fairing::Info {
|
||||
|
@ -52,7 +89,7 @@ impl rocket::fairing::Fairing for Cors {
|
|||
fn on_request(&self, request: &mut Request, _: &rocket::Data) {
|
||||
// Build and merge CORS response
|
||||
// Type annotation is for sanity check
|
||||
let cors_response = build_cors_response(self, request);
|
||||
let cors_response = validate(self, request);
|
||||
if let Err(ref err) = cors_response {
|
||||
error_!("CORS Error: {}", err);
|
||||
let status = err.status();
|
||||
|
@ -61,25 +98,10 @@ impl rocket::fairing::Fairing for Cors {
|
|||
}
|
||||
|
||||
fn on_response(&self, request: &Request, response: &mut rocket::Response) {
|
||||
// Rebuild the response
|
||||
match build_cors_response(self, request) {
|
||||
Err(_) => {
|
||||
// We have dealt with this already
|
||||
}
|
||||
Ok(cors_response) => {
|
||||
cors_response.merge(response);
|
||||
|
||||
// If this was an OPTIONS request and no route can be found, we should turn this
|
||||
// into a HTTP 204 with no content body.
|
||||
// This allows the user to not have to specify an OPTIONS route for everything.
|
||||
//
|
||||
// TODO: Is there anyway we can make this smarter? Only modify status codes for
|
||||
// requests where an actual route exist?
|
||||
if request.method() == http::Method::Options && request.route().is_none() {
|
||||
response.set_status(Status::NoContent);
|
||||
let _ = response.take_body();
|
||||
}
|
||||
}
|
||||
if let Err(err) = on_response_wrapper(self, request, response) {
|
||||
error_!("Fairings on_response error: {}\nMost likely a bug", err);
|
||||
response.set_status(Status::InternalServerError);
|
||||
let _ = response.take_body();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
190
src/lib.rs
190
src/lib.rs
|
@ -506,7 +506,7 @@ impl Cors {
|
|||
/// `Guard` type. This is useful if you want an even more ad-hoc based approach to respond to
|
||||
/// CORS by using a `Cors` that is not in Rocket's managed state.
|
||||
pub fn guard<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result<Guard<'r>, Error> {
|
||||
let response = build_cors_response(self, request)?;
|
||||
let response = validate_and_build(self, request)?;
|
||||
Ok(Guard::new(response))
|
||||
}
|
||||
|
||||
|
@ -693,11 +693,11 @@ impl Response {
|
|||
}
|
||||
|
||||
/// Validate and create a new CORS Response from a request and settings
|
||||
pub fn build_cors_response<'a, 'r>(
|
||||
pub fn validate_and_build<'a, 'r>(
|
||||
options: &'a Cors,
|
||||
request: &'a Request<'r>,
|
||||
) -> Result<Self, Error> {
|
||||
build_cors_response(options, request)
|
||||
validate_and_build(options, request)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -747,7 +747,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> {
|
|||
}
|
||||
};
|
||||
|
||||
match Response::build_cors_response(&options, request) {
|
||||
match Response::validate_and_build(&options, request) {
|
||||
Ok(response) => Outcome::Success(Self::new(response)),
|
||||
Err(error) => Outcome::Failure((error.status(), error)),
|
||||
}
|
||||
|
@ -796,35 +796,60 @@ impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R
|
|||
}
|
||||
}
|
||||
|
||||
/// Validates a request for CORS and returns a CORS Response
|
||||
fn build_cors_response(options: &Cors, request: &Request) -> Result<Response, Error> {
|
||||
// Existing CORS response?
|
||||
// if has_allow_origin(response) {
|
||||
// return Ok(());
|
||||
// }
|
||||
/// Result of CORS validation.
|
||||
///
|
||||
/// The variants hold enough information to build a response to the validation result
|
||||
enum ValidationResult {
|
||||
/// Not a CORS request
|
||||
None,
|
||||
/// Successful preflight request
|
||||
Preflight {
|
||||
origin: Origin,
|
||||
headers: Option<AccessControlRequestHeaders>,
|
||||
},
|
||||
/// Successful actual request
|
||||
Request { origin: Origin },
|
||||
}
|
||||
|
||||
/// Validates a request for CORS and returns a CORS Response
|
||||
fn validate_and_build(options: &Cors, request: &Request) -> Result<Response, Error> {
|
||||
let result = validate(options, request)?;
|
||||
|
||||
Ok(match result {
|
||||
ValidationResult::None => Response::new(),
|
||||
ValidationResult::Preflight { origin, headers } => {
|
||||
preflight_response(options, origin, headers)
|
||||
}
|
||||
ValidationResult::Request { origin } => actual_request_response(options, origin),
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate a CORS request
|
||||
fn validate(options: &Cors, request: &Request) -> Result<ValidationResult, Error> {
|
||||
// 1. If the Origin header is not present terminate this set of steps.
|
||||
// The request is outside the scope of this specification.
|
||||
let origin = origin(request)?;
|
||||
let origin = match origin {
|
||||
None => {
|
||||
// Not a CORS request
|
||||
return Ok(Response::new());
|
||||
return Ok(ValidationResult::None);
|
||||
}
|
||||
Some(origin) => origin,
|
||||
};
|
||||
|
||||
// Check if the request verb is an OPTION or something else
|
||||
let cors_response = match request.method() {
|
||||
match request.method() {
|
||||
http::Method::Options => {
|
||||
let method = request_method(request)?;
|
||||
let headers = request_headers(request)?;
|
||||
preflight(options, origin, method, headers)
|
||||
preflight_validate(options, &origin, &method, &headers)?;
|
||||
Ok(ValidationResult::Preflight { origin, headers })
|
||||
}
|
||||
_ => actual_request(options, origin),
|
||||
}?;
|
||||
|
||||
Ok(cors_response)
|
||||
_ => {
|
||||
actual_request_validate(options, &origin)?;
|
||||
Ok(ValidationResult::Request { origin })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes the responder and based on the provided list of allowed origins,
|
||||
|
@ -905,35 +930,24 @@ fn request_headers(request: &Request) -> Result<Option<AccessControlRequestHeade
|
|||
}
|
||||
}
|
||||
|
||||
/// Construct a preflight response based on the options. Will return an `Err`
|
||||
/// if any of the preflight checks fail.
|
||||
/// Do pre-flight validation checks
|
||||
///
|
||||
/// This implementation references the
|
||||
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests).
|
||||
fn preflight(
|
||||
fn preflight_validate(
|
||||
options: &Cors,
|
||||
origin: Origin,
|
||||
method: Option<AccessControlRequestMethod>,
|
||||
headers: Option<AccessControlRequestHeaders>,
|
||||
) -> Result<Response, Error> {
|
||||
options.validate()?;
|
||||
let response = Response::new();
|
||||
origin: &Origin,
|
||||
method: &Option<AccessControlRequestMethod>,
|
||||
headers: &Option<AccessControlRequestHeaders>,
|
||||
) -> Result<(), Error> {
|
||||
|
||||
options.validate()?; // Fast-forward check for #7
|
||||
|
||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||
|
||||
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
||||
// in list of origins do not set any additional headers and terminate this set of steps.
|
||||
validate_origin(&origin, &options.allowed_origins)?;
|
||||
let response = match options.allowed_origins {
|
||||
AllOrSome::All => {
|
||||
if options.send_wildcard {
|
||||
response.any()
|
||||
} else {
|
||||
response.origin(origin.as_str(), true)
|
||||
}
|
||||
}
|
||||
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
|
||||
};
|
||||
|
||||
// 3. Let `method` be the value as result of parsing the Access-Control-Request-Method
|
||||
// header.
|
||||
|
@ -941,7 +955,7 @@ fn preflight(
|
|||
// do not set any additional headers and terminate this set of steps.
|
||||
// The request is outside the scope of this specification.
|
||||
|
||||
let method = method.ok_or_else(|| Error::MissingRequestMethod)?;
|
||||
let method = method.as_ref().ok_or_else(|| Error::MissingRequestMethod)?;
|
||||
|
||||
// 4. Let header field-names be the values as result of parsing the
|
||||
// Access-Control-Request-Headers headers.
|
||||
|
@ -953,26 +967,29 @@ fn preflight(
|
|||
// 5. If method is not a case-sensitive match for any of the values in list of methods
|
||||
// do not set any additional headers and terminate this set of steps.
|
||||
|
||||
validate_allowed_method(&method, &options.allowed_methods)?;
|
||||
let response = response.methods(&options.allowed_methods);
|
||||
validate_allowed_method(method, &options.allowed_methods)?;
|
||||
|
||||
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
|
||||
// values in list of headers do not set any additional headers and terminate this set of
|
||||
// steps.
|
||||
|
||||
let response = if let Some(ref headers) = headers {
|
||||
if let &Some(ref headers) = headers {
|
||||
validate_allowed_headers(headers, &options.allowed_headers)?;
|
||||
let &AccessControlRequestHeaders(ref headers) = headers;
|
||||
response.headers(
|
||||
headers
|
||||
.iter()
|
||||
.map(|s| &**s.deref())
|
||||
.collect::<Vec<&str>>()
|
||||
.as_slice(),
|
||||
)
|
||||
} else {
|
||||
response
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build a response for pre-flight checks
|
||||
///
|
||||
/// This implementation references the
|
||||
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests).
|
||||
fn preflight_response(
|
||||
options: &Cors,
|
||||
origin: Origin,
|
||||
headers: Option<AccessControlRequestHeaders>,
|
||||
) -> Response {
|
||||
let response = Response::new();
|
||||
|
||||
// 7. If the resource supports credentials add a single Access-Control-Allow-Origin header,
|
||||
// with the value of the Origin header as value, and add a
|
||||
|
@ -983,6 +1000,16 @@ fn preflight(
|
|||
// Note: The string "*" cannot be used for a resource that supports credentials.
|
||||
|
||||
// Validation has been done in options.validate
|
||||
let response = match options.allowed_origins {
|
||||
AllOrSome::All => {
|
||||
if options.send_wildcard {
|
||||
response.any()
|
||||
} else {
|
||||
response.origin(origin.as_str(), true)
|
||||
}
|
||||
}
|
||||
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
|
||||
};
|
||||
let response = response.credentials(options.allow_credentials);
|
||||
|
||||
// 8. Optionally add a single Access-Control-Max-Age header
|
||||
|
@ -998,7 +1025,7 @@ fn preflight(
|
|||
// simply returning the method indicated by Access-Control-Request-Method
|
||||
// (if supported) can be enough.
|
||||
|
||||
// Done above
|
||||
let response = response.methods(&options.allowed_methods);
|
||||
|
||||
// 10. If each of the header field-names is a simple header and none is Content-Type,
|
||||
// this step may be skipped.
|
||||
|
@ -1010,17 +1037,29 @@ fn preflight(
|
|||
// Since the list of headers can be unbounded, simply returning supported headers
|
||||
// from Access-Control-Allow-Headers can be enough.
|
||||
|
||||
// Done above -- we do not do anything special with simple headers
|
||||
// We do not do anything special with simple headers
|
||||
let response = if let Some(ref headers) = headers {
|
||||
let &AccessControlRequestHeaders(ref headers) = headers;
|
||||
response.headers(
|
||||
headers
|
||||
.iter()
|
||||
.map(|s| &**s.deref())
|
||||
.collect::<Vec<&str>>()
|
||||
.as_slice(),
|
||||
)
|
||||
} else {
|
||||
response
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
response
|
||||
}
|
||||
|
||||
/// Respond to an actual request based on the settings.
|
||||
/// If the `Origin` is not provided, then this request was not made by a browser and there is no
|
||||
/// CORS enforcement.
|
||||
fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
||||
/// Do checks for an actual request
|
||||
///
|
||||
/// This implementation references the
|
||||
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests).
|
||||
fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> {
|
||||
options.validate()?;
|
||||
let response = Response::new();
|
||||
|
||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||
|
||||
|
@ -1029,6 +1068,27 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
|||
// Always matching is acceptable since the list of origins can be unbounded.
|
||||
|
||||
validate_origin(&origin, &options.allowed_origins)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build the response for an actual request
|
||||
///
|
||||
/// This implementation references the
|
||||
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests).
|
||||
fn actual_request_response(options: &Cors, origin: Origin) -> Response {
|
||||
let response = Response::new();
|
||||
|
||||
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
|
||||
// with the value of the Origin header as value, and add a
|
||||
// single Access-Control-Allow-Credentials header with the case-sensitive string "true" as
|
||||
// value.
|
||||
// Otherwise, add a single Access-Control-Allow-Origin header,
|
||||
// with either the value of the Origin header or the string "*" as value.
|
||||
// Note: The string "*" cannot be used for a resource that supports credentials.
|
||||
|
||||
// Validation has been done in options.validate
|
||||
|
||||
let response = match options.allowed_origins {
|
||||
AllOrSome::All => {
|
||||
if options.send_wildcard {
|
||||
|
@ -1040,15 +1100,6 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
|||
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
|
||||
};
|
||||
|
||||
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
|
||||
// with the value of the Origin header as value, and add a
|
||||
// single Access-Control-Allow-Credentials header with the case-sensitive string "true" as
|
||||
// value.
|
||||
// Otherwise, add a single Access-Control-Allow-Origin header,
|
||||
// with either the value of the Origin header or the string "*" as value.
|
||||
// Note: The string "*" cannot be used for a resource that supports credentials.
|
||||
|
||||
// Validation has been done in options.validate
|
||||
let response = response.credentials(options.allow_credentials);
|
||||
|
||||
// 4. If the list of exposed headers is not empty add one or more
|
||||
|
@ -1066,7 +1117,8 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
|||
.collect::<Vec<&str>>()
|
||||
.as_slice(),
|
||||
);
|
||||
Ok(response)
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -172,6 +172,7 @@ fn cors_options_bad_origin() {
|
|||
assert_eq!(response.status(), Status::Forbidden);
|
||||
}
|
||||
|
||||
/// Unlike the "ad-hoc" mode, this should return 404 because we don't have such a route
|
||||
#[test]
|
||||
fn cors_options_missing_origin() {
|
||||
let client = Client::new(rocket()).unwrap();
|
||||
|
@ -188,7 +189,7 @@ fn cors_options_missing_origin() {
|
|||
);
|
||||
|
||||
let response = req.dispatch();
|
||||
assert!(response.status().class().is_success());
|
||||
assert_eq!(response.status(), Status::NotFound);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
Loading…
Reference in New Issue