Delay CORS checks and response until `Responder::respond_to` is invoked (#6)

* Delay checking of CORS to just before responding

* Lifetime issues

* Use State::inner()

* Fix lifetime issues

* Bump Rocket

* Document 'static limitation

And link to https://github.com/SergioBenitez/Rocket/pull/345

* Remove extraneous comments
This commit is contained in:
Yong Wen Chua 2017-07-15 01:38:13 +08:00 committed by GitHub
parent 16b89ab31c
commit 7dbc22b523
4 changed files with 286 additions and 235 deletions

View File

@ -16,7 +16,7 @@ travis-ci = { repository = "lawliet89/rocket_cors" }
[dependencies]
log = "0.3"
rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "aa51fe0" }
rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "51a465f2cc88d537079133bcdfec37d029070dcd" }
serde = "1.0"
serde_derive = "1.0"
unicase="1.4"
@ -29,5 +29,5 @@ version_check = "0.1"
[dev-dependencies]
hyper = "0.10"
rocket_codegen = { git = "https://github.com/SergioBenitez/Rocket", rev = "aa51fe0" }
rocket_codegen = { git = "https://github.com/SergioBenitez/Rocket", rev = "51a465f2cc88d537079133bcdfec37d029070dcd" }
serde_json = "1.0"

View File

