Refactor Origin validation

This commit is contained in:
Yong Wen Chua 2018-12-19 11:08:30 +08:00
parent f4b30501df
commit d7e5153e27
No known key found for this signature in database
GPG Key ID: EDC57EEC439CF10B
14 changed files with 138 additions and 146 deletions

View File

@ -18,7 +18,7 @@ travis-ci = { repository = "lawliet89/rocket_cors" }
default = ["serialization"] default = ["serialization"]
# Serialization and deserialization support for settings # Serialization and deserialization support for settings
serialization = ["serde", "serde_derive", "unicase_serde", "url_serde"] serialization = ["serde", "serde_derive", "unicase_serde"]
[dependencies] [dependencies]
rocket = "0.4.0" rocket = "0.4.0"
@ -30,7 +30,6 @@ url = "1.7.2"
serde = { version = "1.0", optional = true } serde = { version = "1.0", optional = true }
serde_derive = { version = "1.0", optional = true } serde_derive = { version = "1.0", optional = true }
unicase_serde = { version = "0.1.0", optional = true } unicase_serde = { version = "0.1.0", optional = true }
url_serde = { version = "0.2.0", optional = true }
[dev-dependencies] [dev-dependencies]
hyper = "0.10" hyper = "0.10"

View File

@ -12,8 +12,7 @@ fn cors<'a>() -> &'a str {
} }
fn main() -> Result<(), Error> { fn main() -> Result<(), Error> {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
let cors = rocket_cors::CorsOptions { let cors = rocket_cors::CorsOptions {

View File

@ -36,8 +36,7 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> {
} }
fn main() -> Result<(), Error> { fn main() -> Result<(), Error> {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
let cors = rocket_cors::CorsOptions { let cors = rocket_cors::CorsOptions {

View File

@ -13,8 +13,7 @@ fn main() {
// The default demonstrates the "All" serialization of several of the settings // The default demonstrates the "All" serialization of several of the settings
let default: CorsOptions = Default::default(); let default: CorsOptions = Default::default();
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
let options = cors::CorsOptions { let options = cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,

View File

@ -59,8 +59,7 @@ fn owned_options<'r>() -> impl Responder<'r> {
} }
fn cors_options() -> CorsOptions { fn cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
rocket_cors::CorsOptions { rocket_cors::CorsOptions {

View File

@ -36,8 +36,7 @@ fn ping_options<'r>() -> impl Responder<'r> {
/// Returns the "application wide" Cors struct /// Returns the "application wide" Cors struct
fn cors_options() -> CorsOptions { fn cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
rocket_cors::CorsOptions { rocket_cors::CorsOptions {

View File

@ -53,7 +53,7 @@ fn on_response_wrapper(
// Not a CORS request // Not a CORS request
return Ok(()); return Ok(());
} }
Some(origin) => origin, Some(origin) => crate::to_origin(origin)?,
}; };
let result = request.local_cache(|| unreachable!("This should not be executed so late")); let result = request.local_cache(|| unreachable!("This should not be executed so late"));
@ -140,8 +140,7 @@ mod tests {
const CORS_ROOT: &'static str = "/my_cors"; const CORS_ROOT: &'static str = "/my_cors";
fn make_cors_options() -> Cors { fn make_cors_options() -> Cors {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
CorsOptions { CorsOptions {
allowed_origins, allowed_origins,

View File

@ -11,12 +11,9 @@ use rocket::{self, Outcome};
#[cfg(feature = "serialization")] #[cfg(feature = "serialization")]
use serde_derive::{Deserialize, Serialize}; use serde_derive::{Deserialize, Serialize};
use unicase::UniCase; use unicase::UniCase;
use url;
#[cfg(feature = "serialization")] #[cfg(feature = "serialization")]
use unicase_serde; use unicase_serde;
#[cfg(feature = "serialization")]
use url_serde;
/// A case insensitive header name /// A case insensitive header name
#[derive(Eq, PartialEq, Clone, Debug, Hash)] #[derive(Eq, PartialEq, Clone, Debug, Hash)]
@ -62,54 +59,55 @@ impl FromStr for HeaderFieldName {
/// A set of case insensitive header names /// A set of case insensitive header names
pub type HeaderFieldNamesSet = HashSet<HeaderFieldName>; pub type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
/// A wrapped `url::Url` to allow for deserialization /// The `Origin` request header used in CORS
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that `Origin` is passed in correctly.
#[derive(Eq, PartialEq, Clone, Hash, Debug)] #[derive(Eq, PartialEq, Clone, Hash, Debug)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
pub struct Url(#[cfg_attr(feature = "serialization", serde(with = "url_serde"))] url::Url); pub struct Origin(pub String);
impl fmt::Display for Url { impl FromStr for Origin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { type Err = !;
self.0.fmt(f)
fn from_str(input: &str) -> Result<Self, Self::Err> {
Ok(Origin(input.to_string()))
} }
} }
impl Deref for Url { impl Deref for Origin {
type Target = url::Url; type Target = str;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.0 &self.0
} }
} }
impl FromStr for Url { impl AsRef<str> for Origin {
type Err = url::ParseError; fn as_ref(&self) -> &str {
self
fn from_str(input: &str) -> Result<Self, Self::Err> {
let url = url::Url::from_str(input)?;
Ok(Url(url))
} }
} }
impl<'a, 'r> FromRequest<'a, 'r> for Url { impl fmt::Display for Origin {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.as_ref().fmt(f)
}
}
impl<'a, 'r> FromRequest<'a, 'r> for Origin {
type Error = crate::Error; type Error = crate::Error;
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> { fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
match request.headers().get_one("Origin") { match request.headers().get_one("Origin") {
Some(origin) => match Self::from_str(origin) { Some(origin) => {
Ok(origin) => Outcome::Success(origin), let Ok(origin) = Self::from_str(origin);
Err(e) => Outcome::Failure((Status::BadRequest, crate::Error::BadOrigin(e))), Outcome::Success(origin)
}, }
None => Outcome::Forward(()), None => Outcome::Forward(()),
} }
} }
} }
/// The `Origin` request header used in CORS
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that `Origin` is passed in correctly.
pub type Origin = Url;
/// The `Access-Control-Request-Method` request header /// The `Access-Control-Request-Method` request header
/// ///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
@ -201,17 +199,17 @@ mod tests {
#[test] #[test]
fn origin_header_conversion() { fn origin_header_conversion() {
let url = "https://foo.bar.xyz"; let url = "https://foo.bar.xyz";
let parsed = not_err!(Origin::from_str(url)); let Ok(parsed) = Origin::from_str(url);
let expected = not_err!(Url::from_str(url)); assert_eq!(parsed.as_ref(), url);
assert_eq!(parsed, expected);
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
let parsed = not_err!(Origin::from_str(url)); let Ok(parsed) = Origin::from_str(url);
let expected = not_err!(Url::from_str(url)); assert_eq!(parsed.as_ref(), url);
assert_eq!(parsed, expected);
// Validation is not done now
let url = "invalid_url"; let url = "invalid_url";
let _ = is_err!(Origin::from_str(url)); let Ok(parsed) = Origin::from_str(url);
assert_eq!(parsed.as_ref(), url);
} }
#[test] #[test]
@ -225,7 +223,7 @@ mod tests {
let outcome: request::Outcome<Origin, crate::Error> = let outcome: request::Outcome<Origin, crate::Error> =
FromRequest::from_request(request.inner()); FromRequest::from_request(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
assert_eq!("https://www.example.com/", parsed_header.as_str()); assert_eq!("https://www.example.com", parsed_header.as_ref());
} }
#[test] #[test]

View File

@ -267,6 +267,8 @@ See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/
intra_doc_link_resolution_failure intra_doc_link_resolution_failure
)] )]
#![doc(test(attr(allow(unused_variables), deny(warnings))))] #![doc(test(attr(allow(unused_variables), deny(warnings))))]
#![feature(never_type)]
#![feature(exhaustive_patterns)]
#[cfg(test)] #[cfg(test)]
#[macro_use] #[macro_use]
@ -276,7 +278,7 @@ mod fairing;
pub mod headers; pub mod headers;
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::{HashMap, HashSet}; use std::collections::HashSet;
use std::error; use std::error;
use std::fmt; use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
@ -293,7 +295,7 @@ use serde_derive::{Deserialize, Serialize};
use crate::headers::{ use crate::headers::{
AccessControlRequestHeaders, AccessControlRequestMethod, HeaderFieldName, HeaderFieldNamesSet, AccessControlRequestHeaders, AccessControlRequestMethod, HeaderFieldName, HeaderFieldNamesSet,
Origin, Url, Origin,
}; };
/// Errors during operations /// Errors during operations
@ -316,7 +318,7 @@ pub enum Error {
/// The request header `Access-Control-Request-Headers` is required but is missing. /// The request header `Access-Control-Request-Headers` is required but is missing.
MissingRequestHeaders, MissingRequestHeaders,
/// Origin is not allowed to make this request /// Origin is not allowed to make this request
OriginNotAllowed(String), OriginNotAllowed(url::Origin),
/// Requested method is not allowed /// Requested method is not allowed
MethodNotAllowed(String), MethodNotAllowed(String),
/// One or more headers requested are not allowed /// One or more headers requested are not allowed
@ -365,7 +367,7 @@ impl fmt::Display for Error {
"The request header `Access-Control-Request-Headers` \ "The request header `Access-Control-Request-Headers` \
is required but is missing") is required but is missing")
} }
Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", &origin), Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin.ascii_serialization()),
Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method), Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method),
Error::HeadersNotAllowed => write!(f, "Headers are not allowed"), Error::HeadersNotAllowed => write!(f, "Headers are not allowed"),
Error::CredentialsWithWildcardOrigin => { write!(f, Error::CredentialsWithWildcardOrigin => { write!(f,
@ -398,6 +400,12 @@ impl<'r> response::Responder<'r> for Error {
} }
} }
impl From<url::ParseError> for Error {
fn from(error: url::ParseError) -> Self {
Error::BadOrigin(error)
}
}
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
/// ///
/// `Default` is implemented for this enum and is `All`. /// `Default` is implemented for this enum and is `All`.
@ -523,31 +531,23 @@ mod method_serde {
/// use rocket_cors::AllowedOrigins; /// use rocket_cors::AllowedOrigins;
/// ///
/// let all_origins = AllowedOrigins::all(); /// let all_origins = AllowedOrigins::all();
/// let (some_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); /// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]);
/// assert!(failed_origins.is_empty());
/// ``` /// ```
pub type AllowedOrigins = AllOrSome<HashSet<Origin>>; pub type AllowedOrigins = AllOrSome<HashSet<Origin>>;
impl AllowedOrigins { impl AllowedOrigins {
/// Allows some origins /// Allows some origins
/// ///
/// Returns a tuple where the first element is the struct `AllowedOrigins`, /// Validation is not performed at this stage, but at a later stage.
/// and the second element pub fn some(urls: &[&str]) -> Self {
/// is a map of strings which failed to parse into URLs and their associated parse errors. AllOrSome::Some(
pub fn some(urls: &[&str]) -> (Self, HashMap<String, url::ParseError>) { urls.iter()
let (ok_set, error_map): (Vec<_>, Vec<_>) = urls .map(|s| {
.iter() let Ok(s) = FromStr::from_str(s);
.map(|s| (s.to_string(), Url::from_str(s))) s
.partition(|&(_, ref r)| r.is_ok()); })
.collect(),
let error_map = error_map )
.into_iter()
.map(|(s, r)| (s.to_string(), r.unwrap_err()))
.collect();
let ok_set = ok_set.into_iter().map(|(_, r)| r.unwrap()).collect();
(AllOrSome::Some(ok_set), error_map)
} }
/// Allows all origins /// Allows all origins
@ -646,7 +646,7 @@ impl AllowedHeaders {
/// { /// {
/// "allowed_origins": { /// "allowed_origins": {
/// "Some": [ /// "Some": [
/// "https://www.acme.com/" /// "https://www.acme.com"
/// ] /// ]
/// }, /// },
/// "allowed_methods": [ /// "allowed_methods": [
@ -819,9 +819,7 @@ impl CorsOptions {
0 0
} }
/// Validates if any of the settings are disallowed or incorrect /// Validates if any of the settings are disallowed, incorrect, or illegal
///
/// This is run during initial Fairing attachment
pub fn validate(&self) -> Result<(), Error> { pub fn validate(&self) -> Result<(), Error> {
if self.allowed_origins.is_all() && self.send_wildcard && self.allow_credentials { if self.allowed_origins.is_all() && self.send_wildcard && self.allow_credentials {
Err(Error::CredentialsWithWildcardOrigin)?; Err(Error::CredentialsWithWildcardOrigin)?;
@ -844,7 +842,7 @@ impl CorsOptions {
/// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`]. /// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`].
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub struct Cors { pub struct Cors {
pub(crate) allowed_origins: AllowedOrigins, pub(crate) allowed_origins: AllOrSome<HashSet<url::Origin>>,
pub(crate) allowed_methods: AllowedMethods, pub(crate) allowed_methods: AllowedMethods,
pub(crate) allowed_headers: AllOrSome<HashSet<HeaderFieldName>>, pub(crate) allowed_headers: AllOrSome<HashSet<HeaderFieldName>>,
pub(crate) allow_credentials: bool, pub(crate) allow_credentials: bool,
@ -859,8 +857,11 @@ impl Cors {
/// Create a `Cors` struct from a [`CorsOptions`] /// Create a `Cors` struct from a [`CorsOptions`]
pub fn from_options(options: &CorsOptions) -> Result<Self, Error> { pub fn from_options(options: &CorsOptions) -> Result<Self, Error> {
options.validate()?; options.validate()?;
let allowed_origins = parse_origins(&options.allowed_origins)?;
Ok(Cors { Ok(Cors {
allowed_origins: options.allowed_origins.clone(), allowed_origins,
allowed_methods: options.allowed_methods.clone(), allowed_methods: options.allowed_methods.clone(),
allowed_headers: options.allowed_headers.clone(), allowed_headers: options.allowed_headers.clone(),
allow_credentials: options.allow_credentials, allow_credentials: options.allow_credentials,
@ -929,7 +930,7 @@ impl Cors {
/// You can get this struct by using `Cors::validate_request` in an ad-hoc manner. /// You can get this struct by using `Cors::validate_request` in an ad-hoc manner.
#[derive(Eq, PartialEq, Debug)] #[derive(Eq, PartialEq, Debug)]
pub(crate) struct Response { pub(crate) struct Response {
allow_origin: Option<AllOrSome<Url>>, allow_origin: Option<AllOrSome<url::Origin>>,
allow_methods: HashSet<Method>, allow_methods: HashSet<Method>,
allow_headers: HeaderFieldNamesSet, allow_headers: HeaderFieldNamesSet,
allow_credentials: bool, allow_credentials: bool,
@ -953,7 +954,7 @@ impl Response {
} }
/// Consumes the `Response` and return an altered response with origin and `vary_origin` set /// Consumes the `Response` and return an altered response with origin and `vary_origin` set
fn origin(mut self, origin: &Url, vary_origin: bool) -> Self { fn origin(mut self, origin: &url::Origin, vary_origin: bool) -> Self {
self.allow_origin = Some(AllOrSome::Some(origin.clone())); self.allow_origin = Some(AllOrSome::Some(origin.clone()));
self.vary_origin = vary_origin; self.vary_origin = vary_origin;
self self
@ -1028,11 +1029,9 @@ impl Response {
Some(ref origin) => origin, Some(ref origin) => origin,
}; };
// Origin should be ASCII serialized
// c.f. https://html.spec.whatwg.org/multipage/origin.html#ascii-serialisation-of-an-origin
let origin = match *origin { let origin = match *origin {
AllOrSome::All => "*".to_string(), AllOrSome::All => "*".to_string(),
AllOrSome::Some(ref origin) => origin.origin().ascii_serialization(), AllOrSome::Some(ref origin) => origin.ascii_serialization(),
}; };
let _ = response.set_raw_header("Access-Control-Allow-Origin", origin); let _ = response.set_raw_header("Access-Control-Allow-Origin", origin);
@ -1261,11 +1260,29 @@ enum ValidationResult {
None, None,
/// Successful preflight request /// Successful preflight request
Preflight { Preflight {
origin: Origin, origin: url::Origin,
headers: Option<AccessControlRequestHeaders>, headers: Option<AccessControlRequestHeaders>,
}, },
/// Successful actual request /// Successful actual request
Request { origin: Origin }, Request { origin: url::Origin },
}
/// Convert a str to Origin
fn to_origin<S: AsRef<str>>(origin: S) -> Result<url::Origin, Error> {
// What to do about Opaque origins?
Ok(url::Url::parse(origin.as_ref())?.origin())
}
/// Parse and process allowed origins
fn parse_origins(origins: &AllowedOrigins) -> Result<AllOrSome<HashSet<url::Origin>>, Error> {
match origins {
AllOrSome::All => Ok(AllOrSome::All),
AllOrSome::Some(ref origins) => {
let parsed: Result<HashSet<url::Origin>, Error> =
origins.iter().map(to_origin).collect();
Ok(AllOrSome::Some(parsed?))
}
}
} }
/// Validates a request for CORS and returns a CORS Response /// Validates a request for CORS and returns a CORS Response
@ -1291,7 +1308,7 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
// Not a CORS request // Not a CORS request
return Ok(ValidationResult::None); return Ok(ValidationResult::None);
} }
Some(origin) => origin, Some(origin) => to_origin(origin)?,
}; };
// Check if the request verb is an OPTION or something else // Check if the request verb is an OPTION or something else
@ -1313,8 +1330,8 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
/// check if the requested origin is allowed. /// check if the requested origin is allowed.
/// Useful for pre-flight and during requests /// Useful for pre-flight and during requests
fn validate_origin( fn validate_origin(
origin: &Origin, origin: &url::Origin,
allowed_origins: &AllowedOrigins, allowed_origins: &AllOrSome<HashSet<url::Origin>>,
) -> Result<(), Error> { ) -> Result<(), Error> {
match *allowed_origins { match *allowed_origins {
// Always matching is acceptable since the list of origins can be unbounded. // Always matching is acceptable since the list of origins can be unbounded.
@ -1322,7 +1339,7 @@ fn validate_origin(
AllOrSome::Some(ref allowed_origins) => allowed_origins AllOrSome::Some(ref allowed_origins) => allowed_origins
.get(origin) .get(origin)
.and_then(|_| Some(())) .and_then(|_| Some(()))
.ok_or_else(|| Error::OriginNotAllowed(origin.to_string())), .ok_or_else(|| Error::OriginNotAllowed(origin.clone())),
} }
} }
@ -1392,7 +1409,7 @@ fn request_headers(request: &Request<'_>) -> Result<Option<AccessControlRequestH
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch)
fn preflight_validate( fn preflight_validate(
options: &Cors, options: &Cors,
origin: &Origin, origin: &url::Origin,
method: &Option<AccessControlRequestMethod>, method: &Option<AccessControlRequestMethod>,
headers: &Option<AccessControlRequestHeaders>, headers: &Option<AccessControlRequestHeaders>,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -1440,7 +1457,7 @@ fn preflight_validate(
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch).
fn preflight_response( fn preflight_response(
options: &Cors, options: &Cors,
origin: &Origin, origin: &url::Origin,
headers: Option<&AccessControlRequestHeaders>, headers: Option<&AccessControlRequestHeaders>,
) -> Response { ) -> Response {
let response = Response::new(); let response = Response::new();
@ -1511,7 +1528,7 @@ fn preflight_response(
/// This implementation references the /// This implementation references the
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests)
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch).
fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> { fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), Error> {
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation // Note: All header parse failures are dealt with in the `FromRequest` trait implementation
// 2. If the value of the Origin header is not a case-sensitive match for any of the values // 2. If the value of the Origin header is not a case-sensitive match for any of the values
@ -1528,7 +1545,7 @@ fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error>
/// This implementation references the /// This implementation references the
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests)
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch)
fn actual_request_response(options: &Cors, origin: &Origin) -> Response { fn actual_request_response(options: &Cors, origin: &url::Origin) -> Response {
let response = Response::new(); let response = Response::new();
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
@ -1628,8 +1645,7 @@ mod tests {
use crate::http::Method; use crate::http::Method;
fn make_cors_options() -> CorsOptions { fn make_cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
CorsOptions { CorsOptions {
allowed_origins, allowed_origins,
@ -1689,7 +1705,8 @@ mod tests {
#[test] #[test]
fn validate_origin_allows_all_origins() { fn validate_origin_allows_all_origins() {
let url = "https://www.example.com"; let url = "https://www.example.com";
let origin = Origin::from_str(url).unwrap(); let Ok(origin) = Origin::from_str(url);
let origin = not_err!(to_origin(&origin));
let allowed_origins = AllOrSome::All; let allowed_origins = AllOrSome::All;
not_err!(validate_origin(&origin, &allowed_origins)); not_err!(validate_origin(&origin, &allowed_origins));
@ -1698,9 +1715,11 @@ mod tests {
#[test] #[test]
fn validate_origin_allows_origin() { fn validate_origin_allows_origin() {
let url = "https://www.example.com"; let url = "https://www.example.com";
let origin = Origin::from_str(url).unwrap(); let Ok(origin) = Origin::from_str(url);
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); let origin = not_err!(to_origin(&origin));
assert!(failed_origins.is_empty()); let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
"https://www.example.com"
])));
not_err!(validate_origin(&origin, &allowed_origins)); not_err!(validate_origin(&origin, &allowed_origins));
} }
@ -1709,9 +1728,11 @@ mod tests {
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn validate_origin_rejects_invalid_origin() { fn validate_origin_rejects_invalid_origin() {
let url = "https://www.acme.com"; let url = "https://www.acme.com";
let origin = Origin::from_str(url).unwrap(); let Ok(origin) = Origin::from_str(url);
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); let origin = not_err!(to_origin(&origin));
assert!(failed_origins.is_empty()); let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
"https://www.example.com"
])));
validate_origin(&origin, &allowed_origins).unwrap(); validate_origin(&origin, &allowed_origins).unwrap();
} }
@ -1719,10 +1740,7 @@ mod tests {
#[test] #[test]
fn response_sets_allow_origin_without_vary_correctly() { fn response_sets_allow_origin_without_vary_correctly() {
let response = Response::new(); let response = Response::new();
let response = response.origin( let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
&FromStr::from_str("https://www.example.com").unwrap(),
false,
);
// Build response and check built response header // Build response and check built response header
let expected_header = vec!["https://www.example.com"]; let expected_header = vec!["https://www.example.com"];
@ -1739,8 +1757,7 @@ mod tests {
#[test] #[test]
fn response_sets_allow_origin_with_vary_correctly() { fn response_sets_allow_origin_with_vary_correctly() {
let response = Response::new(); let response = Response::new();
let response = let response = response.origin(&to_origin("https://www.example.com").unwrap(), true);
response.origin(&FromStr::from_str("https://www.example.com").unwrap(), true);
// Build response and check built response header // Build response and check built response header
let expected_header = vec!["https://www.example.com"]; let expected_header = vec!["https://www.example.com"];
@ -1770,9 +1787,10 @@ mod tests {
#[test] #[test]
fn response_sets_allow_origin_with_ascii_serialization() { fn response_sets_allow_origin_with_ascii_serialization() {
let response = Response::new(); let response = Response::new();
let response = response.origin(&FromStr::from_str("https://аpple.com").unwrap(), false); let response = response.origin(&to_origin("https://аpple.com").unwrap(), false);
// Build response and check built response header // Build response and check built response header
// This is "punycode"
let expected_header = vec!["https://xn--pple-43d.com"]; let expected_header = vec!["https://xn--pple-43d.com"];
let response = response.response(response::Response::new()); let response = response.response(response::Response::new());
let actual_header: Vec<_> = response let actual_header: Vec<_> = response
@ -1786,10 +1804,7 @@ mod tests {
fn response_sets_exposed_headers_correctly() { fn response_sets_exposed_headers_correctly() {
let headers = vec!["Bar", "Baz", "Foo"]; let headers = vec!["Bar", "Baz", "Foo"];
let response = Response::new(); let response = Response::new();
let response = response.origin( let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
&FromStr::from_str("https://www.example.com").unwrap(),
false,
);
let response = response.exposed_headers(&headers); let response = response.exposed_headers(&headers);
// Build response and check built response header // Build response and check built response header
@ -1811,10 +1826,7 @@ mod tests {
#[test] #[test]
fn response_sets_max_age_correctly() { fn response_sets_max_age_correctly() {
let response = Response::new(); let response = Response::new();
let response = response.origin( let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
&FromStr::from_str("https://www.example.com").unwrap(),
false,
);
let response = response.max_age(Some(42)); let response = response.max_age(Some(42));
@ -1828,10 +1840,7 @@ mod tests {
#[test] #[test]
fn response_does_not_set_max_age_when_none() { fn response_does_not_set_max_age_when_none() {
let response = Response::new(); let response = Response::new();
let response = response.origin( let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
&FromStr::from_str("https://www.example.com").unwrap(),
false,
);
let response = response.max_age(None); let response = response.max_age(None);
@ -1944,10 +1953,7 @@ mod tests {
.finalize(); .finalize();
let response = Response::new(); let response = Response::new();
let response = response.origin( let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
&FromStr::from_str("https://www.example.com").unwrap(),
false,
);
let response = response.response(original); let response = response.response(original);
// Check CORS header // Check CORS header
let expected_header = vec!["https://www.example.com"]; let expected_header = vec!["https://www.example.com"];
@ -2023,7 +2029,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight { let expected_result = ValidationResult::Preflight {
origin: FromStr::from_str("https://www.acme.com").unwrap(), origin: to_origin("https://www.acme.com").unwrap(),
// Checks that only a subset of allowed headers are returned // Checks that only a subset of allowed headers are returned
// -- i.e. whatever is requested for // -- i.e. whatever is requested for
headers: Some(FromStr::from_str("Authorization").unwrap()), headers: Some(FromStr::from_str("Authorization").unwrap()),
@ -2058,7 +2064,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight { let expected_result = ValidationResult::Preflight {
origin: FromStr::from_str("https://www.example.com").unwrap(), origin: to_origin("https://www.example.com").unwrap(),
headers: Some(FromStr::from_str("Authorization").unwrap()), headers: Some(FromStr::from_str("Authorization").unwrap()),
}; };
@ -2176,7 +2182,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request { let expected_result = ValidationResult::Request {
origin: FromStr::from_str("https://www.acme.com").unwrap(), origin: to_origin("https://www.acme.com").unwrap(),
}; };
assert_eq!(expected_result, result); assert_eq!(expected_result, result);
@ -2195,7 +2201,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request { let expected_result = ValidationResult::Request {
origin: FromStr::from_str("https://www.example.com").unwrap(), origin: to_origin("https://www.example.com").unwrap(),
}; };
assert_eq!(expected_result, result); assert_eq!(expected_result, result);
@ -2251,7 +2257,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false) .origin(&to_origin("https://www.acme.com").unwrap(), false)
.headers(&["Authorization"]) .headers(&["Authorization"])
.methods(&options.allowed_methods) .methods(&options.allowed_methods)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
@ -2291,7 +2297,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true) .origin(&to_origin("https://www.acme.com").unwrap(), true)
.headers(&["Authorization"]) .headers(&["Authorization"])
.methods(&options.allowed_methods) .methods(&options.allowed_methods)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
@ -2352,7 +2358,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false) .origin(&to_origin("https://www.acme.com").unwrap(), false)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]); .exposed_headers(&["Content-Type", "X-Custom"]);
@ -2375,7 +2381,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true) .origin(&to_origin("https://www.acme.com").unwrap(), true)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]); .exposed_headers(&["Content-Type", "X-Custom"]);

View File

@ -22,8 +22,7 @@ fn panicking_route() {
} }
fn make_cors() -> Cors { fn make_cors() -> Cors {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
CorsOptions { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,

View File

@ -60,8 +60,7 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo
} }
fn make_cors() -> cors::Cors { fn make_cors() -> cors::Cors {
let (allowed_origins, failed_origins) = cors::AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = cors::AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
cors::CorsOptions { cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,

View File

@ -55,7 +55,7 @@ fn request_headers_round_trip_smoke_test() {
.body() .body()
.and_then(|body| body.into_string()) .and_then(|body| body.into_string())
.expect("Non-empty body"); .expect("Non-empty body");
let expected_body = r#"https://foo.bar.xyz/ let expected_body = r#"https://foo.bar.xyz
GET GET
X-Ping, accept-language"#; X-Ping, accept-language"#;
assert_eq!(expected_body, body_str); assert_eq!(expected_body, body_str);

View File

@ -66,8 +66,7 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp
} }
fn make_cors_options() -> CorsOptions { fn make_cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
CorsOptions { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
@ -79,8 +78,7 @@ fn make_cors_options() -> CorsOptions {
} }
fn make_different_cors_options() -> CorsOptions { fn make_different_cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.example.com"]);
assert!(failed_origins.is_empty());
CorsOptions { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,

View File

@ -40,8 +40,7 @@ fn ping_options<'r>() -> impl Responder<'r> {
/// Returns the "application wide" Cors struct /// Returns the "application wide" Cors struct
fn cors_options() -> CorsOptions { fn cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
rocket_cors::CorsOptions { rocket_cors::CorsOptions {