Response unit tests
This commit is contained in:
parent
ca096ceb28
commit
f1391281cd
468
src/lib.rs
468
src/lib.rs
|
@ -540,10 +540,7 @@ impl Options {
|
|||
// 5. If method is not a case-sensitive match for any of the values in list of methods
|
||||
// do not set any additional headers and terminate this set of steps.
|
||||
|
||||
let response = response.allowed_methods(
|
||||
&method,
|
||||
self.allowed_methods.clone(),
|
||||
)?;
|
||||
let response = response.allowed_methods(&method, &self.allowed_methods)?;
|
||||
|
||||
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
|
||||
// values in list of headers do not set any additional headers and terminate this set of
|
||||
|
@ -672,6 +669,7 @@ impl Options {
|
|||
/// - `Access-Control-Allow-Methods`
|
||||
/// - `Access-Control-Allow-Headers`
|
||||
/// - `Vary`
|
||||
#[derive(Debug)]
|
||||
pub struct Response<R> {
|
||||
responder: R,
|
||||
allow_origin: Option<AllOrSome<String>>,
|
||||
|
@ -706,8 +704,9 @@ impl<'r, R: Responder<'r>> Response<R> {
|
|||
}
|
||||
|
||||
/// Consumes the `Response` and return an altered response with origin set to "*"
|
||||
fn any(self) -> Self {
|
||||
self.origin("*", false)
|
||||
fn any(mut self) -> Self {
|
||||
self.allow_origin = Some(AllOrSome::All);
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes the responder and based on the provided list of allowed origins,
|
||||
|
@ -770,8 +769,8 @@ impl<'r, R: Responder<'r>> Response<R> {
|
|||
|
||||
/// Consumes the CORS, set allow_methods to
|
||||
/// passed methods and returns changed CORS
|
||||
fn methods(mut self, methods: HashSet<Method>) -> Self {
|
||||
self.allow_methods = methods;
|
||||
fn methods(mut self, methods: &HashSet<Method>) -> Self {
|
||||
self.allow_methods = methods.clone();
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -780,7 +779,7 @@ impl<'r, R: Responder<'r>> Response<R> {
|
|||
fn allowed_methods(
|
||||
self,
|
||||
method: &AccessControlRequestMethod,
|
||||
allowed_methods: HashSet<Method>,
|
||||
allowed_methods: &HashSet<Method>,
|
||||
) -> Result<Self, Error> {
|
||||
let &AccessControlRequestMethod(ref request_method) = method;
|
||||
if !allowed_methods.iter().any(|m| m == request_method) {
|
||||
|
@ -788,7 +787,7 @@ impl<'r, R: Responder<'r>> Response<R> {
|
|||
}
|
||||
|
||||
// TODO: Subset to route? Or just the method requested for?
|
||||
Ok(self.methods(allowed_methods))
|
||||
Ok(self.methods(&allowed_methods))
|
||||
}
|
||||
|
||||
/// Consumes the CORS, set allow_headers to
|
||||
|
@ -871,7 +870,6 @@ impl<'r, R: Responder<'r>> Response<R> {
|
|||
builder.raw_header("Access-Control-Allow-Headers", headers);
|
||||
}
|
||||
|
||||
|
||||
if !self.allow_methods.is_empty() {
|
||||
let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect();
|
||||
let methods = methods.join(", ");
|
||||
|
@ -951,7 +949,6 @@ mod tests {
|
|||
use rocket;
|
||||
use rocket::local::Client;
|
||||
use rocket::http::Method;
|
||||
use rocket::http::Status;
|
||||
|
||||
use super::*;
|
||||
|
||||
|
@ -961,6 +958,8 @@ mod tests {
|
|||
Client::new(rocket).expect("valid rocket instance")
|
||||
}
|
||||
|
||||
// The following tests check that CORS Request headers are parsed correctly
|
||||
|
||||
#[test]
|
||||
fn origin_header_conversion() {
|
||||
let url = "https://foo.bar.xyz";
|
||||
|
@ -1051,26 +1050,443 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[get("/any")]
|
||||
fn any() -> Response<&'static str> {
|
||||
Response::new("Hello, world!").any()
|
||||
}
|
||||
// The following tests check `Response`'s validation
|
||||
|
||||
#[test]
|
||||
fn response_any_origin_smoke_test() {
|
||||
let rocket = rocket::ignite().mount("/", routes![any]);
|
||||
let client = not_err!(Client::new(rocket));
|
||||
fn response_allows_all_origin_with_wildcard() {
|
||||
let url = "https://www.example.com";
|
||||
let origin = Origin::from_str(url).unwrap();
|
||||
let allowed_origins = AllOrSome::All;
|
||||
let send_wildcard = true;
|
||||
|
||||
let req = client.get("/any");
|
||||
let mut response = req.dispatch();
|
||||
let response = Response::new(());
|
||||
let response = not_err!(response.allowed_origin(
|
||||
&origin,
|
||||
&allowed_origins,
|
||||
send_wildcard,
|
||||
));
|
||||
|
||||
assert_eq!(Status::Ok, response.status());
|
||||
let body_str = response.body().and_then(|body| body.into_string());
|
||||
let values: Vec<_> = response
|
||||
assert_matches!(response.allow_origin, Some(AllOrSome::All));
|
||||
assert_eq!(response.vary_origin, false);
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["*"];
|
||||
let response = response.build();
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(values, vec!["*"]);
|
||||
assert_eq!(body_str, Some("Hello, world!".to_string()));
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_allows_all_origin_with_echoed_domain() {
|
||||
let url = "https://www.example.com";
|
||||
let origin = Origin::from_str(url).unwrap();
|
||||
let allowed_origins = AllOrSome::All;
|
||||
let send_wildcard = false;
|
||||
|
||||
let response = Response::new(());
|
||||
let response = not_err!(response.allowed_origin(
|
||||
&origin,
|
||||
&allowed_origins,
|
||||
send_wildcard,
|
||||
));
|
||||
|
||||
let actual_origin = assert_matches!(
|
||||
response.allow_origin,
|
||||
Some(AllOrSome::Some(ref origin)),
|
||||
origin
|
||||
);
|
||||
assert_eq!(url, actual_origin);
|
||||
assert_eq!(response.vary_origin, true);
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec![url];
|
||||
let response = response.build();
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_allows_origin() {
|
||||
let url = "https://www.example.com";
|
||||
let origin = Origin::from_str(url).unwrap();
|
||||
let (allowed_origins, failed_origins) =
|
||||
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
let send_wildcard = false;
|
||||
|
||||
let response = Response::new(());
|
||||
let response = not_err!(response.allowed_origin(
|
||||
&origin,
|
||||
&allowed_origins,
|
||||
send_wildcard,
|
||||
));
|
||||
|
||||
let actual_origin = assert_matches!(
|
||||
response.allow_origin,
|
||||
Some(AllOrSome::Some(ref origin)),
|
||||
origin
|
||||
);
|
||||
|
||||
assert_eq!(url, actual_origin);
|
||||
assert_eq!(response.vary_origin, false);
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec![url];
|
||||
let response = response.build();
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "OriginNotAllowed")]
|
||||
fn response_rejects_invalid_origin() {
|
||||
let url = "https://www.acme.com";
|
||||
let origin = Origin::from_str(url).unwrap();
|
||||
let (allowed_origins, failed_origins) =
|
||||
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
let send_wildcard = false;
|
||||
|
||||
let response = Response::new(());
|
||||
let _ = response
|
||||
.allowed_origin(&origin, &allowed_origins, send_wildcard)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
|
||||
fn response_credentials_does_not_allow_wildcard_with_all_origins() {
|
||||
let response = Response::new(());
|
||||
let response = response.any();
|
||||
|
||||
let _ = response.credentials(true).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_credentials_allows_specific_origins() {
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
|
||||
let response = response.credentials(true).expect(
|
||||
"to allow specific origins",
|
||||
);
|
||||
assert_eq!(response.allow_credentials, true);
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["true"];
|
||||
let response = response.build();
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Credentials")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_sets_exposed_headers_correctly() {
|
||||
let headers = vec!["Bar", "Baz", "Foo"];
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response.exposed_headers(&headers);
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build();
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Expose-Headers")
|
||||
.collect();
|
||||
|
||||
assert_eq!(1, actual_header.len());
|
||||
let mut actual_headers: Vec<String> = actual_header[0]
|
||||
.split(',')
|
||||
.map(|header| header.trim().to_string())
|
||||
.collect();
|
||||
actual_headers.sort();
|
||||
assert_eq!(headers, actual_headers);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_sets_max_age_correctly() {
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
|
||||
let response = response.max_age(Some(42));
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["42"];
|
||||
let response = response.build();
|
||||
let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_does_not_set_max_age_when_none() {
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
|
||||
let response = response.max_age(None);
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build();
|
||||
assert!(response
|
||||
.headers()
|
||||
.get("Access-Control-Max-Age")
|
||||
.next().is_none())
|
||||
}
|
||||
|
||||
/// When all headers are allowed, tests that the requested headers are echoed back
|
||||
#[test]
|
||||
fn response_allowed_headers_echoes_back_requested_headers() {
|
||||
let allowed_headers = AllOrSome::All;
|
||||
let requested_headers = vec!["Bar", "Foo"];
|
||||
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response
|
||||
.allowed_headers(
|
||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||
&allowed_headers,
|
||||
)
|
||||
.expect("to not fail");
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build();
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Headers")
|
||||
.collect();
|
||||
|
||||
assert_eq!(1, actual_header.len());
|
||||
let mut actual_headers: Vec<String> = actual_header[0]
|
||||
.split(',')
|
||||
.map(|header| header.trim().to_string())
|
||||
.collect();
|
||||
actual_headers.sort();
|
||||
assert_eq!(requested_headers, actual_headers);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_allowed_methods_sets_headers_properly() {
|
||||
let allowed_methods = vec![
|
||||
Method::Get,
|
||||
Method::Head,
|
||||
Method::Post,
|
||||
].into_iter()
|
||||
.collect();
|
||||
|
||||
let method = "GET";
|
||||
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response
|
||||
.allowed_methods(
|
||||
&FromStr::from_str(method).expect("not to fail"),
|
||||
&allowed_methods,
|
||||
)
|
||||
.expect("not to fail");
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build();
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Methods")
|
||||
.collect();
|
||||
|
||||
assert_eq!(1, actual_header.len());
|
||||
let mut actual_headers: Vec<String> = actual_header[0]
|
||||
.split(',')
|
||||
.map(|header| header.trim().to_string())
|
||||
.collect();
|
||||
actual_headers.sort();
|
||||
let mut expected_headers: Vec<_> = allowed_methods.iter().map(|m| m.as_str()).collect();
|
||||
expected_headers.sort();
|
||||
assert_eq!(expected_headers, actual_headers);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "MethodNotAllowed")]
|
||||
fn response_allowed_method_errors_on_disallowed_method() {
|
||||
let allowed_methods = vec![
|
||||
Method::Get,
|
||||
Method::Head,
|
||||
Method::Post,
|
||||
].into_iter()
|
||||
.collect();
|
||||
|
||||
let method = "DELETE";
|
||||
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let _ = response
|
||||
.allowed_methods(
|
||||
&FromStr::from_str(method).expect("not to fail"),
|
||||
&allowed_methods,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// `Response::allowed_headers` should check that headers are allowed, and only
|
||||
/// echoes back the list that is actually requested for and not the whole list
|
||||
#[test]
|
||||
fn response_allowed_headers_validates_and_echoes_requested_headers() {
|
||||
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
||||
let requested_headers = vec!["Bar", "Foo"];
|
||||
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response
|
||||
.allowed_headers(
|
||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||
&AllOrSome::Some(
|
||||
allowed_headers
|
||||
.iter()
|
||||
.map(|s| FromStr::from_str(*s).unwrap())
|
||||
.collect(),
|
||||
),
|
||||
)
|
||||
.expect("to not fail");
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build();
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Headers")
|
||||
.collect();
|
||||
|
||||
assert_eq!(1, actual_header.len());
|
||||
let mut actual_headers: Vec<String> = actual_header[0]
|
||||
.split(',')
|
||||
.map(|header| header.trim().to_string())
|
||||
.collect();
|
||||
actual_headers.sort();
|
||||
assert_eq!(requested_headers, actual_headers);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "HeadersNotAllowed")]
|
||||
fn response_allowed_headers_errors_on_non_subset() {
|
||||
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
||||
let requested_headers = vec!["Bar", "Foo", "Unknown"];
|
||||
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let _ = response
|
||||
.allowed_headers(
|
||||
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||
&AllOrSome::Some(
|
||||
allowed_headers
|
||||
.iter()
|
||||
.map(|s| FromStr::from_str(*s).unwrap())
|
||||
.collect(),
|
||||
),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_does_not_build_if_origin_is_not_set() {
|
||||
let response = Response::new(());
|
||||
let response = response.build();
|
||||
|
||||
let headers: Vec<_> = response.headers().iter().collect();
|
||||
assert_eq!(headers.len(), 0);
|
||||
}
|
||||
|
||||
// Note: Correct operation of Response::build is tested in the tests above for each of the
|
||||
// individual headers
|
||||
|
||||
#[test]
|
||||
fn response_merges_correctly() {
|
||||
use std::io::Cursor;
|
||||
use rocket::http::Status;
|
||||
|
||||
let wrapped = response::Response::build()
|
||||
.status(Status::ImATeapot)
|
||||
.raw_header("X-Teapot-Make", "Rocket")
|
||||
.sized_body(Cursor::new("Brewing the best coffee!"))
|
||||
.finalize();
|
||||
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.acme.com", false);
|
||||
|
||||
let mut response = Response::<String>::merge(wrapped, response.build());
|
||||
assert_eq!(response.status(), Status::ImATeapot);
|
||||
assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string()));
|
||||
|
||||
// Check CORS header
|
||||
let expected_header = vec!["https://www.acme.com"];
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
|
||||
// Check other header
|
||||
let expected_header = vec!["Rocket"];
|
||||
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_does_not_merge_existing_cors() {
|
||||
let wrapped = response::Response::build()
|
||||
.raw_header("Access-Control-Allow-Origin", "https://www.example.com")
|
||||
.finalize();
|
||||
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.acme.com", false);
|
||||
|
||||
let response = Response::<()>::merge(wrapped, response.build());
|
||||
let expected_header = vec!["https://www.example.com"];
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_finalize_smoke_test() {
|
||||
use std::io::Cursor;
|
||||
use rocket::http::Status;
|
||||
|
||||
let wrapped = response::Response::build()
|
||||
.status(Status::ImATeapot)
|
||||
.raw_header("X-Teapot-Make", "Rocket")
|
||||
.sized_body(Cursor::new("Brewing the best coffee!"))
|
||||
.finalize();
|
||||
|
||||
let response = Response::new(wrapped);
|
||||
let response = response.origin("https://www.acme.com", false);
|
||||
|
||||
let client = make_client();
|
||||
let request = client.get("/");
|
||||
let mut response = response.finalize(request.inner()).expect("not to fail");
|
||||
|
||||
assert_eq!(response.status(), Status::ImATeapot);
|
||||
assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string()));
|
||||
|
||||
// Check CORS header
|
||||
let expected_header = vec!["https://www.acme.com"];
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
|
||||
// Check other header
|
||||
let expected_header = vec!["Rocket"];
|
||||
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ fn request_headers(
|
|||
format!("{}\n{}\n{}", origin, method, headers.join(", "))
|
||||
}
|
||||
|
||||
/// Tests that all the headers are parsed correcly in a HTTP request
|
||||
/// Tests that all the request headers are parsed correcly in a HTTP request
|
||||
#[test]
|
||||
fn request_headers_round_trip_smoke_test() {
|
||||
let rocket = rocket::ignite().mount("/", routes![request_headers]);
|
||||
|
|
Loading…
Reference in New Issue