@ -27,7 +27,7 @@ In particular, `rocket_cors` is currently targetted for `nightly-2017-07-13`.
Rocket > 0.3 is needed. At this moment, `0.3` is not released, and this crate will not be published
to Crates.io until Rocket 0.3 is released to Crates.io.
We currently tie this crate to revision [aa51fe0](https://github.com/SergioBenitez/Rocket/tree/aa51fe0) of Rocket.
We currently tie this crate to revision [51a465f2cc88d537079133bcdfec37d029070dcd](https://github.com/SergioBenitez/Rocket/tree/51a465f2cc88d537079133bcdfec37d029070dcd) of Rocket.
## Installation

View File

@ -29,7 +29,7 @@
//! to Crates.io until Rocket 0.3 is released to Crates.io.
//!
//! We currently tie this crate to revision
//! [aa51fe0](https://github.com/SergioBenitez/Rocket/tree/aa51fe0) of Rocket.
//! [51a465f2cc88d537079133bcdfec37d029070dcd](https://github.com/SergioBenitez/Rocket/tree/51a465f2cc88d537079133bcdfec37d029070dcd) of Rocket.
//!
//! ## Installation
//!
@ -118,20 +118,27 @@ extern crate hyper;
use std::collections::{HashSet, HashMap};
use std::error;
use std::fmt;
use std::marker::PhantomData;
use std::ops::Deref;
use std::str::FromStr;
use rocket::request::{self, Request, FromRequest};
use rocket::response::{self, Responder};
use rocket::{Outcome, State};
use rocket::http::{Method, Status};
use rocket::Outcome;
use rocket::request::{self, Request, FromRequest};
use rocket::response;
use unicase::UniCase;
#[cfg(test)]
#[macro_use]
mod test_macros;
/// CORS related error
/// Errors during operations
///
/// This enum implements `rocket::response::Responder` which will return an appropriate status code
/// while printing out the error in the console.
/// Because these errors are usually the result of an error while trying to respond to a CORS
/// request, CORS headers cannot be added to the response and your applications requesting CORS
/// will not be able to see the status code.
#[derive(Debug)]
pub enum Error {
/// The HTTP request header `Origin` is required but was not provided
@ -201,7 +208,7 @@ impl fmt::Display for Error {
}
}
impl<'r> Responder<'r> for Error {
impl<'r> response::Responder<'r> for Error {
fn respond_to(self, _: &Request) -> Result<response::Response<'r>, Status> {
error_!("CORS Error: {:?}", self);
Err(match self {
@ -259,7 +266,6 @@ impl<'a, 'r> FromRequest<'a, 'r> for Url {
}
}
/// The `Origin` request header used in CORS
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
@ -374,7 +380,11 @@ impl AllOrSome<HashSet<Url>> {
}
}
/// Configuration options to for building CORS preflight or actual responses.
/// Responder and Fairing for CORS
///
/// This struct can be used as Fairing for Rocket, or as an ad-hoc responder for any CORS requests.
/// You create a new copy of this struct by defining the configurations in the fields below.
/// This struct can also be deserialized by serde.
///
/// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this
/// struct. The default for each field is described in the docuementation for the field.
@ -473,7 +483,30 @@ impl Default for Options {
}
}
/// Ad-hoc per route CORS response to requests
///
/// Note: If you use this, the lifetime parameter `'r` of your `rocket:::response::Responder<'r>`
/// CANNOT be `'static`. This is because the code generated by Rocket will implicitly try to
/// to restrain the `Request` object passed to the route to `&'static Request`, and it is not
/// possible to have such a reference.
/// See [this PR on Rocket](https://github.com/SergioBenitez/Rocket/pull/345).
pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
options: State<'a, Options>,
responder: R,
) -> Responder<'a, 'r, R> {
options.inner().respond(responder)
}
impl Options {
/// Wrap any `Rocket::Response` and respond with CORS headers.
/// This is only used for ad-hoc route CORS response
fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
&'a self,
responder: R,
) -> Responder<'a, 'r, R> {
Responder::new(responder, self)
}
fn default_allowed_methods() -> HashSet<Method> {
vec![
Method::Get,
@ -486,40 +519,138 @@ impl Options {
].into_iter()
.collect()
}
}
/// A CORS Responder which will inspect the incoming requests and respond accoridingly.
///
/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set,
/// this responder will leave the response untouched.
/// This allows for chaining of several CORS responders.
///
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
/// existing headers defined:
///
/// - `Access-Control-Allow-Origin`
/// - `Access-Control-Expose-Headers`
/// - `Access-Control-Max-Age`
/// - `Access-Control-Allow-Credentials`
/// - `Access-Control-Allow-Methods`
/// - `Access-Control-Allow-Headers`
/// - `Vary`
#[derive(Debug)]
pub struct Responder<'a, 'r: 'a, R> {
responder: R,
options: &'a Options,
marker: PhantomData<response::Responder<'r>>,
}
impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> {
fn new(responder: R, options: &'a Options) -> Self {
Self {
responder,
options,
marker: PhantomData,
}
}
/// Respond to a request
fn respond(self, request: &Request) -> response::Result<'r> {
match self.build_cors_response(request) {
Ok(response) => response,
Err(e) => response::Responder::respond_to(e, request),
}
}
/// Build a CORS response and merge with an existing `rocket::Response` for the request
fn build_cors_response(self, request: &Request) -> Result<response::Result<'r>, Error> {
let original_response = match self.responder.respond_to(request) {
Ok(response) => response,
Err(status) => return Ok(Err(status)), // TODO: Handle this?
};
// Existing CORS response?
if Self::has_allow_origin(&original_response) {
return Ok(Ok(original_response));
}
// 1. If the Origin header is not present terminate this set of steps.
// The request is outside the scope of this specification.
let origin = Self::origin(request)?;
let origin = match origin {
None => {
// Not a CORS request
return Ok(Ok(original_response));
}
Some(origin) => origin,
};
// Check if the request verb is an OPTION or something else
let cors_response = match request.method() {
Method::Options => {
let method = Self::request_method(request)?;
let headers = Self::request_headers(request)?;
Self::preflight(&self.options, origin, method, headers)
}
_ => Self::actual_request(&self.options, origin),
}?;
Ok(Ok(cors_response.build(original_response)))
}
/// Gets the `Origin` request header from the request
fn origin(request: &Request) -> Result<Option<Origin>, Error> {
match Origin::from_request(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(origin) => Ok(Some(origin)),
Outcome::Failure((_, err)) => Err(err),
}
}
/// Gets the `Access-Control-Request-Method` request header from the request
fn request_method(request: &Request) -> Result<Option<AccessControlRequestMethod>, Error> {
match AccessControlRequestMethod::from_request(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(method) => Ok(Some(method)),
Outcome::Failure((_, err)) => Err(err),
}
}
/// Gets the `Access-Control-Request-Headers` request header from the request
fn request_headers(request: &Request) -> Result<Option<AccessControlRequestHeaders>, Error> {
match AccessControlRequestHeaders::from_request(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(geaders) => Ok(Some(geaders)),
Outcome::Failure((_, err)) => Err(err),
}
}
/// Checks if an existing Response already has the header `Access-Control-Allow-Origin`
fn has_allow_origin(response: &response::Response<'r>) -> bool {
response.headers().get("Access-Control-Allow-Origin").next() != None
}
/// Construct a preflight response based on the options. Will return an `Err`
/// if any of the preflight checks fail.
///
/// This implementation references the
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests).
pub fn preflight<'r, R: Responder<'r>>(
&self,
responder: R,
origin: Option<Origin>,
fn preflight(
options: &Options,
origin: Origin,
method: Option<AccessControlRequestMethod>,
headers: Option<AccessControlRequestHeaders>,
) -> Result<Response<R>, Error> {
) -> Result<Response, Error> {
let response = Response::new(responder);
let response = Response::new();
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
// 1. If the Origin header is not present terminate this set of steps.
// The request is outside the scope of this specification.
let origin = match origin {
None => {
// Not a CORS request
return Ok(response);
}
Some(origin) => origin,
};
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
// in list of origins do not set any additional headers and terminate this set of steps.
let response = response.allowed_origin(
&origin,
&self.allowed_origins,
self.send_wildcard,
&options.allowed_origins,
options.send_wildcard,
)?;
// 3. Let `method` be the value as result of parsing the Access-Control-Request-Method
@ -540,13 +671,13 @@ impl Options {
// 5. If method is not a case-sensitive match for any of the values in list of methods
// do not set any additional headers and terminate this set of steps.
let response = response.allowed_methods(&method, &self.allowed_methods)?;
let response = response.allowed_methods(&method, &options.allowed_methods)?;
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
// values in list of headers do not set any additional headers and terminate this set of
// steps.
let response = if let Some(headers) = headers {
response.allowed_headers(&headers, &self.allowed_headers)?
response.allowed_headers(&headers, &options.allowed_headers)?
} else {
response
};
@ -559,12 +690,12 @@ impl Options {
// with either the value of the Origin header or the string "*" as value.
// Note: The string "*" cannot be used for a resource that supports credentials.
let response = response.credentials(self.allow_credentials)?;
let response = response.credentials(options.allow_credentials)?;
// 8. Optionally add a single Access-Control-Max-Age header
// with as value the amount of seconds the user agent is allowed to cache the result of the
// request.
let response = response.max_age(self.max_age);
let response = response.max_age(options.max_age);
// 9. If method is a simple method this step may be skipped.
// Add one or more Access-Control-Allow-Methods headers consisting of
@ -591,36 +722,22 @@ impl Options {
Ok(response)
}
/// Respond to a request based on the settings.
/// Respond to an actual request based on the settings.
/// If the `Origin` is not provided, then this request was not made by a browser and there is no
/// CORS enforcement.
pub fn respond<'r, R: Responder<'r>>(
&self,
responder: R,
origin: Option<Origin>,
) -> Result<Response<R>, Error> {
let response = Response::new(responder);
fn actual_request(options: &Options, origin: Origin) -> Result<Response, Error> {
let response = Response::new();
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
// 1. If the Origin header is not present terminate this set of steps.
// The request is outside the scope of this specification.
let origin = match origin {
None => {
// Not a CORS request
return Ok(response);
}
Some(origin) => origin,
};
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
// in list of origins, do not set any additional headers and terminate this set of steps.
// Always matching is acceptable since the list of origins can be unbounded.
let response = response.allowed_origin(
&origin,
&self.allowed_origins,
self.send_wildcard,
&options.allowed_origins,
options.send_wildcard,
)?;
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
@ -631,7 +748,7 @@ impl Options {
// with either the value of the Origin header or the string "*" as value.
// Note: The string "*" cannot be used for a resource that supports credentials.
let response = response.credentials(self.allow_credentials)?;
let response = response.credentials(options.allow_credentials)?;
// 4. If the list of exposed headers is not empty add one or more
// Access-Control-Expose-Headers headers, with as values the header field names given in
@ -641,7 +758,8 @@ impl Options {
// and url is a case-sensitive match for the URL of the resource.
let response = response.exposed_headers(
self.expose_headers
options
.expose_headers
.iter()
.map(|s| &**s)
.collect::<Vec<&str>>()
@ -651,16 +769,14 @@ impl Options {
}
}
/// A CORS Response which wraps another struct which implements `Responder`. You will typically
/// use [`Options`] instead to verify and build the response instead of this directly.
/// See module level documentation for usage examples.
///
/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set,
/// this responder will leave the response untouched.
/// This allows for chaining of several CORS responders.
///
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
/// existing headers defined:
impl<'a, 'r: 'a, R: response::Responder<'r>> response::Responder<'r> for Responder<'a, 'r, R> {
fn respond_to(self, request: &Request) -> response::Result<'r> {
self.respond(request)
}
}
/// A CORS Response which provides the following CORS headers:
///
/// - `Access-Control-Allow-Origin`
/// - `Access-Control-Expose-Headers`
@ -670,8 +786,7 @@ impl Options {
/// - `Access-Control-Allow-Headers`
/// - `Vary`
#[derive(Debug)]
pub struct Response<R> {
responder: R,
struct Response {
allow_origin: Option<AllOrSome<String>>,
allow_methods: HashSet<Method>,
allow_headers: HeaderFieldNamesSet,
@ -681,14 +796,13 @@ pub struct Response<R> {
vary_origin: bool,
}
impl<'r, R: Responder<'r>> Response<R> {
impl Response {
/// Consumes the responder and return an empty `Response`
fn new(responder: R) -> Self {
fn new() -> Self {
Self {
allow_origin: None,
allow_headers: HashSet::new(),
allow_methods: HashSet::new(),
responder,
allow_credentials: false,
expose_headers: HashSet::new(),
max_age: None,
@ -826,15 +940,18 @@ impl<'r, R: Responder<'r>> Response<R> {
)
}
/// Builds a `rocket::Response` from this struct containing only the CORS headers.
/// Builds a `rocket::Response` from this struct based off some base `rocket::Response`
///
/// This will overwrite any existing CORS headers
#[allow(unused_results)]
fn build(&self) -> response::Response<'r> {
let mut builder = response::Response::build();
fn build<'r>(&self, base: response::Response<'r>) -> response::Response<'r> {
let mut response = response::Response::build_from(base).finalize();
// TODO: We should be able to remove this
let origin = match self.allow_origin {
None => {
// This is not a CORS response
return builder.finalize();
return response;
}
Some(ref origin) => origin,
};
@ -844,10 +961,12 @@ impl<'r, R: Responder<'r>> Response<R> {
AllOrSome::Some(ref origin) => origin.to_string(),
};
builder.raw_header("Access-Control-Allow-Origin", origin);
response.set_raw_header("Access-Control-Allow-Origin", origin);
if self.allow_credentials {
builder.raw_header("Access-Control-Allow-Credentials", "true");
response.set_raw_header("Access-Control-Allow-Credentials", "true");
} else {
response.remove_header("Access-Control-Allow-Credentials");
}
if !self.expose_headers.is_empty() {
@ -857,7 +976,9 @@ impl<'r, R: Responder<'r>> Response<R> {
.collect();
let headers = headers.join(", ");
builder.raw_header("Access-Control-Expose-Headers", headers);
response.set_raw_header("Access-Control-Expose-Headers", headers);
} else {
response.remove_header("Access-Control-Expose-Headers");
}
if !self.allow_headers.is_empty() {
@ -867,76 +988,34 @@ impl<'r, R: Responder<'r>> Response<R> {
.collect();
let headers = headers.join(", ");
builder.raw_header("Access-Control-Allow-Headers", headers);
response.set_raw_header("Access-Control-Allow-Headers", headers);
} else {
response.remove_header("Access-Control-Allow-Headers");
}
if !self.allow_methods.is_empty() {
let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect();
let methods = methods.join(", ");
builder.raw_header("Access-Control-Allow-Methods", methods);
response.set_raw_header("Access-Control-Allow-Methods", methods);
} else {
response.remove_header("Access-Control-Allow-Methods");
}
if self.max_age.is_some() {
let max_age = self.max_age.unwrap();
builder.raw_header("Access-Control-Max-Age", max_age.to_string());
response.set_raw_header("Access-Control-Max-Age", max_age.to_string());
} else {
response.remove_header("Access-Control-Max-Age");
}
if self.vary_origin {
builder.raw_header("Vary", "Origin");
response.set_raw_header("Vary", "Origin");
} else {
response.remove_header("Vary");
}
builder.finalize()
}
/// Merge a `wrapped` Response with a `cors` response
///
/// If the `wrapped` response has the `Access-Control-Allow-Origin` header already defined,
/// it will be left untouched. This allows for chaining of several CORS responders.
///
/// Otherwise, the merging will be done according to the rules of `rocket::Response::merge`.
fn merge(
mut wrapped: response::Response<'r>,
cors: response::Response<'r>,
) -> response::Response<'r> {
let existing_cors = {
wrapped.headers().get("Access-Control-Allow-Origin").next() == None
};
if existing_cors {
wrapped.merge(cors);
}
wrapped
}
/// Finalize the Response by merging the CORS header with the wrapped `Responder
///
/// If the original response has the `Access-Control-Allow-Origin` header already defined,
/// it will be left untouched.This allows for chaining of several CORS responders.
///
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
/// existing headers defined:
///
/// - `Access-Control-Allow-Origin`
/// - `Access-Control-Expose-Headers`
/// - `Access-Control-Max-Age`
/// - `Access-Control-Allow-Credentials`
/// - `Access-Control-Allow-Methods`
/// - `Access-Control-Allow-Headers`
/// - `Vary`
fn finalize(self, request: &Request) -> response::Result<'r> {
let cors_response = self.build();
let original_response = self.responder.respond_to(request)?;
Ok(Self::merge(original_response, cors_response))
}
}
impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
fn respond_to(self, request: &Request) -> response::Result<'r> {
self.finalize(request)
response
}
}
@ -1059,7 +1138,7 @@ mod tests {
let allowed_origins = AllOrSome::All;
let send_wildcard = true;
let response = Response::new(());
let response = Response::new();
let response = not_err!(response.allowed_origin(
&origin,
&allowed_origins,
@ -1071,7 +1150,7 @@ mod tests {
// Build response and check built response header
let expected_header = vec!["*"];
let response = response.build();
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
@ -1086,7 +1165,7 @@ mod tests {
let allowed_origins = AllOrSome::All;
let send_wildcard = false;
let response = Response::new(());
let response = Response::new();
let response = not_err!(response.allowed_origin(
&origin,
&allowed_origins,
@ -1103,7 +1182,7 @@ mod tests {
// Build response and check built response header
let expected_header = vec![url];
let response = response.build();
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
@ -1120,7 +1199,7 @@ mod tests {
assert!(failed_origins.is_empty());
let send_wildcard = false;
let response = Response::new(());
let response = Response::new();
let response = not_err!(response.allowed_origin(
&origin,
&allowed_origins,
@ -1138,7 +1217,7 @@ mod tests {
// Build response and check built response header
let expected_header = vec![url];
let response = response.build();
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
@ -1156,7 +1235,7 @@ mod tests {
assert!(failed_origins.is_empty());
let send_wildcard = false;
let response = Response::new(());
let response = Response::new();
let _ = response
.allowed_origin(&origin, &allowed_origins, send_wildcard)
.unwrap();
@ -1165,7 +1244,7 @@ mod tests {
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn response_credentials_does_not_allow_wildcard_with_all_origins() {
let response = Response::new(());
let response = Response::new();
let response = response.any();
let _ = response.credentials(true).unwrap();
@ -1173,7 +1252,7 @@ mod tests {
#[test]
fn response_credentials_allows_specific_origins() {
let response = Response::new(());
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response.credentials(true).expect(
@ -1183,7 +1262,7 @@ mod tests {
// Build response and check built response header
let expected_header = vec!["true"];
let response = response.build();
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Credentials")
@ -1194,12 +1273,12 @@ mod tests {
#[test]
fn response_sets_exposed_headers_correctly() {
let headers = vec!["Bar", "Baz", "Foo"];
let response = Response::new(());
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response.exposed_headers(&headers);
// Build response and check built response header
let response = response.build();
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Expose-Headers")
@ -1216,27 +1295,27 @@ mod tests {
#[test]
fn response_sets_max_age_correctly() {
let response = Response::new(());
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response.max_age(Some(42));
// Build response and check built response header
let expected_header = vec!["42"];
let response = response.build();
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect();
assert_eq!(expected_header, actual_header);
}
#[test]
fn response_does_not_set_max_age_when_none() {
let response = Response::new(());
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response.max_age(None);
// Build response and check built response header
let response = response.build();
let response = response.build(response::Response::new());
assert!(response
.headers()
.get("Access-Control-Max-Age")
@ -1249,7 +1328,7 @@ mod tests {
let allowed_headers = AllOrSome::All;
let requested_headers = vec!["Bar", "Foo"];
let response = Response::new(());
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response
.allowed_headers(
@ -1259,7 +1338,7 @@ mod tests {
.expect("to not fail");
// Build response and check built response header
let response = response.build();
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Headers")
@ -1285,7 +1364,7 @@ mod tests {
let method = "GET";
let response = Response::new(());
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response
.allowed_methods(
@ -1295,7 +1374,7 @@ mod tests {
.expect("not to fail");
// Build response and check built response header
let response = response.build();
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Methods")
@ -1324,7 +1403,7 @@ mod tests {
let method = "DELETE";
let response = Response::new(());
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let _ = response
.allowed_methods(
@ -1341,7 +1420,7 @@ mod tests {
let allowed_headers = vec!["Bar", "Baz", "Foo"];
let requested_headers = vec!["Bar", "Foo"];
let response = Response::new(());
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response
.allowed_headers(
@ -1356,7 +1435,7 @@ mod tests {
.expect("to not fail");
// Build response and check built response header
let response = response.build();
let response = response.build(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Headers")
@ -1377,7 +1456,7 @@ mod tests {
let allowed_headers = vec!["Bar", "Baz", "Foo"];
let requested_headers = vec!["Bar", "Foo", "Unknown"];
let response = Response::new(());
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let _ = response
.allowed_headers(
@ -1395,98 +1474,77 @@ mod tests {
#[test]
fn response_does_not_build_if_origin_is_not_set() {
let response = Response::new(());
let response = response.build();
let response = Response::new();
let response = response.build(response::Response::new());
let headers: Vec<_> = response.headers().iter().collect();
assert_eq!(headers.len(), 0);
}
// Note: Correct operation of Response::build is tested in the tests above for each of the
// individual headers
#[test]
fn response_merges_correctly() {
fn response_build_removes_existing_cors_headers_and_keeps_others() {
use std::io::Cursor;
use rocket::http::Status;
let wrapped = response::Response::build()
let original = response::Response::build()
.status(Status::ImATeapot)
.raw_header("X-Teapot-Make", "Rocket")
.raw_header("Access-Control-Max-Age", "42")
.sized_body(Cursor::new("Brewing the best coffee!"))
.finalize();
let response = Response::new(());
let response = response.origin("https://www.acme.com", false);
let mut response = Response::<String>::merge(wrapped, response.build());
assert_eq!(response.status(), Status::ImATeapot);
assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string()));
let response = Response::new();
let response = response.origin("https://www.example.com", false);
let response = response.build(original);
// Check CORS header
let expected_header = vec!["https://www.acme.com"];
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
// Check other header
let expected_header = vec!["Rocket"];
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
assert_eq!(expected_header, actual_header);
}
#[test]
fn response_does_not_merge_existing_cors() {
let wrapped = response::Response::build()
.raw_header("Access-Control-Allow-Origin", "https://www.example.com")
.finalize();
let response = Response::new(());
let response = response.origin("https://www.acme.com", false);
let response = Response::<()>::merge(wrapped, response.build());
let expected_header = vec!["https://www.example.com"];
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
}
#[test]
fn response_finalize_smoke_test() {
use std::io::Cursor;
use rocket::http::Status;
let wrapped = response::Response::build()
.status(Status::ImATeapot)
.raw_header("X-Teapot-Make", "Rocket")
.sized_body(Cursor::new("Brewing the best coffee!"))
.finalize();
let response = Response::new(wrapped);
let response = response.origin("https://www.acme.com", false);
let client = make_client();
let request = client.get("/");
let mut response = response.finalize(request.inner()).expect("not to fail");
assert_eq!(response.status(), Status::ImATeapot);
assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string()));
// Check CORS header
let expected_header = vec!["https://www.acme.com"];
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
// Check other header
let expected_header = vec!["Rocket"];
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
assert_eq!(expected_header, actual_header);
// Check that `Access-Control-Max-Age` is removed
assert!(response.headers().get("Access-Control-Max-Age").next().is_none());
}
// The following tests check that preflight checks are done properly
// fn make_cors_options() -> Options {
// let (allowed_origins, failed_origins) =
// AllOrSome::new_from_str_list(&["https://www.acme.com"]);
// assert!(failed_origins.is_empty());
// Options {
// allowed_origins: allowed_origins,
// allowed_methods: [Method::Get].iter().cloned().collect(),
// allowed_headers: AllOrSome::Some(
// ["Authorization"]
// .into_iter()
// .map(|s| s.to_string().into())
// .collect(),
// ),
// allow_credentials: true,
// ..Default::default()
// }
// }
// /// Tests that non CORS preflight are let through without modification
// #[test]
// fn preflight_missing_origins_are_let_through() {
// let options = make_cors_options();
// let client = make_client();
// let request = client.get("/");
// let response = options.preflight((), None, None, None).expect("not to fail");
// let headers: Vec<_> = response.headers().iter().collect();
// assert_eq!(headers.len(), 0);
// }
}

View File

@ -15,21 +15,13 @@ use rocket::local::Client;
use rocket_cors::*;
#[options("/")]
fn cors_options(
origin: Option<Origin>,
method: Option<AccessControlRequestMethod>,
headers: Option<AccessControlRequestHeaders>,
options: State<rocket_cors::Options>,
) -> Result<Response<()>, Error> {
options.preflight((), origin, method, headers)
fn cors_options(options: State<rocket_cors::Options>) -> Responder<&str> {
rocket_cors::respond(options, "")
}
#[get("/")]
fn cors(
origin: Option<Origin>,
options: State<rocket_cors::Options>,
) -> Result<Response<&'static str>, Error> {
options.respond("Hello CORS", origin)
fn cors(options: State<rocket_cors::Options>) -> Responder<&str> {
rocket_cors::respond(options, "Hello CORS")
}
fn make_cors_options() -> Options {
@ -146,6 +138,7 @@ fn cors_get_check() {
let req = client.get("/").header(origin_header).header(authorization);
let mut response = req.dispatch();
println!("{:?}", response);
assert_eq!(response.status(), Status::Ok);
let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string()));