Delay CORS checks and response until `Responder::respond_to` is invoked (#6)

* Delay checking of CORS to just before responding

* Lifetime issues

* Use State::inner()

* Fix lifetime issues

* Bump Rocket

* Document 'static limitation

And link to https://github.com/SergioBenitez/Rocket/pull/345

* Remove extraneous comments
This commit is contained in:
Yong Wen Chua 2017-07-15 01:38:13 +08:00 committed by GitHub
parent 16b89ab31c
commit 7dbc22b523
4 changed files with 286 additions and 235 deletions

View File

@ -16,7 +16,7 @@ travis-ci = { repository = "lawliet89/rocket_cors" }
[dependencies] [dependencies]
log = "0.3" log = "0.3"
rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "aa51fe0" } rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "51a465f2cc88d537079133bcdfec37d029070dcd" }
serde = "1.0" serde = "1.0"
serde_derive = "1.0" serde_derive = "1.0"
unicase="1.4" unicase="1.4"
@ -29,5 +29,5 @@ version_check = "0.1"
[dev-dependencies] [dev-dependencies]
hyper = "0.10" hyper = "0.10"
rocket_codegen = { git = "https://github.com/SergioBenitez/Rocket", rev = "aa51fe0" } rocket_codegen = { git = "https://github.com/SergioBenitez/Rocket", rev = "51a465f2cc88d537079133bcdfec37d029070dcd" }
serde_json = "1.0" serde_json = "1.0"

View File

