Response no longer validates
This commit is contained in:
parent
abda99d71f
commit
6746e835e7
558
src/lib.rs
558
src/lib.rs
|
@ -437,7 +437,7 @@ impl fairing::Fairing for Cors {
|
||||||
use rocket::response::Responder;
|
use rocket::response::Responder;
|
||||||
|
|
||||||
// Build and merge CORS response
|
// Build and merge CORS response
|
||||||
match build_cors_response(self, request, response) {
|
match build_cors_response(self, request) {
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
// CORS error -- overwrite the original response
|
// CORS error -- overwrite the original response
|
||||||
let error_response = match err.respond_to(request) {
|
let error_response = match err.respond_to(request) {
|
||||||
|
@ -451,7 +451,9 @@ impl fairing::Fairing for Cors {
|
||||||
};
|
};
|
||||||
response.merge(error_response)
|
response.merge(error_response)
|
||||||
}
|
}
|
||||||
Ok(()) => {
|
Ok(cors_response) => {
|
||||||
|
cors_response.merge(response);
|
||||||
|
|
||||||
// If this was an OPTIONS request and no route can be found, we should turn this
|
// If this was an OPTIONS request and no route can be found, we should turn this
|
||||||
// into a HTTP 204 with no content body.
|
// into a HTTP 204 with no content body.
|
||||||
// This allows the user to not have to specify an OPTIONS route for everything.
|
// This allows the user to not have to specify an OPTIONS route for everything.
|
||||||
|
@ -508,10 +510,14 @@ impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> {
|
||||||
fn respond(self, request: &Request) -> response::Result<'r> {
|
fn respond(self, request: &Request) -> response::Result<'r> {
|
||||||
let mut response = self.responder.respond_to(request)?; // handle status errors?
|
let mut response = self.responder.respond_to(request)?; // handle status errors?
|
||||||
|
|
||||||
match build_cors_response(self.options, request, &mut response) {
|
match build_cors_response(self.options, request) {
|
||||||
Ok(()) => Ok(response),
|
Ok(cors_response) => {
|
||||||
|
cors_response.merge(&mut response);
|
||||||
|
Ok(response)
|
||||||
|
},
|
||||||
Err(e) => response::Responder::respond_to(e, request),
|
Err(e) => response::Responder::respond_to(e, request),
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -569,48 +575,10 @@ impl Response {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consumes the responder and based on the provided list of allowed origins,
|
/// Consumes the Response and set credentials
|
||||||
/// check if the requested origin is allowed.
|
fn credentials(mut self, value: bool) -> Self {
|
||||||
/// Useful for pre-flight and during requests
|
|
||||||
fn allowed_origin(
|
|
||||||
self,
|
|
||||||
origin: &Origin,
|
|
||||||
allowed_origins: &AllOrSome<HashSet<Url>>,
|
|
||||||
send_wildcard: bool,
|
|
||||||
) -> Result<Self, Error> {
|
|
||||||
let origin = origin.origin().unicode_serialization();
|
|
||||||
match *allowed_origins {
|
|
||||||
// Always matching is acceptable since the list of origins can be unbounded.
|
|
||||||
AllOrSome::All => {
|
|
||||||
if send_wildcard {
|
|
||||||
Ok(self.any())
|
|
||||||
} else {
|
|
||||||
Ok(self.origin(&origin, true))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
AllOrSome::Some(ref allowed_origins) => {
|
|
||||||
let allowed_origins: HashSet<_> = allowed_origins
|
|
||||||
.iter()
|
|
||||||
.map(|o| o.origin().unicode_serialization())
|
|
||||||
.collect();
|
|
||||||
let _ = allowed_origins.get(&origin).ok_or_else(
|
|
||||||
|| Error::OriginNotAllowed,
|
|
||||||
)?;
|
|
||||||
Ok(self.origin(&origin, false))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Consumes the Response and validate whether credentials can be allowed
|
|
||||||
fn credentials(mut self, value: bool) -> Result<Self, Error> {
|
|
||||||
if value {
|
|
||||||
if let Some(AllOrSome::All) = self.allow_origin {
|
|
||||||
Err(Error::CredentialsWithWildcardOrigin)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.allow_credentials = value;
|
self.allow_credentials = value;
|
||||||
Ok(self)
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consumes the CORS, set expose_headers to
|
/// Consumes the CORS, set expose_headers to
|
||||||
|
@ -634,22 +602,6 @@ impl Response {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consumes the CORS, check if requested method is allowed.
|
|
||||||
/// Useful for pre-flight checks
|
|
||||||
fn allowed_methods(
|
|
||||||
self,
|
|
||||||
method: &AccessControlRequestMethod,
|
|
||||||
allowed_methods: &HashSet<Method>,
|
|
||||||
) -> Result<Self, Error> {
|
|
||||||
let &AccessControlRequestMethod(ref request_method) = method;
|
|
||||||
if !allowed_methods.iter().any(|m| m == request_method) {
|
|
||||||
Err(Error::MethodNotAllowed)?
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Subset to route? Or just the method requested for?
|
|
||||||
Ok(self.methods(&allowed_methods))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Consumes the CORS, set allow_headers to
|
/// Consumes the CORS, set allow_headers to
|
||||||
/// passed headers and returns changed CORS
|
/// passed headers and returns changed CORS
|
||||||
fn headers(mut self, headers: &[&str]) -> Self {
|
fn headers(mut self, headers: &[&str]) -> Self {
|
||||||
|
@ -657,35 +609,6 @@ impl Response {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consumes the CORS, check if requested headers are allowed.
|
|
||||||
/// Useful for pre-flight checks
|
|
||||||
fn allowed_headers(
|
|
||||||
self,
|
|
||||||
headers: &AccessControlRequestHeaders,
|
|
||||||
allowed_headers: &AllOrSome<HashSet<HeaderFieldName>>,
|
|
||||||
) -> Result<Self, Error> {
|
|
||||||
let &AccessControlRequestHeaders(ref headers) = headers;
|
|
||||||
|
|
||||||
match *allowed_headers {
|
|
||||||
AllOrSome::All => {}
|
|
||||||
AllOrSome::Some(ref allowed_headers) => {
|
|
||||||
if !headers.is_empty() && !headers.is_subset(allowed_headers) {
|
|
||||||
Err(Error::HeadersNotAllowed)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
self.headers(
|
|
||||||
headers
|
|
||||||
.iter()
|
|
||||||
.map(|s| &**s.deref())
|
|
||||||
.collect::<Vec<&str>>()
|
|
||||||
.as_slice(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Builds a `rocket::Response` from this struct based off some base `rocket::Response`
|
/// Builds a `rocket::Response` from this struct based off some base `rocket::Response`
|
||||||
///
|
///
|
||||||
/// This will overwrite any existing CORS headers
|
/// This will overwrite any existing CORS headers
|
||||||
|
@ -788,12 +711,11 @@ pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
|
||||||
fn build_cors_response(
|
fn build_cors_response(
|
||||||
options: &Cors,
|
options: &Cors,
|
||||||
request: &Request,
|
request: &Request,
|
||||||
response: &mut rocket::Response,
|
) -> Result<Response, Error> {
|
||||||
) -> Result<(), Error> {
|
|
||||||
// Existing CORS response?
|
// Existing CORS response?
|
||||||
if has_allow_origin(response) {
|
// if has_allow_origin(response) {
|
||||||
return Ok(());
|
// return Ok(());
|
||||||
}
|
// }
|
||||||
|
|
||||||
// 1. If the Origin header is not present terminate this set of steps.
|
// 1. If the Origin header is not present terminate this set of steps.
|
||||||
// The request is outside the scope of this specification.
|
// The request is outside the scope of this specification.
|
||||||
|
@ -801,7 +723,7 @@ fn build_cors_response(
|
||||||
let origin = match origin {
|
let origin = match origin {
|
||||||
None => {
|
None => {
|
||||||
// Not a CORS request
|
// Not a CORS request
|
||||||
return Ok(());
|
return Ok(Response::new());
|
||||||
}
|
}
|
||||||
Some(origin) => origin,
|
Some(origin) => origin,
|
||||||
};
|
};
|
||||||
|
@ -816,10 +738,60 @@ fn build_cors_response(
|
||||||
_ => actual_request(options, origin),
|
_ => actual_request(options, origin),
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
cors_response.merge(response);
|
Ok(cors_response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes the responder and based on the provided list of allowed origins,
|
||||||
|
/// check if the requested origin is allowed.
|
||||||
|
/// Useful for pre-flight and during requests
|
||||||
|
fn validate_origin(
|
||||||
|
origin: &Origin,
|
||||||
|
allowed_origins: &AllOrSome<HashSet<Url>>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
match *allowed_origins {
|
||||||
|
// Always matching is acceptable since the list of origins can be unbounded.
|
||||||
|
AllOrSome::All => Ok(()),
|
||||||
|
AllOrSome::Some(ref allowed_origins) => {
|
||||||
|
allowed_origins
|
||||||
|
.get(origin)
|
||||||
|
.and_then(|_| Some(()))
|
||||||
|
.ok_or_else(|| Error::OriginNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate allowed methods
|
||||||
|
fn validate_allowed_method(
|
||||||
|
method: &AccessControlRequestMethod,
|
||||||
|
allowed_methods: &HashSet<Method>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let &AccessControlRequestMethod(ref request_method) = method;
|
||||||
|
if !allowed_methods.iter().any(|m| m == request_method) {
|
||||||
|
Err(Error::MethodNotAllowed)?
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Subset to route? Or just the method requested for?
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate allowed headers
|
||||||
|
fn validate_allowed_headers(
|
||||||
|
headers: &AccessControlRequestHeaders,
|
||||||
|
allowed_headers: &AllOrSome<HashSet<HeaderFieldName>>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let &AccessControlRequestHeaders(ref headers) = headers;
|
||||||
|
|
||||||
|
match *allowed_headers {
|
||||||
|
AllOrSome::All => Ok(()),
|
||||||
|
AllOrSome::Some(ref allowed_headers) => {
|
||||||
|
if !headers.is_empty() && !headers.is_subset(allowed_headers) {
|
||||||
|
Err(Error::HeadersNotAllowed)?
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Gets the `Origin` request header from the request
|
/// Gets the `Origin` request header from the request
|
||||||
fn origin(request: &Request) -> Result<Option<Origin>, Error> {
|
fn origin(request: &Request) -> Result<Option<Origin>, Error> {
|
||||||
match Origin::from_request(request) {
|
match Origin::from_request(request) {
|
||||||
|
@ -863,18 +835,24 @@ fn preflight(
|
||||||
method: Option<AccessControlRequestMethod>,
|
method: Option<AccessControlRequestMethod>,
|
||||||
headers: Option<AccessControlRequestHeaders>,
|
headers: Option<AccessControlRequestHeaders>,
|
||||||
) -> Result<Response, Error> {
|
) -> Result<Response, Error> {
|
||||||
|
options.validate()?;
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
|
|
||||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||||
|
|
||||||
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
||||||
// in list of origins do not set any additional headers and terminate this set of steps.
|
// in list of origins do not set any additional headers and terminate this set of steps.
|
||||||
let response = response.allowed_origin(
|
validate_origin(&origin, &options.allowed_origins)?;
|
||||||
&origin,
|
let response = match options.allowed_origins {
|
||||||
&options.allowed_origins,
|
AllOrSome::All => {
|
||||||
options.send_wildcard,
|
if options.send_wildcard {
|
||||||
)?;
|
response.any()
|
||||||
|
} else {
|
||||||
|
response.origin(origin.as_str(), true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
|
||||||
|
};
|
||||||
|
|
||||||
// 3. Let `method` be the value as result of parsing the Access-Control-Request-Method
|
// 3. Let `method` be the value as result of parsing the Access-Control-Request-Method
|
||||||
// header.
|
// header.
|
||||||
|
@ -894,13 +872,23 @@ fn preflight(
|
||||||
// 5. If method is not a case-sensitive match for any of the values in list of methods
|
// 5. If method is not a case-sensitive match for any of the values in list of methods
|
||||||
// do not set any additional headers and terminate this set of steps.
|
// do not set any additional headers and terminate this set of steps.
|
||||||
|
|
||||||
let response = response.allowed_methods(&method, &options.allowed_methods)?;
|
validate_allowed_method(&method, &options.allowed_methods)?;
|
||||||
|
let response = response.methods(&options.allowed_methods);
|
||||||
|
|
||||||
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
|
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
|
||||||
// values in list of headers do not set any additional headers and terminate this set of
|
// values in list of headers do not set any additional headers and terminate this set of
|
||||||
// steps.
|
// steps.
|
||||||
let response = if let Some(headers) = headers {
|
|
||||||
response.allowed_headers(&headers, &options.allowed_headers)?
|
let response = if let Some(ref headers) = headers {
|
||||||
|
validate_allowed_headers(headers, &options.allowed_headers)?;
|
||||||
|
let &AccessControlRequestHeaders(ref headers) = headers;
|
||||||
|
response.headers(
|
||||||
|
headers
|
||||||
|
.iter()
|
||||||
|
.map(|s| &**s.deref())
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.as_slice(),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
response
|
response
|
||||||
};
|
};
|
||||||
|
@ -913,7 +901,8 @@ fn preflight(
|
||||||
// with either the value of the Origin header or the string "*" as value.
|
// with either the value of the Origin header or the string "*" as value.
|
||||||
// Note: The string "*" cannot be used for a resource that supports credentials.
|
// Note: The string "*" cannot be used for a resource that supports credentials.
|
||||||
|
|
||||||
let response = response.credentials(options.allow_credentials)?;
|
// Validation has been done in options.validate
|
||||||
|
let response = response.credentials(options.allow_credentials);
|
||||||
|
|
||||||
// 8. Optionally add a single Access-Control-Max-Age header
|
// 8. Optionally add a single Access-Control-Max-Age header
|
||||||
// with as value the amount of seconds the user agent is allowed to cache the result of the
|
// with as value the amount of seconds the user agent is allowed to cache the result of the
|
||||||
|
@ -949,6 +938,7 @@ fn preflight(
|
||||||
/// If the `Origin` is not provided, then this request was not made by a browser and there is no
|
/// If the `Origin` is not provided, then this request was not made by a browser and there is no
|
||||||
/// CORS enforcement.
|
/// CORS enforcement.
|
||||||
fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
||||||
|
options.validate()?;
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
|
|
||||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||||
|
@ -957,11 +947,17 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
||||||
// in list of origins, do not set any additional headers and terminate this set of steps.
|
// in list of origins, do not set any additional headers and terminate this set of steps.
|
||||||
// Always matching is acceptable since the list of origins can be unbounded.
|
// Always matching is acceptable since the list of origins can be unbounded.
|
||||||
|
|
||||||
let response = response.allowed_origin(
|
validate_origin(&origin, &options.allowed_origins)?;
|
||||||
&origin,
|
let response = match options.allowed_origins {
|
||||||
&options.allowed_origins,
|
AllOrSome::All => {
|
||||||
options.send_wildcard,
|
if options.send_wildcard {
|
||||||
)?;
|
response.any()
|
||||||
|
} else {
|
||||||
|
response.origin(origin.as_str(), true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
|
||||||
|
};
|
||||||
|
|
||||||
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
|
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
|
||||||
// with the value of the Origin header as value, and add a
|
// with the value of the Origin header as value, and add a
|
||||||
|
@ -971,7 +967,8 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
||||||
// with either the value of the Origin header or the string "*" as value.
|
// with either the value of the Origin header or the string "*" as value.
|
||||||
// Note: The string "*" cannot be used for a resource that supports credentials.
|
// Note: The string "*" cannot be used for a resource that supports credentials.
|
||||||
|
|
||||||
let response = response.credentials(options.allow_credentials)?;
|
// Validation has been done in options.validate
|
||||||
|
let response = response.credentials(options.allow_credentials);
|
||||||
|
|
||||||
// 4. If the list of exposed headers is not empty add one or more
|
// 4. If the list of exposed headers is not empty add one or more
|
||||||
// Access-Control-Expose-Headers headers, with as values the header field names given in
|
// Access-Control-Expose-Headers headers, with as values the header field names given in
|
||||||
|
@ -1033,65 +1030,15 @@ mod tests {
|
||||||
cors.validate().unwrap();
|
cors.validate().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following tests check `Response`'s validation
|
// The following tests check validation
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn response_allows_all_origin_with_wildcard() {
|
fn validate_origin_allows_all_origins() {
|
||||||
let url = "https://www.example.com";
|
let url = "https://www.example.com";
|
||||||
let origin = Origin::from_str(url).unwrap();
|
let origin = Origin::from_str(url).unwrap();
|
||||||
let allowed_origins = AllOrSome::All;
|
let allowed_origins = AllOrSome::All;
|
||||||
let send_wildcard = true;
|
|
||||||
|
|
||||||
let response = Response::new();
|
not_err!(validate_origin(&origin, &allowed_origins));
|
||||||
let response = not_err!(response.allowed_origin(
|
|
||||||
&origin,
|
|
||||||
&allowed_origins,
|
|
||||||
send_wildcard,
|
|
||||||
));
|
|
||||||
|
|
||||||
assert_matches!(response.allow_origin, Some(AllOrSome::All));
|
|
||||||
assert_eq!(response.vary_origin, false);
|
|
||||||
|
|
||||||
// Build response and check built response header
|
|
||||||
let expected_header = vec!["*"];
|
|
||||||
let response = response.build(response::Response::new());
|
|
||||||
let actual_header: Vec<_> = response
|
|
||||||
.headers()
|
|
||||||
.get("Access-Control-Allow-Origin")
|
|
||||||
.collect();
|
|
||||||
assert_eq!(expected_header, actual_header);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn response_allows_all_origin_with_echoed_domain() {
|
|
||||||
let url = "https://www.example.com";
|
|
||||||
let origin = Origin::from_str(url).unwrap();
|
|
||||||
let allowed_origins = AllOrSome::All;
|
|
||||||
let send_wildcard = false;
|
|
||||||
|
|
||||||
let response = Response::new();
|
|
||||||
let response = not_err!(response.allowed_origin(
|
|
||||||
&origin,
|
|
||||||
&allowed_origins,
|
|
||||||
send_wildcard,
|
|
||||||
));
|
|
||||||
|
|
||||||
let actual_origin = assert_matches!(
|
|
||||||
response.allow_origin,
|
|
||||||
Some(AllOrSome::Some(ref origin)),
|
|
||||||
origin
|
|
||||||
);
|
|
||||||
assert_eq!(url, actual_origin);
|
|
||||||
assert_eq!(response.vary_origin, true);
|
|
||||||
|
|
||||||
// Build response and check built response header
|
|
||||||
let expected_header = vec![url];
|
|
||||||
let response = response.build(response::Response::new());
|
|
||||||
let actual_header: Vec<_> = response
|
|
||||||
.headers()
|
|
||||||
.get("Access-Control-Allow-Origin")
|
|
||||||
.collect();
|
|
||||||
assert_eq!(expected_header, actual_header);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -1101,32 +1048,8 @@ mod tests {
|
||||||
let (allowed_origins, failed_origins) =
|
let (allowed_origins, failed_origins) =
|
||||||
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
assert!(failed_origins.is_empty());
|
||||||
let send_wildcard = false;
|
|
||||||
|
|
||||||
let response = Response::new();
|
not_err!(validate_origin(&origin, &allowed_origins));
|
||||||
let response = not_err!(response.allowed_origin(
|
|
||||||
&origin,
|
|
||||||
&allowed_origins,
|
|
||||||
send_wildcard,
|
|
||||||
));
|
|
||||||
|
|
||||||
let actual_origin = assert_matches!(
|
|
||||||
response.allow_origin,
|
|
||||||
Some(AllOrSome::Some(ref origin)),
|
|
||||||
origin
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(url, actual_origin);
|
|
||||||
assert_eq!(response.vary_origin, false);
|
|
||||||
|
|
||||||
// Build response and check built response header
|
|
||||||
let expected_header = vec![url];
|
|
||||||
let response = response.build(response::Response::new());
|
|
||||||
let actual_header: Vec<_> = response
|
|
||||||
.headers()
|
|
||||||
.get("Access-Control-Allow-Origin")
|
|
||||||
.collect();
|
|
||||||
assert_eq!(expected_header, actual_header);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -1137,41 +1060,8 @@ mod tests {
|
||||||
let (allowed_origins, failed_origins) =
|
let (allowed_origins, failed_origins) =
|
||||||
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
assert!(failed_origins.is_empty());
|
||||||
let send_wildcard = false;
|
|
||||||
|
|
||||||
let response = Response::new();
|
validate_origin(&origin, &allowed_origins).unwrap();
|
||||||
let _ = response
|
|
||||||
.allowed_origin(&origin, &allowed_origins, send_wildcard)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
|
|
||||||
fn response_credentials_does_not_allow_wildcard_with_all_origins() {
|
|
||||||
let response = Response::new();
|
|
||||||
let response = response.any();
|
|
||||||
|
|
||||||
let _ = response.credentials(true).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn response_credentials_allows_specific_origins() {
|
|
||||||
let response = Response::new();
|
|
||||||
let response = response.origin("https://www.example.com", false);
|
|
||||||
|
|
||||||
let response = response.credentials(true).expect(
|
|
||||||
"to allow specific origins",
|
|
||||||
);
|
|
||||||
assert_eq!(response.allow_credentials, true);
|
|
||||||
|
|
||||||
// Build response and check built response header
|
|
||||||
let expected_header = vec!["true"];
|
|
||||||
let response = response.build(response::Response::new());
|
|
||||||
let actual_header: Vec<_> = response
|
|
||||||
.headers()
|
|
||||||
.get("Access-Control-Allow-Credentials")
|
|
||||||
.collect();
|
|
||||||
assert_eq!(expected_header, actual_header);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -1220,159 +1110,88 @@ mod tests {
|
||||||
|
|
||||||
// Build response and check built response header
|
// Build response and check built response header
|
||||||
let response = response.build(response::Response::new());
|
let response = response.build(response::Response::new());
|
||||||
assert!(response
|
assert!(
|
||||||
.headers()
|
response
|
||||||
.get("Access-Control-Max-Age")
|
.headers()
|
||||||
.next().is_none())
|
.get("Access-Control-Max-Age")
|
||||||
}
|
.next()
|
||||||
|
.is_none()
|
||||||
/// When all headers are allowed, tests that the requested headers are echoed back
|
)
|
||||||
#[test]
|
|
||||||
fn response_allowed_headers_echoes_back_requested_headers() {
|
|
||||||
let allowed_headers = AllOrSome::All;
|
|
||||||
let requested_headers = vec!["Bar", "Foo"];
|
|
||||||
|
|
||||||
let response = Response::new();
|
|
||||||
let response = response.origin("https://www.example.com", false);
|
|
||||||
let response = response
|
|
||||||
.allowed_headers(
|
|
||||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
|
||||||
&allowed_headers,
|
|
||||||
)
|
|
||||||
.expect("to not fail");
|
|
||||||
|
|
||||||
// Build response and check built response header
|
|
||||||
let response = response.build(response::Response::new());
|
|
||||||
let actual_header: Vec<_> = response
|
|
||||||
.headers()
|
|
||||||
.get("Access-Control-Allow-Headers")
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
assert_eq!(1, actual_header.len());
|
|
||||||
let mut actual_headers: Vec<String> = actual_header[0]
|
|
||||||
.split(',')
|
|
||||||
.map(|header| header.trim().to_string())
|
|
||||||
.collect();
|
|
||||||
actual_headers.sort();
|
|
||||||
assert_eq!(requested_headers, actual_headers);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn response_allowed_methods_sets_headers_properly() {
|
fn allowed_methods_validated_correctly() {
|
||||||
let allowed_methods = vec![
|
let allowed_methods = vec![Method::Get, Method::Head, Method::Post]
|
||||||
Method::Get,
|
.into_iter()
|
||||||
Method::Head,
|
|
||||||
Method::Post,
|
|
||||||
].into_iter()
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let method = "GET";
|
let method = "GET";
|
||||||
|
|
||||||
let response = Response::new();
|
not_err!(validate_allowed_method(
|
||||||
let response = response.origin("https://www.example.com", false);
|
&FromStr::from_str(method).expect("not to fail"),
|
||||||
let response = response
|
&allowed_methods,
|
||||||
.allowed_methods(
|
));
|
||||||
&FromStr::from_str(method).expect("not to fail"),
|
|
||||||
&allowed_methods,
|
|
||||||
)
|
|
||||||
.expect("not to fail");
|
|
||||||
|
|
||||||
// Build response and check built response header
|
|
||||||
let response = response.build(response::Response::new());
|
|
||||||
let actual_header: Vec<_> = response
|
|
||||||
.headers()
|
|
||||||
.get("Access-Control-Allow-Methods")
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
assert_eq!(1, actual_header.len());
|
|
||||||
let mut actual_headers: Vec<String> = actual_header[0]
|
|
||||||
.split(',')
|
|
||||||
.map(|header| header.trim().to_string())
|
|
||||||
.collect();
|
|
||||||
actual_headers.sort();
|
|
||||||
let mut expected_headers: Vec<_> = allowed_methods.iter().map(|m| m.as_str()).collect();
|
|
||||||
expected_headers.sort();
|
|
||||||
assert_eq!(expected_headers, actual_headers);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "MethodNotAllowed")]
|
#[should_panic(expected = "MethodNotAllowed")]
|
||||||
fn response_allowed_method_errors_on_disallowed_method() {
|
fn allowed_methods_errors_on_disallowed_method() {
|
||||||
let allowed_methods = vec![
|
let allowed_methods = vec![Method::Get, Method::Head, Method::Post]
|
||||||
Method::Get,
|
.into_iter()
|
||||||
Method::Head,
|
|
||||||
Method::Post,
|
|
||||||
].into_iter()
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let method = "DELETE";
|
let method = "DELETE";
|
||||||
|
|
||||||
let response = Response::new();
|
validate_allowed_method(
|
||||||
let response = response.origin("https://www.example.com", false);
|
&FromStr::from_str(method).expect("not to fail"),
|
||||||
let _ = response
|
&allowed_methods,
|
||||||
.allowed_methods(
|
).unwrap()
|
||||||
&FromStr::from_str(method).expect("not to fail"),
|
}
|
||||||
&allowed_methods,
|
|
||||||
)
|
#[test]
|
||||||
.unwrap();
|
fn all_allowed_headers_are_validated_correctly() {
|
||||||
|
let allowed_headers = AllOrSome::All;
|
||||||
|
let requested_headers = vec!["Bar", "Foo"];
|
||||||
|
|
||||||
|
not_err!(validate_allowed_headers(
|
||||||
|
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||||
|
&allowed_headers,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `Response::allowed_headers` should check that headers are allowed, and only
|
/// `Response::allowed_headers` should check that headers are allowed, and only
|
||||||
/// echoes back the list that is actually requested for and not the whole list
|
/// echoes back the list that is actually requested for and not the whole list
|
||||||
#[test]
|
#[test]
|
||||||
fn response_allowed_headers_validates_and_echoes_requested_headers() {
|
fn allowed_headers_are_validated_correctly() {
|
||||||
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
||||||
let requested_headers = vec!["Bar", "Foo"];
|
let requested_headers = vec!["Bar", "Foo"];
|
||||||
|
|
||||||
let response = Response::new();
|
not_err!(validate_allowed_headers(
|
||||||
let response = response.origin("https://www.example.com", false);
|
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||||
let response = response
|
&AllOrSome::Some(
|
||||||
.allowed_headers(
|
allowed_headers
|
||||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
.iter()
|
||||||
&AllOrSome::Some(
|
.map(|s| FromStr::from_str(*s).unwrap())
|
||||||
allowed_headers
|
.collect(),
|
||||||
.iter()
|
),
|
||||||
.map(|s| FromStr::from_str(*s).unwrap())
|
));
|
||||||
.collect(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.expect("to not fail");
|
|
||||||
|
|
||||||
// Build response and check built response header
|
|
||||||
let response = response.build(response::Response::new());
|
|
||||||
let actual_header: Vec<_> = response
|
|
||||||
.headers()
|
|
||||||
.get("Access-Control-Allow-Headers")
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
assert_eq!(1, actual_header.len());
|
|
||||||
let mut actual_headers: Vec<String> = actual_header[0]
|
|
||||||
.split(',')
|
|
||||||
.map(|header| header.trim().to_string())
|
|
||||||
.collect();
|
|
||||||
actual_headers.sort();
|
|
||||||
assert_eq!(requested_headers, actual_headers);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "HeadersNotAllowed")]
|
#[should_panic(expected = "HeadersNotAllowed")]
|
||||||
fn response_allowed_headers_errors_on_non_subset() {
|
fn allowed_headers_errors_on_non_subset() {
|
||||||
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
||||||
let requested_headers = vec!["Bar", "Foo", "Unknown"];
|
let requested_headers = vec!["Bar", "Foo", "Unknown"];
|
||||||
|
|
||||||
let response = Response::new();
|
validate_allowed_headers(
|
||||||
let response = response.origin("https://www.example.com", false);
|
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||||
let _ = response
|
&AllOrSome::Some(
|
||||||
.allowed_headers(
|
allowed_headers
|
||||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
.iter()
|
||||||
&AllOrSome::Some(
|
.map(|s| FromStr::from_str(*s).unwrap())
|
||||||
allowed_headers
|
.collect(),
|
||||||
.iter()
|
),
|
||||||
.map(|s| FromStr::from_str(*s).unwrap())
|
).unwrap();
|
||||||
.collect(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1413,42 +1232,19 @@ mod tests {
|
||||||
assert_eq!(expected_header, actual_header);
|
assert_eq!(expected_header, actual_header);
|
||||||
|
|
||||||
// Check that `Access-Control-Max-Age` is removed
|
// Check that `Access-Control-Max-Age` is removed
|
||||||
assert!(response.headers().get("Access-Control-Max-Age").next().is_none());
|
assert!(
|
||||||
|
response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Max-Age")
|
||||||
|
.next()
|
||||||
|
.is_none()
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following tests check that preflight checks are done properly
|
// TODO: Preflight tests
|
||||||
|
// TODO: Actual requests tests
|
||||||
|
|
||||||
// fn make_cors_options() -> Cors {
|
// Origin all (wildcard + echoed with Vary). Origin Echo
|
||||||
// let (allowed_origins, failed_origins) =
|
|
||||||
// AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
|
||||||
// assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
// Cors {
|
|
||||||
// allowed_origins: allowed_origins,
|
|
||||||
// allowed_methods: [Method::Get].iter().cloned().collect(),
|
|
||||||
// allowed_headers: AllOrSome::Some(
|
|
||||||
// ["Authorization"]
|
|
||||||
// .into_iter()
|
|
||||||
// .map(|s| s.to_string().into())
|
|
||||||
// .collect(),
|
|
||||||
// ),
|
|
||||||
// allow_credentials: true,
|
|
||||||
// ..Default::default()
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// Tests that non CORS preflight are let through without modification
|
|
||||||
// #[test]
|
|
||||||
// fn preflight_missing_origins_are_let_through() {
|
|
||||||
// let options = make_cors_options();
|
|
||||||
// let client = make_client();
|
|
||||||
// let request = client.get("/");
|
|
||||||
|
|
||||||
// let response = options.preflight((), None, None, None).expect("not to fail");
|
|
||||||
|
|
||||||
// let headers: Vec<_> = response.headers().iter().collect();
|
|
||||||
// assert_eq!(headers.len(), 0);
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ use std::ops::Deref;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
use rocket::local::Client;
|
use rocket::local::Client;
|
||||||
use rocket::http::{Header, Status};
|
use rocket::http::Header;
|
||||||
use rocket_cors::headers::*;
|
use rocket_cors::headers::*;
|
||||||
|
|
||||||
#[get("/request_headers")]
|
#[get("/request_headers")]
|
||||||
|
|
Loading…
Reference in New Issue