Merge pull request #5 from lawliet89/preserve-existing
Preserve existing CORS requests
This commit is contained in:
commit
16b89ab31c
664
src/lib.rs
664
src/lib.rs
|
@ -261,6 +261,9 @@ impl<'a, 'r> FromRequest<'a, 'r> for Url {
|
||||||
|
|
||||||
|
|
||||||
/// The `Origin` request header used in CORS
|
/// The `Origin` request header used in CORS
|
||||||
|
///
|
||||||
|
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||||
|
/// to ensure that Origins are passed in correctly.
|
||||||
pub type Origin = Url;
|
pub type Origin = Url;
|
||||||
|
|
||||||
/// The `Access-Control-Request-Method` request header
|
/// The `Access-Control-Request-Method` request header
|
||||||
|
@ -537,10 +540,7 @@ impl Options {
|
||||||
// 5. If method is not a case-sensitive match for any of the values in list of methods
|
// 5. If method is not a case-sensitive match for any of the values in list of methods
|
||||||
// do not set any additional headers and terminate this set of steps.
|
// do not set any additional headers and terminate this set of steps.
|
||||||
|
|
||||||
let response = response.allowed_methods(
|
let response = response.allowed_methods(&method, &self.allowed_methods)?;
|
||||||
&method,
|
|
||||||
self.allowed_methods.clone(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
|
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
|
||||||
// values in list of headers do not set any additional headers and terminate this set of
|
// values in list of headers do not set any additional headers and terminate this set of
|
||||||
|
@ -654,6 +654,22 @@ impl Options {
|
||||||
/// A CORS Response which wraps another struct which implements `Responder`. You will typically
|
/// A CORS Response which wraps another struct which implements `Responder`. You will typically
|
||||||
/// use [`Options`] instead to verify and build the response instead of this directly.
|
/// use [`Options`] instead to verify and build the response instead of this directly.
|
||||||
/// See module level documentation for usage examples.
|
/// See module level documentation for usage examples.
|
||||||
|
///
|
||||||
|
/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set,
|
||||||
|
/// this responder will leave the response untouched.
|
||||||
|
/// This allows for chaining of several CORS responders.
|
||||||
|
///
|
||||||
|
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
|
||||||
|
/// existing headers defined:
|
||||||
|
///
|
||||||
|
/// - `Access-Control-Allow-Origin`
|
||||||
|
/// - `Access-Control-Expose-Headers`
|
||||||
|
/// - `Access-Control-Max-Age`
|
||||||
|
/// - `Access-Control-Allow-Credentials`
|
||||||
|
/// - `Access-Control-Allow-Methods`
|
||||||
|
/// - `Access-Control-Allow-Headers`
|
||||||
|
/// - `Vary`
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Response<R> {
|
pub struct Response<R> {
|
||||||
responder: R,
|
responder: R,
|
||||||
allow_origin: Option<AllOrSome<String>>,
|
allow_origin: Option<AllOrSome<String>>,
|
||||||
|
@ -688,8 +704,9 @@ impl<'r, R: Responder<'r>> Response<R> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consumes the `Response` and return an altered response with origin set to "*"
|
/// Consumes the `Response` and return an altered response with origin set to "*"
|
||||||
fn any(self) -> Self {
|
fn any(mut self) -> Self {
|
||||||
self.origin("*", false)
|
self.allow_origin = Some(AllOrSome::All);
|
||||||
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consumes the responder and based on the provided list of allowed origins,
|
/// Consumes the responder and based on the provided list of allowed origins,
|
||||||
|
@ -752,8 +769,8 @@ impl<'r, R: Responder<'r>> Response<R> {
|
||||||
|
|
||||||
/// Consumes the CORS, set allow_methods to
|
/// Consumes the CORS, set allow_methods to
|
||||||
/// passed methods and returns changed CORS
|
/// passed methods and returns changed CORS
|
||||||
fn methods(mut self, methods: HashSet<Method>) -> Self {
|
fn methods(mut self, methods: &HashSet<Method>) -> Self {
|
||||||
self.allow_methods = methods;
|
self.allow_methods = methods.clone();
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -762,7 +779,7 @@ impl<'r, R: Responder<'r>> Response<R> {
|
||||||
fn allowed_methods(
|
fn allowed_methods(
|
||||||
self,
|
self,
|
||||||
method: &AccessControlRequestMethod,
|
method: &AccessControlRequestMethod,
|
||||||
allowed_methods: HashSet<Method>,
|
allowed_methods: &HashSet<Method>,
|
||||||
) -> Result<Self, Error> {
|
) -> Result<Self, Error> {
|
||||||
let &AccessControlRequestMethod(ref request_method) = method;
|
let &AccessControlRequestMethod(ref request_method) = method;
|
||||||
if !allowed_methods.iter().any(|m| m == request_method) {
|
if !allowed_methods.iter().any(|m| m == request_method) {
|
||||||
|
@ -770,7 +787,7 @@ impl<'r, R: Responder<'r>> Response<R> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Subset to route? Or just the method requested for?
|
// TODO: Subset to route? Or just the method requested for?
|
||||||
Ok(self.methods(allowed_methods))
|
Ok(self.methods(&allowed_methods))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consumes the CORS, set allow_headers to
|
/// Consumes the CORS, set allow_headers to
|
||||||
|
@ -808,26 +825,23 @@ impl<'r, R: Responder<'r>> Response<R> {
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
|
/// Builds a `rocket::Response` from this struct containing only the CORS headers.
|
||||||
#[allow(unused_results)]
|
#[allow(unused_results)]
|
||||||
fn respond_to(self, request: &Request) -> response::Result<'r> {
|
fn build(&self) -> response::Response<'r> {
|
||||||
use std::borrow::Cow;
|
let mut builder = response::Response::build();
|
||||||
|
|
||||||
let mut builder = response::Response::build_from(self.responder.respond_to(request)?);
|
|
||||||
|
|
||||||
let origin = match self.allow_origin {
|
let origin = match self.allow_origin {
|
||||||
None => {
|
None => {
|
||||||
// This is not a CORS response
|
// This is not a CORS response
|
||||||
return Ok(builder.finalize());
|
return builder.finalize();
|
||||||
}
|
}
|
||||||
Some(origin) => origin,
|
Some(ref origin) => origin,
|
||||||
};
|
};
|
||||||
|
|
||||||
let origin: Cow<str> = match origin {
|
let origin = match *origin {
|
||||||
AllOrSome::All => Into::into("*"),
|
AllOrSome::All => "*".to_string(),
|
||||||
AllOrSome::Some(origin) => Into::into(origin),
|
AllOrSome::Some(ref origin) => origin.to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
builder.raw_header("Access-Control-Allow-Origin", origin);
|
builder.raw_header("Access-Control-Allow-Origin", origin);
|
||||||
|
@ -838,7 +852,7 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
|
||||||
|
|
||||||
if !self.expose_headers.is_empty() {
|
if !self.expose_headers.is_empty() {
|
||||||
let headers: Vec<String> = self.expose_headers
|
let headers: Vec<String> = self.expose_headers
|
||||||
.into_iter()
|
.iter()
|
||||||
.map(|s| s.deref().to_string())
|
.map(|s| s.deref().to_string())
|
||||||
.collect();
|
.collect();
|
||||||
let headers = headers.join(", ");
|
let headers = headers.join(", ");
|
||||||
|
@ -848,7 +862,7 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
|
||||||
|
|
||||||
if !self.allow_headers.is_empty() {
|
if !self.allow_headers.is_empty() {
|
||||||
let headers: Vec<String> = self.allow_headers
|
let headers: Vec<String> = self.allow_headers
|
||||||
.into_iter()
|
.iter()
|
||||||
.map(|s| s.deref().to_string())
|
.map(|s| s.deref().to_string())
|
||||||
.collect();
|
.collect();
|
||||||
let headers = headers.join(", ");
|
let headers = headers.join(", ");
|
||||||
|
@ -856,9 +870,8 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
|
||||||
builder.raw_header("Access-Control-Allow-Headers", headers);
|
builder.raw_header("Access-Control-Allow-Headers", headers);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if !self.allow_methods.is_empty() {
|
if !self.allow_methods.is_empty() {
|
||||||
let methods: Vec<_> = self.allow_methods.into_iter().map(|m| m.as_str()).collect();
|
let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect();
|
||||||
let methods = methods.join(", ");
|
let methods = methods.join(", ");
|
||||||
|
|
||||||
builder.raw_header("Access-Control-Allow-Methods", methods);
|
builder.raw_header("Access-Control-Allow-Methods", methods);
|
||||||
|
@ -873,7 +886,57 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
|
||||||
builder.raw_header("Vary", "Origin");
|
builder.raw_header("Vary", "Origin");
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(builder.finalize())
|
builder.finalize()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Merge a `wrapped` Response with a `cors` response
|
||||||
|
///
|
||||||
|
/// If the `wrapped` response has the `Access-Control-Allow-Origin` header already defined,
|
||||||
|
/// it will be left untouched. This allows for chaining of several CORS responders.
|
||||||
|
///
|
||||||
|
/// Otherwise, the merging will be done according to the rules of `rocket::Response::merge`.
|
||||||
|
fn merge(
|
||||||
|
mut wrapped: response::Response<'r>,
|
||||||
|
cors: response::Response<'r>,
|
||||||
|
) -> response::Response<'r> {
|
||||||
|
|
||||||
|
let existing_cors = {
|
||||||
|
wrapped.headers().get("Access-Control-Allow-Origin").next() == None
|
||||||
|
};
|
||||||
|
|
||||||
|
if existing_cors {
|
||||||
|
wrapped.merge(cors);
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapped
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Finalize the Response by merging the CORS header with the wrapped `Responder
|
||||||
|
///
|
||||||
|
/// If the original response has the `Access-Control-Allow-Origin` header already defined,
|
||||||
|
/// it will be left untouched.This allows for chaining of several CORS responders.
|
||||||
|
///
|
||||||
|
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
|
||||||
|
/// existing headers defined:
|
||||||
|
///
|
||||||
|
/// - `Access-Control-Allow-Origin`
|
||||||
|
/// - `Access-Control-Expose-Headers`
|
||||||
|
/// - `Access-Control-Max-Age`
|
||||||
|
/// - `Access-Control-Allow-Credentials`
|
||||||
|
/// - `Access-Control-Allow-Methods`
|
||||||
|
/// - `Access-Control-Allow-Headers`
|
||||||
|
/// - `Vary`
|
||||||
|
fn finalize(self, request: &Request) -> response::Result<'r> {
|
||||||
|
let cors_response = self.build();
|
||||||
|
let original_response = self.responder.respond_to(request)?;
|
||||||
|
|
||||||
|
Ok(Self::merge(original_response, cors_response))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
|
||||||
|
fn respond_to(self, request: &Request) -> response::Result<'r> {
|
||||||
|
self.finalize(request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -886,22 +949,46 @@ mod tests {
|
||||||
use rocket;
|
use rocket;
|
||||||
use rocket::local::Client;
|
use rocket::local::Client;
|
||||||
use rocket::http::Method;
|
use rocket::http::Method;
|
||||||
use rocket::http::{Header, Status};
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
/// Make a client with no routes for unit testing
|
||||||
|
fn make_client() -> Client {
|
||||||
|
let rocket = rocket::ignite();
|
||||||
|
Client::new(rocket).expect("valid rocket instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following tests check that CORS Request headers are parsed correctly
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn origin_header_conversion() {
|
fn origin_header_conversion() {
|
||||||
let url = "https://foo.bar.xyz";
|
let url = "https://foo.bar.xyz";
|
||||||
let _ = not_err!(Origin::from_str(url));
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
|
let expected = not_err!(Url::from_str(url));
|
||||||
|
assert_eq!(parsed, expected);
|
||||||
|
|
||||||
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
|
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
|
||||||
let _ = not_err!(Origin::from_str(url));
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
|
let expected = not_err!(Url::from_str(url));
|
||||||
|
assert_eq!(parsed, expected);
|
||||||
|
|
||||||
let url = "invalid_url";
|
let url = "invalid_url";
|
||||||
let _ = is_err!(Origin::from_str(url));
|
let _ = is_err!(Origin::from_str(url));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn origin_header_parsing() {
|
||||||
|
let client = make_client();
|
||||||
|
let mut request = client.get("/");
|
||||||
|
|
||||||
|
let origin = hyper::header::Origin::new("https", "www.example.com", None);
|
||||||
|
request.add_header(origin);
|
||||||
|
|
||||||
|
let outcome: request::Outcome<Origin, Error> = FromRequest::from_request(request.inner());
|
||||||
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
|
assert_eq!("https://www.example.com/", parsed_header.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn request_method_conversion() {
|
fn request_method_conversion() {
|
||||||
let method = "POST";
|
let method = "POST";
|
||||||
|
@ -916,6 +1003,20 @@ mod tests {
|
||||||
let _ = is_err!(AccessControlRequestMethod::from_str(method));
|
let _ = is_err!(AccessControlRequestMethod::from_str(method));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_method_parsing() {
|
||||||
|
let client = make_client();
|
||||||
|
let mut request = client.get("/");
|
||||||
|
let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get);
|
||||||
|
request.add_header(method);
|
||||||
|
let outcome: request::Outcome<AccessControlRequestMethod, Error> =
|
||||||
|
FromRequest::from_request(request.inner());
|
||||||
|
|
||||||
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
|
let AccessControlRequestMethod(parsed_method) = parsed_header;
|
||||||
|
assert_eq!("GET", parsed_method.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn request_headers_conversion() {
|
fn request_headers_conversion() {
|
||||||
let headers = ["foo", "bar", "baz"];
|
let headers = ["foo", "bar", "baz"];
|
||||||
|
@ -926,75 +1027,466 @@ mod tests {
|
||||||
assert_eq!(actual_headers, expected_headers);
|
assert_eq!(actual_headers, expected_headers);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/request_headers")]
|
|
||||||
#[allow(needless_pass_by_value)]
|
|
||||||
fn request_headers(
|
|
||||||
origin: Origin,
|
|
||||||
method: AccessControlRequestMethod,
|
|
||||||
headers: AccessControlRequestHeaders,
|
|
||||||
) -> String {
|
|
||||||
let AccessControlRequestMethod(method) = method;
|
|
||||||
let AccessControlRequestHeaders(headers) = headers;
|
|
||||||
let mut headers = headers
|
|
||||||
.iter()
|
|
||||||
.map(|s| s.deref().to_string())
|
|
||||||
.collect::<Vec<String>>();
|
|
||||||
headers.sort();
|
|
||||||
format!("{}\n{}\n{}", origin, method, headers.join(", "))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tests that all the headers are parsed correcly in a HTTP request
|
|
||||||
#[test]
|
#[test]
|
||||||
fn request_headers_round_trip_smoke_test() {
|
fn request_headers_parsing() {
|
||||||
let rocket = rocket::ignite().mount("/", routes![request_headers]);
|
let client = make_client();
|
||||||
let client = not_err!(Client::new(rocket));
|
let mut request = client.get("/");
|
||||||
|
let headers = hyper::header::AccessControlRequestHeaders(vec![
|
||||||
let origin_header = Header::from(not_err!(
|
|
||||||
hyper::header::Origin::from_str("https://foo.bar.xyz")
|
|
||||||
));
|
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
|
||||||
hyper::method::Method::Get,
|
|
||||||
));
|
|
||||||
let request_headers = hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("accept-language").unwrap(),
|
FromStr::from_str("accept-language").unwrap(),
|
||||||
FromStr::from_str("X-Ping").unwrap(),
|
FromStr::from_str("date").unwrap(),
|
||||||
]);
|
]);
|
||||||
let request_headers = Header::from(request_headers);
|
request.add_header(headers);
|
||||||
let req = client
|
let outcome: request::Outcome<AccessControlRequestHeaders, Error> =
|
||||||
.get("/request_headers")
|
FromRequest::from_request(request.inner());
|
||||||
.header(origin_header)
|
|
||||||
.header(method_header)
|
|
||||||
.header(request_headers);
|
|
||||||
let mut response = req.dispatch();
|
|
||||||
|
|
||||||
assert_eq!(Status::Ok, response.status());
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
let body_str = not_none!(response.body().and_then(|body| body.into_string()));
|
let AccessControlRequestHeaders(parsed_headers) = parsed_header;
|
||||||
let expected_body = r#"https://foo.bar.xyz/
|
let mut parsed_headers: Vec<String> =
|
||||||
GET
|
parsed_headers.iter().map(|s| s.to_string()).collect();
|
||||||
X-Ping, accept-language"#;
|
parsed_headers.sort();
|
||||||
assert_eq!(expected_body, body_str);
|
assert_eq!(
|
||||||
|
vec!["accept-language".to_string(), "date".to_string()],
|
||||||
|
parsed_headers
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/any")]
|
// The following tests check `Response`'s validation
|
||||||
fn any() -> Response<&'static str> {
|
|
||||||
Response::new("Hello, world!").any()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn response_any_origin_smoke_test() {
|
fn response_allows_all_origin_with_wildcard() {
|
||||||
let rocket = rocket::ignite().mount("/", routes![any]);
|
let url = "https://www.example.com";
|
||||||
let client = not_err!(Client::new(rocket));
|
let origin = Origin::from_str(url).unwrap();
|
||||||
|
let allowed_origins = AllOrSome::All;
|
||||||
|
let send_wildcard = true;
|
||||||
|
|
||||||
let req = client.get("/any");
|
let response = Response::new(());
|
||||||
let mut response = req.dispatch();
|
let response = not_err!(response.allowed_origin(
|
||||||
|
&origin,
|
||||||
|
&allowed_origins,
|
||||||
|
send_wildcard,
|
||||||
|
));
|
||||||
|
|
||||||
assert_eq!(Status::Ok, response.status());
|
assert_matches!(response.allow_origin, Some(AllOrSome::All));
|
||||||
let body_str = response.body().and_then(|body| body.into_string());
|
assert_eq!(response.vary_origin, false);
|
||||||
let values: Vec<_> = response
|
|
||||||
|
// Build response and check built response header
|
||||||
|
let expected_header = vec!["*"];
|
||||||
|
let response = response.build();
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
.headers()
|
.headers()
|
||||||
.get("Access-Control-Allow-Origin")
|
.get("Access-Control-Allow-Origin")
|
||||||
.collect();
|
.collect();
|
||||||
assert_eq!(values, vec!["*"]);
|
assert_eq!(expected_header, actual_header);
|
||||||
assert_eq!(body_str, Some("Hello, world!".to_string()));
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_allows_all_origin_with_echoed_domain() {
|
||||||
|
let url = "https://www.example.com";
|
||||||
|
let origin = Origin::from_str(url).unwrap();
|
||||||
|
let allowed_origins = AllOrSome::All;
|
||||||
|
let send_wildcard = false;
|
||||||
|
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = not_err!(response.allowed_origin(
|
||||||
|
&origin,
|
||||||
|
&allowed_origins,
|
||||||
|
send_wildcard,
|
||||||
|
));
|
||||||
|
|
||||||
|
let actual_origin = assert_matches!(
|
||||||
|
response.allow_origin,
|
||||||
|
Some(AllOrSome::Some(ref origin)),
|
||||||
|
origin
|
||||||
|
);
|
||||||
|
assert_eq!(url, actual_origin);
|
||||||
|
assert_eq!(response.vary_origin, true);
|
||||||
|
|
||||||
|
// Build response and check built response header
|
||||||
|
let expected_header = vec![url];
|
||||||
|
let response = response.build();
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Allow-Origin")
|
||||||
|
.collect();
|
||||||
|
assert_eq!(expected_header, actual_header);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_allows_origin() {
|
||||||
|
let url = "https://www.example.com";
|
||||||
|
let origin = Origin::from_str(url).unwrap();
|
||||||
|
let (allowed_origins, failed_origins) =
|
||||||
|
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
||||||
|
assert!(failed_origins.is_empty());
|
||||||
|
let send_wildcard = false;
|
||||||
|
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = not_err!(response.allowed_origin(
|
||||||
|
&origin,
|
||||||
|
&allowed_origins,
|
||||||
|
send_wildcard,
|
||||||
|
));
|
||||||
|
|
||||||
|
let actual_origin = assert_matches!(
|
||||||
|
response.allow_origin,
|
||||||
|
Some(AllOrSome::Some(ref origin)),
|
||||||
|
origin
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(url, actual_origin);
|
||||||
|
assert_eq!(response.vary_origin, false);
|
||||||
|
|
||||||
|
// Build response and check built response header
|
||||||
|
let expected_header = vec![url];
|
||||||
|
let response = response.build();
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Allow-Origin")
|
||||||
|
.collect();
|
||||||
|
assert_eq!(expected_header, actual_header);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "OriginNotAllowed")]
|
||||||
|
fn response_rejects_invalid_origin() {
|
||||||
|
let url = "https://www.acme.com";
|
||||||
|
let origin = Origin::from_str(url).unwrap();
|
||||||
|
let (allowed_origins, failed_origins) =
|
||||||
|
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
||||||
|
assert!(failed_origins.is_empty());
|
||||||
|
let send_wildcard = false;
|
||||||
|
|
||||||
|
let response = Response::new(());
|
||||||
|
let _ = response
|
||||||
|
.allowed_origin(&origin, &allowed_origins, send_wildcard)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
|
||||||
|
fn response_credentials_does_not_allow_wildcard_with_all_origins() {
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.any();
|
||||||
|
|
||||||
|
let _ = response.credentials(true).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_credentials_allows_specific_origins() {
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.example.com", false);
|
||||||
|
|
||||||
|
let response = response.credentials(true).expect(
|
||||||
|
"to allow specific origins",
|
||||||
|
);
|
||||||
|
assert_eq!(response.allow_credentials, true);
|
||||||
|
|
||||||
|
// Build response and check built response header
|
||||||
|
let expected_header = vec!["true"];
|
||||||
|
let response = response.build();
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Allow-Credentials")
|
||||||
|
.collect();
|
||||||
|
assert_eq!(expected_header, actual_header);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_sets_exposed_headers_correctly() {
|
||||||
|
let headers = vec!["Bar", "Baz", "Foo"];
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.example.com", false);
|
||||||
|
let response = response.exposed_headers(&headers);
|
||||||
|
|
||||||
|
// Build response and check built response header
|
||||||
|
let response = response.build();
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Expose-Headers")
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(1, actual_header.len());
|
||||||
|
let mut actual_headers: Vec<String> = actual_header[0]
|
||||||
|
.split(',')
|
||||||
|
.map(|header| header.trim().to_string())
|
||||||
|
.collect();
|
||||||
|
actual_headers.sort();
|
||||||
|
assert_eq!(headers, actual_headers);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_sets_max_age_correctly() {
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.example.com", false);
|
||||||
|
|
||||||
|
let response = response.max_age(Some(42));
|
||||||
|
|
||||||
|
// Build response and check built response header
|
||||||
|
let expected_header = vec!["42"];
|
||||||
|
let response = response.build();
|
||||||
|
let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect();
|
||||||
|
assert_eq!(expected_header, actual_header);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_does_not_set_max_age_when_none() {
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.example.com", false);
|
||||||
|
|
||||||
|
let response = response.max_age(None);
|
||||||
|
|
||||||
|
// Build response and check built response header
|
||||||
|
let response = response.build();
|
||||||
|
assert!(response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Max-Age")
|
||||||
|
.next().is_none())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// When all headers are allowed, tests that the requested headers are echoed back
|
||||||
|
#[test]
|
||||||
|
fn response_allowed_headers_echoes_back_requested_headers() {
|
||||||
|
let allowed_headers = AllOrSome::All;
|
||||||
|
let requested_headers = vec!["Bar", "Foo"];
|
||||||
|
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.example.com", false);
|
||||||
|
let response = response
|
||||||
|
.allowed_headers(
|
||||||
|
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||||
|
&allowed_headers,
|
||||||
|
)
|
||||||
|
.expect("to not fail");
|
||||||
|
|
||||||
|
// Build response and check built response header
|
||||||
|
let response = response.build();
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Allow-Headers")
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(1, actual_header.len());
|
||||||
|
let mut actual_headers: Vec<String> = actual_header[0]
|
||||||
|
.split(',')
|
||||||
|
.map(|header| header.trim().to_string())
|
||||||
|
.collect();
|
||||||
|
actual_headers.sort();
|
||||||
|
assert_eq!(requested_headers, actual_headers);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_allowed_methods_sets_headers_properly() {
|
||||||
|
let allowed_methods = vec![
|
||||||
|
Method::Get,
|
||||||
|
Method::Head,
|
||||||
|
Method::Post,
|
||||||
|
].into_iter()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let method = "GET";
|
||||||
|
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.example.com", false);
|
||||||
|
let response = response
|
||||||
|
.allowed_methods(
|
||||||
|
&FromStr::from_str(method).expect("not to fail"),
|
||||||
|
&allowed_methods,
|
||||||
|
)
|
||||||
|
.expect("not to fail");
|
||||||
|
|
||||||
|
// Build response and check built response header
|
||||||
|
let response = response.build();
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Allow-Methods")
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(1, actual_header.len());
|
||||||
|
let mut actual_headers: Vec<String> = actual_header[0]
|
||||||
|
.split(',')
|
||||||
|
.map(|header| header.trim().to_string())
|
||||||
|
.collect();
|
||||||
|
actual_headers.sort();
|
||||||
|
let mut expected_headers: Vec<_> = allowed_methods.iter().map(|m| m.as_str()).collect();
|
||||||
|
expected_headers.sort();
|
||||||
|
assert_eq!(expected_headers, actual_headers);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "MethodNotAllowed")]
|
||||||
|
fn response_allowed_method_errors_on_disallowed_method() {
|
||||||
|
let allowed_methods = vec![
|
||||||
|
Method::Get,
|
||||||
|
Method::Head,
|
||||||
|
Method::Post,
|
||||||
|
].into_iter()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let method = "DELETE";
|
||||||
|
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.example.com", false);
|
||||||
|
let _ = response
|
||||||
|
.allowed_methods(
|
||||||
|
&FromStr::from_str(method).expect("not to fail"),
|
||||||
|
&allowed_methods,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `Response::allowed_headers` should check that headers are allowed, and only
|
||||||
|
/// echoes back the list that is actually requested for and not the whole list
|
||||||
|
#[test]
|
||||||
|
fn response_allowed_headers_validates_and_echoes_requested_headers() {
|
||||||
|
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
||||||
|
let requested_headers = vec!["Bar", "Foo"];
|
||||||
|
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.example.com", false);
|
||||||
|
let response = response
|
||||||
|
.allowed_headers(
|
||||||
|
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||||
|
&AllOrSome::Some(
|
||||||
|
allowed_headers
|
||||||
|
.iter()
|
||||||
|
.map(|s| FromStr::from_str(*s).unwrap())
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.expect("to not fail");
|
||||||
|
|
||||||
|
// Build response and check built response header
|
||||||
|
let response = response.build();
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Allow-Headers")
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(1, actual_header.len());
|
||||||
|
let mut actual_headers: Vec<String> = actual_header[0]
|
||||||
|
.split(',')
|
||||||
|
.map(|header| header.trim().to_string())
|
||||||
|
.collect();
|
||||||
|
actual_headers.sort();
|
||||||
|
assert_eq!(requested_headers, actual_headers);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "HeadersNotAllowed")]
|
||||||
|
fn response_allowed_headers_errors_on_non_subset() {
|
||||||
|
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
||||||
|
let requested_headers = vec!["Bar", "Foo", "Unknown"];
|
||||||
|
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.example.com", false);
|
||||||
|
let _ = response
|
||||||
|
.allowed_headers(
|
||||||
|
&FromStr::from_str(&requested_headers.join(",")).unwrap(),
|
||||||
|
&AllOrSome::Some(
|
||||||
|
allowed_headers
|
||||||
|
.iter()
|
||||||
|
.map(|s| FromStr::from_str(*s).unwrap())
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_does_not_build_if_origin_is_not_set() {
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.build();
|
||||||
|
|
||||||
|
let headers: Vec<_> = response.headers().iter().collect();
|
||||||
|
assert_eq!(headers.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Correct operation of Response::build is tested in the tests above for each of the
|
||||||
|
// individual headers
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_merges_correctly() {
|
||||||
|
use std::io::Cursor;
|
||||||
|
use rocket::http::Status;
|
||||||
|
|
||||||
|
let wrapped = response::Response::build()
|
||||||
|
.status(Status::ImATeapot)
|
||||||
|
.raw_header("X-Teapot-Make", "Rocket")
|
||||||
|
.sized_body(Cursor::new("Brewing the best coffee!"))
|
||||||
|
.finalize();
|
||||||
|
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.acme.com", false);
|
||||||
|
|
||||||
|
let mut response = Response::<String>::merge(wrapped, response.build());
|
||||||
|
assert_eq!(response.status(), Status::ImATeapot);
|
||||||
|
assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string()));
|
||||||
|
|
||||||
|
// Check CORS header
|
||||||
|
let expected_header = vec!["https://www.acme.com"];
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Allow-Origin")
|
||||||
|
.collect();
|
||||||
|
assert_eq!(expected_header, actual_header);
|
||||||
|
|
||||||
|
// Check other header
|
||||||
|
let expected_header = vec!["Rocket"];
|
||||||
|
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
|
||||||
|
assert_eq!(expected_header, actual_header);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_does_not_merge_existing_cors() {
|
||||||
|
let wrapped = response::Response::build()
|
||||||
|
.raw_header("Access-Control-Allow-Origin", "https://www.example.com")
|
||||||
|
.finalize();
|
||||||
|
|
||||||
|
let response = Response::new(());
|
||||||
|
let response = response.origin("https://www.acme.com", false);
|
||||||
|
|
||||||
|
let response = Response::<()>::merge(wrapped, response.build());
|
||||||
|
let expected_header = vec!["https://www.example.com"];
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Allow-Origin")
|
||||||
|
.collect();
|
||||||
|
assert_eq!(expected_header, actual_header);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_finalize_smoke_test() {
|
||||||
|
use std::io::Cursor;
|
||||||
|
use rocket::http::Status;
|
||||||
|
|
||||||
|
let wrapped = response::Response::build()
|
||||||
|
.status(Status::ImATeapot)
|
||||||
|
.raw_header("X-Teapot-Make", "Rocket")
|
||||||
|
.sized_body(Cursor::new("Brewing the best coffee!"))
|
||||||
|
.finalize();
|
||||||
|
|
||||||
|
let response = Response::new(wrapped);
|
||||||
|
let response = response.origin("https://www.acme.com", false);
|
||||||
|
|
||||||
|
let client = make_client();
|
||||||
|
let request = client.get("/");
|
||||||
|
let mut response = response.finalize(request.inner()).expect("not to fail");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::ImATeapot);
|
||||||
|
assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string()));
|
||||||
|
|
||||||
|
// Check CORS header
|
||||||
|
let expected_header = vec!["https://www.acme.com"];
|
||||||
|
let actual_header: Vec<_> = response
|
||||||
|
.headers()
|
||||||
|
.get("Access-Control-Allow-Origin")
|
||||||
|
.collect();
|
||||||
|
assert_eq!(expected_header, actual_header);
|
||||||
|
|
||||||
|
// Check other header
|
||||||
|
let expected_header = vec!["Rocket"];
|
||||||
|
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
|
||||||
|
assert_eq!(expected_header, actual_header);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,13 +12,6 @@ macro_rules! is_err {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! not_none {
|
|
||||||
($e:expr) => (match $e {
|
|
||||||
Some(e) => e,
|
|
||||||
None => panic!("{} failed with None", stringify!($e)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! assert_matches {
|
macro_rules! assert_matches {
|
||||||
($e: expr, $p: pat) => (assert_matches!($e, $p, ()));
|
($e: expr, $p: pat) => (assert_matches!($e, $p, ()));
|
||||||
($e: expr, $p: pat, $f: expr) => (match $e {
|
($e: expr, $p: pat, $f: expr) => (match $e {
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
//! This crate tests that all the request headers are parsed correctly in the round trip
|
||||||
|
#![feature(plugin, custom_derive)]
|
||||||
|
#![plugin(rocket_codegen)]
|
||||||
|
extern crate hyper;
|
||||||
|
extern crate rocket;
|
||||||
|
extern crate rocket_cors;
|
||||||
|
|
||||||
|
use std::ops::Deref;
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use rocket::local::Client;
|
||||||
|
use rocket::http::{Header, Status};
|
||||||
|
use rocket_cors::*;
|
||||||
|
|
||||||
|
#[get("/request_headers")]
|
||||||
|
fn request_headers(
|
||||||
|
origin: Origin,
|
||||||
|
method: AccessControlRequestMethod,
|
||||||
|
headers: AccessControlRequestHeaders,
|
||||||
|
) -> String {
|
||||||
|
let AccessControlRequestMethod(method) = method;
|
||||||
|
let AccessControlRequestHeaders(headers) = headers;
|
||||||
|
let mut headers = headers
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.deref().to_string())
|
||||||
|
.collect::<Vec<String>>();
|
||||||
|
headers.sort();
|
||||||
|
format!("{}\n{}\n{}", origin, method, headers.join(", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tests that all the request headers are parsed correcly in a HTTP request
|
||||||
|
#[test]
|
||||||
|
fn request_headers_round_trip_smoke_test() {
|
||||||
|
let rocket = rocket::ignite().mount("/", routes![request_headers]);
|
||||||
|
let client = Client::new(rocket).expect("A valid Rocket client");
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://foo.bar.xyz").unwrap(),
|
||||||
|
);
|
||||||
|
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
||||||
|
hyper::method::Method::Get,
|
||||||
|
));
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(vec![
|
||||||
|
FromStr::from_str("accept-language").unwrap(),
|
||||||
|
FromStr::from_str("X-Ping").unwrap(),
|
||||||
|
]);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
let req = client
|
||||||
|
.get("/request_headers")
|
||||||
|
.header(origin_header)
|
||||||
|
.header(method_header)
|
||||||
|
.header(request_headers);
|
||||||
|
let mut response = req.dispatch();
|
||||||
|
|
||||||
|
assert_eq!(Status::Ok, response.status());
|
||||||
|
let body_str = response.body().and_then(|body| body.into_string()).expect(
|
||||||
|
"Non-empty body",
|
||||||
|
);
|
||||||
|
let expected_body = r#"https://foo.bar.xyz/
|
||||||
|
GET
|
||||||
|
X-Ping, accept-language"#;
|
||||||
|
assert_eq!(expected_body, body_str);
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
//! This crate tests using rocket_cors using the "classic"" per-route handling
|
//! This crate tests using rocket_cors using the "classic" per-route handling
|
||||||
|
|
||||||
#![feature(plugin, custom_derive)]
|
#![feature(plugin, custom_derive)]
|
||||||
#![plugin(rocket_codegen)]
|
#![plugin(rocket_codegen)]
|
||||||
|
|
Loading…
Reference in New Issue