Support Regex in origins configuration (#62)

* Refactor Origin

* Fix tests

* Fix tests

* Add JSON deserialization test

* Support regex

* Fix wording

* Fix wording
This commit is contained in:
Yong Wen Chua 2019-03-12 15:05:40 +08:00 committed by GitHub
parent f9bffe77d6
commit 6f56109d77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 395 additions and 148 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "rocket_cors" name = "rocket_cors"
version = "0.4.0" version = "0.5.0"
license = "MIT/Apache-2.0" license = "MIT/Apache-2.0"
authors = ["Yong Wen Chua <me@yongwen.xyz>"] authors = ["Yong Wen Chua <me@yongwen.xyz>"]
description = "Cross-origin resource sharing (CORS) for Rocket.rs applications" description = "Cross-origin resource sharing (CORS) for Rocket.rs applications"
@ -21,6 +21,7 @@ default = ["serialization"]
serialization = ["serde", "serde_derive", "unicase_serde"] serialization = ["serde", "serde_derive", "unicase_serde"]
[dependencies] [dependencies]
regex = "1.1"
rocket = "0.4.0" rocket = "0.4.0"
log = "0.3" log = "0.3"
unicase = "2.0" unicase = "2.0"

View File

@ -1,7 +1,6 @@
# rocket_cors # rocket_cors
[![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors) [![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors)
[![Dependency Status](https://dependencyci.com/github/lawliet89/rocket_cors/badge)](https://dependencyci.com/github/lawliet89/rocket_cors)
[![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors) [![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors)
[![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors) [![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors)
@ -31,7 +30,7 @@ work, but they are subject to the minimum that Rocket sets.
Add the following to Cargo.toml: Add the following to Cargo.toml:
```toml ```toml
rocket_cors = "0.4.0" rocket_cors = "0.5.0"
``` ```
To use the latest `master` branch, for example: To use the latest `master` branch, for example:

View File

@ -12,11 +12,11 @@ fn cors<'a>() -> &'a str {
} }
fn main() -> Result<(), Error> { fn main() -> Result<(), Error> {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this // You can also deserialize this
let cors = rocket_cors::CorsOptions { let cors = rocket_cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,

View File

@ -36,11 +36,11 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> {
} }
fn main() -> Result<(), Error> { fn main() -> Result<(), Error> {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this // You can also deserialize this
let cors = rocket_cors::CorsOptions { let cors = rocket_cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,

View File

@ -13,10 +13,10 @@ fn main() {
// The default demonstrates the "All" serialization of several of the settings // The default demonstrates the "All" serialization of several of the settings
let default: CorsOptions = Default::default(); let default: CorsOptions = Default::default();
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
let options = cors::CorsOptions { let options = cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins,
allowed_methods: vec![Method::Get, Method::Post, Method::Delete] allowed_methods: vec![Method::Get, Method::Post, Method::Delete]
.into_iter() .into_iter()
.map(From::from) .map(From::from)

View File

@ -59,11 +59,11 @@ fn owned_options<'r>() -> impl Responder<'r> {
} }
fn cors_options() -> CorsOptions { fn cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this // You can also deserialize this
rocket_cors::CorsOptions { rocket_cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,

View File

@ -36,11 +36,11 @@ fn ping_options<'r>() -> impl Responder<'r> {
/// Returns the "application wide" Cors struct /// Returns the "application wide" Cors struct
fn cors_options() -> CorsOptions { fn cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this // You can also deserialize this
rocket_cors::CorsOptions { rocket_cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,

View File

@ -63,6 +63,7 @@ fn on_response_wrapper(
return Ok(()); return Ok(());
} }
let origin = origin.to_string();
let cors_response = if request.method() == http::Method::Options { let cors_response = if request.method() == http::Method::Options {
let headers = request_headers(request)?; let headers = request_headers(request)?;
preflight_response(options, &origin, headers.as_ref()) preflight_response(options, &origin, headers.as_ref())
@ -137,10 +138,10 @@ mod tests {
use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions}; use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions};
const CORS_ROOT: &'static str = "/my_cors"; const CORS_ROOT: &str = "/my_cors";
fn make_cors_options() -> Cors { fn make_cors_options() -> Cors {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
CorsOptions { CorsOptions {
allowed_origins, allowed_origins,

View File

@ -63,28 +63,51 @@ pub type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
/// ///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that `Origin` is passed in correctly. /// to ensure that `Origin` is passed in correctly.
///
/// Reference: [Mozilla](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin)
#[derive(Eq, PartialEq, Clone, Hash, Debug)] #[derive(Eq, PartialEq, Clone, Hash, Debug)]
pub struct Origin(pub url::Origin); pub enum Origin {
/// A `null` Origin
Null,
/// A well-formed origin that was parsed by [`url::Url::origin`]
Parsed(url::Origin),
}
impl Origin {
/// Perform an
/// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin)
/// of this origin.
pub fn ascii_serialization(&self) -> String {
self.to_string()
}
/// Returns whether the origin was parsed as non-opaque
pub fn is_tuple(&self) -> bool {
match self {
Origin::Null => false,
Origin::Parsed(ref parsed) => parsed.is_tuple(),
}
}
}
impl FromStr for Origin { impl FromStr for Origin {
type Err = crate::Error; type Err = crate::Error;
fn from_str(input: &str) -> Result<Self, Self::Err> { fn from_str(input: &str) -> Result<Self, Self::Err> {
Ok(Origin(crate::to_origin(input)?)) if input.to_lowercase() == "null" {
Ok(Origin::Null)
} else {
Ok(Origin::Parsed(crate::to_origin(input)?))
} }
} }
impl Deref for Origin {
type Target = url::Origin;
fn deref(&self) -> &Self::Target {
&self.0
}
} }
impl fmt::Display for Origin { impl fmt::Display for Origin {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.ascii_serialization()) match self {
Origin::Null => write!(f, "null"),
Origin::Parsed(ref parsed) => write!(f, "{}", parsed.ascii_serialization()),
}
} }
} }
@ -195,6 +218,10 @@ mod tests {
let parsed = not_err!(Origin::from_str(url)); let parsed = not_err!(Origin::from_str(url));
assert_eq!(parsed.ascii_serialization(), url); assert_eq!(parsed.ascii_serialization(), url);
let url = "https://foo.bar.xyz:1234";
let parsed = not_err!(Origin::from_str(url));
assert_eq!(parsed.ascii_serialization(), url);
// this should never really be sent by a compliant user agent // this should never really be sent by a compliant user agent
let url = "https://foo.bar.xyz/path/somewhere"; let url = "https://foo.bar.xyz/path/somewhere";
let parsed = not_err!(Origin::from_str(url)); let parsed = not_err!(Origin::from_str(url));
@ -239,7 +266,7 @@ mod tests {
); );
let method = "INVALID"; let method = "INVALID";
let _ = is_err!(AccessControlRequestMethod::from_str(method)); is_err!(AccessControlRequestMethod::from_str(method));
} }
#[test] #[test]
@ -281,7 +308,7 @@ mod tests {
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
let AccessControlRequestHeaders(parsed_headers) = parsed_header; let AccessControlRequestHeaders(parsed_headers) = parsed_header;
let mut parsed_headers: Vec<String> = let mut parsed_headers: Vec<String> =
parsed_headers.iter().map(|s| s.to_string()).collect(); parsed_headers.iter().map(ToString::to_string).collect();
parsed_headers.sort(); parsed_headers.sort();
assert_eq!( assert_eq!(
vec!["accept-language".to_string(), "date".to_string()], vec!["accept-language".to_string(), "date".to_string()],

View File

@ -1,6 +1,5 @@
/*! /*!
[![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors) [![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors)
[![Dependency Status](https://dependencyci.com/github/lawliet89/rocket_cors/badge)](https://dependencyci.com/github/lawliet89/rocket_cors)
[![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors) [![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors)
[![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors) [![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors)
@ -30,7 +29,7 @@ might work, but they are subject to the minimum that Rocket sets.
Add the following to Cargo.toml: Add the following to Cargo.toml:
```toml ```toml
rocket_cors = "0.4.0" rocket_cors = "0.5.0"
``` ```
To use the latest `master` branch, for example: To use the latest `master` branch, for example:
@ -46,7 +45,7 @@ the [`CorsOptions`] struct that is described below. If you would like to disable
change your `Cargo.toml` to: change your `Cargo.toml` to:
```toml ```toml
rocket_cors = { version = "0.4.0", default-features = false } rocket_cors = { version = "0.5.0", default-features = false }
``` ```
## Usage ## Usage
@ -63,9 +62,9 @@ Each of the examples can be run off the repository via `cargo run --example xxx`
### `CorsOptions` Struct ### `CorsOptions` Struct
The [`CorsOptiopns`] struct contains the settings for CORS requests to be validated The [`CorsOptions`] struct contains the settings for CORS requests to be validated
and for responses to be generated. Defaults are defined for every field in the struct, and and for responses to be generated. Defaults are defined for every field in the struct, and
are documented on the [`CorsOptiopns`] page. You can also deserialize are documented on the [`CorsOptions`] page. You can also deserialize
the struct from some format like JSON, YAML or TOML when the default `serialization` feature the struct from some format like JSON, YAML or TOML when the default `serialization` feature
is enabled. is enabled.
@ -284,6 +283,7 @@ use std::ops::Deref;
use std::str::FromStr; use std::str::FromStr;
use ::log::{error, info, log}; use ::log::{error, info, log};
use regex::RegexSet;
use rocket::http::{self, Status}; use rocket::http::{self, Status};
use rocket::request::{FromRequest, Request}; use rocket::request::{FromRequest, Request};
use rocket::response; use rocket::response;
@ -309,6 +309,8 @@ pub enum Error {
MissingOrigin, MissingOrigin,
/// The HTTP request header `Origin` could not be parsed correctly. /// The HTTP request header `Origin` could not be parsed correctly.
BadOrigin(url::ParseError), BadOrigin(url::ParseError),
/// The configured Allowed Origin is opaque and cannot be parsed.
OpaqueAllowedOrigin(String),
/// The request header `Access-Control-Request-Method` is required but is missing /// The request header `Access-Control-Request-Method` is required but is missing
MissingRequestMethod, MissingRequestMethod,
/// The request header `Access-Control-Request-Method` has an invalid value /// The request header `Access-Control-Request-Method` has an invalid value
@ -316,9 +318,11 @@ pub enum Error {
/// The request header `Access-Control-Request-Headers` is required but is missing. /// The request header `Access-Control-Request-Headers` is required but is missing.
MissingRequestHeaders, MissingRequestHeaders,
/// Origin is not allowed to make this request /// Origin is not allowed to make this request
OriginNotAllowed(url::Origin), OriginNotAllowed(String),
/// Requested method is not allowed /// Requested method is not allowed
MethodNotAllowed(String), MethodNotAllowed(String),
/// A regular expression compilation error
RegexError(regex::Error),
/// One or more headers requested are not allowed /// One or more headers requested are not allowed
HeadersNotAllowed, HeadersNotAllowed,
/// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C
@ -365,7 +369,7 @@ impl fmt::Display for Error {
"The request header `Access-Control-Request-Headers` \ "The request header `Access-Control-Request-Headers` \
is required but is missing") is required but is missing")
} }
Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin.ascii_serialization()), Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin),
Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method), Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method),
Error::HeadersNotAllowed => write!(f, "Headers are not allowed"), Error::HeadersNotAllowed => write!(f, "Headers are not allowed"),
Error::CredentialsWithWildcardOrigin => { write!(f, Error::CredentialsWithWildcardOrigin => { write!(f,
@ -377,7 +381,9 @@ impl fmt::Display for Error {
} }
Error::MissingInjectedHeader => write!(f, Error::MissingInjectedHeader => write!(f,
"The `on_response` handler of Fairing could not find the injected header from the \ "The `on_response` handler of Fairing could not find the injected header from the \
Request. Either some other fairing has removed it, or this is a bug.") Request. Either some other fairing has removed it, or this is a bug."),
Error::OpaqueAllowedOrigin(ref origin) => write!(f, "The configured Origin '{}' not have a parsable Origin. Use a regex instead.", origin),
Error::RegexError(ref e) => write!(f, "{}", e),
} }
} }
} }
@ -404,6 +410,12 @@ impl From<url::ParseError> for Error {
} }
} }
impl From<regex::Error> for Error {
fn from(error: regex::Error) -> Self {
Error::RegexError(error)
}
}
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
/// ///
/// `Default` is implemented for this enum and is `All`. /// `Default` is implemented for this enum and is `All`.
@ -529,16 +541,47 @@ mod method_serde {
/// use rocket_cors::AllowedOrigins; /// use rocket_cors::AllowedOrigins;
/// ///
/// let all_origins = AllowedOrigins::all(); /// let all_origins = AllowedOrigins::all();
/// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]); /// let some_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
/// let null_origins = AllowedOrigins::some_null();
/// ``` /// ```
pub type AllowedOrigins = AllOrSome<HashSet<String>>; pub type AllowedOrigins = AllOrSome<Origins>;
impl AllowedOrigins { impl AllowedOrigins {
/// Allows some origins /// Allows some origins
/// ///
/// Validation is not performed at this stage, but at a later stage. /// Validation is not performed at this stage, but at a later stage.
pub fn some(urls: &[&str]) -> Self { pub fn some<S1: AsRef<str>, S2: AsRef<str>>(exact: &[S1], regex: &[S2]) -> Self {
AllOrSome::Some(urls.iter().map(|s| s.to_string()).collect()) AllOrSome::Some(Origins {
exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()),
regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()),
..Default::default()
})
}
/// Allows some _exact_ origins
///
/// Validation is not performed at this stage, but at a later stage.
pub fn some_exact<S: AsRef<str>>(exact: &[S]) -> Self {
AllOrSome::Some(Origins {
exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()),
..Default::default()
})
}
/// Allow some __regex__ origins
pub fn some_regex<S: AsRef<str>>(regex: &[S]) -> Self {
AllOrSome::Some(Origins {
regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()),
..Default::default()
})
}
/// Allow some `null` origins
pub fn some_null() -> Self {
AllOrSome::Some(Origins {
allow_null: true,
..Default::default()
})
} }
/// Allows all origins /// Allows all origins
@ -547,6 +590,105 @@ impl AllowedOrigins {
} }
} }
/// Origins that are allowed to make CORS requests.
///
/// An origin is defined according to the defined
/// [syntax](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin).
///
/// Origins can be specified as an exact match or using regex.
///
/// These Origins are specified as logical `ORs`. That is, if any of the origins match, the entire
/// request is considered to be valid.
#[derive(Clone, PartialEq, Eq, Debug, Default)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serialization", serde(default))]
pub struct Origins {
/// Whether null origins are accepted
#[cfg_attr(feature = "serialization", serde(default))]
pub allow_null: bool,
/// Origins that must be matched exactly as provided.
///
/// These __must__ be valid URL strings that will be parsed and validated when
/// creating [`Cors`].
#[cfg_attr(feature = "serialization", serde(default))]
pub exact: Option<HashSet<String>>,
/// Origins that will be matched via __any__ regex in this list.
///
/// These __must__ be valid Regex that will be parsed and validated when creating [`Cors`].
///
/// The regex will be matched according to the
/// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin)
/// of the incoming Origin.
///
/// For more information on the syntax of Regex in Rust, see the
/// [documentation](https://docs.rs/regex).
#[cfg_attr(feature = "serialization", serde(default))]
pub regex: Option<HashSet<String>>,
}
/// Parsed set of configured allowed origins
#[derive(Clone, Debug)]
pub(crate) struct ParsedAllowedOrigins {
pub allow_null: bool,
pub exact: HashSet<url::Origin>,
pub regex: Option<RegexSet>,
}
impl ParsedAllowedOrigins {
fn parse(origins: &Origins) -> Result<Self, Error> {
let exact: Result<HashSet<url::Origin>, Error> = match &origins.exact {
Some(exact) => exact.iter().map(|url| to_origin(url.as_str())).collect(),
None => Ok(Default::default()),
};
let exact = exact?;
// Let's check if any of them is Opaque
exact.iter().try_for_each(|url| {
if !url.is_tuple() {
Err(Error::OpaqueAllowedOrigin(url.ascii_serialization()))
} else {
Ok(())
}
})?;
let regex = match &origins.regex {
None => None,
Some(ref regex) => Some(RegexSet::new(regex)?),
};
Ok(Self {
allow_null: origins.allow_null,
exact,
regex,
})
}
fn verify(&self, origin: &Origin) -> bool {
info_!("Verifying origin: {}", origin);
match origin {
Origin::Null => {
info_!("Origin is null. Allowing? {}", self.allow_null);
self.allow_null
}
Origin::Parsed(ref parsed) => {
// Verify by exact, then regex
if self.exact.get(parsed).is_some() {
info_!("Origin has an exact match");
return true;
}
if let Some(regex_set) = &self.regex {
let regex_match = regex_set.is_match(&parsed.ascii_serialization());
info_!("Origin has a regex match? {}", regex_match);
return regex_match;
}
info!("Origin does not match anything");
false
}
}
}
}
/// A list of allowed methods /// A list of allowed methods
/// ///
/// The [list](https://api.rocket.rs/rocket/http/enum.Method.html) /// The [list](https://api.rocket.rs/rocket/http/enum.Method.html)
@ -636,9 +778,10 @@ impl AllowedHeaders {
/// ```json /// ```json
/// { /// {
/// "allowed_origins": { /// "allowed_origins": {
/// "Some": [ /// "Some": {
/// "https://www.acme.com" /// "exact": ["https://www.acme.com"],
/// ] /// "regex": ["^https://www.example-[A-z0-9]*.com$"]
/// }
/// }, /// },
/// "allowed_methods": [ /// "allowed_methods": [
/// "POST", /// "POST",
@ -831,9 +974,9 @@ impl CorsOptions {
/// documentation at the [crate root](index.html) for usage information. /// documentation at the [crate root](index.html) for usage information.
/// ///
/// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`]. /// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`].
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug)]
pub struct Cors { pub struct Cors {
pub(crate) allowed_origins: AllOrSome<HashSet<url::Origin>>, pub(crate) allowed_origins: AllOrSome<ParsedAllowedOrigins>,
pub(crate) allowed_methods: AllowedMethods, pub(crate) allowed_methods: AllowedMethods,
pub(crate) allowed_headers: AllOrSome<HashSet<HeaderFieldName>>, pub(crate) allowed_headers: AllOrSome<HashSet<HeaderFieldName>>,
pub(crate) allow_credentials: bool, pub(crate) allow_credentials: bool,
@ -921,7 +1064,7 @@ impl Cors {
/// You can get this struct by using `Cors::validate_request` in an ad-hoc manner. /// You can get this struct by using `Cors::validate_request` in an ad-hoc manner.
#[derive(Eq, PartialEq, Debug)] #[derive(Eq, PartialEq, Debug)]
pub(crate) struct Response { pub(crate) struct Response {
allow_origin: Option<AllOrSome<url::Origin>>, allow_origin: Option<AllOrSome<String>>,
allow_methods: HashSet<Method>, allow_methods: HashSet<Method>,
allow_headers: HeaderFieldNamesSet, allow_headers: HeaderFieldNamesSet,
allow_credentials: bool, allow_credentials: bool,
@ -945,8 +1088,8 @@ impl Response {
} }
/// Consumes the `Response` and return an altered response with origin and `vary_origin` set /// Consumes the `Response` and return an altered response with origin and `vary_origin` set
fn origin(mut self, origin: &url::Origin, vary_origin: bool) -> Self { fn origin(mut self, origin: &str, vary_origin: bool) -> Self {
self.allow_origin = Some(AllOrSome::Some(origin.clone())); self.allow_origin = Some(AllOrSome::Some(origin.to_string()));
self.vary_origin = vary_origin; self.vary_origin = vary_origin;
self self
} }
@ -1022,7 +1165,7 @@ impl Response {
let origin = match *origin { let origin = match *origin {
AllOrSome::All => "*".to_string(), AllOrSome::All => "*".to_string(),
AllOrSome::Some(ref origin) => origin.ascii_serialization(), AllOrSome::Some(ref origin) => origin.to_string(),
}; };
let _ = response.set_raw_header("Access-Control-Allow-Origin", origin); let _ = response.set_raw_header("Access-Control-Allow-Origin", origin);
@ -1251,27 +1394,25 @@ enum ValidationResult {
None, None,
/// Successful preflight request /// Successful preflight request
Preflight { Preflight {
origin: url::Origin, origin: String,
headers: Option<AccessControlRequestHeaders>, headers: Option<AccessControlRequestHeaders>,
}, },
/// Successful actual request /// Successful actual request
Request { origin: url::Origin }, Request { origin: String },
} }
/// Convert a str to Origin /// Convert a str to a URL Origin
fn to_origin<S: AsRef<str>>(origin: S) -> Result<url::Origin, Error> { fn to_origin<S: AsRef<str>>(origin: S) -> Result<url::Origin, Error> {
// What to do about Opaque origins?
Ok(url::Url::parse(origin.as_ref())?.origin()) Ok(url::Url::parse(origin.as_ref())?.origin())
} }
/// Parse and process allowed origins /// Parse and process allowed origins
fn parse_origins(origins: &AllowedOrigins) -> Result<AllOrSome<HashSet<url::Origin>>, Error> { fn parse_origins(origins: &AllowedOrigins) -> Result<AllOrSome<ParsedAllowedOrigins>, Error> {
match origins { match origins {
AllOrSome::All => Ok(AllOrSome::All), AllOrSome::All => Ok(AllOrSome::All),
AllOrSome::Some(ref origins) => { AllOrSome::Some(origins) => {
let parsed: Result<HashSet<url::Origin>, Error> = let parsed = ParsedAllowedOrigins::parse(origins)?;
origins.iter().map(to_origin).collect(); Ok(AllOrSome::Some(parsed))
Ok(AllOrSome::Some(parsed?))
} }
} }
} }
@ -1309,14 +1450,14 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
let headers = request_headers(request)?; let headers = request_headers(request)?;
preflight_validate(options, &origin, &method, &headers)?; preflight_validate(options, &origin, &method, &headers)?;
Ok(ValidationResult::Preflight { Ok(ValidationResult::Preflight {
origin: origin.deref().clone(), origin: origin.to_string(),
headers, headers,
}) })
} }
_ => { _ => {
actual_request_validate(options, &origin)?; actual_request_validate(options, &origin)?;
Ok(ValidationResult::Request { Ok(ValidationResult::Request {
origin: origin.deref().clone(), origin: origin.to_string(),
}) })
} }
} }
@ -1326,16 +1467,19 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
/// check if the requested origin is allowed. /// check if the requested origin is allowed.
/// Useful for pre-flight and during requests /// Useful for pre-flight and during requests
fn validate_origin( fn validate_origin(
origin: &url::Origin, origin: &Origin,
allowed_origins: &AllOrSome<HashSet<url::Origin>>, allowed_origins: &AllOrSome<ParsedAllowedOrigins>,
) -> Result<(), Error> { ) -> Result<(), Error> {
match *allowed_origins { match *allowed_origins {
// Always matching is acceptable since the list of origins can be unbounded. // Always matching is acceptable since the list of origins can be unbounded.
AllOrSome::All => Ok(()), AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_origins) => allowed_origins AllOrSome::Some(ref allowed_origins) => {
.get(origin) if allowed_origins.verify(origin) {
.and_then(|_| Some(())) Ok(())
.ok_or_else(|| Error::OriginNotAllowed(origin.clone())), } else {
Err(Error::OriginNotAllowed(origin.to_string()))
}
}
} }
} }
@ -1405,7 +1549,7 @@ fn request_headers(request: &Request<'_>) -> Result<Option<AccessControlRequestH
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch)
fn preflight_validate( fn preflight_validate(
options: &Cors, options: &Cors,
origin: &url::Origin, origin: &Origin,
method: &Option<AccessControlRequestMethod>, method: &Option<AccessControlRequestMethod>,
headers: &Option<AccessControlRequestHeaders>, headers: &Option<AccessControlRequestHeaders>,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -1453,7 +1597,7 @@ fn preflight_validate(
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch).
fn preflight_response( fn preflight_response(
options: &Cors, options: &Cors,
origin: &url::Origin, origin: &str,
headers: Option<&AccessControlRequestHeaders>, headers: Option<&AccessControlRequestHeaders>,
) -> Response { ) -> Response {
let response = Response::new(); let response = Response::new();
@ -1524,7 +1668,7 @@ fn preflight_response(
/// This implementation references the /// This implementation references the
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests)
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch).
fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), Error> { fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> {
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation // Note: All header parse failures are dealt with in the `FromRequest` trait implementation
// 2. If the value of the Origin header is not a case-sensitive match for any of the values // 2. If the value of the Origin header is not a case-sensitive match for any of the values
@ -1541,7 +1685,7 @@ fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), E
/// This implementation references the /// This implementation references the
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests)
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch)
fn actual_request_response(options: &Cors, origin: &url::Origin) -> Response { fn actual_request_response(options: &Cors, origin: &str) -> Response {
let response = Response::new(); let response = Response::new();
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
@ -1640,8 +1784,12 @@ mod tests {
use super::*; use super::*;
use crate::http::Method; use crate::http::Method;
fn to_parsed_origin<S: AsRef<str>>(origin: S) -> Result<Origin, Error> {
Origin::from_str(origin.as_ref())
}
fn make_cors_options() -> CorsOptions { fn make_cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
CorsOptions { CorsOptions {
allowed_origins, allowed_origins,
@ -1652,8 +1800,8 @@ mod tests {
allowed_headers: AllowedHeaders::some(&[&"Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&[&"Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,
expose_headers: ["Content-Type", "X-Custom"] expose_headers: ["Content-Type", "X-Custom"]
.into_iter() .iter()
.map(|s| s.to_string().into()) .map(|s| s.to_string())
.collect(), .collect(),
..Default::default() ..Default::default()
} }
@ -1727,6 +1875,64 @@ mod tests {
fn cors_default_deserialization_is_correct() { fn cors_default_deserialization_is_correct() {
let deserialized: CorsOptions = serde_json::from_str("{}").expect("To not fail"); let deserialized: CorsOptions = serde_json::from_str("{}").expect("To not fail");
assert_eq!(deserialized, CorsOptions::default()); assert_eq!(deserialized, CorsOptions::default());
let expected_json = r#"
{
"allowed_origins": "All",
"allowed_methods": [
"POST",
"PATCH",
"PUT",
"DELETE",
"HEAD",
"OPTIONS",
"GET"
],
"allowed_headers": "All",
"allow_credentials": false,
"expose_headers": [],
"max_age": null,
"send_wildcard": false,
"fairing_route_base": "/cors",
"fairing_route_rank": 0
}
"#;
let actual: CorsOptions = serde_json::from_str(expected_json).expect("to not fail");
assert_eq!(actual, CorsOptions::default());
}
/// Checks that the example provided can actually be deserialized
#[cfg(feature = "serialization")]
#[test]
fn cors_options_example_can_be_deserialized() {
let json = r#"{
"allowed_origins": {
"Some": {
"exact": ["https://www.acme.com"],
"regex": ["^https://www.example-[A-z0-9]*.com$"]
}
},
"allowed_methods": [
"POST",
"DELETE",
"GET"
],
"allowed_headers": {
"Some": [
"Accept",
"Authorization"
]
},
"allow_credentials": true,
"expose_headers": [
"Content-Type",
"X-Custom"
],
"max_age": 42,
"send_wildcard": false,
"fairing_route_base": "/mycors"
}"#;
let _: CorsOptions = serde_json::from_str(json).expect("to not fail");
} }
// The following tests check validation // The following tests check validation
@ -1734,7 +1940,7 @@ mod tests {
#[test] #[test]
fn validate_origin_allows_all_origins() { fn validate_origin_allows_all_origins() {
let url = "https://www.example.com"; let url = "https://www.example.com";
let origin = not_err!(to_origin(&url)); let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = AllOrSome::All; let allowed_origins = AllOrSome::All;
not_err!(validate_origin(&origin, &allowed_origins)); not_err!(validate_origin(&origin, &allowed_origins));
@ -1743,8 +1949,8 @@ mod tests {
#[test] #[test]
fn validate_origin_allows_origin() { fn validate_origin_allows_origin() {
let url = "https://www.example.com"; let url = "https://www.example.com";
let origin = not_err!(to_origin(&url)); let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[
"https://www.example.com" "https://www.example.com"
]))); ])));
@ -1762,19 +1968,48 @@ mod tests {
]; ];
for (url, allowed_origin) in cases { for (url, allowed_origin) in cases {
let origin = not_err!(to_origin(&url)); let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[allowed_origin]))); let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[
allowed_origin
])));
not_err!(validate_origin(&origin, &allowed_origins)); not_err!(validate_origin(&origin, &allowed_origins));
} }
} }
#[test]
fn validate_origin_validates_regex() {
let url = "https://www.example-something.com";
let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_regex(&[
"^https://www.example-[A-z0-9]+.com$"
])));
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test]
fn validate_origin_validates_mixed_settings() {
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(
&["https://www.acme.com"],
&["^https://www.example-[A-z0-9]+.com$"]
)));
let url = "https://www.example-something123.com";
let origin = not_err!(to_parsed_origin(&url));
not_err!(validate_origin(&origin, &allowed_origins));
let url = "https://www.acme.com";
let origin = not_err!(to_parsed_origin(&url));
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test] #[test]
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn validate_origin_rejects_invalid_origin() { fn validate_origin_rejects_invalid_origin() {
let url = "https://www.acme.com"; let url = "https://www.acme.com";
let origin = not_err!(to_origin(&url)); let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[
"https://www.example.com" "https://www.example.com"
]))); ])));
@ -1784,7 +2019,7 @@ mod tests {
#[test] #[test]
fn response_sets_allow_origin_without_vary_correctly() { fn response_sets_allow_origin_without_vary_correctly() {
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.origin("https://www.example.com", false);
// Build response and check built response header // Build response and check built response header
let expected_header = vec!["https://www.example.com"]; let expected_header = vec!["https://www.example.com"];
@ -1801,7 +2036,7 @@ mod tests {
#[test] #[test]
fn response_sets_allow_origin_with_vary_correctly() { fn response_sets_allow_origin_with_vary_correctly() {
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), true); let response = response.origin("https://www.example.com", true);
// Build response and check built response header // Build response and check built response header
let expected_header = vec!["https://www.example.com"]; let expected_header = vec!["https://www.example.com"];
@ -1828,27 +2063,11 @@ mod tests {
assert_eq!(expected_header, actual_header); assert_eq!(expected_header, actual_header);
} }
#[test]
fn response_sets_allow_origin_with_ascii_serialization() {
let response = Response::new();
let response = response.origin(&to_origin("https://аpple.com").unwrap(), false);
// Build response and check built response header
// This is "punycode"
let expected_header = vec!["https://xn--pple-43d.com"];
let response = response.response(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
}
#[test] #[test]
fn response_sets_exposed_headers_correctly() { fn response_sets_exposed_headers_correctly() {
let headers = vec!["Bar", "Baz", "Foo"]; let headers = vec!["Bar", "Baz", "Foo"];
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.origin("https://www.example.com", false);
let response = response.exposed_headers(&headers); let response = response.exposed_headers(&headers);
// Build response and check built response header // Build response and check built response header
@ -1870,7 +2089,7 @@ mod tests {
#[test] #[test]
fn response_sets_max_age_correctly() { fn response_sets_max_age_correctly() {
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.origin("https://www.example.com", false);
let response = response.max_age(Some(42)); let response = response.max_age(Some(42));
@ -1884,7 +2103,7 @@ mod tests {
#[test] #[test]
fn response_does_not_set_max_age_when_none() { fn response_does_not_set_max_age_when_none() {
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.origin("https://www.example.com", false);
let response = response.max_age(None); let response = response.max_age(None);
@ -1997,7 +2216,7 @@ mod tests {
.finalize(); .finalize();
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.origin("https://www.example.com", false);
let response = response.response(original); let response = response.response(original);
// Check CORS header // Check CORS header
let expected_header = vec!["https://www.example.com"]; let expected_header = vec!["https://www.example.com"];
@ -2073,7 +2292,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight { let expected_result = ValidationResult::Preflight {
origin: to_origin("https://www.acme.com").unwrap(), origin: "https://www.acme.com".to_string(),
// Checks that only a subset of allowed headers are returned // Checks that only a subset of allowed headers are returned
// -- i.e. whatever is requested for // -- i.e. whatever is requested for
headers: Some(FromStr::from_str("Authorization").unwrap()), headers: Some(FromStr::from_str("Authorization").unwrap()),
@ -2108,7 +2327,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight { let expected_result = ValidationResult::Preflight {
origin: to_origin("https://www.example.com").unwrap(), origin: "https://www.example.com".to_string(),
headers: Some(FromStr::from_str("Authorization").unwrap()), headers: Some(FromStr::from_str("Authorization").unwrap()),
}; };
@ -2226,7 +2445,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request { let expected_result = ValidationResult::Request {
origin: to_origin("https://www.acme.com").unwrap(), origin: "https://www.acme.com".to_string(),
}; };
assert_eq!(expected_result, result); assert_eq!(expected_result, result);
@ -2245,7 +2464,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request { let expected_result = ValidationResult::Request {
origin: to_origin("https://www.example.com").unwrap(), origin: "https://www.example.com".to_string(),
}; };
assert_eq!(expected_result, result); assert_eq!(expected_result, result);
@ -2301,7 +2520,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&to_origin("https://www.acme.com").unwrap(), false) .origin("https://www.acme.com", false)
.headers(&["Authorization"]) .headers(&["Authorization"])
.methods(&options.allowed_methods) .methods(&options.allowed_methods)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
@ -2341,7 +2560,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&to_origin("https://www.acme.com").unwrap(), true) .origin("https://www.acme.com", true)
.headers(&["Authorization"]) .headers(&["Authorization"])
.methods(&options.allowed_methods) .methods(&options.allowed_methods)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
@ -2402,7 +2621,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&to_origin("https://www.acme.com").unwrap(), false) .origin("https://www.acme.com", false)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]); .exposed_headers(&["Content-Type", "X-Custom"]);
@ -2425,7 +2644,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&to_origin("https://www.acme.com").unwrap(), true) .origin("https://www.acme.com", true)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]); .exposed_headers(&["Content-Type", "X-Custom"]);

View File

@ -1,14 +1,14 @@
//! This crate tests using `rocket_cors` using Fairings //! This crate tests using `rocket_cors` using Fairings
#![feature(proc_macro_hygiene, decl_macro)] #![feature(proc_macro_hygiene, decl_macro)]
use hyper; use hyper;
#[macro_use]
extern crate rocket;
use std::str::FromStr; use std::str::FromStr;
use rocket::http::Method; use rocket::http::Method;
use rocket::http::{Header, Status}; use rocket::http::{Header, Status};
use rocket::local::Client; use rocket::local::Client;
use rocket::response::Body;
use rocket::{get, routes};
use rocket_cors::*; use rocket_cors::*;
#[get("/")] #[get("/")]
@ -22,10 +22,10 @@ fn panicking_route() {
} }
fn make_cors() -> Cors { fn make_cors() -> Cors {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
CorsOptions { CorsOptions {
allowed_origins: allowed_origins, allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,
@ -73,7 +73,7 @@ fn smoke_test() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response let origin_header = response
@ -124,7 +124,7 @@ fn cors_get_check() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response let origin_header = response
@ -144,7 +144,7 @@ fn cors_get_no_origin() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
} }

View File

@ -1,8 +1,6 @@
//! This crate tests using `rocket_cors` using the per-route handling with request guard //! This crate tests using `rocket_cors` using the per-route handling with request guard
#![feature(proc_macro_hygiene, decl_macro)] #![feature(proc_macro_hygiene, decl_macro)]
use hyper; use hyper;
#[macro_use]
extern crate rocket;
use rocket_cors as cors; use rocket_cors as cors;
use std::str::FromStr; use std::str::FromStr;
@ -10,6 +8,8 @@ use std::str::FromStr;
use rocket::http::Method; use rocket::http::Method;
use rocket::http::{Header, Status}; use rocket::http::{Header, Status};
use rocket::local::Client; use rocket::local::Client;
use rocket::response::Body;
use rocket::{get, options, routes};
use rocket::{Response, State}; use rocket::{Response, State};
#[get("/")] #[get("/")]
@ -60,10 +60,10 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo
} }
fn make_cors() -> cors::Cors { fn make_cors() -> cors::Cors {
let allowed_origins = cors::AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = cors::AllowedOrigins::some_exact(&["https://www.acme.com"]);
cors::CorsOptions { cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: cors::AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: cors::AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,
@ -119,7 +119,7 @@ fn smoke_test() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response let origin_header = response
@ -205,7 +205,7 @@ fn cors_get_check() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response let origin_header = response
@ -226,7 +226,7 @@ fn cors_get_no_origin() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
assert!(response assert!(response
.headers() .headers()
@ -408,7 +408,7 @@ fn overridden_options_routes_are_used() {
.header(request_headers); .header(request_headers);
let mut response = req.dispatch(); let mut response = req.dispatch();
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
assert_eq!(body_str, Some("Manual CORS Preflight".to_string())); assert_eq!(body_str, Some("Manual CORS Preflight".to_string()));

View File

@ -1,14 +1,14 @@
//! This crate tests that all the request headers are parsed correctly in the round trip //! This crate tests that all the request headers are parsed correctly in the round trip
#![feature(proc_macro_hygiene, decl_macro)] #![feature(proc_macro_hygiene, decl_macro)]
use hyper; use hyper;
#[macro_use]
extern crate rocket;
use std::ops::Deref; use std::ops::Deref;
use std::str::FromStr; use std::str::FromStr;
use rocket::http::Header; use rocket::http::Header;
use rocket::local::Client; use rocket::local::Client;
use rocket::response::Body;
use rocket::{get, routes};
use rocket_cors::headers::*; use rocket_cors::headers::*;
#[get("/request_headers")] #[get("/request_headers")]
@ -53,7 +53,7 @@ fn request_headers_round_trip_smoke_test() {
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response let body_str = response
.body() .body()
.and_then(|body| body.into_string()) .and_then(Body::into_string)
.expect("Non-empty body"); .expect("Non-empty body");
let expected_body = r#"https://foo.bar.xyz let expected_body = r#"https://foo.bar.xyz
GET GET

View File

@ -1,16 +1,16 @@
//! This crate tests using `rocket_cors` using manual mode //! This crate tests using `rocket_cors` using manual mode
#![feature(proc_macro_hygiene, decl_macro)] #![feature(proc_macro_hygiene, decl_macro)]
use hyper; use hyper;
#[macro_use]
extern crate rocket;
use std::str::FromStr; use std::str::FromStr;
use rocket::http::Method; use rocket::http::Method;
use rocket::http::{Header, Status}; use rocket::http::{Header, Status};
use rocket::local::Client; use rocket::local::Client;
use rocket::response::Body;
use rocket::response::Responder; use rocket::response::Responder;
use rocket::State; use rocket::State;
use rocket::{get, options, routes};
use rocket_cors::*; use rocket_cors::*;
/// Using a borrowed `Cors` /// Using a borrowed `Cors`
@ -23,7 +23,7 @@ fn cors(options: State<'_, Cors>) -> impl Responder<'_> {
#[get("/panic")] #[get("/panic")]
fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> { fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> {
options.inner().respond_borrowed(|_| -> () { options.inner().respond_borrowed(|_| {
panic!("This route will panic"); panic!("This route will panic");
}) })
} }
@ -66,10 +66,10 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp
} }
fn make_cors_options() -> CorsOptions { fn make_cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
CorsOptions { CorsOptions {
allowed_origins: allowed_origins, allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,
@ -78,10 +78,10 @@ fn make_cors_options() -> CorsOptions {
} }
fn make_different_cors_options() -> CorsOptions { fn make_different_cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.example.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.example.com"]);
CorsOptions { CorsOptions {
allowed_origins: allowed_origins, allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,
@ -129,7 +129,7 @@ fn smoke_test() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response let origin_header = response
@ -180,7 +180,7 @@ fn cors_get_borrowed_check() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response let origin_header = response
@ -200,7 +200,7 @@ fn cors_get_no_origin() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
} }
@ -378,7 +378,7 @@ fn cors_options_owned_check() {
.header(request_headers); .header(request_headers);
let mut response = req.dispatch(); let mut response = req.dispatch();
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
assert_eq!(body_str, Some("Manual CORS Preflight".to_string())); assert_eq!(body_str, Some("Manual CORS Preflight".to_string()));
@ -404,7 +404,7 @@ fn cors_get_owned_check() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS Owned".to_string())); assert_eq!(body_str, Some("Hello CORS Owned".to_string()));
let origin_header = response let origin_header = response

View File

@ -4,15 +4,15 @@
//! `ping` route that you want to allow all Origins to access. //! `ping` route that you want to allow all Origins to access.
#![feature(proc_macro_hygiene, decl_macro)] #![feature(proc_macro_hygiene, decl_macro)]
use hyper; use hyper;
#[macro_use]
extern crate rocket;
use rocket_cors; use rocket_cors;
use std::str::FromStr; use std::str::FromStr;
use rocket::http::{Header, Method, Status}; use rocket::http::{Header, Method, Status};
use rocket::local::Client; use rocket::local::Client;
use rocket::response::Body;
use rocket::response::Responder; use rocket::response::Responder;
use rocket::{get, options, routes};
use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard}; use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard};
@ -40,11 +40,11 @@ fn ping_options<'r>() -> impl Responder<'r> {
/// Returns the "application wide" Cors struct /// Returns the "application wide" Cors struct
fn cors_options() -> CorsOptions { fn cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this // You can also deserialize this
rocket_cors::CorsOptions { rocket_cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,
@ -100,7 +100,7 @@ fn smoke_test() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS!".to_string())); assert_eq!(body_str, Some("Hello CORS!".to_string()));
let origin_header = response let origin_header = response
@ -151,7 +151,7 @@ fn cors_get_check() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS!".to_string())); assert_eq!(body_str, Some("Hello CORS!".to_string()));
let origin_header = response let origin_header = response
@ -171,7 +171,7 @@ fn cors_get_no_origin() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS!".to_string())); assert_eq!(body_str, Some("Hello CORS!".to_string()));
} }
@ -333,7 +333,7 @@ fn cors_get_ping_check() {
let mut response = req.dispatch(); let mut response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Pong!".to_string())); assert_eq!(body_str, Some("Pong!".to_string()));
let origin_header = response let origin_header = response