@ -27,7 +27,7 @@ In particular, `rocket_cors` is currently targetted for `nightly-2017-07-13`.
Rocket > 0.3 is needed. At this moment, `0.3` is not released, and this crate will not be published Rocket > 0.3 is needed. At this moment, `0.3` is not released, and this crate will not be published
to Crates.io until Rocket 0.3 is released to Crates.io. to Crates.io until Rocket 0.3 is released to Crates.io.
We currently tie this crate to revision [aa51fe0](https://github.com/SergioBenitez/Rocket/tree/aa51fe0) of Rocket. We currently tie this crate to revision [51a465f2cc88d537079133bcdfec37d029070dcd](https://github.com/SergioBenitez/Rocket/tree/51a465f2cc88d537079133bcdfec37d029070dcd) of Rocket.
## Installation ## Installation

View File

@ -29,7 +29,7 @@
//! to Crates.io until Rocket 0.3 is released to Crates.io. //! to Crates.io until Rocket 0.3 is released to Crates.io.
//! //!
//! We currently tie this crate to revision //! We currently tie this crate to revision
//! [aa51fe0](https://github.com/SergioBenitez/Rocket/tree/aa51fe0) of Rocket. //! [51a465f2cc88d537079133bcdfec37d029070dcd](https://github.com/SergioBenitez/Rocket/tree/51a465f2cc88d537079133bcdfec37d029070dcd) of Rocket.
//! //!
//! ## Installation //! ## Installation
//! //!
@ -118,20 +118,27 @@ extern crate hyper;
use std::collections::{HashSet, HashMap}; use std::collections::{HashSet, HashMap};
use std::error; use std::error;
use std::fmt; use std::fmt;
use std::marker::PhantomData;
use std::ops::Deref; use std::ops::Deref;
use std::str::FromStr; use std::str::FromStr;
use rocket::request::{self, Request, FromRequest}; use rocket::{Outcome, State};
use rocket::response::{self, Responder};
use rocket::http::{Method, Status}; use rocket::http::{Method, Status};
use rocket::Outcome; use rocket::request::{self, Request, FromRequest};
use rocket::response;
use unicase::UniCase; use unicase::UniCase;
#[cfg(test)] #[cfg(test)]
#[macro_use] #[macro_use]
mod test_macros; mod test_macros;
/// CORS related error /// Errors during operations
///
/// This enum implements `rocket::response::Responder` which will return an appropriate status code
/// while printing out the error in the console.
/// Because these errors are usually the result of an error while trying to respond to a CORS
/// request, CORS headers cannot be added to the response and your applications requesting CORS
/// will not be able to see the status code.
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
/// The HTTP request header `Origin` is required but was not provided /// The HTTP request header `Origin` is required but was not provided
@ -201,7 +208,7 @@ impl fmt::Display for Error {
} }
} }
impl<'r> Responder<'r> for Error { impl<'r> response::Responder<'r> for Error {
fn respond_to(self, _: &Request) -> Result<response::Response<'r>, Status> { fn respond_to(self, _: &Request) -> Result<response::Response<'r>, Status> {
error_!("CORS Error: {:?}", self); error_!("CORS Error: {:?}", self);
Err(match self { Err(match self {
@ -259,7 +266,6 @@ impl<'a, 'r> FromRequest<'a, 'r> for Url {
} }
} }
/// The `Origin` request header used in CORS /// The `Origin` request header used in CORS
/// ///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
@ -374,7 +380,11 @@ impl AllOrSome<HashSet<Url>> {
} }
} }
/// Configuration options to for building CORS preflight or actual responses. /// Responder and Fairing for CORS
///
/// This struct can be used as Fairing for Rocket, or as an ad-hoc responder for any CORS requests.
/// You create a new copy of this struct by defining the configurations in the fields below.
/// This struct can also be deserialized by serde.
/// ///
/// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this /// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this
/// struct. The default for each field is described in the docuementation for the field. /// struct. The default for each field is described in the docuementation for the field.
@ -473,7 +483,30 @@ impl Default for Options {
} }
} }
/// Ad-hoc per route CORS response to requests
///
/// Note: If you use this, the lifetime parameter `'r` of your `rocket:::response::Responder<'r>`
/// CANNOT be `'static`. This is because the code generated by Rocket will implicitly try to
/// to restrain the `Request` object passed to the route to `&'static Request`, and it is not
/// possible to have such a reference.
/// See [this PR on Rocket](https://github.com/SergioBenitez/Rocket/pull/345).
pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
options: State<'a, Options>,
responder: R,
) -> Responder<'a, 'r, R> {
options.inner().respond(responder)
}
impl Options { impl Options {
/// Wrap any `Rocket::Response` and respond with CORS headers.
/// This is only used for ad-hoc route CORS response
fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
&'a self,
responder: R,
) -> Responder<'a, 'r, R> {
Responder::new(responder, self)
}
fn default_allowed_methods() -> HashSet<Method> { fn default_allowed_methods() -> HashSet<Method> {
vec![ vec![
Method::Get, Method::Get,
@ -486,40 +519,138 @@ impl Options {
].into_iter() ].into_iter()
.collect() .collect()
} }
}
/// A CORS Responder which will inspect the incoming requests and respond accoridingly.
///
/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set,
/// this responder will leave the response untouched.
/// This allows for chaining of several CORS responders.
///
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
/// existing headers defined:
///
/// - `Access-Control-Allow-Origin`
/// - `Access-Control-Expose-Headers`
/// - `Access-Control-Max-Age`
/// - `Access-Control-Allow-Credentials`
/// - `Access-Control-Allow-Methods`
/// - `Access-Control-Allow-Headers`
/// - `Vary`
#[derive(Debug)]
pub struct Responder<'a, 'r: 'a, R> {
responder: R,
options: &'a Options,
marker: PhantomData<response::Responder<'r>>,
}
impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> {
fn new(responder: R, options: &'a Options) -> Self {
Self {
responder,
options,
marker: PhantomData,
}
}
/// Respond to a request
fn respond(self, request: &Request) -> response::Result<'r> {
match self.build_cors_response(request) {
Ok(response) => response,
Err(e) => response::Responder::respond_to(e, request),
}
}
/// Build a CORS response and merge with an existing `rocket::Response` for the request
fn build_cors_response(self, request: &Request) -> Result<response::Result<'r>, Error> {
let original_response = match self.responder.respond_to(request) {
Ok(response) => response,
Err(status) => return Ok(Err(status)), // TODO: Handle this?
};
// Existing CORS response?
if Self::has_allow_origin(&original_response) {
return Ok(Ok(original_response));
}
// 1. If the Origin header is not present terminate this set of steps.
// The request is outside the scope of this specification.
let origin = Self::origin(request)?;
let origin = match origin {
None => {
// Not a CORS request
return Ok(Ok(original_response));
}
Some(origin) => origin,
};
// Check if the request verb is an OPTION or something else
let cors_response = match request.method() {
Method::Options => {
let method = Self::request_method(request)?;
let headers = Self::request_headers(request)?;
Self::preflight(&self.options, origin, method, headers)
}
_ => Self::actual_request(&self.options, origin),
}?;
Ok(Ok(cors_response.build(original_response)))
}
/// Gets the `Origin` request header from the request
fn origin(request: &Request) -> Result<Option<Origin>, Error> {
match Origin::from_request(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(origin) => Ok(Some(origin)),
Outcome::Failure((_, err)) => Err(err),
}
}
/// Gets the `Access-Control-Request-Method` request header from the request
fn request_method(request: &Request) -> Result<Option<AccessControlRequestMethod>, Error> {
match AccessControlRequestMethod::from_request(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(method) => Ok(Some(method)),
Outcome::Failure((_, err)) => Err(err),
}
}
/// Gets the `Access-Control-Request-Headers` request header from the request
fn request_headers(request: &Request) -> Result<Option<AccessControlRequestHeaders>, Error> {
match AccessControlRequestHeaders::from_request(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(geaders) => Ok(Some(geaders)),
Outcome::Failure((_, err)) => Err(err),
}
}
/// Checks if an existing Response already has the header `Access-Control-Allow-Origin`
fn has_allow_origin(response: &response::Response<'r>) -> bool {
response.headers().get("Access-Control-Allow-Origin").next() != None
}
/// Construct a preflight response based on the options. Will return an `Err` /// Construct a preflight response based on the options. Will return an `Err`
/// if any of the preflight checks fail. /// if any of the preflight checks fail.
/// ///
/// This implementation references the /// This implementation references the
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests). /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests).
pub fn preflight<'r, R: Responder<'r>>( fn preflight(
&self, options: &Options,
responder: R, origin: Origin,
origin: Option<Origin>,
method: Option<AccessControlRequestMethod>, method: Option<AccessControlRequestMethod>,
headers: Option<AccessControlRequestHeaders>, headers: Option<AccessControlRequestHeaders>,
) -> Result<Response<R>, Error> { ) -> Result<Response, Error> {
let response = Response::new(responder); let response = Response::new();
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation // Note: All header parse failures are dealt with in the `FromRequest` trait implementation
// 1. If the Origin header is not present terminate this set of steps.
// The request is outside the scope of this specification.
let origin = match origin {
None => {
// Not a CORS request
return Ok(response);
}
Some(origin) => origin,
};
// 2. If the value of the Origin header is not a case-sensitive match for any of the values // 2. If the value of the Origin header is not a case-sensitive match for any of the values
// in list of origins do not set any additional headers and terminate this set of steps. // in list of origins do not set any additional headers and terminate this set of steps.
let response = response.allowed_origin( let response = response.allowed_origin(
&origin, &origin,
&self.allowed_origins, &options.allowed_origins,
self.send_wildcard, options.send_wildcard,
)?; )?;
// 3. Let `method` be the value as result of parsing the Access-Control-Request-Method // 3. Let `method` be the value as result of parsing the Access-Control-Request-Method
@ -540,13 +671,13 @@ impl Options {
// 5. If method is not a case-sensitive match for any of the values in list of methods // 5. If method is not a case-sensitive match for any of the values in list of methods
// do not set any additional headers and terminate this set of steps. // do not set any additional headers and terminate this set of steps.
let response = response.allowed_methods(&method, &self.allowed_methods)?; let response = response.allowed_methods(&method, &options.allowed_methods)?;
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the // 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
// values in list of headers do not set any additional headers and terminate this set of // values in list of headers do not set any additional headers and terminate this set of
// steps. // steps.
let response = if let Some(headers) = headers { let response = if let Some(headers) = headers {
response.allowed_headers(&headers, &self.allowed_headers)? response.allowed_headers(&headers, &options.allowed_headers)?
} else { } else {
response response
}; };
@ -559,12 +690,12 @@ impl Options {
// with either the value of the Origin header or the string "*" as value. // with either the value of the Origin header or the string "*" as value.
// Note: The string "*" cannot be used for a resource that supports credentials. // Note: The string "*" cannot be used for a resource that supports credentials.
let response = response.credentials(self.allow_credentials)?; let response = response.credentials(options.allow_credentials)?;
// 8. Optionally add a single Access-Control-Max-Age header // 8. Optionally add a single Access-Control-Max-Age header
// with as value the amount of seconds the user agent is allowed to cache the result of the // with as value the amount of seconds the user agent is allowed to cache the result of the
// request. // request.
let response = response.max_age(self.max_age); let response = response.max_age(options.max_age);
// 9. If method is a simple method this step may be skipped. // 9. If method is a simple method this step may be skipped.
// Add one or more Access-Control-Allow-Methods headers consisting of // Add one or more Access-Control-Allow-Methods headers consisting of
@ -591,36 +722,22 @@ impl Options {
Ok(response) Ok(response)
} }
/// Respond to a request based on the settings. /// Respond to an actual request based on the settings.
/// If the `Origin` is not provided, then this request was not made by a browser and there is no /// If the `Origin` is not provided, then this request was not made by a browser and there is no
/// CORS enforcement. /// CORS enforcement.
pub fn respond<'r, R: Responder<'r>>( fn actual_request(options: &Options, origin: Origin) -> Result<Response, Error> {
&self, let response = Response::new();
responder: R,
origin: Option<Origin>,
) -> Result<Response<R>, Error> {
let response = Response::new(responder);
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation // Note: All header parse failures are dealt with in the `FromRequest` trait implementation
// 1. If the Origin header is not present terminate this set of steps.
// The request is outside the scope of this specification.
let origin = match origin {
None => {
// Not a CORS request
return Ok(response);
}
Some(origin) => origin,
};
// 2. If the value of the Origin header is not a case-sensitive match for any of the values // 2. If the value of the Origin header is not a case-sensitive match for any of the values
// in list of origins, do not set any additional headers and terminate this set of steps. // in list of origins, do not set any additional headers and terminate this set of steps.
// Always matching is acceptable since the list of origins can be unbounded. // Always matching is acceptable since the list of origins can be unbounded.
let response = response.allowed_origin( let response = response.allowed_origin(
&origin, &origin,
&self.allowed_origins, &options.allowed_origins,
self.send_wildcard, options.send_wildcard,
)?; )?;
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
@ -631,7 +748,7 @@ impl Options {
// with either the value of the Origin header or the string "*" as value. // with either the value of the Origin header or the string "*" as value.
// Note: The string "*" cannot be used for a resource that supports credentials. // Note: The string "*" cannot be used for a resource that supports credentials.
let response = response.credentials(self.allow_credentials)?; let response = response.credentials(options.allow_credentials)?;
// 4. If the list of exposed headers is not empty add one or more // 4. If the list of exposed headers is not empty add one or more
// Access-Control-Expose-Headers headers, with as values the header field names given in // Access-Control-Expose-Headers headers, with as values the header field names given in
@ -641,7 +758,8 @@ impl Options {
// and url is a case-sensitive match for the URL of the resource. // and url is a case-sensitive match for the URL of the resource.
let response = response.exposed_headers( let response = response.exposed_headers(
self.expose_headers options
.expose_headers
.iter() .iter()
.map(|s| &**s) .map(|s| &**s)
.collect::<Vec<&str>>() .collect::<Vec<&str>>()
@ -651,16 +769,14 @@ impl Options {
} }
} }
/// A CORS Response which wraps another struct which implements `Responder`. You will typically impl<'a, 'r: 'a, R: response::Responder<'r>> response::Responder<'r> for Responder<'a, 'r, R> {
/// use [`Options`] instead to verify and build the response instead of this directly. fn respond_to(self, request: &Request) -> response::Result<'r> {
/// See module level documentation for usage examples. self.respond(request)
/// }
/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set, }
/// this responder will leave the response untouched.
/// This allows for chaining of several CORS responders.
/// /// A CORS Response which provides the following CORS headers:
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
/// existing headers defined:
/// ///
/// - `Access-Control-Allow-Origin` /// - `Access-Control-Allow-Origin`
/// - `Access-Control-Expose-Headers` /// - `Access-Control-Expose-Headers`
@ -670,8 +786,7 @@ impl Options {
/// - `Access-Control-Allow-Headers` /// - `Access-Control-Allow-Headers`
/// - `Vary` /// - `Vary`
#[derive(Debug)] #[derive(Debug)]
pub struct Response<R> { struct Response {
responder: R,
allow_origin: Option<AllOrSome<String>>, allow_origin: Option<AllOrSome<String>>,
allow_methods: HashSet<Method>, allow_methods: HashSet<Method>,
allow_headers: HeaderFieldNamesSet, allow_headers: HeaderFieldNamesSet,
@ -681,14 +796,13 @@ pub struct Response<R> {
vary_origin: bool, vary_origin: bool,
} }
impl<'r, R: Responder<'r>> Response<R> { impl Response {
/// Consumes the responder and return an empty `Response` /// Consumes the responder and return an empty `Response`
fn new(responder: R) -> Self { fn new() -> Self {
Self { Self {
allow_origin: None, allow_origin: None,
allow_headers: HashSet::new(), allow_headers: HashSet::new(),
allow_methods: HashSet::new(), allow_methods: HashSet::new(),
responder,
allow_credentials: false, allow_credentials: false,
expose_headers: HashSet::new(), expose_headers: HashSet::new(),
max_age: None, max_age: None,
@ -826,15 +940,18 @@ impl<'r, R: Responder<'r>> Response<R> {
) )
} }
/// Builds a `rocket::Response` from this struct containing only the CORS headers. /// Builds a `rocket::Response` from this struct based off some base `rocket::Response`
///
/// This will overwrite any existing CORS headers
#[allow(unused_results)] #[allow(unused_results)]
fn build(&self) -> response::Response<'r> { fn build<'r>(&self, base: response::Response<'r>) -> response::Response<'r> {
let mut builder = response::Response::build(); let mut response = response::Response::build_from(base).finalize();
// TODO: We should be able to remove this
let origin = match self.allow_origin { let origin = match self.allow_origin {
None => { None => {
// This is not a CORS response // This is not a CORS response
return builder.finalize(); return response;
} }
Some(ref origin) => origin, Some(ref origin) => origin,
}; };
@ -844,10 +961,12 @@ impl<'r, R: Responder<'r>> Response<R> {
AllOrSome::Some(ref origin) => origin.to_string(), AllOrSome::Some(ref origin) => origin.to_string(),
}; };
builder.raw_header("Access-Control-Allow-Origin", origin); response.set_raw_header("Access-Control-Allow-Origin", origin);
if self.allow_credentials { if self.allow_credentials {
builder.raw_header("Access-Control-Allow-Credentials", "true"); response.set_raw_header("Access-Control-Allow-Credentials", "true");
} else {
response.remove_header("Access-Control-Allow-Credentials");
} }
if !self.expose_headers.is_empty() { if !self.expose_headers.is_empty() {
@ -857,7 +976,9 @@ impl<'r, R: Responder<'r>> Response<R> {
.collect(); .collect();
let headers = headers.join(", "); let headers = headers.join(", ");
builder.raw_header("Access-Control-Expose-Headers", headers); response.set_raw_header("Access-Control-Expose-Headers", headers);
} else {
response.remove_header("Access-Control-Expose-Headers");
} }
if !self.allow_headers.is_empty() { if !self.allow_headers.is_empty() {
@ -867,76 +988,34 @@ impl<'r, R: Responder<'r>> Response<R> {
.collect(); .collect();
let headers = headers.join(", "); let headers = headers.join(", ");
builder.raw_header("Access-Control-Allow-Headers", headers); response.set_raw_header("Access-Control-Allow-Headers", headers);
} else {
response.remove_header("Access-Control-Allow-Headers");
} }
if !self.allow_methods.is_empty() { if !self.allow_methods.is_empty() {
let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect(); let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect();
let methods = methods.join(", "); let methods = methods.join(", ");
builder.raw_header("Access-Control-Allow-Methods", methods); response.set_raw_header("Access-Control-Allow-Methods", methods);
} else {
response.remove_header("Access-Control-Allow-Methods");
} }
if self.max_age.is_some() { if self.max_age.is_some() {
let max_age = self.max_age.unwrap(); let max_age = self.max_age.unwrap();
builder.raw_header("Access-Control-Max-Age", max_age.to_string()); response.set_raw_header("Access-Control-Max-Age", max_age.to_string());
} else {
response.remove_header("Access-Control-Max-Age");
} }
if self.vary_origin { if self.vary_origin {
builder.raw_header("Vary", "Origin"); response.set_raw_header("Vary", "Origin");
} else {
response.remove_header("Vary");
} }
builder.finalize() response
}
/// Merge a `wrapped` Response with a `cors` response
///
/// If the `wrapped` response has the `Access-Control-Allow-Origin` header already defined,
/// it will be left untouched. This allows for chaining of several CORS responders.
///
/// Otherwise, the merging will be done according to the rules of `rocket::Response::merge`.
fn merge(
mut wrapped: response::Response<'r>,
cors: response::Response<'r>,
) -> response::Response<'r> {
let existing_cors = {
wrapped.headers().get("Access-Control-Allow-Origin").next() == None
};
if existing_cors {
wrapped.merge(cors);
}
wrapped
}
/// Finalize the Response by merging the CORS header with the wrapped `Responder
///
/// If the original response has the `Access-Control-Allow-Origin` header already defined,
/// it will be left untouched.This allows for chaining of several CORS responders.
///
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
/// existing headers defined:
///
/// - `Access-Control-Allow-Origin`
/// - `Access-Control-Expose-Headers`
/// - `Access-Control-Max-Age`
/// - `Access-Control-Allow-Credentials`
/// - `Access-Control-Allow-Methods`
/// - `Access-Control-Allow-Headers`
/// - `Vary`
fn finalize(self, request: &Request) -> response::Result<'r> {
let cors_response = self.build();
let original_response = self.responder.respond_to(request)?;
Ok(Self::merge(original_response, cors_response))
}
}
impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
fn respond_to(self, request: &Request) -> response::Result<'r> {
self.finalize(request)
} }
} }
@ -1059,7 +1138,7 @@ mod tests {
let allowed_origins = AllOrSome::All; let allowed_origins = AllOrSome::All;
let send_wildcard = true; let send_wildcard = true;
let response = Response::new(()); let response = Response::new();
let response = not_err!(response.allowed_origin( let response = not_err!(response.allowed_origin(
&origin, &origin,
&allowed_origins, &allowed_origins,
@ -1071,7 +1150,7 @@ mod tests {
// Build response and check built response header // Build response and check built response header
let expected_header = vec!["*"]; let expected_header = vec!["*"];
let response = response.build(); let response = response.build(response::Response::new());
let actual_header: Vec<_> = response let actual_header: Vec<_> = response
.headers() .headers()
.get("Access-Control-Allow-Origin") .get("Access-Control-Allow-Origin")
@ -1086,7 +1165,7 @@ mod tests {
let allowed_origins = AllOrSome::All; let allowed_origins = AllOrSome::All;
let send_wildcard = false; let send_wildcard = false;
let response = Response::new(()); let response = Response::new();
let response = not_err!(response.allowed_origin( let response = not_err!(response.allowed_origin(
&origin, &origin,
&allowed_origins, &allowed_origins,
@ -1103,7 +1182,7 @@ mod tests {
// Build response and check built response header // Build response and check built response header
let expected_header = vec![url]; let expected_header = vec![url];
let response = response.build(); let response = response.build(response::Response::new());
let actual_header: Vec<_> = response let actual_header: Vec<_> = response
.headers() .headers()
.get("Access-Control-Allow-Origin") .get("Access-Control-Allow-Origin")
@ -1120,7 +1199,7 @@ mod tests {
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
let send_wildcard = false; let send_wildcard = false;
let response = Response::new(()); let response = Response::new();
let response = not_err!(response.allowed_origin( let response = not_err!(response.allowed_origin(
&origin, &origin,
&allowed_origins, &allowed_origins,
@ -1138,7 +1217,7 @@ mod tests {
// Build response and check built response header // Build response and check built response header
let expected_header = vec![url]; let expected_header = vec![url];
let response = response.build(); let response = response.build(response::Response::new());
let actual_header: Vec<_> = response let actual_header: Vec<_> = response
.headers() .headers()
.get("Access-Control-Allow-Origin") .get("Access-Control-Allow-Origin")
@ -1156,7 +1235,7 @@ mod tests {
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
let send_wildcard = false; let send_wildcard = false;
let response = Response::new(()); let response = Response::new();
let _ = response let _ = response
.allowed_origin(&origin, &allowed_origins, send_wildcard) .allowed_origin(&origin, &allowed_origins, send_wildcard)
.unwrap(); .unwrap();
@ -1165,7 +1244,7 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")] #[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn response_credentials_does_not_allow_wildcard_with_all_origins() { fn response_credentials_does_not_allow_wildcard_with_all_origins() {
let response = Response::new(()); let response = Response::new();
let response = response.any(); let response = response.any();
let _ = response.credentials(true).unwrap(); let _ = response.credentials(true).unwrap();
@ -1173,7 +1252,7 @@ mod tests {
#[test] #[test]
fn response_credentials_allows_specific_origins() { fn response_credentials_allows_specific_origins() {
let response = Response::new(()); let response = Response::new();
let response = response.origin("https://www.example.com", false); let response = response.origin("https://www.example.com", false);
let response = response.credentials(true).expect( let response = response.credentials(true).expect(
@ -1183,7 +1262,7 @@ mod tests {
// Build response and check built response header // Build response and check built response header
let expected_header = vec!["true"]; let expected_header = vec!["true"];
let response = response.build(); let response = response.build(response::Response::new());
let actual_header: Vec<_> = response let actual_header: Vec<_> = response
.headers() .headers()
.get("Access-Control-Allow-Credentials") .get("Access-Control-Allow-Credentials")
@ -1194,12 +1273,12 @@ mod tests {
#[test] #[test]
fn response_sets_exposed_headers_correctly() { fn response_sets_exposed_headers_correctly() {
let headers = vec!["Bar", "Baz", "Foo"]; let headers = vec!["Bar", "Baz", "Foo"];
let response = Response::new(()); let response = Response::new();
let response = response.origin("https://www.example.com", false); let response = response.origin("https://www.example.com", false);
let response = response.exposed_headers(&headers); let response = response.exposed_headers(&headers);
// Build response and check built response header // Build response and check built response header
let response = response.build(); let response = response.build(response::Response::new());
let actual_header: Vec<_> = response let actual_header: Vec<_> = response
.headers() .headers()
.get("Access-Control-Expose-Headers") .get("Access-Control-Expose-Headers")
@ -1216,27 +1295,27 @@ mod tests {
#[test] #[test]
fn response_sets_max_age_correctly() { fn response_sets_max_age_correctly() {
let response = Response::new(()); let response = Response::new();
let response = response.origin("https://www.example.com", false); let response = response.origin("https://www.example.com", false);
let response = response.max_age(Some(42)); let response = response.max_age(Some(42));
// Build response and check built response header // Build response and check built response header
let expected_header = vec!["42"]; let expected_header = vec!["42"];
let response = response.build(); let response = response.build(response::Response::new());
let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect(); let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect();
assert_eq!(expected_header, actual_header); assert_eq!(expected_header, actual_header);
} }
#[test] #[test]
fn response_does_not_set_max_age_when_none() { fn response_does_not_set_max_age_when_none() {
let response = Response::new(()); let response = Response::new();
let response = response.origin("https://www.example.com", false); let response = response.origin("https://www.example.com", false);
let response = response.max_age(None); let response = response.max_age(None);
// Build response and check built response header // Build response and check built response header
let response = response.build(); let response = response.build(response::Response::new());
assert!(response assert!(response
.headers() .headers()
.get("Access-Control-Max-Age") .get("Access-Control-Max-Age")
@ -1249,7 +1328,7 @@ mod tests {
let allowed_headers = AllOrSome::All; let allowed_headers = AllOrSome::All;
let requested_headers = vec!["Bar", "Foo"]; let requested_headers = vec!["Bar", "Foo"];
let response = Response::new(()); let response = Response::new();
let response = response.origin("https://www.example.com", false); let response = response.origin("https://www.example.com", false);
let response = response let response = response
.allowed_headers( .allowed_headers(
@ -1259,7 +1338,7 @@ mod tests {
.expect("to not fail"); .expect("to not fail");
// Build response and check built response header // Build response and check built response header
let response = response.build(); let response = response.build(response::Response::new());
let actual_header: Vec<_> = response let actual_header: Vec<_> = response
.headers() .headers()
.get("Access-Control-Allow-Headers") .get("Access-Control-Allow-Headers")
@ -1285,7 +1364,7 @@ mod tests {
let method = "GET"; let method = "GET";
let response = Response::new(()); let response = Response::new();
let response = response.origin("https://www.example.com", false); let response = response.origin("https://www.example.com", false);
let response = response let response = response
.allowed_methods( .allowed_methods(
@ -1295,7 +1374,7 @@ mod tests {
.expect("not to fail"); .expect("not to fail");
// Build response and check built response header // Build response and check built response header
let response = response.build(); let response = response.build(response::Response::new());
let actual_header: Vec<_> = response let actual_header: Vec<_> = response
.headers() .headers()
.get("Access-Control-Allow-Methods") .get("Access-Control-Allow-Methods")
@ -1324,7 +1403,7 @@ mod tests {
let method = "DELETE"; let method = "DELETE";
let response = Response::new(()); let response = Response::new();
let response = response.origin("https://www.example.com", false); let response = response.origin("https://www.example.com", false);
let _ = response let _ = response
.allowed_methods( .allowed_methods(
@ -1341,7 +1420,7 @@ mod tests {
let allowed_headers = vec!["Bar", "Baz", "Foo"]; let allowed_headers = vec!["Bar", "Baz", "Foo"];
let requested_headers = vec!["Bar", "Foo"]; let requested_headers = vec!["Bar", "Foo"];
let response = Response::new(()); let response = Response::new();
let response = response.origin("https://www.example.com", false); let response = response.origin("https://www.example.com", false);
let response = response let response = response
.allowed_headers( .allowed_headers(
@ -1356,7 +1435,7 @@ mod tests {
.expect("to not fail"); .expect("to not fail");
// Build response and check built response header // Build response and check built response header
let response = response.build(); let response = response.build(response::Response::new());
let actual_header: Vec<_> = response let actual_header: Vec<_> = response
.headers() .headers()
.get("Access-Control-Allow-Headers") .get("Access-Control-Allow-Headers")
@ -1377,7 +1456,7 @@ mod tests {
let allowed_headers = vec!["Bar", "Baz", "Foo"]; let allowed_headers = vec!["Bar", "Baz", "Foo"];
let requested_headers = vec!["Bar", "Foo", "Unknown"]; let requested_headers = vec!["Bar", "Foo", "Unknown"];
let response = Response::new(()); let response = Response::new();
let response = response.origin("https://www.example.com", false); let response = response.origin("https://www.example.com", false);
let _ = response let _ = response
.allowed_headers( .allowed_headers(
@ -1395,98 +1474,77 @@ mod tests {
#[test] #[test]
fn response_does_not_build_if_origin_is_not_set() { fn response_does_not_build_if_origin_is_not_set() {
let response = Response::new(()); let response = Response::new();
let response = response.build(); let response = response.build(response::Response::new());
let headers: Vec<_> = response.headers().iter().collect(); let headers: Vec<_> = response.headers().iter().collect();
assert_eq!(headers.len(), 0); assert_eq!(headers.len(), 0);
} }
// Note: Correct operation of Response::build is tested in the tests above for each of the
// individual headers
#[test] #[test]
fn response_merges_correctly() { fn response_build_removes_existing_cors_headers_and_keeps_others() {
use std::io::Cursor; use std::io::Cursor;
use rocket::http::Status;
let wrapped = response::Response::build() let original = response::Response::build()
.status(Status::ImATeapot) .status(Status::ImATeapot)
.raw_header("X-Teapot-Make", "Rocket") .raw_header("X-Teapot-Make", "Rocket")
.raw_header("Access-Control-Max-Age", "42")
.sized_body(Cursor::new("Brewing the best coffee!")) .sized_body(Cursor::new("Brewing the best coffee!"))
.finalize(); .finalize();
let response = Response::new(()); let response = Response::new();
let response = response.origin("https://www.acme.com", false); let response = response.origin("https://www.example.com", false);
let response = response.build(original);
let mut response = Response::<String>::merge(wrapped, response.build());
assert_eq!(response.status(), Status::ImATeapot);
assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string()));
// Check CORS header // Check CORS header
let expected_header = vec!["https://www.acme.com"];
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
// Check other header
let expected_header = vec!["Rocket"];
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
assert_eq!(expected_header, actual_header);
}
#[test]
fn response_does_not_merge_existing_cors() {
let wrapped = response::Response::build()
.raw_header("Access-Control-Allow-Origin", "https://www.example.com")
.finalize();
let response = Response::new(());
let response = response.origin("https://www.acme.com", false);
let response = Response::<()>::merge(wrapped, response.build());
let expected_header = vec!["https://www.example.com"]; let expected_header = vec!["https://www.example.com"];
let actual_header: Vec<_> = response let actual_header: Vec<_> = response
.headers() .headers()
.get("Access-Control-Allow-Origin") .get("Access-Control-Allow-Origin")
.collect(); .collect();
assert_eq!(expected_header, actual_header); assert_eq!(expected_header, actual_header);
}
#[test]
fn response_finalize_smoke_test() {
use std::io::Cursor;
use rocket::http::Status;
let wrapped = response::Response::build()
.status(Status::ImATeapot)
.raw_header("X-Teapot-Make", "Rocket")
.sized_body(Cursor::new("Brewing the best coffee!"))
.finalize();
let response = Response::new(wrapped);
let response = response.origin("https://www.acme.com", false);
let client = make_client();
let request = client.get("/");
let mut response = response.finalize(request.inner()).expect("not to fail");
assert_eq!(response.status(), Status::ImATeapot);
assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string()));
// Check CORS header
let expected_header = vec!["https://www.acme.com"];
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
// Check other header // Check other header
let expected_header = vec!["Rocket"]; let expected_header = vec!["Rocket"];
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect(); let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
assert_eq!(expected_header, actual_header); assert_eq!(expected_header, actual_header);
// Check that `Access-Control-Max-Age` is removed
assert!(response.headers().get("Access-Control-Max-Age").next().is_none());
} }
// The following tests check that preflight checks are done properly
// fn make_cors_options() -> Options {
// let (allowed_origins, failed_origins) =
// AllOrSome::new_from_str_list(&["https://www.acme.com"]);
// assert!(failed_origins.is_empty());
// Options {
// allowed_origins: allowed_origins,
// allowed_methods: [Method::Get].iter().cloned().collect(),
// allowed_headers: AllOrSome::Some(
// ["Authorization"]
// .into_iter()
// .map(|s| s.to_string().into())
// .collect(),
// ),
// allow_credentials: true,
// ..Default::default()
// }
// }
// /// Tests that non CORS preflight are let through without modification
// #[test]
// fn preflight_missing_origins_are_let_through() {
// let options = make_cors_options();
// let client = make_client();
// let request = client.get("/");
// let response = options.preflight((), None, None, None).expect("not to fail");
// let headers: Vec<_> = response.headers().iter().collect();
// assert_eq!(headers.len(), 0);
// }
} }

View File

@ -15,21 +15,13 @@ use rocket::local::Client;
use rocket_cors::*; use rocket_cors::*;
#[options("/")] #[options("/")]
fn cors_options( fn cors_options(options: State<rocket_cors::Options>) -> Responder<&str> {
origin: Option<Origin>, rocket_cors::respond(options, "")
method: Option<AccessControlRequestMethod>,
headers: Option<AccessControlRequestHeaders>,
options: State<rocket_cors::Options>,
) -> Result<Response<()>, Error> {
options.preflight((), origin, method, headers)
} }
#[get("/")] #[get("/")]
fn cors( fn cors(options: State<rocket_cors::Options>) -> Responder<&str> {
origin: Option<Origin>, rocket_cors::respond(options, "Hello CORS")
options: State<rocket_cors::Options>,
) -> Result<Response<&'static str>, Error> {
options.respond("Hello CORS", origin)
} }
fn make_cors_options() -> Options { fn make_cors_options() -> Options {
@ -146,6 +138,7 @@ fn cors_get_check() {
let req = client.get("/").header(origin_header).header(authorization); let req = client.get("/").header(origin_header).header(authorization);
let mut response = req.dispatch(); let mut response = req.dispatch();
println!("{:?}", response);
assert_eq!(response.status(), Status::Ok); assert_eq!(response.status(), Status::Ok);
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));