Response no longer validates
This commit is contained in:
parent
abda99d71f
commit
6746e835e7
558
src/lib.rs
558
src/lib.rs
|
@ -437,7 +437,7 @@ impl fairing::Fairing for Cors {
|
|||
use rocket::response::Responder;
|
||||
|
||||
// Build and merge CORS response
|
||||
match build_cors_response(self, request, response) {
|
||||
match build_cors_response(self, request) {
|
||||
Err(err) => {
|
||||
// CORS error -- overwrite the original response
|
||||
let error_response = match err.respond_to(request) {
|
||||
|
@ -451,7 +451,9 @@ impl fairing::Fairing for Cors {
|
|||
};
|
||||
response.merge(error_response)
|
||||
}
|
||||
Ok(()) => {
|
||||
Ok(cors_response) => {
|
||||
cors_response.merge(response);
|
||||
|
||||
// If this was an OPTIONS request and no route can be found, we should turn this
|
||||
// into a HTTP 204 with no content body.
|
||||
// This allows the user to not have to specify an OPTIONS route for everything.
|
||||
|
@ -508,10 +510,14 @@ impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> {
|
|||
fn respond(self, request: &Request) -> response::Result<'r> {
|
||||
let mut response = self.responder.respond_to(request)?; // handle status errors?
|
||||
|
||||
match build_cors_response(self.options, request, &mut response) {
|
||||
Ok(()) => Ok(response),
|
||||
match build_cors_response(self.options, request) {
|
||||
Ok(cors_response) => {
|
||||
cors_response.merge(&mut response);
|
||||
Ok(response)
|
||||
},
|
||||
Err(e) => response::Responder::respond_to(e, request),
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -569,48 +575,10 @@ impl Response {
|
|||
self
|
||||
}
|
||||
|
||||
/// Consumes the responder and based on the provided list of allowed origins,
|
||||
/// check if the requested origin is allowed.
|
||||
/// Useful for pre-flight and during requests
|
||||
fn allowed_origin(
|
||||
self,
|
||||
origin: &Origin,
|
||||
allowed_origins: &AllOrSome<HashSet<Url>>,
|
||||
send_wildcard: bool,
|
||||
) -> Result<Self, Error> {
|
||||
let origin = origin.origin().unicode_serialization();
|
||||
match *allowed_origins {
|
||||
// Always matching is acceptable since the list of origins can be unbounded.
|
||||
AllOrSome::All => {
|
||||
if send_wildcard {
|
||||
Ok(self.any())
|
||||
} else {
|
||||
Ok(self.origin(&origin, true))
|
||||
}
|
||||
}
|
||||
AllOrSome::Some(ref allowed_origins) => {
|
||||
let allowed_origins: HashSet<_> = allowed_origins
|
||||
.iter()
|
||||
.map(|o| o.origin().unicode_serialization())
|
||||
.collect();
|
||||
let _ = allowed_origins.get(&origin).ok_or_else(
|
||||
|| Error::OriginNotAllowed,
|
||||
)?;
|
||||
Ok(self.origin(&origin, false))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes the Response and validate whether credentials can be allowed
|
||||
fn credentials(mut self, value: bool) -> Result<Self, Error> {
|
||||
if value {
|
||||
if let Some(AllOrSome::All) = self.allow_origin {
|
||||
Err(Error::CredentialsWithWildcardOrigin)?;
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes the Response and set credentials
|
||||
fn credentials(mut self, value: bool) -> Self {
|
||||
self.allow_credentials = value;
|
||||
Ok(self)
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes the CORS, set expose_headers to
|
||||
|
@ -634,22 +602,6 @@ impl Response {
|
|||
self
|
||||
}
|
||||
|
||||
/// Consumes the CORS, check if requested method is allowed.
|
||||
/// Useful for pre-flight checks
|
||||
fn allowed_methods(
|
||||
self,
|
||||
method: &AccessControlRequestMethod,
|
||||
allowed_methods: &HashSet<Method>,
|
||||
) -> Result<Self, Error> {
|
||||
let &AccessControlRequestMethod(ref request_method) = method;
|
||||
if !allowed_methods.iter().any(|m| m == request_method) {
|
||||
Err(Error::MethodNotAllowed)?
|
||||
}
|
||||
|
||||
// TODO: Subset to route? Or just the method requested for?
|
||||
Ok(self.methods(&allowed_methods))
|
||||
}
|
||||
|
||||
/// Consumes the CORS, set allow_headers to
|
||||
/// passed headers and returns changed CORS
|
||||
fn headers(mut self, headers: &[&str]) -> Self {
|
||||
|
@ -657,35 +609,6 @@ impl Response {
|
|||
self
|
||||
}
|
||||
|
||||
/// Consumes the CORS, check if requested headers are allowed.
|
||||
/// Useful for pre-flight checks
|
||||
fn allowed_headers(
|
||||
self,
|
||||
headers: &AccessControlRequestHeaders,
|
||||
allowed_headers: &AllOrSome<HashSet<HeaderFieldName>>,
|
||||
) -> Result<Self, Error> {
|
||||
let &AccessControlRequestHeaders(ref headers) = headers;
|
||||
|
||||
match *allowed_headers {
|
||||
AllOrSome::All => {}
|
||||
AllOrSome::Some(ref allowed_headers) => {
|
||||
if !headers.is_empty() && !headers.is_subset(allowed_headers) {
|
||||
Err(Error::HeadersNotAllowed)?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(
|
||||
self.headers(
|
||||
headers
|
||||
.iter()
|
||||
.map(|s| &**s.deref())
|
||||
.collect::<Vec<&str>>()
|
||||
.as_slice(),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
/// Builds a `rocket::Response` from this struct based off some base `rocket::Response`
|
||||
///
|
||||
/// This will overwrite any existing CORS headers
|
||||
|
@ -788,12 +711,11 @@ pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
|
|||
fn build_cors_response(
|
||||
options: &Cors,
|
||||
request: &Request,
|
||||
response: &mut rocket::Response,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<Response, Error> {
|
||||
// Existing CORS response?
|
||||
if has_allow_origin(response) {
|
||||
return Ok(());
|
||||
}
|
||||
// if has_allow_origin(response) {
|
||||
// return Ok(());
|
||||
// }
|
||||
|
||||
// 1. If the Origin header is not present terminate this set of steps.
|
||||
// The request is outside the scope of this specification.
|
||||
|
@ -801,7 +723,7 @@ fn build_cors_response(
|
|||
let origin = match origin {
|
||||
None => {
|
||||
// Not a CORS request
|
||||
return Ok(());
|
||||
return Ok(Response::new());
|
||||
}
|
||||
Some(origin) => origin,
|
||||
};
|
||||
|
@ -816,10 +738,60 @@ fn build_cors_response(
|
|||
_ => actual_request(options, origin),
|
||||
}?;
|
||||
|
||||
cors_response.merge(response);
|
||||
Ok(cors_response)
|
||||
}
|
||||
|
||||
/// Consumes the responder and based on the provided list of allowed origins,
|
||||
/// check if the requested origin is allowed.
|
||||
/// Useful for pre-flight and during requests
|
||||
fn validate_origin(
|
||||
origin: &Origin,
|
||||
allowed_origins: &AllOrSome<HashSet<Url>>,
|
||||
) -> Result<(), Error> {
|
||||
match *allowed_origins {
|
||||
// Always matching is acceptable since the list of origins can be unbounded.
|
||||
AllOrSome::All => Ok(()),
|
||||
AllOrSome::Some(ref allowed_origins) => {
|
||||
allowed_origins
|
||||
.get(origin)
|
||||
.and_then(|_| Some(()))
|
||||
.ok_or_else(|| Error::OriginNotAllowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate allowed methods
|
||||
fn validate_allowed_method(
|
||||
method: &AccessControlRequestMethod,
|
||||
allowed_methods: &HashSet<Method>,
|
||||
) -> Result<(), Error> {
|
||||
let &AccessControlRequestMethod(ref request_method) = method;
|
||||
if !allowed_methods.iter().any(|m| m == request_method) {
|
||||
Err(Error::MethodNotAllowed)?
|
||||
}
|
||||
|
||||
// TODO: Subset to route? Or just the method requested for?
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate allowed headers
|
||||
fn validate_allowed_headers(
|
||||
headers: &AccessControlRequestHeaders,
|
||||
allowed_headers: &AllOrSome<HashSet<HeaderFieldName>>,
|
||||
) -> Result<(), Error> {
|
||||
let &AccessControlRequestHeaders(ref headers) = headers;
|
||||
|
||||
match *allowed_headers {
|
||||
AllOrSome::All => Ok(()),
|
||||
AllOrSome::Some(ref allowed_headers) => {
|
||||
if !headers.is_empty() && !headers.is_subset(allowed_headers) {
|
||||
Err(Error::HeadersNotAllowed)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the `Origin` request header from the request
|
||||
fn origin(request: &Request) -> Result<Option<Origin>, Error> {
|
||||
match Origin::from_request(request) {
|
||||
|
@ -863,18 +835,24 @@ fn preflight(
|
|||
method: Option<AccessControlRequestMethod>,
|
||||
headers: Option<AccessControlRequestHeaders>,
|
||||
) -> Result<Response, Error> {
|
||||
|
||||
options.validate()?;
|
||||
let response = Response::new();
|
||||
|
||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||
|
||||
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
||||
// in list of origins do not set any additional headers and terminate this set of steps.
|
||||
let response = response.allowed_origin(
|
||||
&origin,
|
||||
&options.allowed_origins,
|
||||
options.send_wildcard,
|
||||
)?;
|
||||
validate_origin(&origin, &options.allowed_origins)?;
|
||||
let response = match options.allowed_origins {
|
||||
AllOrSome::All => {
|
||||
if options.send_wildcard {
|
||||
response.any()
|
||||
} else {
|
||||
response.origin(origin.as_str(), true)
|
||||
}
|
||||
}
|
||||
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
|
||||
};
|
||||
|
||||
// 3. Let `method` be the value as result of parsing the Access-Control-Request-Method
|
||||
// header.
|
||||
|
@ -894,13 +872,23 @@ fn preflight(
|
|||
// 5. If method is not a case-sensitive match for any of the values in list of methods
|
||||
// do not set any additional headers and terminate this set of steps.
|
||||
|
||||
let response = response.allowed_methods(&method, &options.allowed_methods)?;
|
||||
validate_allowed_method(&method, &options.allowed_methods)?;
|
||||
let response = response.methods(&options.allowed_methods);
|
||||
|
||||
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
|
||||
// values in list of headers do not set any additional headers and terminate this set of
|
||||
// steps.
|
||||
let response = if let Some(headers) = headers {
|
||||
response.allowed_headers(&headers, &options.allowed_headers)?
|
||||
|
||||
let response = if let Some(ref headers) = headers {
|
||||
validate_allowed_headers(headers, &options.allowed_headers)?;
|
||||
let &AccessControlRequestHeaders(ref headers) = headers;
|
||||
response.headers(
|
||||
headers
|
||||
.iter()
|
||||
.map(|s| &**s.deref())
|
||||
.collect::<Vec<&str>>()
|
||||
.as_slice(),
|
||||
)
|
||||
} else {
|
||||
response
|
||||
};
|
||||
|
@ -913,7 +901,8 @@ fn preflight(
|
|||
// with either the value of the Origin header or the string "*" as value.
|
||||
// Note: The string "*" cannot be used for a resource that supports credentials.
|
||||
|
||||
let response = response.credentials(options.allow_credentials)?;
|
||||
// Validation has been done in options.validate
|
||||
let response = response.credentials(options.allow_credentials);
|
||||
|
||||
// 8. Optionally add a single Access-Control-Max-Age header
|
||||
// with as value the amount of seconds the user agent is allowed to cache the result of the
|
||||
|
@ -949,6 +938,7 @@ fn preflight(
|
|||
/// If the `Origin` is not provided, then this request was not made by a browser and there is no
|
||||
/// CORS enforcement.
|
||||
fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
||||
options.validate()?;
|
||||
let response = Response::new();
|
||||
|
||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||
|
@ -957,11 +947,17 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
|||
// in list of origins, do not set any additional headers and terminate this set of steps.
|
||||
// Always matching is acceptable since the list of origins can be unbounded.
|
||||
|
||||
let response = response.allowed_origin(
|
||||
&origin,
|
||||
&options.allowed_origins,
|
||||
options.send_wildcard,
|
||||
)?;
|
||||
validate_origin(&origin, &options.allowed_origins)?;
|
||||
let response = match options.allowed_origins {
|
||||
AllOrSome::All => {
|
||||
if options.send_wildcard {
|
||||
response.any()
|
||||
} else {
|
||||
response.origin(origin.as_str(), true)
|
||||
}
|
||||
}
|
||||
AllOrSome::Some(_) => response.origin(origin.as_str(), false),
|
||||
};
|
||||
|
||||
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
|
||||
// with the value of the Origin header as value, and add a
|
||||
|
@ -971,7 +967,8 @@ fn actual_request(options: &Cors, origin: Origin) -> Result<Response, Error> {
|
|||
// with either the value of the Origin header or the string "*" as value.
|
||||
// Note: The string "*" cannot be used for a resource that supports credentials.
|
||||
|
||||
let response = response.credentials(options.allow_credentials)?;
|
||||
// Validation has been done in options.validate
|
||||
let response = response.credentials(options.allow_credentials);
|
||||
|
||||
// 4. If the list of exposed headers is not empty add one or more
|
||||
// Access-Control-Expose-Headers headers, with as values the header field names given in
|
||||
|
@ -1033,65 +1030,15 @@ mod tests {
|
|||
cors.validate().unwrap();
|
||||
}
|
||||
|
||||
// The following tests check `Response`'s validation
|
||||
// The following tests check validation
|
||||
|
||||
#[test]
|
||||
fn response_allows_all_origin_with_wildcard() {
|
||||
fn validate_origin_allows_all_origins() {
|
||||
let url = "https://www.example.com";
|
||||
let origin = Origin::from_str(url).unwrap();
|
||||
let allowed_origins = AllOrSome::All;
|
||||
let send_wildcard = true;
|
||||
|
||||
let response = Response::new();
|
||||
let response = not_err!(response.allowed_origin(
|
||||
&origin,
|
||||
&allowed_origins,
|
||||
send_wildcard,
|
||||
));
|
||||
|
||||
assert_matches!(response.allow_origin, Some(AllOrSome::All));
|
||||
assert_eq!(response.vary_origin, false);
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["*"];
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_allows_all_origin_with_echoed_domain() {
|
||||
let url = "https://www.example.com";
|
||||
let origin = Origin::from_str(url).unwrap();
|
||||
let allowed_origins = AllOrSome::All;
|
||||
let send_wildcard = false;
|
||||
|
||||
let response = Response::new();
|
||||
let response = not_err!(response.allowed_origin(
|
||||
&origin,
|
||||
&allowed_origins,
|
||||
send_wildcard,
|
||||
));
|
||||
|
||||
let actual_origin = assert_matches!(
|
||||
response.allow_origin,
|
||||
Some(AllOrSome::Some(ref origin)),
|
||||
origin
|
||||
);
|
||||
assert_eq!(url, actual_origin);
|
||||
assert_eq!(response.vary_origin, true);
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec![url];
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
not_err!(validate_origin(&origin, &allowed_origins));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -1101,32 +1048,8 @@ mod tests {
|
|||
let (allowed_origins, failed_origins) =
|
||||
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
let send_wildcard = false;
|
||||
|
||||
let response = Response::new();
|
||||
let response = not_err!(response.allowed_origin(
|
||||
&origin,
|
||||
&allowed_origins,
|
||||
send_wildcard,
|
||||
));
|
||||
|
||||
let actual_origin = assert_matches!(
|
||||
response.allow_origin,
|
||||
Some(AllOrSome::Some(ref origin)),
|
||||
origin
|
||||
);
|
||||
|
||||
assert_eq!(url, actual_origin);
|
||||
assert_eq!(response.vary_origin, false);
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec![url];
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
not_err!(validate_origin(&origin, &allowed_origins));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -1137,41 +1060,8 @@ mod tests {
|
|||
let (allowed_origins, failed_origins) =
|
||||
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
let send_wildcard = false;
|
||||
|
||||
let response = Response::new();
|
||||
let _ = response
|
||||
.allowed_origin(&origin, &allowed_origins, send_wildcard)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
|
||||
fn response_credentials_does_not_allow_wildcard_with_all_origins() {
|
||||
let response = Response::new();
|
||||
let response = response.any();
|
||||
|
||||
let _ = response.credentials(true).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_credentials_allows_specific_origins() {
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
|
||||
let response = response.credentials(true).expect(
|
||||
"to allow specific origins",
|
||||
);
|
||||
assert_eq!(response.allow_credentials, true);
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["true"];
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Credentials")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
validate_origin(&origin, &allowed_origins).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -1220,159 +1110,88 @@ mod tests {
|
|||
|
||||
// Build response and check built response header
|
||||
let response = response.build(response::Response::new());
|
||||
assert!(response
|
||||
.headers()
|
||||
.get("Access-Control-Max-Age")
|
||||
.next().is_none())
|
||||
}
|
||||
|
||||
/// When all headers are allowed, tests that the requested headers are echoed back
|
||||
#[test]
|
||||
fn response_allowed_headers_echoes_back_requested_headers() {
|
||||
let allowed_headers = AllOrSome::All;
|
||||
let requested_headers = vec!["Bar", "Foo"];
|
||||
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response
|
||||
.allowed_headers(
|
||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||
&allowed_headers,
|
||||
)
|
||||
.expect("to not fail");
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Headers")
|
||||
.collect();
|
||||
|
||||
assert_eq!(1, actual_header.len());
|
||||
let mut actual_headers: Vec<String> = actual_header[0]
|
||||
.split(',')
|
||||
.map(|header| header.trim().to_string())
|
||||
.collect();
|
||||
actual_headers.sort();
|
||||
assert_eq!(requested_headers, actual_headers);
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get("Access-Control-Max-Age")
|
||||
.next()
|
||||
.is_none()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_allowed_methods_sets_headers_properly() {
|
||||
let allowed_methods = vec![
|
||||
Method::Get,
|
||||
Method::Head,
|
||||
Method::Post,
|
||||
].into_iter()
|
||||
fn allowed_methods_validated_correctly() {
|
||||
let allowed_methods = vec![Method::Get, Method::Head, Method::Post]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let method = "GET";
|
||||
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response
|
||||
.allowed_methods(
|
||||
&FromStr::from_str(method).expect("not to fail"),
|
||||
&allowed_methods,
|
||||
)
|
||||
.expect("not to fail");
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Methods")
|
||||
.collect();
|
||||
|
||||
assert_eq!(1, actual_header.len());
|
||||
let mut actual_headers: Vec<String> = actual_header[0]
|
||||
.split(',')
|
||||
.map(|header| header.trim().to_string())
|
||||
.collect();
|
||||
actual_headers.sort();
|
||||
let mut expected_headers: Vec<_> = allowed_methods.iter().map(|m| m.as_str()).collect();
|
||||
expected_headers.sort();
|
||||
assert_eq!(expected_headers, actual_headers);
|
||||
not_err!(validate_allowed_method(
|
||||
&FromStr::from_str(method).expect("not to fail"),
|
||||
&allowed_methods,
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "MethodNotAllowed")]
|
||||
fn response_allowed_method_errors_on_disallowed_method() {
|
||||
let allowed_methods = vec![
|
||||
Method::Get,
|
||||
Method::Head,
|
||||
Method::Post,
|
||||
].into_iter()
|
||||
fn allowed_methods_errors_on_disallowed_method() {
|
||||
let allowed_methods = vec![Method::Get, Method::Head, Method::Post]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let method = "DELETE";
|
||||
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let _ = response
|
||||
.allowed_methods(
|
||||
&FromStr::from_str(method).expect("not to fail"),
|
||||
&allowed_methods,
|
||||
)
|
||||
.unwrap();
|
||||
validate_allowed_method(
|
||||
&FromStr::from_str(method).expect("not to fail"),
|
||||
&allowed_methods,
|
||||
).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_allowed_headers_are_validated_correctly() {
|
||||
let allowed_headers = AllOrSome::All;
|
||||
let requested_headers = vec!["Bar", "Foo"];
|
||||
|
||||
not_err!(validate_allowed_headers(
|
||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||
&allowed_headers,
|
||||
));
|
||||
}
|
||||
|
||||
/// `Response::allowed_headers` should check that headers are allowed, and only
|
||||
/// echoes back the list that is actually requested for and not the whole list
|
||||
#[test]
|
||||
fn response_allowed_headers_validates_and_echoes_requested_headers() {
|
||||
fn allowed_headers_are_validated_correctly() {
|
||||
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
||||
let requested_headers = vec!["Bar", "Foo"];
|
||||
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response
|
||||
.allowed_headers(
|
||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||
&AllOrSome::Some(
|
||||
allowed_headers
|
||||
.iter()
|
||||
.map(|s| FromStr::from_str(*s).unwrap())
|
||||
.collect(),
|
||||
),
|
||||
)
|
||||
.expect("to not fail");
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Headers")
|
||||
.collect();
|
||||
|
||||
assert_eq!(1, actual_header.len());
|
||||
let mut actual_headers: Vec<String> = actual_header[0]
|
||||
.split(',')
|
||||
.map(|header| header.trim().to_string())
|
||||
.collect();
|
||||
actual_headers.sort();
|
||||
assert_eq!(requested_headers, actual_headers);
|
||||
not_err!(validate_allowed_headers(
|
||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||
&AllOrSome::Some(
|
||||
allowed_headers
|
||||
.iter()
|
||||
.map(|s| FromStr::from_str(*s).unwrap())
|
||||
.collect(),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "HeadersNotAllowed")]
|
||||
fn response_allowed_headers_errors_on_non_subset() {
|
||||
fn allowed_headers_errors_on_non_subset() {
|
||||
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
||||
let requested_headers = vec!["Bar", "Foo", "Unknown"];
|
||||
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let _ = response
|
||||
.allowed_headers(
|
||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||
&AllOrSome::Some(
|
||||
allowed_headers
|
||||
.iter()
|
||||
.map(|s| FromStr::from_str(*s).unwrap())
|
||||
.collect(),
|
||||
),
|
||||
)
|
||||
.unwrap();
|
||||
validate_allowed_headers(
|
||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||
&AllOrSome::Some(
|
||||
allowed_headers
|
||||
.iter()
|
||||
.map(|s| FromStr::from_str(*s).unwrap())
|
||||
.collect(),
|
||||
),
|
||||
).unwrap();
|
||||
|
||||
}
|
||||
|
||||
|
@ -1413,42 +1232,19 @@ mod tests {
|
|||
assert_eq!(expected_header, actual_header);
|
||||
|
||||
// Check that `Access-Control-Max-Age` is removed
|
||||
assert!(response.headers().get("Access-Control-Max-Age").next().is_none());
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get("Access-Control-Max-Age")
|
||||
.next()
|
||||
.is_none()
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
|
||||
// The following tests check that preflight checks are done properly
|
||||
// TODO: Preflight tests
|
||||
// TODO: Actual requests tests
|
||||
|
||||
// fn make_cors_options() -> Cors {
|
||||
// let (allowed_origins, failed_origins) =
|
||||
// AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
// assert!(failed_origins.is_empty());
|
||||
|
||||
// Cors {
|
||||
// allowed_origins: allowed_origins,
|
||||
// allowed_methods: [Method::Get].iter().cloned().collect(),
|
||||
// allowed_headers: AllOrSome::Some(
|
||||
// ["Authorization"]
|
||||
// .into_iter()
|
||||
// .map(|s| s.to_string().into())
|
||||
// .collect(),
|
||||
// ),
|
||||
// allow_credentials: true,
|
||||
// ..Default::default()
|
||||
// }
|
||||
// }
|
||||
|
||||
// /// Tests that non CORS preflight are let through without modification
|
||||
// #[test]
|
||||
// fn preflight_missing_origins_are_let_through() {
|
||||
// let options = make_cors_options();
|
||||
// let client = make_client();
|
||||
// let request = client.get("/");
|
||||
|
||||
// let response = options.preflight((), None, None, None).expect("not to fail");
|
||||
|
||||
// let headers: Vec<_> = response.headers().iter().collect();
|
||||
// assert_eq!(headers.len(), 0);
|
||||
// }
|
||||
// Origin all (wildcard + echoed with Vary). Origin Echo
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ use std::ops::Deref;
|
|||
use std::str::FromStr;
|
||||
|
||||
use rocket::local::Client;
|
||||
use rocket::http::{Header, Status};
|
||||
use rocket::http::Header;
|
||||
use rocket_cors::headers::*;
|
||||
|
||||
#[get("/request_headers")]
|
||||
|
|
Loading…
Reference in New Issue