Refactor Origins to better support additional use cases (#59)
* Specify an internal structure for Cors * Use type alias * Refactor Origin validation * Separate out `Origin` * Add tests
This commit is contained in:
parent
9d8f7aa6f7
commit
f9bffe77d6
|
@ -18,7 +18,7 @@ travis-ci = { repository = "lawliet89/rocket_cors" }
|
||||||
default = ["serialization"]
|
default = ["serialization"]
|
||||||
|
|
||||||
# Serialization and deserialization support for settings
|
# Serialization and deserialization support for settings
|
||||||
serialization = ["serde", "serde_derive", "unicase_serde", "url_serde"]
|
serialization = ["serde", "serde_derive", "unicase_serde"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
rocket = "0.4.0"
|
rocket = "0.4.0"
|
||||||
|
@ -30,7 +30,6 @@ url = "1.7.2"
|
||||||
serde = { version = "1.0", optional = true }
|
serde = { version = "1.0", optional = true }
|
||||||
serde_derive = { version = "1.0", optional = true }
|
serde_derive = { version = "1.0", optional = true }
|
||||||
unicase_serde = { version = "0.1.0", optional = true }
|
unicase_serde = { version = "0.1.0", optional = true }
|
||||||
url_serde = { version = "0.2.0", optional = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
hyper = "0.10"
|
hyper = "0.10"
|
||||||
|
|
|
@ -12,8 +12,7 @@ fn cors<'a>() -> &'a str {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), Error> {
|
fn main() -> Result<(), Error> {
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
// You can also deserialize this
|
// You can also deserialize this
|
||||||
let cors = rocket_cors::CorsOptions {
|
let cors = rocket_cors::CorsOptions {
|
||||||
|
|
|
@ -36,8 +36,7 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), Error> {
|
fn main() -> Result<(), Error> {
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
// You can also deserialize this
|
// You can also deserialize this
|
||||||
let cors = rocket_cors::CorsOptions {
|
let cors = rocket_cors::CorsOptions {
|
||||||
|
|
|
@ -13,8 +13,7 @@ fn main() {
|
||||||
// The default demonstrates the "All" serialization of several of the settings
|
// The default demonstrates the "All" serialization of several of the settings
|
||||||
let default: CorsOptions = Default::default();
|
let default: CorsOptions = Default::default();
|
||||||
|
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
let options = cors::CorsOptions {
|
let options = cors::CorsOptions {
|
||||||
allowed_origins: allowed_origins,
|
allowed_origins: allowed_origins,
|
||||||
|
|
|
@ -59,8 +59,7 @@ fn owned_options<'r>() -> impl Responder<'r> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cors_options() -> CorsOptions {
|
fn cors_options() -> CorsOptions {
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
// You can also deserialize this
|
// You can also deserialize this
|
||||||
rocket_cors::CorsOptions {
|
rocket_cors::CorsOptions {
|
||||||
|
|
|
@ -36,8 +36,7 @@ fn ping_options<'r>() -> impl Responder<'r> {
|
||||||
|
|
||||||
/// Returns the "application wide" Cors struct
|
/// Returns the "application wide" Cors struct
|
||||||
fn cors_options() -> CorsOptions {
|
fn cors_options() -> CorsOptions {
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
// You can also deserialize this
|
// You can also deserialize this
|
||||||
rocket_cors::CorsOptions {
|
rocket_cors::CorsOptions {
|
||||||
|
|
|
@ -140,8 +140,7 @@ mod tests {
|
||||||
const CORS_ROOT: &'static str = "/my_cors";
|
const CORS_ROOT: &'static str = "/my_cors";
|
||||||
|
|
||||||
fn make_cors_options() -> Cors {
|
fn make_cors_options() -> Cors {
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
CorsOptions {
|
CorsOptions {
|
||||||
allowed_origins,
|
allowed_origins,
|
||||||
|
|
|
@ -11,12 +11,9 @@ use rocket::{self, Outcome};
|
||||||
#[cfg(feature = "serialization")]
|
#[cfg(feature = "serialization")]
|
||||||
use serde_derive::{Deserialize, Serialize};
|
use serde_derive::{Deserialize, Serialize};
|
||||||
use unicase::UniCase;
|
use unicase::UniCase;
|
||||||
use url;
|
|
||||||
|
|
||||||
#[cfg(feature = "serialization")]
|
#[cfg(feature = "serialization")]
|
||||||
use unicase_serde;
|
use unicase_serde;
|
||||||
#[cfg(feature = "serialization")]
|
|
||||||
use url_serde;
|
|
||||||
|
|
||||||
/// A case insensitive header name
|
/// A case insensitive header name
|
||||||
#[derive(Eq, PartialEq, Clone, Debug, Hash)]
|
#[derive(Eq, PartialEq, Clone, Debug, Hash)]
|
||||||
|
@ -62,54 +59,48 @@ impl FromStr for HeaderFieldName {
|
||||||
/// A set of case insensitive header names
|
/// A set of case insensitive header names
|
||||||
pub type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
|
pub type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
|
||||||
|
|
||||||
/// A wrapped `url::Url` to allow for deserialization
|
/// The `Origin` request header used in CORS
|
||||||
|
///
|
||||||
|
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||||
|
/// to ensure that `Origin` is passed in correctly.
|
||||||
#[derive(Eq, PartialEq, Clone, Hash, Debug)]
|
#[derive(Eq, PartialEq, Clone, Hash, Debug)]
|
||||||
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
|
pub struct Origin(pub url::Origin);
|
||||||
pub struct Url(#[cfg_attr(feature = "serialization", serde(with = "url_serde"))] url::Url);
|
|
||||||
|
|
||||||
impl fmt::Display for Url {
|
impl FromStr for Origin {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
type Err = crate::Error;
|
||||||
self.0.fmt(f)
|
|
||||||
|
fn from_str(input: &str) -> Result<Self, Self::Err> {
|
||||||
|
Ok(Origin(crate::to_origin(input)?))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Deref for Url {
|
impl Deref for Origin {
|
||||||
type Target = url::Url;
|
type Target = url::Origin;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FromStr for Url {
|
impl fmt::Display for Origin {
|
||||||
type Err = url::ParseError;
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(f, "{}", self.ascii_serialization())
|
||||||
fn from_str(input: &str) -> Result<Self, Self::Err> {
|
|
||||||
let url = url::Url::from_str(input)?;
|
|
||||||
Ok(Url(url))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Url {
|
impl<'a, 'r> FromRequest<'a, 'r> for Origin {
|
||||||
type Error = crate::Error;
|
type Error = crate::Error;
|
||||||
|
|
||||||
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
|
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
|
||||||
match request.headers().get_one("Origin") {
|
match request.headers().get_one("Origin") {
|
||||||
Some(origin) => match Self::from_str(origin) {
|
Some(origin) => match Self::from_str(origin) {
|
||||||
Ok(origin) => Outcome::Success(origin),
|
Ok(origin) => Outcome::Success(origin),
|
||||||
Err(e) => Outcome::Failure((Status::BadRequest, crate::Error::BadOrigin(e))),
|
Err(e) => Outcome::Failure((Status::BadRequest, e)),
|
||||||
},
|
},
|
||||||
None => Outcome::Forward(()),
|
None => Outcome::Forward(()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The `Origin` request header used in CORS
|
|
||||||
///
|
|
||||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
|
||||||
/// to ensure that `Origin` is passed in correctly.
|
|
||||||
pub type Origin = Url;
|
|
||||||
|
|
||||||
/// The `Access-Control-Request-Method` request header
|
/// The `Access-Control-Request-Method` request header
|
||||||
///
|
///
|
||||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||||
|
@ -202,13 +193,13 @@ mod tests {
|
||||||
fn origin_header_conversion() {
|
fn origin_header_conversion() {
|
||||||
let url = "https://foo.bar.xyz";
|
let url = "https://foo.bar.xyz";
|
||||||
let parsed = not_err!(Origin::from_str(url));
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
let expected = not_err!(Url::from_str(url));
|
assert_eq!(parsed.ascii_serialization(), url);
|
||||||
assert_eq!(parsed, expected);
|
|
||||||
|
|
||||||
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
|
// this should never really be sent by a compliant user agent
|
||||||
|
let url = "https://foo.bar.xyz/path/somewhere";
|
||||||
let parsed = not_err!(Origin::from_str(url));
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
let expected = not_err!(Url::from_str(url));
|
let expected = "https://foo.bar.xyz";
|
||||||
assert_eq!(parsed, expected);
|
assert_eq!(parsed.ascii_serialization(), expected);
|
||||||
|
|
||||||
let url = "invalid_url";
|
let url = "invalid_url";
|
||||||
let _ = is_err!(Origin::from_str(url));
|
let _ = is_err!(Origin::from_str(url));
|
||||||
|
@ -225,7 +216,10 @@ mod tests {
|
||||||
let outcome: request::Outcome<Origin, crate::Error> =
|
let outcome: request::Outcome<Origin, crate::Error> =
|
||||||
FromRequest::from_request(request.inner());
|
FromRequest::from_request(request.inner());
|
||||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
assert_eq!("https://www.example.com/", parsed_header.as_str());
|
assert_eq!(
|
||||||
|
"https://www.example.com",
|
||||||
|
parsed_header.ascii_serialization()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
260
src/lib.rs
260
src/lib.rs
|
@ -276,7 +276,7 @@ mod fairing;
|
||||||
pub mod headers;
|
pub mod headers;
|
||||||
|
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::HashSet;
|
||||||
use std::error;
|
use std::error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
@ -293,7 +293,7 @@ use serde_derive::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::headers::{
|
use crate::headers::{
|
||||||
AccessControlRequestHeaders, AccessControlRequestMethod, HeaderFieldName, HeaderFieldNamesSet,
|
AccessControlRequestHeaders, AccessControlRequestMethod, HeaderFieldName, HeaderFieldNamesSet,
|
||||||
Origin, Url,
|
Origin,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Errors during operations
|
/// Errors during operations
|
||||||
|
@ -316,7 +316,7 @@ pub enum Error {
|
||||||
/// The request header `Access-Control-Request-Headers` is required but is missing.
|
/// The request header `Access-Control-Request-Headers` is required but is missing.
|
||||||
MissingRequestHeaders,
|
MissingRequestHeaders,
|
||||||
/// Origin is not allowed to make this request
|
/// Origin is not allowed to make this request
|
||||||
OriginNotAllowed(String),
|
OriginNotAllowed(url::Origin),
|
||||||
/// Requested method is not allowed
|
/// Requested method is not allowed
|
||||||
MethodNotAllowed(String),
|
MethodNotAllowed(String),
|
||||||
/// One or more headers requested are not allowed
|
/// One or more headers requested are not allowed
|
||||||
|
@ -365,7 +365,7 @@ impl fmt::Display for Error {
|
||||||
"The request header `Access-Control-Request-Headers` \
|
"The request header `Access-Control-Request-Headers` \
|
||||||
is required but is missing")
|
is required but is missing")
|
||||||
}
|
}
|
||||||
Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", &origin),
|
Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin.ascii_serialization()),
|
||||||
Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method),
|
Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method),
|
||||||
Error::HeadersNotAllowed => write!(f, "Headers are not allowed"),
|
Error::HeadersNotAllowed => write!(f, "Headers are not allowed"),
|
||||||
Error::CredentialsWithWildcardOrigin => { write!(f,
|
Error::CredentialsWithWildcardOrigin => { write!(f,
|
||||||
|
@ -398,6 +398,12 @@ impl<'r> response::Responder<'r> for Error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<url::ParseError> for Error {
|
||||||
|
fn from(error: url::ParseError) -> Self {
|
||||||
|
Error::BadOrigin(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
|
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
|
||||||
///
|
///
|
||||||
/// `Default` is implemented for this enum and is `All`.
|
/// `Default` is implemented for this enum and is `All`.
|
||||||
|
@ -523,31 +529,16 @@ mod method_serde {
|
||||||
/// use rocket_cors::AllowedOrigins;
|
/// use rocket_cors::AllowedOrigins;
|
||||||
///
|
///
|
||||||
/// let all_origins = AllowedOrigins::all();
|
/// let all_origins = AllowedOrigins::all();
|
||||||
/// let (some_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
/// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
/// assert!(failed_origins.is_empty());
|
|
||||||
/// ```
|
/// ```
|
||||||
pub type AllowedOrigins = AllOrSome<HashSet<Url>>;
|
pub type AllowedOrigins = AllOrSome<HashSet<String>>;
|
||||||
|
|
||||||
impl AllowedOrigins {
|
impl AllowedOrigins {
|
||||||
/// Allows some origins
|
/// Allows some origins
|
||||||
///
|
///
|
||||||
/// Returns a tuple where the first element is the struct `AllowedOrigins`,
|
/// Validation is not performed at this stage, but at a later stage.
|
||||||
/// and the second element
|
pub fn some(urls: &[&str]) -> Self {
|
||||||
/// is a map of strings which failed to parse into URLs and their associated parse errors.
|
AllOrSome::Some(urls.iter().map(|s| s.to_string()).collect())
|
||||||
pub fn some(urls: &[&str]) -> (Self, HashMap<String, url::ParseError>) {
|
|
||||||
let (ok_set, error_map): (Vec<_>, Vec<_>) = urls
|
|
||||||
.iter()
|
|
||||||
.map(|s| (s.to_string(), Url::from_str(s)))
|
|
||||||
.partition(|&(_, ref r)| r.is_ok());
|
|
||||||
|
|
||||||
let error_map = error_map
|
|
||||||
.into_iter()
|
|
||||||
.map(|(s, r)| (s.to_string(), r.unwrap_err()))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let ok_set = ok_set.into_iter().map(|(_, r)| r.unwrap()).collect();
|
|
||||||
|
|
||||||
(AllOrSome::Some(ok_set), error_map)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allows all origins
|
/// Allows all origins
|
||||||
|
@ -646,7 +637,7 @@ impl AllowedHeaders {
|
||||||
/// {
|
/// {
|
||||||
/// "allowed_origins": {
|
/// "allowed_origins": {
|
||||||
/// "Some": [
|
/// "Some": [
|
||||||
/// "https://www.acme.com/"
|
/// "https://www.acme.com"
|
||||||
/// ]
|
/// ]
|
||||||
/// },
|
/// },
|
||||||
/// "allowed_methods": [
|
/// "allowed_methods": [
|
||||||
|
@ -714,7 +705,7 @@ pub struct CorsOptions {
|
||||||
///
|
///
|
||||||
/// Defaults to `All`.
|
/// Defaults to `All`.
|
||||||
#[cfg_attr(feature = "serialization", serde(default))]
|
#[cfg_attr(feature = "serialization", serde(default))]
|
||||||
pub allowed_headers: AllOrSome<HashSet<HeaderFieldName>>,
|
pub allowed_headers: AllowedHeaders,
|
||||||
/// Allows users to make authenticated requests.
|
/// Allows users to make authenticated requests.
|
||||||
/// If true, injects the `Access-Control-Allow-Credentials` header in responses.
|
/// If true, injects the `Access-Control-Allow-Credentials` header in responses.
|
||||||
/// This allows cookies and credentials to be submitted across domains.
|
/// This allows cookies and credentials to be submitted across domains.
|
||||||
|
@ -819,9 +810,7 @@ impl CorsOptions {
|
||||||
0
|
0
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validates if any of the settings are disallowed or incorrect
|
/// Validates if any of the settings are disallowed, incorrect, or illegal
|
||||||
///
|
|
||||||
/// This is run during initial Fairing attachment
|
|
||||||
pub fn validate(&self) -> Result<(), Error> {
|
pub fn validate(&self) -> Result<(), Error> {
|
||||||
if self.allowed_origins.is_all() && self.send_wildcard && self.allow_credentials {
|
if self.allowed_origins.is_all() && self.send_wildcard && self.allow_credentials {
|
||||||
Err(Error::CredentialsWithWildcardOrigin)?;
|
Err(Error::CredentialsWithWildcardOrigin)?;
|
||||||
|
@ -842,22 +831,37 @@ impl CorsOptions {
|
||||||
/// documentation at the [crate root](index.html) for usage information.
|
/// documentation at the [crate root](index.html) for usage information.
|
||||||
///
|
///
|
||||||
/// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`].
|
/// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`].
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
pub struct Cors(CorsOptions);
|
pub struct Cors {
|
||||||
|
pub(crate) allowed_origins: AllOrSome<HashSet<url::Origin>>,
|
||||||
impl Deref for Cors {
|
pub(crate) allowed_methods: AllowedMethods,
|
||||||
type Target = CorsOptions;
|
pub(crate) allowed_headers: AllOrSome<HashSet<HeaderFieldName>>,
|
||||||
|
pub(crate) allow_credentials: bool,
|
||||||
fn deref(&self) -> &Self::Target {
|
pub(crate) expose_headers: HashSet<String>,
|
||||||
&self.0
|
pub(crate) max_age: Option<usize>,
|
||||||
}
|
pub(crate) send_wildcard: bool,
|
||||||
|
pub(crate) fairing_route_base: String,
|
||||||
|
pub(crate) fairing_route_rank: isize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Cors {
|
impl Cors {
|
||||||
/// Create a `Cors` struct from a [`CorsOptions`]
|
/// Create a `Cors` struct from a [`CorsOptions`]
|
||||||
pub fn from_options(options: &CorsOptions) -> Result<Self, Error> {
|
pub fn from_options(options: &CorsOptions) -> Result<Self, Error> {
|
||||||
options.validate()?;
|
options.validate()?;
|
||||||
Ok(Cors(options.clone()))
|
|
||||||
|
let allowed_origins = parse_origins(&options.allowed_origins)?;
|
||||||
|
|
||||||
|
Ok(Cors {
|
||||||
|
allowed_origins,
|
||||||
|
allowed_methods: options.allowed_methods.clone(),
|
||||||
|
allowed_headers: options.allowed_headers.clone(),
|
||||||
|
allow_credentials: options.allow_credentials,
|
||||||
|
expose_headers: options.expose_headers.clone(),
|
||||||
|
max_age: options.max_age,
|
||||||
|
send_wildcard: options.send_wildcard,
|
||||||
|
fairing_route_base: options.fairing_route_base.clone(),
|
||||||
|
fairing_route_rank: options.fairing_route_rank,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Manually respond to a request with CORS checks and headers using an Owned `Cors`.
|
/// Manually respond to a request with CORS checks and headers using an Owned `Cors`.
|
||||||
|
@ -917,7 +921,7 @@ impl Cors {
|
||||||
/// You can get this struct by using `Cors::validate_request` in an ad-hoc manner.
|
/// You can get this struct by using `Cors::validate_request` in an ad-hoc manner.
|
||||||
#[derive(Eq, PartialEq, Debug)]
|
#[derive(Eq, PartialEq, Debug)]
|
||||||
pub(crate) struct Response {
|
pub(crate) struct Response {
|
||||||
allow_origin: Option<AllOrSome<Url>>,
|
allow_origin: Option<AllOrSome<url::Origin>>,
|
||||||
allow_methods: HashSet<Method>,
|
allow_methods: HashSet<Method>,
|
||||||
allow_headers: HeaderFieldNamesSet,
|
allow_headers: HeaderFieldNamesSet,
|
||||||
allow_credentials: bool,
|
allow_credentials: bool,
|
||||||
|
@ -941,7 +945,7 @@ impl Response {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consumes the `Response` and return an altered response with origin and `vary_origin` set
|
/// Consumes the `Response` and return an altered response with origin and `vary_origin` set
|
||||||
fn origin(mut self, origin: &Url, vary_origin: bool) -> Self {
|
fn origin(mut self, origin: &url::Origin, vary_origin: bool) -> Self {
|
||||||
self.allow_origin = Some(AllOrSome::Some(origin.clone()));
|
self.allow_origin = Some(AllOrSome::Some(origin.clone()));
|
||||||
self.vary_origin = vary_origin;
|
self.vary_origin = vary_origin;
|
||||||
self
|
self
|
||||||
|
@ -1016,11 +1020,9 @@ impl Response {
|
||||||
Some(ref origin) => origin,
|
Some(ref origin) => origin,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Origin should be ASCII serialized
|
|
||||||
// c.f. https://html.spec.whatwg.org/multipage/origin.html#ascii-serialisation-of-an-origin
|
|
||||||
let origin = match *origin {
|
let origin = match *origin {
|
||||||
AllOrSome::All => "*".to_string(),
|
AllOrSome::All => "*".to_string(),
|
||||||
AllOrSome::Some(ref origin) => origin.origin().ascii_serialization(),
|
AllOrSome::Some(ref origin) => origin.ascii_serialization(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let _ = response.set_raw_header("Access-Control-Allow-Origin", origin);
|
let _ = response.set_raw_header("Access-Control-Allow-Origin", origin);
|
||||||
|
@ -1249,11 +1251,29 @@ enum ValidationResult {
|
||||||
None,
|
None,
|
||||||
/// Successful preflight request
|
/// Successful preflight request
|
||||||
Preflight {
|
Preflight {
|
||||||
origin: Origin,
|
origin: url::Origin,
|
||||||
headers: Option<AccessControlRequestHeaders>,
|
headers: Option<AccessControlRequestHeaders>,
|
||||||
},
|
},
|
||||||
/// Successful actual request
|
/// Successful actual request
|
||||||
Request { origin: Origin },
|
Request { origin: url::Origin },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a str to Origin
|
||||||
|
fn to_origin<S: AsRef<str>>(origin: S) -> Result<url::Origin, Error> {
|
||||||
|
// What to do about Opaque origins?
|
||||||
|
Ok(url::Url::parse(origin.as_ref())?.origin())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse and process allowed origins
|
||||||
|
fn parse_origins(origins: &AllowedOrigins) -> Result<AllOrSome<HashSet<url::Origin>>, Error> {
|
||||||
|
match origins {
|
||||||
|
AllOrSome::All => Ok(AllOrSome::All),
|
||||||
|
AllOrSome::Some(ref origins) => {
|
||||||
|
let parsed: Result<HashSet<url::Origin>, Error> =
|
||||||
|
origins.iter().map(to_origin).collect();
|
||||||
|
Ok(AllOrSome::Some(parsed?))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validates a request for CORS and returns a CORS Response
|
/// Validates a request for CORS and returns a CORS Response
|
||||||
|
@ -1288,11 +1308,16 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
|
||||||
let method = request_method(request)?;
|
let method = request_method(request)?;
|
||||||
let headers = request_headers(request)?;
|
let headers = request_headers(request)?;
|
||||||
preflight_validate(options, &origin, &method, &headers)?;
|
preflight_validate(options, &origin, &method, &headers)?;
|
||||||
Ok(ValidationResult::Preflight { origin, headers })
|
Ok(ValidationResult::Preflight {
|
||||||
|
origin: origin.deref().clone(),
|
||||||
|
headers,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
actual_request_validate(options, &origin)?;
|
actual_request_validate(options, &origin)?;
|
||||||
Ok(ValidationResult::Request { origin })
|
Ok(ValidationResult::Request {
|
||||||
|
origin: origin.deref().clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1301,8 +1326,8 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
|
||||||
/// check if the requested origin is allowed.
|
/// check if the requested origin is allowed.
|
||||||
/// Useful for pre-flight and during requests
|
/// Useful for pre-flight and during requests
|
||||||
fn validate_origin(
|
fn validate_origin(
|
||||||
origin: &Origin,
|
origin: &url::Origin,
|
||||||
allowed_origins: &AllOrSome<HashSet<Url>>,
|
allowed_origins: &AllOrSome<HashSet<url::Origin>>,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
match *allowed_origins {
|
match *allowed_origins {
|
||||||
// Always matching is acceptable since the list of origins can be unbounded.
|
// Always matching is acceptable since the list of origins can be unbounded.
|
||||||
|
@ -1310,14 +1335,14 @@ fn validate_origin(
|
||||||
AllOrSome::Some(ref allowed_origins) => allowed_origins
|
AllOrSome::Some(ref allowed_origins) => allowed_origins
|
||||||
.get(origin)
|
.get(origin)
|
||||||
.and_then(|_| Some(()))
|
.and_then(|_| Some(()))
|
||||||
.ok_or_else(|| Error::OriginNotAllowed(origin.to_string())),
|
.ok_or_else(|| Error::OriginNotAllowed(origin.clone())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate allowed methods
|
/// Validate allowed methods
|
||||||
fn validate_allowed_method(
|
fn validate_allowed_method(
|
||||||
method: &AccessControlRequestMethod,
|
method: &AccessControlRequestMethod,
|
||||||
allowed_methods: &HashSet<Method>,
|
allowed_methods: &AllowedMethods,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let &AccessControlRequestMethod(ref request_method) = method;
|
let &AccessControlRequestMethod(ref request_method) = method;
|
||||||
if !allowed_methods.iter().any(|m| m == request_method) {
|
if !allowed_methods.iter().any(|m| m == request_method) {
|
||||||
|
@ -1331,7 +1356,7 @@ fn validate_allowed_method(
|
||||||
/// Validate allowed headers
|
/// Validate allowed headers
|
||||||
fn validate_allowed_headers(
|
fn validate_allowed_headers(
|
||||||
headers: &AccessControlRequestHeaders,
|
headers: &AccessControlRequestHeaders,
|
||||||
allowed_headers: &AllOrSome<HashSet<HeaderFieldName>>,
|
allowed_headers: &AllowedHeaders,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let &AccessControlRequestHeaders(ref headers) = headers;
|
let &AccessControlRequestHeaders(ref headers) = headers;
|
||||||
|
|
||||||
|
@ -1380,12 +1405,10 @@ fn request_headers(request: &Request<'_>) -> Result<Option<AccessControlRequestH
|
||||||
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch)
|
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch)
|
||||||
fn preflight_validate(
|
fn preflight_validate(
|
||||||
options: &Cors,
|
options: &Cors,
|
||||||
origin: &Origin,
|
origin: &url::Origin,
|
||||||
method: &Option<AccessControlRequestMethod>,
|
method: &Option<AccessControlRequestMethod>,
|
||||||
headers: &Option<AccessControlRequestHeaders>,
|
headers: &Option<AccessControlRequestHeaders>,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
options.validate()?; // Fast-forward check for #7
|
|
||||||
|
|
||||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||||
|
|
||||||
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
||||||
|
@ -1430,7 +1453,7 @@ fn preflight_validate(
|
||||||
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch).
|
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch).
|
||||||
fn preflight_response(
|
fn preflight_response(
|
||||||
options: &Cors,
|
options: &Cors,
|
||||||
origin: &Origin,
|
origin: &url::Origin,
|
||||||
headers: Option<&AccessControlRequestHeaders>,
|
headers: Option<&AccessControlRequestHeaders>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
|
@ -1501,9 +1524,7 @@ fn preflight_response(
|
||||||
/// This implementation references the
|
/// This implementation references the
|
||||||
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests)
|
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests)
|
||||||
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch).
|
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch).
|
||||||
fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> {
|
fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), Error> {
|
||||||
options.validate()?;
|
|
||||||
|
|
||||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||||
|
|
||||||
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
||||||
|
@ -1520,7 +1541,7 @@ fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error>
|
||||||
/// This implementation references the
|
/// This implementation references the
|
||||||
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests)
|
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests)
|
||||||
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch)
|
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch)
|
||||||
fn actual_request_response(options: &Cors, origin: &Origin) -> Response {
|
fn actual_request_response(options: &Cors, origin: &url::Origin) -> Response {
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
|
|
||||||
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
|
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
|
||||||
|
@ -1620,8 +1641,7 @@ mod tests {
|
||||||
use crate::http::Method;
|
use crate::http::Method;
|
||||||
|
|
||||||
fn make_cors_options() -> CorsOptions {
|
fn make_cors_options() -> CorsOptions {
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
CorsOptions {
|
CorsOptions {
|
||||||
allowed_origins,
|
allowed_origins,
|
||||||
|
@ -1653,6 +1673,39 @@ mod tests {
|
||||||
Client::new(rocket).expect("valid rocket instance")
|
Client::new(rocket).expect("valid rocket instance")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// `to_origin` tests
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn origin_is_parsed_properly() {
|
||||||
|
let url = "https://foo.bar.xyz";
|
||||||
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
|
assert_eq!(parsed.ascii_serialization(), url);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn origin_parsing_strips_paths() {
|
||||||
|
// this should never really be sent by a compliant user agent
|
||||||
|
let url = "https://foo.bar.xyz/path/somewhere";
|
||||||
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
|
let expected = "https://foo.bar.xyz";
|
||||||
|
assert_eq!(parsed.ascii_serialization(), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "BadOrigin")]
|
||||||
|
fn origin_parsing_disallows_invalid_origins() {
|
||||||
|
let url = "invalid_url";
|
||||||
|
let _ = Origin::from_str(url).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn origin_parses_opaque_origins() {
|
||||||
|
let url = "blob://foobar";
|
||||||
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
|
|
||||||
|
assert!(!parsed.is_tuple());
|
||||||
|
}
|
||||||
|
|
||||||
// CORS options test
|
// CORS options test
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -1681,7 +1734,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn validate_origin_allows_all_origins() {
|
fn validate_origin_allows_all_origins() {
|
||||||
let url = "https://www.example.com";
|
let url = "https://www.example.com";
|
||||||
let origin = Origin::from_str(url).unwrap();
|
let origin = not_err!(to_origin(&url));
|
||||||
let allowed_origins = AllOrSome::All;
|
let allowed_origins = AllOrSome::All;
|
||||||
|
|
||||||
not_err!(validate_origin(&origin, &allowed_origins));
|
not_err!(validate_origin(&origin, &allowed_origins));
|
||||||
|
@ -1690,20 +1743,40 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn validate_origin_allows_origin() {
|
fn validate_origin_allows_origin() {
|
||||||
let url = "https://www.example.com";
|
let url = "https://www.example.com";
|
||||||
let origin = Origin::from_str(url).unwrap();
|
let origin = not_err!(to_origin(&url));
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]);
|
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
|
||||||
assert!(failed_origins.is_empty());
|
"https://www.example.com"
|
||||||
|
])));
|
||||||
|
|
||||||
not_err!(validate_origin(&origin, &allowed_origins));
|
not_err!(validate_origin(&origin, &allowed_origins));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_origin_handles_punycode_properly() {
|
||||||
|
// Test a variety of scenarios where the Origin and settings are in punycode, or not
|
||||||
|
let cases = vec![
|
||||||
|
("https://аpple.com", "https://аpple.com"),
|
||||||
|
("https://аpple.com", "https://xn--pple-43d.com"),
|
||||||
|
("https://xn--pple-43d.com", "https://аpple.com"),
|
||||||
|
("https://xn--pple-43d.com", "https://xn--pple-43d.com"),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (url, allowed_origin) in cases {
|
||||||
|
let origin = not_err!(to_origin(&url));
|
||||||
|
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[allowed_origin])));
|
||||||
|
|
||||||
|
not_err!(validate_origin(&origin, &allowed_origins));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "OriginNotAllowed")]
|
#[should_panic(expected = "OriginNotAllowed")]
|
||||||
fn validate_origin_rejects_invalid_origin() {
|
fn validate_origin_rejects_invalid_origin() {
|
||||||
let url = "https://www.acme.com";
|
let url = "https://www.acme.com";
|
||||||
let origin = Origin::from_str(url).unwrap();
|
let origin = not_err!(to_origin(&url));
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]);
|
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
|
||||||
assert!(failed_origins.is_empty());
|
"https://www.example.com"
|
||||||
|
])));
|
||||||
|
|
||||||
validate_origin(&origin, &allowed_origins).unwrap();
|
validate_origin(&origin, &allowed_origins).unwrap();
|
||||||
}
|
}
|
||||||
|
@ -1711,10 +1784,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn response_sets_allow_origin_without_vary_correctly() {
|
fn response_sets_allow_origin_without_vary_correctly() {
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
let response = response.origin(
|
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
|
||||||
&FromStr::from_str("https://www.example.com").unwrap(),
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Build response and check built response header
|
// Build response and check built response header
|
||||||
let expected_header = vec!["https://www.example.com"];
|
let expected_header = vec!["https://www.example.com"];
|
||||||
|
@ -1731,8 +1801,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn response_sets_allow_origin_with_vary_correctly() {
|
fn response_sets_allow_origin_with_vary_correctly() {
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
let response =
|
let response = response.origin(&to_origin("https://www.example.com").unwrap(), true);
|
||||||
response.origin(&FromStr::from_str("https://www.example.com").unwrap(), true);
|
|
||||||
|
|
||||||
// Build response and check built response header
|
// Build response and check built response header
|
||||||
let expected_header = vec!["https://www.example.com"];
|
let expected_header = vec!["https://www.example.com"];
|
||||||
|
@ -1762,9 +1831,10 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn response_sets_allow_origin_with_ascii_serialization() {
|
fn response_sets_allow_origin_with_ascii_serialization() {
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
let response = response.origin(&FromStr::from_str("https://аpple.com").unwrap(), false);
|
let response = response.origin(&to_origin("https://аpple.com").unwrap(), false);
|
||||||
|
|
||||||
// Build response and check built response header
|
// Build response and check built response header
|
||||||
|
// This is "punycode"
|
||||||
let expected_header = vec!["https://xn--pple-43d.com"];
|
let expected_header = vec!["https://xn--pple-43d.com"];
|
||||||
let response = response.response(response::Response::new());
|
let response = response.response(response::Response::new());
|
||||||
let actual_header: Vec<_> = response
|
let actual_header: Vec<_> = response
|
||||||
|
@ -1778,10 +1848,7 @@ mod tests {
|
||||||
fn response_sets_exposed_headers_correctly() {
|
fn response_sets_exposed_headers_correctly() {
|
||||||
let headers = vec!["Bar", "Baz", "Foo"];
|
let headers = vec!["Bar", "Baz", "Foo"];
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
let response = response.origin(
|
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
|
||||||
&FromStr::from_str("https://www.example.com").unwrap(),
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
let response = response.exposed_headers(&headers);
|
let response = response.exposed_headers(&headers);
|
||||||
|
|
||||||
// Build response and check built response header
|
// Build response and check built response header
|
||||||
|
@ -1803,10 +1870,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn response_sets_max_age_correctly() {
|
fn response_sets_max_age_correctly() {
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
let response = response.origin(
|
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
|
||||||
&FromStr::from_str("https://www.example.com").unwrap(),
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
let response = response.max_age(Some(42));
|
let response = response.max_age(Some(42));
|
||||||
|
|
||||||
|
@ -1820,10 +1884,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn response_does_not_set_max_age_when_none() {
|
fn response_does_not_set_max_age_when_none() {
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
let response = response.origin(
|
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
|
||||||
&FromStr::from_str("https://www.example.com").unwrap(),
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
let response = response.max_age(None);
|
let response = response.max_age(None);
|
||||||
|
|
||||||
|
@ -1936,10 +1997,7 @@ mod tests {
|
||||||
.finalize();
|
.finalize();
|
||||||
|
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
let response = response.origin(
|
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
|
||||||
&FromStr::from_str("https://www.example.com").unwrap(),
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
let response = response.response(original);
|
let response = response.response(original);
|
||||||
// Check CORS header
|
// Check CORS header
|
||||||
let expected_header = vec!["https://www.example.com"];
|
let expected_header = vec!["https://www.example.com"];
|
||||||
|
@ -2015,7 +2073,7 @@ mod tests {
|
||||||
|
|
||||||
let result = validate(&cors, request.inner()).expect("to not fail");
|
let result = validate(&cors, request.inner()).expect("to not fail");
|
||||||
let expected_result = ValidationResult::Preflight {
|
let expected_result = ValidationResult::Preflight {
|
||||||
origin: FromStr::from_str("https://www.acme.com").unwrap(),
|
origin: to_origin("https://www.acme.com").unwrap(),
|
||||||
// Checks that only a subset of allowed headers are returned
|
// Checks that only a subset of allowed headers are returned
|
||||||
// -- i.e. whatever is requested for
|
// -- i.e. whatever is requested for
|
||||||
headers: Some(FromStr::from_str("Authorization").unwrap()),
|
headers: Some(FromStr::from_str("Authorization").unwrap()),
|
||||||
|
@ -2050,7 +2108,7 @@ mod tests {
|
||||||
|
|
||||||
let result = validate(&cors, request.inner()).expect("to not fail");
|
let result = validate(&cors, request.inner()).expect("to not fail");
|
||||||
let expected_result = ValidationResult::Preflight {
|
let expected_result = ValidationResult::Preflight {
|
||||||
origin: FromStr::from_str("https://www.example.com").unwrap(),
|
origin: to_origin("https://www.example.com").unwrap(),
|
||||||
headers: Some(FromStr::from_str("Authorization").unwrap()),
|
headers: Some(FromStr::from_str("Authorization").unwrap()),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2168,7 +2226,7 @@ mod tests {
|
||||||
|
|
||||||
let result = validate(&cors, request.inner()).expect("to not fail");
|
let result = validate(&cors, request.inner()).expect("to not fail");
|
||||||
let expected_result = ValidationResult::Request {
|
let expected_result = ValidationResult::Request {
|
||||||
origin: FromStr::from_str("https://www.acme.com").unwrap(),
|
origin: to_origin("https://www.acme.com").unwrap(),
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(expected_result, result);
|
assert_eq!(expected_result, result);
|
||||||
|
@ -2187,7 +2245,7 @@ mod tests {
|
||||||
|
|
||||||
let result = validate(&cors, request.inner()).expect("to not fail");
|
let result = validate(&cors, request.inner()).expect("to not fail");
|
||||||
let expected_result = ValidationResult::Request {
|
let expected_result = ValidationResult::Request {
|
||||||
origin: FromStr::from_str("https://www.example.com").unwrap(),
|
origin: to_origin("https://www.example.com").unwrap(),
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(expected_result, result);
|
assert_eq!(expected_result, result);
|
||||||
|
@ -2243,7 +2301,7 @@ mod tests {
|
||||||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||||
|
|
||||||
let expected_response = Response::new()
|
let expected_response = Response::new()
|
||||||
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false)
|
.origin(&to_origin("https://www.acme.com").unwrap(), false)
|
||||||
.headers(&["Authorization"])
|
.headers(&["Authorization"])
|
||||||
.methods(&options.allowed_methods)
|
.methods(&options.allowed_methods)
|
||||||
.credentials(options.allow_credentials)
|
.credentials(options.allow_credentials)
|
||||||
|
@ -2283,7 +2341,7 @@ mod tests {
|
||||||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||||
|
|
||||||
let expected_response = Response::new()
|
let expected_response = Response::new()
|
||||||
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true)
|
.origin(&to_origin("https://www.acme.com").unwrap(), true)
|
||||||
.headers(&["Authorization"])
|
.headers(&["Authorization"])
|
||||||
.methods(&options.allowed_methods)
|
.methods(&options.allowed_methods)
|
||||||
.credentials(options.allow_credentials)
|
.credentials(options.allow_credentials)
|
||||||
|
@ -2344,7 +2402,7 @@ mod tests {
|
||||||
|
|
||||||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||||
let expected_response = Response::new()
|
let expected_response = Response::new()
|
||||||
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false)
|
.origin(&to_origin("https://www.acme.com").unwrap(), false)
|
||||||
.credentials(options.allow_credentials)
|
.credentials(options.allow_credentials)
|
||||||
.exposed_headers(&["Content-Type", "X-Custom"]);
|
.exposed_headers(&["Content-Type", "X-Custom"]);
|
||||||
|
|
||||||
|
@ -2367,7 +2425,7 @@ mod tests {
|
||||||
|
|
||||||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||||
let expected_response = Response::new()
|
let expected_response = Response::new()
|
||||||
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true)
|
.origin(&to_origin("https://www.acme.com").unwrap(), true)
|
||||||
.credentials(options.allow_credentials)
|
.credentials(options.allow_credentials)
|
||||||
.exposed_headers(&["Content-Type", "X-Custom"]);
|
.exposed_headers(&["Content-Type", "X-Custom"]);
|
||||||
|
|
||||||
|
|
|
@ -22,8 +22,7 @@ fn panicking_route() {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_cors() -> Cors {
|
fn make_cors() -> Cors {
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
CorsOptions {
|
CorsOptions {
|
||||||
allowed_origins: allowed_origins,
|
allowed_origins: allowed_origins,
|
||||||
|
|
|
@ -60,8 +60,7 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_cors() -> cors::Cors {
|
fn make_cors() -> cors::Cors {
|
||||||
let (allowed_origins, failed_origins) = cors::AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = cors::AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
cors::CorsOptions {
|
cors::CorsOptions {
|
||||||
allowed_origins: allowed_origins,
|
allowed_origins: allowed_origins,
|
||||||
|
|
|
@ -55,7 +55,7 @@ fn request_headers_round_trip_smoke_test() {
|
||||||
.body()
|
.body()
|
||||||
.and_then(|body| body.into_string())
|
.and_then(|body| body.into_string())
|
||||||
.expect("Non-empty body");
|
.expect("Non-empty body");
|
||||||
let expected_body = r#"https://foo.bar.xyz/
|
let expected_body = r#"https://foo.bar.xyz
|
||||||
GET
|
GET
|
||||||
X-Ping, accept-language"#;
|
X-Ping, accept-language"#;
|
||||||
assert_eq!(expected_body, body_str);
|
assert_eq!(expected_body, body_str);
|
||||||
|
|
|
@ -66,8 +66,7 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_cors_options() -> CorsOptions {
|
fn make_cors_options() -> CorsOptions {
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
CorsOptions {
|
CorsOptions {
|
||||||
allowed_origins: allowed_origins,
|
allowed_origins: allowed_origins,
|
||||||
|
@ -79,8 +78,7 @@ fn make_cors_options() -> CorsOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_different_cors_options() -> CorsOptions {
|
fn make_different_cors_options() -> CorsOptions {
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.example.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
CorsOptions {
|
CorsOptions {
|
||||||
allowed_origins: allowed_origins,
|
allowed_origins: allowed_origins,
|
||||||
|
|
|
@ -40,8 +40,7 @@ fn ping_options<'r>() -> impl Responder<'r> {
|
||||||
|
|
||||||
/// Returns the "application wide" Cors struct
|
/// Returns the "application wide" Cors struct
|
||||||
fn cors_options() -> CorsOptions {
|
fn cors_options() -> CorsOptions {
|
||||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
assert!(failed_origins.is_empty());
|
|
||||||
|
|
||||||
// You can also deserialize this
|
// You can also deserialize this
|
||||||
rocket_cors::CorsOptions {
|
rocket_cors::CorsOptions {
|
||||||
|
|
Loading…
Reference in New Issue