Response no longer validates

This commit is contained in:
Yong Wen Chua 2017-07-16 21:31:13 +08:00
parent abda99d71f
commit 6746e835e7
2 changed files with 178 additions and 382 deletions

View File

@ -437,7 +437,7 @@ impl fairing::Fairing for Cors {
use rocket::response::Responder;
// Build and merge CORS response
match build_cors_response(self, request, response) {
match build_cors_response(self, request) {
Err(err) => {
// CORS error -- overwrite the original response
let error_response = match err.respond_to(request) {
@ -451,7 +451,9 @@ impl fairing::Fairing for Cors {
};
response.merge(error_response)
}
Ok(()) => {
Ok(cors_response) => {
cors_response.merge(response);
// If this was an OPTIONS request and no route can be found, we should turn this
// into a HTTP 204 with no content body.
// This allows the user to not have to specify an OPTIONS route for everything.
@ -508,10 +510,14 @@ impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> {
fn respond(self, request: &Request) -> response::Result<'r> {
let mut response = self.responder.respond_to(request)?; // handle status errors?
match build_cors_response(self.options, request, &mut response) {
Ok(()) => Ok(response),
match build_cors_response(self.options, request) {
Ok(cors_response) => {
cors_response.merge(&mut response);
Ok(response)
},
Err(e) => response::Responder::respond_to(e, request),
}
}
}
@ -569,48 +575,10 @@ impl Response {
self
}
/// Consumes the responder and based on the provided list of allowed origins,
/// check if the requested origin is allowed.
/// Useful for pre-flight and during requests
fn allowed_origin(
self,
origin: &Origin,
allowed_origins: &AllOrSome<HashSet<Url>>,
send_wildcard: bool,
) -> Result<Self, Error> {
let origin = origin.origin().unicode_serialization();
match *allowed_origins {
// Always matching is acceptable since the list of origins can be unbounded.
AllOrSome::All => {
if send_wildcard {
Ok(self.any())
} else {
Ok(self.origin(&origin, true))
}
}
AllOrSome::Some(ref allowed_origins) => {
let allowed_origins: HashSet<_> = allowed_origins
.iter()
.map(|o| o.origin().unicode_serialization())
.collect();
let _ = allowed_origins.get(&origin).ok_or_else(
|| Error::OriginNotAllowed,
)?;
Ok(self.origin(&origin, false))
}
}
}
/// Consumes the Response and validate whether credentials can be allowed
fn credentials(mut self, value: bool) -> Result<Self, Error> {
if value {
if let Some(AllOrSome::All) = self.allow_origin {
Err(Error::CredentialsWithWildcardOrigin)?;
}
}
/// Consumes the Response and set credentials
fn credentials(mut self, value: bool) -> Self {
self.allow_credentials = value;
Ok(self)
self
}
/// Consumes the CORS, set expose_headers to
@ -634,22 +602,6 @@ impl Response {
self
}
/// Consumes the CORS, check if requested method is allowed.
/// Useful for pre-flight checks
fn allowed_methods(
self,
method: &AccessControlRequestMethod,
allowed_methods: &HashSet<Method>,
) -> Result<Self, Error> {
let &AccessControlRequestMethod(ref request_method) = method;
if !allowed_methods.iter().any(|m| m == request_method) {
Err(Error::MethodNotAllowed)?
}
// TODO: Subset to route? Or just the method requested for?
Ok(self.methods(&allowed_methods))
}
/// Consumes the CORS, set allow_headers to
/// passed headers and returns changed CORS
fn headers(mut self, headers: &[&str]) -> Self {
@ -657,35 +609,6 @@ impl Response {
self
}
/// Consumes the CORS, check if requested headers are allowed.
/// Useful for pre-flight checks
fn allowed_headers(
self,
headers: &AccessControlRequestHeaders,
allowed_headers: &AllOrSome<HashSet<HeaderFieldName>>,
) -> Result<Self, Error> {
let &AccessControlRequestHeaders(ref headers) = headers;
match *allowed_headers {
AllOrSome::All => {}
AllOrSome::Some(ref allowed_headers) => {
if !headers.is_empty() && !headers.is_subset(allowed_headers) {
Err(Error::HeadersNotAllowed)?
}
}
};
Ok(
self.headers(
headers
.iter()
.map(|s| &**s.deref())
.collect::<Vec<&str>>()
.as_slice(),
),
)
}
/// Builds a `rocket::Response` from this struct based off some base `rocket::Response`
///
/// This will overwrite any existing CORS headers
@ -788,12 +711,11 @@ pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
fn build_cors_response(
options: &Cors,
request: &Request,
response: &mut rocket::Response,
) -> Result<(), Error> {
) -> Result<Response, Error> {
// Existing CORS response?
if has_allow_origin(response) {
return Ok(());
}
// if has_allow_origin(response) {
// return Ok(());
// }
// 1. If the Origin header is not present terminate this set of steps.
// The request is outside the scope of this specification.
@ -801,7 +723,7 @@ fn build_cors_response(
let origin = match origin {
None => {
// Not a CORS request
return Ok(());
return Ok(Response::new());
}
Some(origin) => origin,
};
@ -816,10 +738,60 @@ fn build_cors_response(
_ => actual_request(options, origin),
}?;
cors_response.merge(response);
Ok(cors_response)
}
/// Consumes the responder and based on the provided list of allowed origins,
/// check if the requested origin is allowed.
/// Useful for pre-flight and during requests
fn validate_origin(
origin: &Origin,
allowed_origins: &AllOrSome<HashSet<Url>>,
) -> Result<(), Error> {
match *allowed_origins {
// Always matching is acceptable since the list of origins can be unbounded.
AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_origins) => {
allowed_origins
.get(origin)
.and_then(|_| Some(()))
.ok_or_else(|| Error::OriginNotAllowed)
}
}
}
/// Validate allowed methods
fn validate_allowed_method(
method: &AccessControlRequestMethod,
allowed_methods: &HashSet<Method>,
) -> Result<(), Error> {
let &AccessControlRequestMethod(ref request_method) = method;
if !allowed_methods.iter().any(|m| m == request_method) {
Err(Error::MethodNotAllowed)?
}
// TODO: Subset to route? Or just the method requested for?
Ok(())
}
/// Validate allowed headers
fn validate_allowed_headers(
headers: &AccessControlRequestHeaders,
allowed_headers: &AllOrSome<HashSet<HeaderFieldName>>,
) -> Result<(), Error> {
let &AccessControlRequestHeaders(ref headers) = headers;
match *allowed_headers {
AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_headers) => {
if !headers.is_empty() && !headers.is_subset(allowed_headers) {
Err(Error::HeadersNotAllowed)?
}
Ok(())
}
}
}
/// Gets the `Origin` request header from the request
fn origin(request: &Request) -> Result<Option<Origin>, Error> {
match Origin::from_request(request) {
@ -863,18 +835,24 @@ fn preflight(
method: Option<AccessControlRequestMethod>,
headers: Option<AccessControlRequestHeaders>,
) -> Result<Response, Error> {
options.validate()?;
let response = Response::new();
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
// in list of origins do not set any additional headers and terminate this set of steps.
let response = response.allowed_origin(
&origin,
&options.allowed_origins,
options.send_wildcard,
)?;
validate_origin(&origin, &options.allowed_origins)?;
let response = match options.allowed_origins {
AllOrSome::All => {
if options.send_wildcard {
response.any()
} else {
response.origin(origin.as_str(), true)
}
}
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
};
// 3. Let `method` be the value as result of parsing the Access-Control-Request-Method
// header.
@ -894,13 +872,23 @@ fn preflight(
// 5. If method is not a case-sensitive match for any of the values in list of methods
// do not set any additional headers and terminate this set of steps.
let response = response.allowed_methods(&method, &options.allowed_methods)?;
validate_allowed_method(&method, &options.allowed_methods)?;
let response = response.methods(&options.allowed_methods);
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
// values in list of headers do not set any additional headers and terminate this set of
// steps.
let response = if let Some(headers) = headers {
response.allowed_headers(&headers, &options.allowed_headers)?
let response = if let Some(ref headers) = headers {
validate_allowed_headers(headers, &options.allowed_headers)?;
let &AccessControlRequestHeaders(ref headers) = headers;
response.headers(
headers
.iter()
.map(|s| &**s.deref())
.collect::<Vec<&str>>()
.as_slice(),
)
} else {
response
};
@ -913,7 +901,8 @@ fn preflight(
// with either the value of the Origin header or the string "*" as value.
// Note: The string "*" cannot be used for a resource that supports credentials.
let response = response.credentials(options.allow_credentials)?;
// Validation has been done in options.validate
let response = response.credentials(options.allow_credentials);
// 8. Optionally add a single Access-Control-Max-Age header
// with as value the amount of seconds the user agent is allowed to cache the result of the
@ -949,6 +938,7 @@ fn preflight(
/// If the `Origin` is not provided, then this request was not made by a browser and there is no
/// CORS enforcement.
fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
options.validate()?;
let response = Response::new();
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
@ -957,11 +947,17 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
// in list of origins, do not set any additional headers and terminate this set of steps.
// Always matching is acceptable since the list of origins can be unbounded.
let response = response.allowed_origin(
&origin,
&options.allowed_origins,
options.send_wildcard,
)?;
validate_origin(&origin, &options.allowed_origins)?;
let response = match options.allowed_origins {
AllOrSome::All => {
if options.send_wildcard {
response.any()
} else {
response.origin(origin.as_str(), true)
}
}
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
};
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
// with the value of the Origin header as value, and add a
@ -971,7 +967,8 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
// with either the value of the Origin header or the string "*" as value.
// Note: The string "*" cannot be used for a resource that supports credentials.
let response = response.credentials(options.allow_credentials)?;
// Validation has been done in options.validate
let response = response.credentials(options.allow_credentials);
// 4. If the list of exposed headers is not empty add one or more
// Access-Control-Expose-Headers headers, with as values the header field names given in
@ -1033,65 +1030,15 @@ mod tests {
cors.validate().unwrap();
}
// The following tests check `Response`'s validation
// The following tests check validation
#[test]
fn response_allows_all_origin_with_wildcard() {
fn validate_origin_allows_all_origins() {
let url = "https://www.example.com";
let origin = Origin::from_str(url).unwrap();
let allowed_origins = AllOrSome::All;
let send_wildcard = true;
let response = Response::new();
let response = not_err!(response.allowed_origin(
&origin,
&allowed_origins,
send_wildcard,
));
assert_matches!(response.allow_origin, Some(AllOrSome::All));
assert_eq!(response.vary_origin, false);
// Build response and check built response header
let expected_header = vec!["*"];
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
}
#[test]
fn response_allows_all_origin_with_echoed_domain() {
let url = "https://www.example.com";
let origin = Origin::from_str(url).unwrap();
let allowed_origins = AllOrSome::All;
let send_wildcard = false;
let response = Response::new();
let response = not_err!(response.allowed_origin(
&origin,
&allowed_origins,
send_wildcard,
));
let actual_origin = assert_matches!(
response.allow_origin,
Some(AllOrSome::Some(ref origin)),
origin
);
assert_eq!(url, actual_origin);
assert_eq!(response.vary_origin, true);
// Build response and check built response header
let expected_header = vec![url];
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test]
@ -1101,32 +1048,8 @@ mod tests {
let (allowed_origins, failed_origins) =
AllOrSome::new_from_str_list(&["https://www.example.com"]);
assert!(failed_origins.is_empty());
let send_wildcard = false;
let response = Response::new();
let response = not_err!(response.allowed_origin(
&origin,
&allowed_origins,
send_wildcard,
));
let actual_origin = assert_matches!(
response.allow_origin,
Some(AllOrSome::Some(ref origin)),
origin
);
assert_eq!(url, actual_origin);
assert_eq!(response.vary_origin, false);
// Build response and check built response header
let expected_header = vec![url];
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test]
@ -1137,41 +1060,8 @@ mod tests {
let (allowed_origins, failed_origins) =
AllOrSome::new_from_str_list(&["https://www.example.com"]);
assert!(failed_origins.is_empty());
let send_wildcard = false;
let response = Response::new();
let _ = response
.allowed_origin(&origin, &allowed_origins, send_wildcard)
.unwrap();
}
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn response_credentials_does_not_allow_wildcard_with_all_origins() {
let response = Response::new();
let response = response.any();
let _ = response.credentials(true).unwrap();
}
#[test]
fn response_credentials_allows_specific_origins() {
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response.credentials(true).expect(
"to allow specific origins",
);
assert_eq!(response.allow_credentials, true);
// Build response and check built response header
let expected_header = vec!["true"];
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Credentials")
.collect();
assert_eq!(expected_header, actual_header);
validate_origin(&origin, &allowed_origins).unwrap();
}
#[test]
@ -1220,114 +1110,63 @@ mod tests {
// Build response and check built response header
let response = response.build(response::Response::new());
assert!(response
assert!(
response
.headers()
.get("Access-Control-Max-Age")
.next().is_none())
}
/// When all headers are allowed, tests that the requested headers are echoed back
#[test]
fn response_allowed_headers_echoes_back_requested_headers() {
let allowed_headers = AllOrSome::All;
let requested_headers = vec!["Bar", "Foo"];
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response
.allowed_headers(
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
&allowed_headers,
.next()
.is_none()
)
.expect("to not fail");
// Build response and check built response header
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Headers")
.collect();
assert_eq!(1, actual_header.len());
let mut actual_headers: Vec<String> = actual_header[0]
.split(',')
.map(|header| header.trim().to_string())
.collect();
actual_headers.sort();
assert_eq!(requested_headers, actual_headers);
}
#[test]
fn response_allowed_methods_sets_headers_properly() {
let allowed_methods = vec![
Method::Get,
Method::Head,
Method::Post,
].into_iter()
fn allowed_methods_validated_correctly() {
let allowed_methods = vec![Method::Get, Method::Head, Method::Post]
.into_iter()
.collect();
let method = "GET";
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response
.allowed_methods(
not_err!(validate_allowed_method(
&FromStr::from_str(method).expect("not to fail"),
&allowed_methods,
)
.expect("not to fail");
// Build response and check built response header
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Methods")
.collect();
assert_eq!(1, actual_header.len());
let mut actual_headers: Vec<String> = actual_header[0]
.split(',')
.map(|header| header.trim().to_string())
.collect();
actual_headers.sort();
let mut expected_headers: Vec<_> = allowed_methods.iter().map(|m| m.as_str()).collect();
expected_headers.sort();
assert_eq!(expected_headers, actual_headers);
));
}
#[test]
#[should_panic(expected = "MethodNotAllowed")]
fn response_allowed_method_errors_on_disallowed_method() {
let allowed_methods = vec![
Method::Get,
Method::Head,
Method::Post,
].into_iter()
fn allowed_methods_errors_on_disallowed_method() {
let allowed_methods = vec![Method::Get, Method::Head, Method::Post]
.into_iter()
.collect();
let method = "DELETE";
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let _ = response
.allowed_methods(
validate_allowed_method(
&FromStr::from_str(method).expect("not to fail"),
&allowed_methods,
)
.unwrap();
).unwrap()
}
#[test]
fn all_allowed_headers_are_validated_correctly() {
let allowed_headers = AllOrSome::All;
let requested_headers = vec!["Bar", "Foo"];
not_err!(validate_allowed_headers(
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
&allowed_headers,
));
}
/// `Response::allowed_headers` should check that headers are allowed, and only
/// echoes back the list that is actually requested for and not the whole list
#[test]
fn response_allowed_headers_validates_and_echoes_requested_headers() {
fn allowed_headers_are_validated_correctly() {
let allowed_headers = vec!["Bar", "Baz", "Foo"];
let requested_headers = vec!["Bar", "Foo"];
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response
.allowed_headers(
not_err!(validate_allowed_headers(
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
&AllOrSome::Some(
allowed_headers
@ -1335,35 +1174,16 @@ mod tests {
.map(|s| FromStr::from_str(*s).unwrap())
.collect(),
),
)
.expect("to not fail");
// Build response and check built response header
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Headers")
.collect();
assert_eq!(1, actual_header.len());
let mut actual_headers: Vec<String> = actual_header[0]
.split(',')
.map(|header| header.trim().to_string())
.collect();
actual_headers.sort();
assert_eq!(requested_headers, actual_headers);
));
}
#[test]
#[should_panic(expected = "HeadersNotAllowed")]
fn response_allowed_headers_errors_on_non_subset() {
fn allowed_headers_errors_on_non_subset() {
let allowed_headers = vec!["Bar", "Baz", "Foo"];
let requested_headers = vec!["Bar", "Foo", "Unknown"];
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let _ = response
.allowed_headers(
validate_allowed_headers(
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
&AllOrSome::Some(
allowed_headers
@ -1371,8 +1191,7 @@ mod tests {
.map(|s| FromStr::from_str(*s).unwrap())
.collect(),
),
)
.unwrap();
).unwrap();
}
@ -1413,42 +1232,19 @@ mod tests {
assert_eq!(expected_header, actual_header);
// Check that `Access-Control-Max-Age` is removed
assert!(response.headers().get("Access-Control-Max-Age").next().is_none());
assert!(
response
.headers()
.get("Access-Control-Max-Age")
.next()
.is_none()
);
}
// The following tests check that preflight checks are done properly
// TODO: Preflight tests
// TODO: Actual requests tests
// fn make_cors_options() -> Cors {
// let (allowed_origins, failed_origins) =
// AllOrSome::new_from_str_list(&["https://www.acme.com"]);
// assert!(failed_origins.is_empty());
// Cors {
// allowed_origins: allowed_origins,
// allowed_methods: [Method::Get].iter().cloned().collect(),
// allowed_headers: AllOrSome::Some(
// ["Authorization"]
// .into_iter()
// .map(|s| s.to_string().into())
// .collect(),
// ),
// allow_credentials: true,
// ..Default::default()
// }
// }
// /// Tests that non CORS preflight are let through without modification
// #[test]
// fn preflight_missing_origins_are_let_through() {
// let options = make_cors_options();
// let client = make_client();
// let request = client.get("/");
// let response = options.preflight((), None, None, None).expect("not to fail");
// let headers: Vec<_> = response.headers().iter().collect();
// assert_eq!(headers.len(), 0);
// }
// Origin all (wildcard + echoed with Vary). Origin Echo
}

View File

@ -9,7 +9,7 @@ use std::ops::Deref;
use std::str::FromStr;
use rocket::local::Client;
use rocket::http::{Header, Status};
use rocket::http::Header;
use rocket_cors::headers::*;
#[get("/request_headers")]