Delay CORS checks and response until `Responder::respond_to` is invoked (#6)
* Delay checking of CORS to just before responding * Lifetime issues * Use State::inner() * Fix lifetime issues * Bump Rocket * Document 'static limitation And link to https://github.com/SergioBenitez/Rocket/pull/345 * Remove extraneous comments
This commit is contained in:
parent
16b89ab31c
commit
7dbc22b523
|
@ -16,7 +16,7 @@ travis-ci = { repository = "lawliet89/rocket_cors" }
|
|||
|
||||
[dependencies]
|
||||
log = "0.3"
|
||||
rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "aa51fe0" }
|
||||
rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "51a465f2cc88d537079133bcdfec37d029070dcd" }
|
||||
serde = "1.0"
|
||||
serde_derive = "1.0"
|
||||
unicase="1.4"
|
||||
|
@ -29,5 +29,5 @@ version_check = "0.1"
|
|||
|
||||
[dev-dependencies]
|
||||
hyper = "0.10"
|
||||
rocket_codegen = { git = "https://github.com/SergioBenitez/Rocket", rev = "aa51fe0" }
|
||||
rocket_codegen = { git = "https://github.com/SergioBenitez/Rocket", rev = "51a465f2cc88d537079133bcdfec37d029070dcd" }
|
||||
serde_json = "1.0"
|
||||
|
|
|
@ -27,7 +27,7 @@ In particular, `rocket_cors` is currently targetted for `nightly-2017-07-13`.
|
|||
Rocket > 0.3 is needed. At this moment, `0.3` is not released, and this crate will not be published
|
||||
to Crates.io until Rocket 0.3 is released to Crates.io.
|
||||
|
||||
We currently tie this crate to revision [aa51fe0](https://github.com/SergioBenitez/Rocket/tree/aa51fe0) of Rocket.
|
||||
We currently tie this crate to revision [51a465f2cc88d537079133bcdfec37d029070dcd](https://github.com/SergioBenitez/Rocket/tree/51a465f2cc88d537079133bcdfec37d029070dcd) of Rocket.
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
498
src/lib.rs
498
src/lib.rs
|
@ -29,7 +29,7 @@
|
|||
//! to Crates.io until Rocket 0.3 is released to Crates.io.
|
||||
//!
|
||||
//! We currently tie this crate to revision
|
||||
//! [aa51fe0](https://github.com/SergioBenitez/Rocket/tree/aa51fe0) of Rocket.
|
||||
//! [51a465f2cc88d537079133bcdfec37d029070dcd](https://github.com/SergioBenitez/Rocket/tree/51a465f2cc88d537079133bcdfec37d029070dcd) of Rocket.
|
||||
//!
|
||||
//! ## Installation
|
||||
//!
|
||||
|
@ -118,20 +118,27 @@ extern crate hyper;
|
|||
use std::collections::{HashSet, HashMap};
|
||||
use std::error;
|
||||
use std::fmt;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Deref;
|
||||
use std::str::FromStr;
|
||||
|
||||
use rocket::request::{self, Request, FromRequest};
|
||||
use rocket::response::{self, Responder};
|
||||
use rocket::{Outcome, State};
|
||||
use rocket::http::{Method, Status};
|
||||
use rocket::Outcome;
|
||||
use rocket::request::{self, Request, FromRequest};
|
||||
use rocket::response;
|
||||
use unicase::UniCase;
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
mod test_macros;
|
||||
|
||||
/// CORS related error
|
||||
/// Errors during operations
|
||||
///
|
||||
/// This enum implements `rocket::response::Responder` which will return an appropriate status code
|
||||
/// while printing out the error in the console.
|
||||
/// Because these errors are usually the result of an error while trying to respond to a CORS
|
||||
/// request, CORS headers cannot be added to the response and your applications requesting CORS
|
||||
/// will not be able to see the status code.
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
/// The HTTP request header `Origin` is required but was not provided
|
||||
|
@ -201,7 +208,7 @@ impl fmt::Display for Error {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'r> Responder<'r> for Error {
|
||||
impl<'r> response::Responder<'r> for Error {
|
||||
fn respond_to(self, _: &Request) -> Result<response::Response<'r>, Status> {
|
||||
error_!("CORS Error: {:?}", self);
|
||||
Err(match self {
|
||||
|
@ -259,7 +266,6 @@ impl<'a, 'r> FromRequest<'a, 'r> for Url {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/// The `Origin` request header used in CORS
|
||||
///
|
||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||
|
@ -374,7 +380,11 @@ impl AllOrSome<HashSet<Url>> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Configuration options to for building CORS preflight or actual responses.
|
||||
/// Responder and Fairing for CORS
|
||||
///
|
||||
/// This struct can be used as Fairing for Rocket, or as an ad-hoc responder for any CORS requests.
|
||||
/// You create a new copy of this struct by defining the configurations in the fields below.
|
||||
/// This struct can also be deserialized by serde.
|
||||
///
|
||||
/// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this
|
||||
/// struct. The default for each field is described in the docuementation for the field.
|
||||
|
@ -473,7 +483,30 @@ impl Default for Options {
|
|||
}
|
||||
}
|
||||
|
||||
/// Ad-hoc per route CORS response to requests
|
||||
///
|
||||
/// Note: If you use this, the lifetime parameter `'r` of your `rocket:::response::Responder<'r>`
|
||||
/// CANNOT be `'static`. This is because the code generated by Rocket will implicitly try to
|
||||
/// to restrain the `Request` object passed to the route to `&'static Request`, and it is not
|
||||
/// possible to have such a reference.
|
||||
/// See [this PR on Rocket](https://github.com/SergioBenitez/Rocket/pull/345).
|
||||
pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
|
||||
options: State<'a, Options>,
|
||||
responder: R,
|
||||
) -> Responder<'a, 'r, R> {
|
||||
options.inner().respond(responder)
|
||||
}
|
||||
|
||||
impl Options {
|
||||
/// Wrap any `Rocket::Response` and respond with CORS headers.
|
||||
/// This is only used for ad-hoc route CORS response
|
||||
fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
|
||||
&'a self,
|
||||
responder: R,
|
||||
) -> Responder<'a, 'r, R> {
|
||||
Responder::new(responder, self)
|
||||
}
|
||||
|
||||
fn default_allowed_methods() -> HashSet<Method> {
|
||||
vec![
|
||||
Method::Get,
|
||||
|
@ -486,40 +519,138 @@ impl Options {
|
|||
].into_iter()
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// A CORS Responder which will inspect the incoming requests and respond accoridingly.
|
||||
///
|
||||
/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set,
|
||||
/// this responder will leave the response untouched.
|
||||
/// This allows for chaining of several CORS responders.
|
||||
///
|
||||
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
|
||||
/// existing headers defined:
|
||||
///
|
||||
/// - `Access-Control-Allow-Origin`
|
||||
/// - `Access-Control-Expose-Headers`
|
||||
/// - `Access-Control-Max-Age`
|
||||
/// - `Access-Control-Allow-Credentials`
|
||||
/// - `Access-Control-Allow-Methods`
|
||||
/// - `Access-Control-Allow-Headers`
|
||||
/// - `Vary`
|
||||
#[derive(Debug)]
|
||||
pub struct Responder<'a, 'r: 'a, R> {
|
||||
responder: R,
|
||||
options: &'a Options,
|
||||
marker: PhantomData<response::Responder<'r>>,
|
||||
}
|
||||
|
||||
impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> {
|
||||
fn new(responder: R, options: &'a Options) -> Self {
|
||||
Self {
|
||||
responder,
|
||||
options,
|
||||
marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Respond to a request
|
||||
fn respond(self, request: &Request) -> response::Result<'r> {
|
||||
match self.build_cors_response(request) {
|
||||
Ok(response) => response,
|
||||
Err(e) => response::Responder::respond_to(e, request),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a CORS response and merge with an existing `rocket::Response` for the request
|
||||
fn build_cors_response(self, request: &Request) -> Result<response::Result<'r>, Error> {
|
||||
let original_response = match self.responder.respond_to(request) {
|
||||
Ok(response) => response,
|
||||
Err(status) => return Ok(Err(status)), // TODO: Handle this?
|
||||
};
|
||||
|
||||
// Existing CORS response?
|
||||
if Self::has_allow_origin(&original_response) {
|
||||
return Ok(Ok(original_response));
|
||||
}
|
||||
|
||||
// 1. If the Origin header is not present terminate this set of steps.
|
||||
// The request is outside the scope of this specification.
|
||||
let origin = Self::origin(request)?;
|
||||
let origin = match origin {
|
||||
None => {
|
||||
// Not a CORS request
|
||||
return Ok(Ok(original_response));
|
||||
}
|
||||
Some(origin) => origin,
|
||||
};
|
||||
|
||||
// Check if the request verb is an OPTION or something else
|
||||
let cors_response = match request.method() {
|
||||
Method::Options => {
|
||||
let method = Self::request_method(request)?;
|
||||
let headers = Self::request_headers(request)?;
|
||||
Self::preflight(&self.options, origin, method, headers)
|
||||
}
|
||||
_ => Self::actual_request(&self.options, origin),
|
||||
}?;
|
||||
|
||||
Ok(Ok(cors_response.build(original_response)))
|
||||
}
|
||||
|
||||
/// Gets the `Origin` request header from the request
|
||||
fn origin(request: &Request) -> Result<Option<Origin>, Error> {
|
||||
match Origin::from_request(request) {
|
||||
Outcome::Forward(()) => Ok(None),
|
||||
Outcome::Success(origin) => Ok(Some(origin)),
|
||||
Outcome::Failure((_, err)) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the `Access-Control-Request-Method` request header from the request
|
||||
fn request_method(request: &Request) -> Result<Option<AccessControlRequestMethod>, Error> {
|
||||
match AccessControlRequestMethod::from_request(request) {
|
||||
Outcome::Forward(()) => Ok(None),
|
||||
Outcome::Success(method) => Ok(Some(method)),
|
||||
Outcome::Failure((_, err)) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the `Access-Control-Request-Headers` request header from the request
|
||||
fn request_headers(request: &Request) -> Result<Option<AccessControlRequestHeaders>, Error> {
|
||||
match AccessControlRequestHeaders::from_request(request) {
|
||||
Outcome::Forward(()) => Ok(None),
|
||||
Outcome::Success(geaders) => Ok(Some(geaders)),
|
||||
Outcome::Failure((_, err)) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if an existing Response already has the header `Access-Control-Allow-Origin`
|
||||
fn has_allow_origin(response: &response::Response<'r>) -> bool {
|
||||
response.headers().get("Access-Control-Allow-Origin").next() != None
|
||||
}
|
||||
|
||||
/// Construct a preflight response based on the options. Will return an `Err`
|
||||
/// if any of the preflight checks fail.
|
||||
///
|
||||
/// This implementation references the
|
||||
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests).
|
||||
pub fn preflight<'r, R: Responder<'r>>(
|
||||
&self,
|
||||
responder: R,
|
||||
origin: Option<Origin>,
|
||||
fn preflight(
|
||||
options: &Options,
|
||||
origin: Origin,
|
||||
method: Option<AccessControlRequestMethod>,
|
||||
headers: Option<AccessControlRequestHeaders>,
|
||||
) -> Result<Response<R>, Error> {
|
||||
) -> Result<Response, Error> {
|
||||
|
||||
let response = Response::new(responder);
|
||||
let response = Response::new();
|
||||
|
||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||
|
||||
// 1. If the Origin header is not present terminate this set of steps.
|
||||
// The request is outside the scope of this specification.
|
||||
let origin = match origin {
|
||||
None => {
|
||||
// Not a CORS request
|
||||
return Ok(response);
|
||||
}
|
||||
Some(origin) => origin,
|
||||
};
|
||||
|
||||
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
||||
// in list of origins do not set any additional headers and terminate this set of steps.
|
||||
let response = response.allowed_origin(
|
||||
&origin,
|
||||
&self.allowed_origins,
|
||||
self.send_wildcard,
|
||||
&options.allowed_origins,
|
||||
options.send_wildcard,
|
||||
)?;
|
||||
|
||||
// 3. Let `method` be the value as result of parsing the Access-Control-Request-Method
|
||||
|
@ -540,13 +671,13 @@ impl Options {
|
|||
// 5. If method is not a case-sensitive match for any of the values in list of methods
|
||||
// do not set any additional headers and terminate this set of steps.
|
||||
|
||||
let response = response.allowed_methods(&method, &self.allowed_methods)?;
|
||||
let response = response.allowed_methods(&method, &options.allowed_methods)?;
|
||||
|
||||
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
|
||||
// values in list of headers do not set any additional headers and terminate this set of
|
||||
// steps.
|
||||
let response = if let Some(headers) = headers {
|
||||
response.allowed_headers(&headers, &self.allowed_headers)?
|
||||
response.allowed_headers(&headers, &options.allowed_headers)?
|
||||
} else {
|
||||
response
|
||||
};
|
||||
|
@ -559,12 +690,12 @@ impl Options {
|
|||
// with either the value of the Origin header or the string "*" as value.
|
||||
// Note: The string "*" cannot be used for a resource that supports credentials.
|
||||
|
||||
let response = response.credentials(self.allow_credentials)?;
|
||||
let response = response.credentials(options.allow_credentials)?;
|
||||
|
||||
// 8. Optionally add a single Access-Control-Max-Age header
|
||||
// with as value the amount of seconds the user agent is allowed to cache the result of the
|
||||
// request.
|
||||
let response = response.max_age(self.max_age);
|
||||
let response = response.max_age(options.max_age);
|
||||
|
||||
// 9. If method is a simple method this step may be skipped.
|
||||
// Add one or more Access-Control-Allow-Methods headers consisting of
|
||||
|
@ -591,36 +722,22 @@ impl Options {
|
|||
Ok(response)
|
||||
}
|
||||
|
||||
/// Respond to a request based on the settings.
|
||||
/// Respond to an actual request based on the settings.
|
||||
/// If the `Origin` is not provided, then this request was not made by a browser and there is no
|
||||
/// CORS enforcement.
|
||||
pub fn respond<'r, R: Responder<'r>>(
|
||||
&self,
|
||||
responder: R,
|
||||
origin: Option<Origin>,
|
||||
) -> Result<Response<R>, Error> {
|
||||
let response = Response::new(responder);
|
||||
fn actual_request(options: &Options, origin: Origin) -> Result<Response, Error> {
|
||||
let response = Response::new();
|
||||
|
||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||
|
||||
// 1. If the Origin header is not present terminate this set of steps.
|
||||
// The request is outside the scope of this specification.
|
||||
let origin = match origin {
|
||||
None => {
|
||||
// Not a CORS request
|
||||
return Ok(response);
|
||||
}
|
||||
Some(origin) => origin,
|
||||
};
|
||||
|
||||
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
||||
// in list of origins, do not set any additional headers and terminate this set of steps.
|
||||
// Always matching is acceptable since the list of origins can be unbounded.
|
||||
|
||||
let response = response.allowed_origin(
|
||||
&origin,
|
||||
&self.allowed_origins,
|
||||
self.send_wildcard,
|
||||
&options.allowed_origins,
|
||||
options.send_wildcard,
|
||||
)?;
|
||||
|
||||
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
|
||||
|
@ -631,7 +748,7 @@ impl Options {
|
|||
// with either the value of the Origin header or the string "*" as value.
|
||||
// Note: The string "*" cannot be used for a resource that supports credentials.
|
||||
|
||||
let response = response.credentials(self.allow_credentials)?;
|
||||
let response = response.credentials(options.allow_credentials)?;
|
||||
|
||||
// 4. If the list of exposed headers is not empty add one or more
|
||||
// Access-Control-Expose-Headers headers, with as values the header field names given in
|
||||
|
@ -641,7 +758,8 @@ impl Options {
|
|||
// and url is a case-sensitive match for the URL of the resource.
|
||||
|
||||
let response = response.exposed_headers(
|
||||
self.expose_headers
|
||||
options
|
||||
.expose_headers
|
||||
.iter()
|
||||
.map(|s| &**s)
|
||||
.collect::<Vec<&str>>()
|
||||
|
@ -651,16 +769,14 @@ impl Options {
|
|||
}
|
||||
}
|
||||
|
||||
/// A CORS Response which wraps another struct which implements `Responder`. You will typically
|
||||
/// use [`Options`] instead to verify and build the response instead of this directly.
|
||||
/// See module level documentation for usage examples.
|
||||
///
|
||||
/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set,
|
||||
/// this responder will leave the response untouched.
|
||||
/// This allows for chaining of several CORS responders.
|
||||
///
|
||||
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
|
||||
/// existing headers defined:
|
||||
impl<'a, 'r: 'a, R: response::Responder<'r>> response::Responder<'r> for Responder<'a, 'r, R> {
|
||||
fn respond_to(self, request: &Request) -> response::Result<'r> {
|
||||
self.respond(request)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// A CORS Response which provides the following CORS headers:
|
||||
///
|
||||
/// - `Access-Control-Allow-Origin`
|
||||
/// - `Access-Control-Expose-Headers`
|
||||
|
@ -670,8 +786,7 @@ impl Options {
|
|||
/// - `Access-Control-Allow-Headers`
|
||||
/// - `Vary`
|
||||
#[derive(Debug)]
|
||||
pub struct Response<R> {
|
||||
responder: R,
|
||||
struct Response {
|
||||
allow_origin: Option<AllOrSome<String>>,
|
||||
allow_methods: HashSet<Method>,
|
||||
allow_headers: HeaderFieldNamesSet,
|
||||
|
@ -681,14 +796,13 @@ pub struct Response<R> {
|
|||
vary_origin: bool,
|
||||
}
|
||||
|
||||
impl<'r, R: Responder<'r>> Response<R> {
|
||||
impl Response {
|
||||
/// Consumes the responder and return an empty `Response`
|
||||
fn new(responder: R) -> Self {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
allow_origin: None,
|
||||
allow_headers: HashSet::new(),
|
||||
allow_methods: HashSet::new(),
|
||||
responder,
|
||||
allow_credentials: false,
|
||||
expose_headers: HashSet::new(),
|
||||
max_age: None,
|
||||
|
@ -826,15 +940,18 @@ impl<'r, R: Responder<'r>> Response<R> {
|
|||
)
|
||||
}
|
||||
|
||||
/// Builds a `rocket::Response` from this struct containing only the CORS headers.
|
||||
/// Builds a `rocket::Response` from this struct based off some base `rocket::Response`
|
||||
///
|
||||
/// This will overwrite any existing CORS headers
|
||||
#[allow(unused_results)]
|
||||
fn build(&self) -> response::Response<'r> {
|
||||
let mut builder = response::Response::build();
|
||||
fn build<'r>(&self, base: response::Response<'r>) -> response::Response<'r> {
|
||||
let mut response = response::Response::build_from(base).finalize();
|
||||
|
||||
// TODO: We should be able to remove this
|
||||
let origin = match self.allow_origin {
|
||||
None => {
|
||||
// This is not a CORS response
|
||||
return builder.finalize();
|
||||
return response;
|
||||
}
|
||||
Some(ref origin) => origin,
|
||||
};
|
||||
|
@ -844,10 +961,12 @@ impl<'r, R: Responder<'r>> Response<R> {
|
|||
AllOrSome::Some(ref origin) => origin.to_string(),
|
||||
};
|
||||
|
||||
builder.raw_header("Access-Control-Allow-Origin", origin);
|
||||
response.set_raw_header("Access-Control-Allow-Origin", origin);
|
||||
|
||||
if self.allow_credentials {
|
||||
builder.raw_header("Access-Control-Allow-Credentials", "true");
|
||||
response.set_raw_header("Access-Control-Allow-Credentials", "true");
|
||||
} else {
|
||||
response.remove_header("Access-Control-Allow-Credentials");
|
||||
}
|
||||
|
||||
if !self.expose_headers.is_empty() {
|
||||
|
@ -857,7 +976,9 @@ impl<'r, R: Responder<'r>> Response<R> {
|
|||
.collect();
|
||||
let headers = headers.join(", ");
|
||||
|
||||
builder.raw_header("Access-Control-Expose-Headers", headers);
|
||||
response.set_raw_header("Access-Control-Expose-Headers", headers);
|
||||
} else {
|
||||
response.remove_header("Access-Control-Expose-Headers");
|
||||
}
|
||||
|
||||
if !self.allow_headers.is_empty() {
|
||||
|
@ -867,76 +988,34 @@ impl<'r, R: Responder<'r>> Response<R> {
|
|||
.collect();
|
||||
let headers = headers.join(", ");
|
||||
|
||||
builder.raw_header("Access-Control-Allow-Headers", headers);
|
||||
response.set_raw_header("Access-Control-Allow-Headers", headers);
|
||||
} else {
|
||||
response.remove_header("Access-Control-Allow-Headers");
|
||||
}
|
||||
|
||||
if !self.allow_methods.is_empty() {
|
||||
let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect();
|
||||
let methods = methods.join(", ");
|
||||
|
||||
builder.raw_header("Access-Control-Allow-Methods", methods);
|
||||
response.set_raw_header("Access-Control-Allow-Methods", methods);
|
||||
} else {
|
||||
response.remove_header("Access-Control-Allow-Methods");
|
||||
}
|
||||
|
||||
if self.max_age.is_some() {
|
||||
let max_age = self.max_age.unwrap();
|
||||
builder.raw_header("Access-Control-Max-Age", max_age.to_string());
|
||||
response.set_raw_header("Access-Control-Max-Age", max_age.to_string());
|
||||
} else {
|
||||
response.remove_header("Access-Control-Max-Age");
|
||||
}
|
||||
|
||||
if self.vary_origin {
|
||||
builder.raw_header("Vary", "Origin");
|
||||
response.set_raw_header("Vary", "Origin");
|
||||
} else {
|
||||
response.remove_header("Vary");
|
||||
}
|
||||
|
||||
builder.finalize()
|
||||
}
|
||||
|
||||
/// Merge a `wrapped` Response with a `cors` response
|
||||
///
|
||||
/// If the `wrapped` response has the `Access-Control-Allow-Origin` header already defined,
|
||||
/// it will be left untouched. This allows for chaining of several CORS responders.
|
||||
///
|
||||
/// Otherwise, the merging will be done according to the rules of `rocket::Response::merge`.
|
||||
fn merge(
|
||||
mut wrapped: response::Response<'r>,
|
||||
cors: response::Response<'r>,
|
||||
) -> response::Response<'r> {
|
||||
|
||||
let existing_cors = {
|
||||
wrapped.headers().get("Access-Control-Allow-Origin").next() == None
|
||||
};
|
||||
|
||||
if existing_cors {
|
||||
wrapped.merge(cors);
|
||||
}
|
||||
|
||||
wrapped
|
||||
}
|
||||
|
||||
/// Finalize the Response by merging the CORS header with the wrapped `Responder
|
||||
///
|
||||
/// If the original response has the `Access-Control-Allow-Origin` header already defined,
|
||||
/// it will be left untouched.This allows for chaining of several CORS responders.
|
||||
///
|
||||
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
|
||||
/// existing headers defined:
|
||||
///
|
||||
/// - `Access-Control-Allow-Origin`
|
||||
/// - `Access-Control-Expose-Headers`
|
||||
/// - `Access-Control-Max-Age`
|
||||
/// - `Access-Control-Allow-Credentials`
|
||||
/// - `Access-Control-Allow-Methods`
|
||||
/// - `Access-Control-Allow-Headers`
|
||||
/// - `Vary`
|
||||
fn finalize(self, request: &Request) -> response::Result<'r> {
|
||||
let cors_response = self.build();
|
||||
let original_response = self.responder.respond_to(request)?;
|
||||
|
||||
Ok(Self::merge(original_response, cors_response))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
|
||||
fn respond_to(self, request: &Request) -> response::Result<'r> {
|
||||
self.finalize(request)
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1059,7 +1138,7 @@ mod tests {
|
|||
let allowed_origins = AllOrSome::All;
|
||||
let send_wildcard = true;
|
||||
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = not_err!(response.allowed_origin(
|
||||
&origin,
|
||||
&allowed_origins,
|
||||
|
@ -1071,7 +1150,7 @@ mod tests {
|
|||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["*"];
|
||||
let response = response.build();
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
|
@ -1086,7 +1165,7 @@ mod tests {
|
|||
let allowed_origins = AllOrSome::All;
|
||||
let send_wildcard = false;
|
||||
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = not_err!(response.allowed_origin(
|
||||
&origin,
|
||||
&allowed_origins,
|
||||
|
@ -1103,7 +1182,7 @@ mod tests {
|
|||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec![url];
|
||||
let response = response.build();
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
|
@ -1120,7 +1199,7 @@ mod tests {
|
|||
assert!(failed_origins.is_empty());
|
||||
let send_wildcard = false;
|
||||
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = not_err!(response.allowed_origin(
|
||||
&origin,
|
||||
&allowed_origins,
|
||||
|
@ -1138,7 +1217,7 @@ mod tests {
|
|||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec![url];
|
||||
let response = response.build();
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
|
@ -1156,7 +1235,7 @@ mod tests {
|
|||
assert!(failed_origins.is_empty());
|
||||
let send_wildcard = false;
|
||||
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let _ = response
|
||||
.allowed_origin(&origin, &allowed_origins, send_wildcard)
|
||||
.unwrap();
|
||||
|
@ -1165,7 +1244,7 @@ mod tests {
|
|||
#[test]
|
||||
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
|
||||
fn response_credentials_does_not_allow_wildcard_with_all_origins() {
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = response.any();
|
||||
|
||||
let _ = response.credentials(true).unwrap();
|
||||
|
@ -1173,7 +1252,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn response_credentials_allows_specific_origins() {
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
|
||||
let response = response.credentials(true).expect(
|
||||
|
@ -1183,7 +1262,7 @@ mod tests {
|
|||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["true"];
|
||||
let response = response.build();
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Credentials")
|
||||
|
@ -1194,12 +1273,12 @@ mod tests {
|
|||
#[test]
|
||||
fn response_sets_exposed_headers_correctly() {
|
||||
let headers = vec!["Bar", "Baz", "Foo"];
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response.exposed_headers(&headers);
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build();
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Expose-Headers")
|
||||
|
@ -1216,27 +1295,27 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn response_sets_max_age_correctly() {
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
|
||||
let response = response.max_age(Some(42));
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["42"];
|
||||
let response = response.build();
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_does_not_set_max_age_when_none() {
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
|
||||
let response = response.max_age(None);
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build();
|
||||
let response = response.build(response::Response::new());
|
||||
assert!(response
|
||||
.headers()
|
||||
.get("Access-Control-Max-Age")
|
||||
|
@ -1249,7 +1328,7 @@ mod tests {
|
|||
let allowed_headers = AllOrSome::All;
|
||||
let requested_headers = vec!["Bar", "Foo"];
|
||||
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response
|
||||
.allowed_headers(
|
||||
|
@ -1259,7 +1338,7 @@ mod tests {
|
|||
.expect("to not fail");
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build();
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Headers")
|
||||
|
@ -1285,7 +1364,7 @@ mod tests {
|
|||
|
||||
let method = "GET";
|
||||
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response
|
||||
.allowed_methods(
|
||||
|
@ -1295,7 +1374,7 @@ mod tests {
|
|||
.expect("not to fail");
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build();
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Methods")
|
||||
|
@ -1324,7 +1403,7 @@ mod tests {
|
|||
|
||||
let method = "DELETE";
|
||||
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let _ = response
|
||||
.allowed_methods(
|
||||
|
@ -1341,7 +1420,7 @@ mod tests {
|
|||
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
||||
let requested_headers = vec!["Bar", "Foo"];
|
||||
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response
|
||||
.allowed_headers(
|
||||
|
@ -1356,7 +1435,7 @@ mod tests {
|
|||
.expect("to not fail");
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build();
|
||||
let response = response.build(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Headers")
|
||||
|
@ -1377,7 +1456,7 @@ mod tests {
|
|||
let allowed_headers = vec!["Bar", "Baz", "Foo"];
|
||||
let requested_headers = vec!["Bar", "Foo", "Unknown"];
|
||||
|
||||
let response = Response::new(());
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let _ = response
|
||||
.allowed_headers(
|
||||
|
@ -1395,98 +1474,77 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn response_does_not_build_if_origin_is_not_set() {
|
||||
let response = Response::new(());
|
||||
let response = response.build();
|
||||
let response = Response::new();
|
||||
let response = response.build(response::Response::new());
|
||||
|
||||
let headers: Vec<_> = response.headers().iter().collect();
|
||||
assert_eq!(headers.len(), 0);
|
||||
}
|
||||
|
||||
// Note: Correct operation of Response::build is tested in the tests above for each of the
|
||||
// individual headers
|
||||
|
||||
#[test]
|
||||
fn response_merges_correctly() {
|
||||
fn response_build_removes_existing_cors_headers_and_keeps_others() {
|
||||
use std::io::Cursor;
|
||||
use rocket::http::Status;
|
||||
|
||||
let wrapped = response::Response::build()
|
||||
let original = response::Response::build()
|
||||
.status(Status::ImATeapot)
|
||||
.raw_header("X-Teapot-Make", "Rocket")
|
||||
.raw_header("Access-Control-Max-Age", "42")
|
||||
.sized_body(Cursor::new("Brewing the best coffee!"))
|
||||
.finalize();
|
||||
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.acme.com", false);
|
||||
|
||||
let mut response = Response::<String>::merge(wrapped, response.build());
|
||||
assert_eq!(response.status(), Status::ImATeapot);
|
||||
assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string()));
|
||||
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response.build(original);
|
||||
// Check CORS header
|
||||
let expected_header = vec!["https://www.acme.com"];
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
|
||||
// Check other header
|
||||
let expected_header = vec!["Rocket"];
|
||||
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_does_not_merge_existing_cors() {
|
||||
let wrapped = response::Response::build()
|
||||
.raw_header("Access-Control-Allow-Origin", "https://www.example.com")
|
||||
.finalize();
|
||||
|
||||
let response = Response::new(());
|
||||
let response = response.origin("https://www.acme.com", false);
|
||||
|
||||
let response = Response::<()>::merge(wrapped, response.build());
|
||||
let expected_header = vec!["https://www.example.com"];
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_finalize_smoke_test() {
|
||||
use std::io::Cursor;
|
||||
use rocket::http::Status;
|
||||
|
||||
let wrapped = response::Response::build()
|
||||
.status(Status::ImATeapot)
|
||||
.raw_header("X-Teapot-Make", "Rocket")
|
||||
.sized_body(Cursor::new("Brewing the best coffee!"))
|
||||
.finalize();
|
||||
|
||||
let response = Response::new(wrapped);
|
||||
let response = response.origin("https://www.acme.com", false);
|
||||
|
||||
let client = make_client();
|
||||
let request = client.get("/");
|
||||
let mut response = response.finalize(request.inner()).expect("not to fail");
|
||||
|
||||
assert_eq!(response.status(), Status::ImATeapot);
|
||||
assert_eq!(response.body_string(), Some("Brewing the best coffee!".to_string()));
|
||||
|
||||
// Check CORS header
|
||||
let expected_header = vec!["https://www.acme.com"];
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
|
||||
// Check other header
|
||||
let expected_header = vec!["Rocket"];
|
||||
let actual_header: Vec<_> = response.headers().get("X-Teapot-Make").collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
|
||||
// Check that `Access-Control-Max-Age` is removed
|
||||
assert!(response.headers().get("Access-Control-Max-Age").next().is_none());
|
||||
|
||||
|
||||
}
|
||||
|
||||
// The following tests check that preflight checks are done properly
|
||||
|
||||
// fn make_cors_options() -> Options {
|
||||
// let (allowed_origins, failed_origins) =
|
||||
// AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
// assert!(failed_origins.is_empty());
|
||||
|
||||
// Options {
|
||||
// allowed_origins: allowed_origins,
|
||||
// allowed_methods: [Method::Get].iter().cloned().collect(),
|
||||
// allowed_headers: AllOrSome::Some(
|
||||
// ["Authorization"]
|
||||
// .into_iter()
|
||||
// .map(|s| s.to_string().into())
|
||||
// .collect(),
|
||||
// ),
|
||||
// allow_credentials: true,
|
||||
// ..Default::default()
|
||||
// }
|
||||
// }
|
||||
|
||||
// /// Tests that non CORS preflight are let through without modification
|
||||
// #[test]
|
||||
// fn preflight_missing_origins_are_let_through() {
|
||||
// let options = make_cors_options();
|
||||
// let client = make_client();
|
||||
// let request = client.get("/");
|
||||
|
||||
// let response = options.preflight((), None, None, None).expect("not to fail");
|
||||
|
||||
// let headers: Vec<_> = response.headers().iter().collect();
|
||||
// assert_eq!(headers.len(), 0);
|
||||
// }
|
||||
}
|
||||
|
|
|
@ -15,21 +15,13 @@ use rocket::local::Client;
|
|||
use rocket_cors::*;
|
||||
|
||||
#[options("/")]
|
||||
fn cors_options(
|
||||
origin: Option<Origin>,
|
||||
method: Option<AccessControlRequestMethod>,
|
||||
headers: Option<AccessControlRequestHeaders>,
|
||||
options: State<rocket_cors::Options>,
|
||||
) -> Result<Response<()>, Error> {
|
||||
options.preflight((), origin, method, headers)
|
||||
fn cors_options(options: State<rocket_cors::Options>) -> Responder<&str> {
|
||||
rocket_cors::respond(options, "")
|
||||
}
|
||||
|
||||
#[get("/")]
|
||||
fn cors(
|
||||
origin: Option<Origin>,
|
||||
options: State<rocket_cors::Options>,
|
||||
) -> Result<Response<&'static str>, Error> {
|
||||
options.respond("Hello CORS", origin)
|
||||
fn cors(options: State<rocket_cors::Options>) -> Responder<&str> {
|
||||
rocket_cors::respond(options, "Hello CORS")
|
||||
}
|
||||
|
||||
fn make_cors_options() -> Options {
|
||||
|
@ -146,6 +138,7 @@ fn cors_get_check() {
|
|||
let req = client.get("/").header(origin_header).header(authorization);
|
||||
|
||||
let mut response = req.dispatch();
|
||||
println!("{:?}", response);
|
||||
assert_eq!(response.status(), Status::Ok);
|
||||
let body_str = response.body().and_then(|body| body.into_string());
|
||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||
|
|
Loading…
Reference in New Issue