Refactor Implementation (#3)

* Introduce AllOrSome enum

More AllOrSome usage

Allowed methods

Preflight response

Response

* Additional documentation

* Support non CORS requests
This commit is contained in:
Yong Wen Chua 2017-07-14 11:03:45 +08:00 committed by GitHub
parent c9190abdc4
commit dfc1cdfee0
2 changed files with 432 additions and 256 deletions

View File

@ -34,11 +34,15 @@ We currently tie this crate to revision [aa51fe0](https://github.com/SergioBenit
<!-- Add the following to Cargo.toml:
```toml
biscuit = "0.0.6"
rocket_cors = "0.0.6"
``` -->
To use the latest `master` branch, for example:
```toml
biscuit = { git = "https://github.com/lawliet89/rocket_cors", branch = "master" }
rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", branch = "master" }
```
## Reference
- [W3C CORS Recommendation](https://www.w3.org/TR/cors/#resource-processing-model)

View File

@ -1,84 +1,50 @@
//! Cross-origin resource sharing (CORS) for Rocket.rs applications
//! [![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors)
//! [![Dependency Status](https://dependencyci.com/github/lawliet89/rocket_cors/badge)](https://dependencyci.com/github/lawliet89/rocket_cors)
//! [![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors)
//! <!-- [![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors) -->
//! <!-- [![Documentation](https://docs.rs/rocket_cors/badge.svg)](https://docs.rs/rocket_cors) -->
//!
//! Rocket (as of v0.2) does not have middleware support. Support for it is (supposedly)
//! on the way. In the mean time, we adopt an
//! [example implementation](https://github.com/SergioBenitez/Rocket/pull/141) to nest
//! `Responders` to acheive the same effect in the short run.
//! - Documentation: stable | [master branch](https://lawliet89.github.io/rocket_cors)
//!
//! # Examples
//! Cross-origin resource sharing (CORS) for [Rocket](https://rocket.rs/) applications
//!
//! ## Requirements
//!
//! - Nightly Rust
//! - Rocket > 0.3
//!
//! ### Nightly Rust
//!
//! Rocket requires nightly Rust. You should probably install Rust with
//! [rustup](https://www.rustup.rs/), then override the code directory to use nightly instead of
//! stable. See
//! [installation instructions](https://rocket.rs/guide/getting-started/#installing-rust).
//!
//! In particular, `rocket_cors` is currently targetted for `nightly-2017-07-13`.
//!
//! ### Rocket > 0.3
//!
//! Rocket > 0.3 is needed. At this moment, `0.3` is not released, and this crate will not be
//! published
//! to Crates.io until Rocket 0.3 is released to Crates.io.
//!
//! We currently tie this crate to revision
//! [aa51fe0](https://github.com/SergioBenitez/Rocket/tree/aa51fe0) of Rocket.
//!
//! ## Installation
//!
//! <!-- Add the following to Cargo.toml:
//!
//! ```toml
//! rocket_cors = "0.0.6"
//! ``` -->
//!
//! To use the latest `master` branch, for example:
//!
//! ```toml
//! rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", branch = "master" }
//! ```
//! #![feature(plugin, custom_derive)]
//! #![plugin(rocket_codegen)]
//! extern crate hyper;
//! extern crate rocket;
//! extern crate rocket_cors;
//!
//! use std::str::FromStr;
//!
//! use rocket::State;
//! use rocket::http::Method::*;
//! use rocket::http::{Header, Status};
//! use rocket::local::Client;
//! use rocket_cors::*;
//!
//! #[options("/")]
//! fn cors_options(origin: Option<Origin>,
//! method: AccessControlRequestMethod,
//! headers: AccessControlRequestHeaders,
//! options: State<rocket_cors::Options>)
//! -> Result<Response<()>, Error> {
//! options.preflight(origin, &method, Some(&headers))
//! }
//!
//! #[get("/")]
//! fn cors(origin: Option<Origin>, options: State<rocket_cors::Options>)
//! -> Result<Response<&'static str>, Error>
//! {
//! options.respond("Hello CORS", origin)
//! }
//!
//! # fn main() {
//! let (allowed_origins, failed_origins) =
//! AllowedOrigins::new_from_str_list(&["https://www.acme.com"]);
//! assert!(failed_origins.is_empty());
//! let cors_options = rocket_cors::Options {
//! allowed_origins: allowed_origins,
//! allowed_methods: [Get].iter().cloned().collect(),
//! allowed_headers: ["Authorization"].iter().map(|s| s.to_string().into()).collect(),
//! allow_credentials: true,
//! ..Default::default()
//! };
//! let rocket = rocket::ignite().mount("/", routes![cors, cors_options]).manage(cors_options);
//! let client = Client::new(rocket).unwrap();
//!
//! // `Options` pre-flight checks
//! let origin_header =
//! Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
//! let method_header =
//! Header::from(hyper::header::AccessControlRequestMethod(hyper::method::Method::Get));
//! let request_headers =
//! hyper::header::AccessControlRequestHeaders(
//! vec![FromStr::from_str("Authorization").unwrap()]);
//! let request_headers = Header::from(request_headers);
//! let req =
//! client.options("/").header(origin_header).header(method_header).header(request_headers);
//!
//! let response = req.dispatch();
//! assert_eq!(response.status(), Status::Ok);
//!
//! // "Actual" request
//! let origin_header =
//! Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
//! let authorization = Header::new("Authorization", "let me in");
//! let req = client.get("/").header(origin_header).header(authorization);
//!
//! let mut response = req.dispatch();
//! assert_eq!(response.status(), Status::Ok);
//! let body_str = response.body().and_then(|body| body.into_string());
//! assert_eq!(body_str, Some("Hello CORS".to_string()));
//! # }
//! ```
#![allow(
legacy_directory_ownership,
@ -183,6 +149,10 @@ pub enum Error {
MethodNotAllowed,
/// One or more headers requested are not allowed
HeadersNotAllowed,
/// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C
///
/// This is a misconfiguration. Check the docuemntation for `Options`.
CredentialsWithWildcardOrigin,
}
impl error::Error for Error {
@ -204,6 +174,11 @@ impl error::Error for Error {
Error::OriginNotAllowed => "Origin is not allowed to request",
Error::MethodNotAllowed => "Method is not allowed",
Error::HeadersNotAllowed => "Headers are not allowed",
Error::CredentialsWithWildcardOrigin => {
"Credentials are allowed, but the Origin is set to \"*\". \
This is not allowed by W3C"
}
}
}
@ -231,6 +206,7 @@ impl<'r> Responder<'r> for Error {
Err(match self {
Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed |
Error::HeadersNotAllowed => Status::Forbidden,
Error::CredentialsWithWildcardOrigin => Status::InternalServerError,
_ => Status::BadRequest,
})
}
@ -308,12 +284,13 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
Err(e) => Outcome::Failure((Status::BadRequest, Error::BadRequestMethod(e))),
}
}
None => Outcome::Failure((Status::BadRequest, Error::MissingRequestMethod)),
None => Outcome::Forward(()),
}
}
}
type HeaderFieldNamesSet = HashSet<UniCase<String>>;
type HeaderFieldName = UniCase<String>;
type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
/// The `Access-Control-Request-Headers` request header
#[derive(Debug)]
@ -350,82 +327,30 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
}
}
}
None => Outcome::Failure((Status::BadRequest, Error::MissingRequestHeaders)),
None => Outcome::Forward(()),
}
}
}
/// Origins that are allowed to issue CORS request. This is needed for browser
/// access to the authentication server, but tools like `curl`
/// do not obey nor enforce the CORS convention.
///
/// This enum (de)serialized as an [untagged](https://serde.rs/enum-representations.html)
/// enum variant.
///
/// # Examples
/// ## Allow all origins
/// ```json
/// { "allowed_origins": null }
/// ```
/// ```
/// extern crate rocket_cors;
/// #[macro_use]
/// extern crate serde_derive;
/// extern crate serde_json;
///
/// use rocket_cors::*;
///
/// # fn main() {
/// #[derive(Serialize, Deserialize)]
/// struct Test {
/// allowed_origins: AllowedOrigins
/// }
///
/// let json = r#"{ "allowed_origins": null }"#;
/// let deserialized: Test = serde_json::from_str(json).unwrap();
/// # }
/// ```
/// ## Allow specific origins
///
/// ```json
/// { "allowed_origins": ["http://127.0.0.1:8000/","https://foobar.com/"] }
/// ```
///
/// ```
/// extern crate rocket_cors;
/// #[macro_use]
/// extern crate serde_derive;
/// extern crate serde_json;
///
/// use rocket_cors::*;
///
/// # fn main() {
/// #[derive(Serialize, Deserialize)]
/// struct Test {
/// allowed_origins: AllowedOrigins
/// }
///
/// let json = r#"{ "allowed_origins": ["http://127.0.0.1:8000/","https://foobar.com/"] }"#;
/// let deserialized: Test = serde_json::from_str(json).unwrap();
/// # }
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AllowedOrigins {
/// All origins are allowed. Equivalent to the "*" value.
pub enum AllOrSome<T> {
/// Everything is allowed. Usually equivalent to the "*" value.
All,
/// Only origins listed are allowed.
Some(HashSet<Url>),
/// Only some of `T` is allowed
Some(T),
}
impl Default for AllowedOrigins {
impl<T> Default for AllOrSome<T> {
fn default() -> Self {
AllowedOrigins::All
AllOrSome::All
}
}
impl AllowedOrigins {
/// New `AllowedOrigins` from a list of URL strings.
/// Returns a tuple where the first element is the struct `AllowedOrigins`,
impl AllOrSome<HashSet<Url>> {
/// New `AllOrSome` from a list of URL strings.
/// Returns a tuple where the first element is the struct `AllOrSome`,
/// and the second element
/// is a map of strings which failed to parse into URLs and their associated parse errors.
pub fn new_from_str_list(urls: &[&str]) -> (Self, HashMap<String, url::ParseError>) {
@ -440,59 +365,234 @@ impl AllowedOrigins {
let ok_set = ok_set.into_iter().map(|(_, r)| r.unwrap()).collect();
(AllowedOrigins::Some(ok_set), error_map)
(AllOrSome::Some(ok_set), error_map)
}
}
/// Options to aid in the building of a CORS response during pre-flight or after.
/// See module level documentation for usage examples.
#[derive(Clone, Debug, Default)]
/// Configuration options to for building CORS preflight or actual responses.
///
/// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this
/// struct. The default for each field is described in the docuementation for the field.
#[derive(Clone, Debug)]
pub struct Options {
/// Origins that are allowed to make requests.
/// Will be verified against the `Origin` request header.
pub allowed_origins: AllowedOrigins,
/// Methods that the clients are allowed to request in.
/// Will be verified against the `Access-Control-Request-Method` request header
/// during pre-flight only.
///
/// When `All` is set, and `send_wildcard` is set, "*" will be sent in
/// the `Access-Control-Allow-Origin` response header. Otherwise, the client's `Origin` request
/// header will be echoed back in the `Access-Control-Allow-Origin` response header.
///
/// When `Some` is set, the client's `Origin` request header will be checked in a
/// case-sensitive manner.
///
/// This is the `list of origins` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
///
/// This field defaults to `All`.
/// # Examples
/// ## Allow all origins
/// ```json
/// { "allowed_origins": null }
///
/// ## Allow specific origins
///
/// ```json
/// { "allowed_origins": ["http://127.0.0.1:8000/","https://foobar.com/"] }
/// ```
// #[serde(default)]
pub allowed_origins: AllOrSome<HashSet<Url>>,
/// The list of methods which the allowed origins are allowed to access for
/// non-simple requests.
///
/// This is the `list of methods` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
/// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
// #[serde(default = "Options::default_allowed_methods")]
pub allowed_methods: HashSet<Method>,
/// Headers that the clients are allowed to request in.
/// Will be verified against the `Access-Control-Request-Headers` request header
/// during pre-flight only.
pub allowed_headers: HeaderFieldNamesSet,
/// The `Access-Control-Allow-Credentials` response header
/// The list of header field names which can be used when this resource is accessed by allowed
/// origins.
///
/// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers`
/// will be echoed back in the `Access-Control-Allow-Headers` header.
///
/// This is the `list of headers` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
/// Defaults to `All`.
pub allowed_headers: AllOrSome<HashSet<HeaderFieldName>>,
/// Allows users to make authenticated requests.
/// If true, injects the `Access-Control-Allow-Credentials` header in responses.
/// This allows cookies and credentials to be submitted across domains.
///
/// This **CANNOT** be used in conjunction with `allowed_origins` set to `All` and
/// `send_wildcard` set to `true`. Depending on the mode of usage, this will either result
/// in an `Error::CredentialsWithWildcardOrigin` error during Rocket launch or runtime.
///
/// Defaults to `false`.
pub allow_credentials: bool,
/// The `Access-Control-Expose-Headers` responde header
/// The list of headers which are safe to expose to the API of a CORS API specification.
/// This corresponds to the `Access-Control-Expose-Headers` responde header.
///
/// This is the `list of exposed headers` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
/// This defaults to an empty set.
pub expose_headers: HashSet<String>,
/// The `Access-Control-Max-Age` response header
/// The maximum time for which this CORS request maybe cached. This value is set as the
/// `Access-Control-Max-Age` header.
///
/// This defaults to `None` (unset).
pub max_age: Option<usize>,
/// If true, and the `allowed_origins` parameter is `All`, a wildcard
/// `Access-Control-Allow-Origin` response header is sent, rather than the requests
/// `Origin` header.
///
/// This is the `supports credentials flag` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
/// This **CANNOT** be used in conjunction with `allowed_origins` set to `All` and
/// `allow_credentials` set to `true`. Depending on the mode of usage, this will either result
/// in an `Error::CredentialsWithWildcardOrigin` error during Rocket launch or runtime.
///
/// Defaults to `false`.
pub send_wildcard: bool,
}
impl Default for Options {
fn default() -> Self {
Self {
allowed_origins: Default::default(),
allowed_methods: Self::default_allowed_methods(),
allowed_headers: Default::default(),
allow_credentials: Default::default(),
expose_headers: Default::default(),
max_age: Default::default(),
send_wildcard: Default::default(),
}
}
}
impl Options {
fn default_allowed_methods() -> HashSet<Method> {
vec![
Method::Get,
Method::Head,
Method::Post,
Method::Options,
Method::Put,
Method::Patch,
Method::Delete,
].into_iter()
.collect()
}
/// Construct a preflight response based on the options. Will return an `Err`
/// if any of the preflight checks
/// fail.
pub fn preflight(
/// if any of the preflight checks fail.
///
/// This implementation references the
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-preflight-requests).
pub fn preflight<'r, R: Responder<'r>>(
&self,
responder: R,
origin: Option<Origin>,
method: &AccessControlRequestMethod,
headers: Option<&AccessControlRequestHeaders>,
) -> Result<Response<()>, Error> {
method: Option<AccessControlRequestMethod>,
headers: Option<AccessControlRequestHeaders>,
) -> Result<Response<R>, Error> {
let response = Response::new(responder);
match origin {
None => Err(Error::MissingOrigin),
Some(origin) => {
let response = Response::<()>::allowed_origin((), &origin, &self.allowed_origins)?
.allowed_methods(method, self.allowed_methods.clone())?;
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
match headers {
Some(headers) => {
self.append(response.allowed_headers(headers, &self.allowed_headers))
}
None => Ok(response),
}
// 1. If the Origin header is not present terminate this set of steps.
// The request is outside the scope of this specification.
let origin = match origin {
None => {
// Not a CORS request
return Ok(response);
}
}
Some(origin) => origin,
};
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
// in list of origins do not set any additional headers and terminate this set of steps.
let response = response.allowed_origin(
&origin,
&self.allowed_origins,
self.send_wildcard,
)?;
// 3. Let `method` be the value as result of parsing the Access-Control-Request-Method
// header.
// If there is no Access-Control-Request-Method header or if parsing failed,
// do not set any additional headers and terminate this set of steps.
// The request is outside the scope of this specification.
let method = method.ok_or_else(|| Error::MissingRequestMethod)?;
// 4. Let header field-names be the values as result of parsing the
// Access-Control-Request-Headers headers.
// If there are no Access-Control-Request-Headers headers
// let header field-names be the empty list.
// If parsing failed do not set any additional headers and terminate this set of steps.
// The request is outside the scope of this specification.
// 5. If method is not a case-sensitive match for any of the values in list of methods
// do not set any additional headers and terminate this set of steps.
let response = response.allowed_methods(
&method,
self.allowed_methods.clone(),
)?;
// 6. If any of the header field-names is not a ASCII case-insensitive match for any of the
// values in list of headers do not set any additional headers and terminate this set of
// steps.
let response = if let Some(headers) = headers {
response.allowed_headers(&headers, &self.allowed_headers)?
} else {
response
};
// 7. If the resource supports credentials add a single Access-Control-Allow-Origin header,
// with the value of the Origin header as value, and add a
// single Access-Control-Allow-Credentials header with the case-sensitive string "true" as
// value.
// Otherwise, add a single Access-Control-Allow-Origin header,
// with either the value of the Origin header or the string "*" as value.
// Note: The string "*" cannot be used for a resource that supports credentials.
let response = response.credentials(self.allow_credentials)?;
// 8. Optionally add a single Access-Control-Max-Age header
// with as value the amount of seconds the user agent is allowed to cache the result of the
// request.
let response = response.max_age(self.max_age);
// 9. If method is a simple method this step may be skipped.
// Add one or more Access-Control-Allow-Methods headers consisting of
// (a subset of) the list of methods.
// If a method is a simple method it does not need to be listed, but this is not prohibited.
// Since the list of methods can be unbounded,
// simply returning the method indicated by Access-Control-Request-Method
// (if supported) can be enough.
// Done above
// 10. If each of the header field-names is a simple header and none is Content-Type,
// this step may be skipped.
// Add one or more Access-Control-Allow-Headers headers consisting of (a subset of)
// the list of headers.
// If a header field name is a simple header and is not Content-Type,
// it is not required to be listed. Content-Type is to be listed as only a
// subset of its values makes it qualify as simple header.
// Since the list of headers can be unbounded, simply returning supported headers
// from Access-Control-Allow-Headers can be enough.
// Done above -- we do not do anything special with simple headers
Ok(response)
}
/// Respond to a request based on the settings.
@ -503,34 +603,55 @@ impl Options {
responder: R,
origin: Option<Origin>,
) -> Result<Response<R>, Error> {
match origin {
None => Ok(Response::<R>::any(responder)),
Some(origin) => {
self.append(Response::<R>::allowed_origin(
responder,
&origin,
&self.allowed_origins,
))
}
}
}
let response = Response::new(responder);
fn append<'r, R: Responder<'r>>(
&self,
response: Result<Response<R>, Error>,
) -> Result<Response<R>, Error> {
Ok(
response?
.credentials(self.allow_credentials)
.exposed_headers(
self.expose_headers
.iter()
.map(|s| &**s)
.collect::<Vec<&str>>()
.as_slice(),
)
.max_age(self.max_age),
)
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
// 1. If the Origin header is not present terminate this set of steps.
// The request is outside the scope of this specification.
let origin = match origin {
None => {
// Not a CORS request
return Ok(response);
}
Some(origin) => origin,
};
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
// in list of origins, do not set any additional headers and terminate this set of steps.
// Always matching is acceptable since the list of origins can be unbounded.
let response = response.allowed_origin(
&origin,
&self.allowed_origins,
self.send_wildcard,
)?;
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
// with the value of the Origin header as value, and add a
// single Access-Control-Allow-Credentials header with the case-sensitive string "true" as
// value.
// Otherwise, add a single Access-Control-Allow-Origin header,
// with either the value of the Origin header or the string "*" as value.
// Note: The string "*" cannot be used for a resource that supports credentials.
let response = response.credentials(self.allow_credentials)?;
// 4. If the list of exposed headers is not empty add one or more
// Access-Control-Expose-Headers headers, with as values the header field names given in
// the list of exposed headers.
// By not adding the appropriate headers resource can also clear the preflight result cache
// of all entries where origin is a case-sensitive match for the value of the Origin header
// and url is a case-sensitive match for the URL of the resource.
let response = response.exposed_headers(
self.expose_headers
.iter()
.map(|s| &**s)
.collect::<Vec<&str>>()
.as_slice(),
);
Ok(response)
}
}
@ -539,40 +660,62 @@ impl Options {
/// See module level documentation for usage examples.
pub struct Response<R> {
responder: R,
allow_origin: String,
allow_origin: Option<AllOrSome<String>>,
allow_methods: HashSet<Method>,
allow_headers: HeaderFieldNamesSet,
allow_credentials: bool,
expose_headers: HeaderFieldNamesSet,
max_age: Option<usize>,
vary_origin: bool,
}
impl<'r, R: Responder<'r>> Response<R> {
/// Consumes the responder and origin and returns basic CORS
fn origin(responder: R, origin: &str) -> Self {
/// Consumes the responder and return an empty `Response`
fn new(responder: R) -> Self {
Self {
allow_origin: origin.to_string(),
allow_origin: None,
allow_headers: HashSet::new(),
allow_methods: HashSet::new(),
responder: responder,
responder,
allow_credentials: false,
expose_headers: HashSet::new(),
max_age: None,
vary_origin: false,
}
}
/// Consumes the `Response` and return an altered response with origin and `vary_origin` set
fn origin(mut self, origin: &str, vary_origin: bool) -> Self {
self.allow_origin = Some(AllOrSome::Some(origin.to_string()));
self.vary_origin = vary_origin;
self
}
/// Consumes the `Response` and return an altered response with origin set to "*"
fn any(self) -> Self {
self.origin("*", false)
}
/// Consumes the responder and based on the provided list of allowed origins,
/// check if the requested origin is allowed.
/// Useful for pre-flight and during requests
pub fn allowed_origin(
responder: R,
fn allowed_origin(
self,
origin: &Origin,
allowed_origins: &AllowedOrigins,
allowed_origins: &AllOrSome<HashSet<Url>>,
send_wildcard: bool,
) -> Result<Self, Error> {
let origin = origin.origin().unicode_serialization();
match *allowed_origins {
AllowedOrigins::All => Ok(Self::any(responder)),
AllowedOrigins::Some(ref allowed_origins) => {
let origin = origin.origin().unicode_serialization();
// Always matching is acceptable since the list of origins can be unbounded.
AllOrSome::All => {
if send_wildcard {
Ok(self.any())
} else {
Ok(self.origin(&origin, true))
}
}
AllOrSome::Some(ref allowed_origins) => {
let allowed_origins: HashSet<_> = allowed_origins
.iter()
.map(|o| o.origin().unicode_serialization())
@ -580,33 +723,33 @@ impl<'r, R: Responder<'r>> Response<R> {
let _ = allowed_origins.get(&origin).ok_or_else(
|| Error::OriginNotAllowed,
)?;
Ok(Self::origin(responder, &origin))
Ok(self.origin(&origin, false))
}
}
}
/// Consumes responder and returns CORS with any origin
pub fn any(responder: R) -> Self {
Self::origin(responder, "*")
}
/// Consumes the Response and validate whether credentials can be allowed
fn credentials(mut self, value: bool) -> Result<Self, Error> {
if value {
if let Some(AllOrSome::All) = self.allow_origin {
Err(Error::CredentialsWithWildcardOrigin)?;
}
}
/// Consumes the CORS, set allow_credentials to
/// new value and returns changed CORS
pub fn credentials(mut self, value: bool) -> Self {
self.allow_credentials = value;
self
Ok(self)
}
/// Consumes the CORS, set expose_headers to
/// passed headers and returns changed CORS
pub fn exposed_headers(mut self, headers: &[&str]) -> Self {
fn exposed_headers(mut self, headers: &[&str]) -> Self {
self.expose_headers = headers.into_iter().map(|s| s.to_string().into()).collect();
self
}
/// Consumes the CORS, set max_age to
/// passed value and returns changed CORS
pub fn max_age(mut self, value: Option<usize>) -> Self {
fn max_age(mut self, value: Option<usize>) -> Self {
self.max_age = value;
self
}
@ -620,7 +763,7 @@ impl<'r, R: Responder<'r>> Response<R> {
/// Consumes the CORS, check if requested method is allowed.
/// Useful for pre-flight checks
pub fn allowed_methods(
fn allowed_methods(
self,
method: &AccessControlRequestMethod,
allowed_methods: HashSet<Method>,
@ -629,6 +772,8 @@ impl<'r, R: Responder<'r>> Response<R> {
if !allowed_methods.iter().any(|m| m == request_method) {
Err(Error::MethodNotAllowed)?
}
// TODO: Subset to route? Or just the method requested for?
Ok(self.methods(allowed_methods))
}
@ -639,20 +784,27 @@ impl<'r, R: Responder<'r>> Response<R> {
self
}
/// Consumes the CORS, check if requested headersa are allowed.
/// Consumes the CORS, check if requested headers are allowed.
/// Useful for pre-flight checks
pub fn allowed_headers(
fn allowed_headers(
self,
headers: &AccessControlRequestHeaders,
allowed_headers: &HeaderFieldNamesSet,
allowed_headers: &AllOrSome<HashSet<HeaderFieldName>>,
) -> Result<Self, Error> {
let &AccessControlRequestHeaders(ref headers) = headers;
if !headers.is_empty() && !headers.is_subset(allowed_headers) {
Err(Error::HeadersNotAllowed)?
}
match *allowed_headers {
AllOrSome::All => {}
AllOrSome::Some(ref allowed_headers) => {
if !headers.is_empty() && !headers.is_subset(allowed_headers) {
Err(Error::HeadersNotAllowed)?
}
}
};
Ok(
self.headers(
allowed_headers
headers
.iter()
.map(|s| &**s.deref())
.collect::<Vec<&str>>()
@ -663,15 +815,29 @@ impl<'r, R: Responder<'r>> Response<R> {
}
impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
#[allow(unused_results)]
fn respond_to(self, request: &Request) -> response::Result<'r> {
let mut response = response::Response::build_from(self.responder.respond_to(request)?)
.raw_header("Access-Control-Allow-Origin", self.allow_origin)
.finalize();
use std::borrow::Cow;
let mut builder = response::Response::build_from(self.responder.respond_to(request)?);
let origin = match self.allow_origin {
None => {
// This is not a CORS response
return Ok(builder.finalize());
}
Some(origin) => origin,
};
let origin: Cow<str> = match origin {
AllOrSome::All => Into::into("*"),
AllOrSome::Some(origin) => Into::into(origin),
};
builder.raw_header("Access-Control-Allow-Origin", origin);
if self.allow_credentials {
response.set_raw_header("Access-Control-Allow-Credentials", "true");
} else {
response.set_raw_header("Access-Control-Allow-Credentials", "false");
builder.raw_header("Access-Control-Allow-Credentials", "true");
}
if !self.expose_headers.is_empty() {
@ -681,7 +847,7 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
.collect();
let headers = headers.join(", ");
response.set_raw_header("Access-Control-Expose-Headers", headers);
builder.raw_header("Access-Control-Expose-Headers", headers);
}
if !self.allow_headers.is_empty() {
@ -691,7 +857,7 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
.collect();
let headers = headers.join(", ");
response.set_raw_header("Access-Control-Allow-Headers", headers);
builder.raw_header("Access-Control-Allow-Headers", headers);
}
@ -699,15 +865,19 @@ impl<'r, R: Responder<'r>> Responder<'r> for Response<R> {
let methods: Vec<_> = self.allow_methods.into_iter().map(|m| m.as_str()).collect();
let methods = methods.join(", ");
response.set_raw_header("Access-Control-Allow-Methods", methods);
builder.raw_header("Access-Control-Allow-Methods", methods);
}
if self.max_age.is_some() {
let max_age = self.max_age.unwrap();
response.set_raw_header("Access-Control-Max-Age", max_age.to_string());
builder.raw_header("Access-Control-Max-Age", max_age.to_string());
}
Ok(response)
if self.vary_origin {
builder.raw_header("Vary", "Origin");
}
Ok(builder.finalize())
}
}
@ -813,7 +983,7 @@ X-Ping, accept-language"#;
#[get("/any")]
#[cfg_attr(feature = "clippy_lints", allow(needless_pass_by_value))]
fn any() -> Response<&'static str> {
Response::any("Hello, world!")
Response::new("Hello, world!").any()
}
#[test]
@ -838,11 +1008,11 @@ X-Ping, accept-language"#;
#[allow(needless_pass_by_value)]
fn cors_options(
origin: Option<Origin>,
method: AccessControlRequestMethod,
headers: AccessControlRequestHeaders,
method: Option<AccessControlRequestMethod>,
headers: Option<AccessControlRequestHeaders>,
options: State<Options>,
) -> Result<Response<()>, Error> {
options.preflight(origin, &method, Some(&headers))
options.preflight((), origin, method, headers)
}
#[get("/")]
@ -856,16 +1026,18 @@ X-Ping, accept-language"#;
fn make_cors_options() -> Options {
let (allowed_origins, failed_origins) =
AllowedOrigins::new_from_str_list(&["https://www.acme.com"]);
AllOrSome::new_from_str_list(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
Options {
allowed_origins: allowed_origins,
allowed_methods: [Method::Get].iter().cloned().collect(),
allowed_headers: ["Authorization"]
.iter()
.map(|s| s.to_string().into())
.collect(),
allowed_headers: AllOrSome::Some(
["Authorization"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true,
..Default::default()
}
@ -980,7 +1152,7 @@ X-Ping, accept-language"#;
);
let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden);
assert_eq!(response.status(), Status::Ok);
}
#[test]