From f9bffe77d6117f19194f5261fd4f8c71a8e3f56c Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Tue, 12 Mar 2019 09:58:51 +0800 Subject: [PATCH] Refactor Origins to better support additional use cases (#59) * Specify an internal structure for Cors * Use type alias * Refactor Origin validation * Separate out `Origin` * Add tests --- Cargo.toml | 3 +- examples/fairing.rs | 3 +- examples/guard.rs | 3 +- examples/json.rs | 3 +- examples/manual.rs | 3 +- examples/mix.rs | 3 +- src/fairing.rs | 3 +- src/headers.rs | 58 +++++----- src/lib.rs | 260 +++++++++++++++++++++++++++----------------- tests/fairing.rs | 3 +- tests/guard.rs | 3 +- tests/headers.rs | 2 +- tests/manual.rs | 6 +- tests/mix.rs | 3 +- 14 files changed, 198 insertions(+), 158 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a86bb2e..2ae7377 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ travis-ci = { repository = "lawliet89/rocket_cors" } default = ["serialization"] # Serialization and deserialization support for settings -serialization = ["serde", "serde_derive", "unicase_serde", "url_serde"] +serialization = ["serde", "serde_derive", "unicase_serde"] [dependencies] rocket = "0.4.0" @@ -30,7 +30,6 @@ url = "1.7.2" serde = { version = "1.0", optional = true } serde_derive = { version = "1.0", optional = true } unicase_serde = { version = "0.1.0", optional = true } -url_serde = { version = "0.2.0", optional = true } [dev-dependencies] hyper = "0.10" diff --git a/examples/fairing.rs b/examples/fairing.rs index c82a89e..393ce4f 100644 --- a/examples/fairing.rs +++ b/examples/fairing.rs @@ -12,8 +12,7 @@ fn cors<'a>() -> &'a str { } fn main() -> Result<(), Error> { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); // You can also deserialize this let cors = rocket_cors::CorsOptions { diff --git a/examples/guard.rs b/examples/guard.rs index 20d8705..2e507e6 100644 --- a/examples/guard.rs +++ b/examples/guard.rs @@ -36,8 +36,7 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> { } fn main() -> Result<(), Error> { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); // You can also deserialize this let cors = rocket_cors::CorsOptions { diff --git a/examples/json.rs b/examples/json.rs index e62be45..c79000f 100644 --- a/examples/json.rs +++ b/examples/json.rs @@ -13,8 +13,7 @@ fn main() { // The default demonstrates the "All" serialization of several of the settings let default: CorsOptions = Default::default(); - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let options = cors::CorsOptions { allowed_origins: allowed_origins, diff --git a/examples/manual.rs b/examples/manual.rs index 3bac6df..35b4274 100644 --- a/examples/manual.rs +++ b/examples/manual.rs @@ -59,8 +59,7 @@ fn owned_options<'r>() -> impl Responder<'r> { } fn cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions { diff --git a/examples/mix.rs b/examples/mix.rs index 217c61d..3a243e4 100644 --- a/examples/mix.rs +++ b/examples/mix.rs @@ -36,8 +36,7 @@ fn ping_options<'r>() -> impl Responder<'r> { /// Returns the "application wide" Cors struct fn cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions { diff --git a/src/fairing.rs b/src/fairing.rs index c94e32a..7f497dd 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -140,8 +140,7 @@ mod tests { const CORS_ROOT: &'static str = "/my_cors"; fn make_cors_options() -> Cors { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); CorsOptions { allowed_origins, diff --git a/src/headers.rs b/src/headers.rs index 7685dfd..6e819e9 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -11,12 +11,9 @@ use rocket::{self, Outcome}; #[cfg(feature = "serialization")] use serde_derive::{Deserialize, Serialize}; use unicase::UniCase; -use url; #[cfg(feature = "serialization")] use unicase_serde; -#[cfg(feature = "serialization")] -use url_serde; /// A case insensitive header name #[derive(Eq, PartialEq, Clone, Debug, Hash)] @@ -62,54 +59,48 @@ impl FromStr for HeaderFieldName { /// A set of case insensitive header names pub type HeaderFieldNamesSet = HashSet; -/// A wrapped `url::Url` to allow for deserialization +/// The `Origin` request header used in CORS +/// +/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) +/// to ensure that `Origin` is passed in correctly. #[derive(Eq, PartialEq, Clone, Hash, Debug)] -#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] -pub struct Url(#[cfg_attr(feature = "serialization", serde(with = "url_serde"))] url::Url); +pub struct Origin(pub url::Origin); -impl fmt::Display for Url { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) +impl FromStr for Origin { + type Err = crate::Error; + + fn from_str(input: &str) -> Result { + Ok(Origin(crate::to_origin(input)?)) } } -impl Deref for Url { - type Target = url::Url; +impl Deref for Origin { + type Target = url::Origin; fn deref(&self) -> &Self::Target { &self.0 } } -impl FromStr for Url { - type Err = url::ParseError; - - fn from_str(input: &str) -> Result { - let url = url::Url::from_str(input)?; - Ok(Url(url)) +impl fmt::Display for Origin { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.ascii_serialization()) } } -impl<'a, 'r> FromRequest<'a, 'r> for Url { +impl<'a, 'r> FromRequest<'a, 'r> for Origin { type Error = crate::Error; fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome { match request.headers().get_one("Origin") { Some(origin) => match Self::from_str(origin) { Ok(origin) => Outcome::Success(origin), - Err(e) => Outcome::Failure((Status::BadRequest, crate::Error::BadOrigin(e))), + Err(e) => Outcome::Failure((Status::BadRequest, e)), }, None => Outcome::Forward(()), } } } - -/// The `Origin` request header used in CORS -/// -/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) -/// to ensure that `Origin` is passed in correctly. -pub type Origin = Url; - /// The `Access-Control-Request-Method` request header /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) @@ -202,13 +193,13 @@ mod tests { fn origin_header_conversion() { let url = "https://foo.bar.xyz"; let parsed = not_err!(Origin::from_str(url)); - let expected = not_err!(Url::from_str(url)); - assert_eq!(parsed, expected); + assert_eq!(parsed.ascii_serialization(), url); - let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used + // this should never really be sent by a compliant user agent + let url = "https://foo.bar.xyz/path/somewhere"; let parsed = not_err!(Origin::from_str(url)); - let expected = not_err!(Url::from_str(url)); - assert_eq!(parsed, expected); + let expected = "https://foo.bar.xyz"; + assert_eq!(parsed.ascii_serialization(), expected); let url = "invalid_url"; let _ = is_err!(Origin::from_str(url)); @@ -225,7 +216,10 @@ mod tests { let outcome: request::Outcome = FromRequest::from_request(request.inner()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); - assert_eq!("https://www.example.com/", parsed_header.as_str()); + assert_eq!( + "https://www.example.com", + parsed_header.ascii_serialization() + ); } #[test] diff --git a/src/lib.rs b/src/lib.rs index 827340f..0378016 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -276,7 +276,7 @@ mod fairing; pub mod headers; use std::borrow::Cow; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::error; use std::fmt; use std::marker::PhantomData; @@ -293,7 +293,7 @@ use serde_derive::{Deserialize, Serialize}; use crate::headers::{ AccessControlRequestHeaders, AccessControlRequestMethod, HeaderFieldName, HeaderFieldNamesSet, - Origin, Url, + Origin, }; /// Errors during operations @@ -316,7 +316,7 @@ pub enum Error { /// The request header `Access-Control-Request-Headers` is required but is missing. MissingRequestHeaders, /// Origin is not allowed to make this request - OriginNotAllowed(String), + OriginNotAllowed(url::Origin), /// Requested method is not allowed MethodNotAllowed(String), /// One or more headers requested are not allowed @@ -365,7 +365,7 @@ impl fmt::Display for Error { "The request header `Access-Control-Request-Headers` \ is required but is missing") } - Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", &origin), + Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin.ascii_serialization()), Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method), Error::HeadersNotAllowed => write!(f, "Headers are not allowed"), Error::CredentialsWithWildcardOrigin => { write!(f, @@ -398,6 +398,12 @@ impl<'r> response::Responder<'r> for Error { } } +impl From for Error { + fn from(error: url::ParseError) -> Self { + Error::BadOrigin(error) + } +} + /// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// /// `Default` is implemented for this enum and is `All`. @@ -523,31 +529,16 @@ mod method_serde { /// use rocket_cors::AllowedOrigins; /// /// let all_origins = AllowedOrigins::all(); -/// let (some_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); -/// assert!(failed_origins.is_empty()); +/// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]); /// ``` -pub type AllowedOrigins = AllOrSome>; +pub type AllowedOrigins = AllOrSome>; impl AllowedOrigins { /// Allows some origins /// - /// Returns a tuple where the first element is the struct `AllowedOrigins`, - /// and the second element - /// is a map of strings which failed to parse into URLs and their associated parse errors. - pub fn some(urls: &[&str]) -> (Self, HashMap) { - let (ok_set, error_map): (Vec<_>, Vec<_>) = urls - .iter() - .map(|s| (s.to_string(), Url::from_str(s))) - .partition(|&(_, ref r)| r.is_ok()); - - let error_map = error_map - .into_iter() - .map(|(s, r)| (s.to_string(), r.unwrap_err())) - .collect(); - - let ok_set = ok_set.into_iter().map(|(_, r)| r.unwrap()).collect(); - - (AllOrSome::Some(ok_set), error_map) + /// Validation is not performed at this stage, but at a later stage. + pub fn some(urls: &[&str]) -> Self { + AllOrSome::Some(urls.iter().map(|s| s.to_string()).collect()) } /// Allows all origins @@ -646,7 +637,7 @@ impl AllowedHeaders { /// { /// "allowed_origins": { /// "Some": [ -/// "https://www.acme.com/" +/// "https://www.acme.com" /// ] /// }, /// "allowed_methods": [ @@ -714,7 +705,7 @@ pub struct CorsOptions { /// /// Defaults to `All`. #[cfg_attr(feature = "serialization", serde(default))] - pub allowed_headers: AllOrSome>, + pub allowed_headers: AllowedHeaders, /// Allows users to make authenticated requests. /// If true, injects the `Access-Control-Allow-Credentials` header in responses. /// This allows cookies and credentials to be submitted across domains. @@ -819,9 +810,7 @@ impl CorsOptions { 0 } - /// Validates if any of the settings are disallowed or incorrect - /// - /// This is run during initial Fairing attachment + /// Validates if any of the settings are disallowed, incorrect, or illegal pub fn validate(&self) -> Result<(), Error> { if self.allowed_origins.is_all() && self.send_wildcard && self.allow_credentials { Err(Error::CredentialsWithWildcardOrigin)?; @@ -842,22 +831,37 @@ impl CorsOptions { /// documentation at the [crate root](index.html) for usage information. /// /// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`]. -#[derive(Clone, Debug)] -pub struct Cors(CorsOptions); - -impl Deref for Cors { - type Target = CorsOptions; - - fn deref(&self) -> &Self::Target { - &self.0 - } +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Cors { + pub(crate) allowed_origins: AllOrSome>, + pub(crate) allowed_methods: AllowedMethods, + pub(crate) allowed_headers: AllOrSome>, + pub(crate) allow_credentials: bool, + pub(crate) expose_headers: HashSet, + pub(crate) max_age: Option, + pub(crate) send_wildcard: bool, + pub(crate) fairing_route_base: String, + pub(crate) fairing_route_rank: isize, } impl Cors { /// Create a `Cors` struct from a [`CorsOptions`] pub fn from_options(options: &CorsOptions) -> Result { options.validate()?; - Ok(Cors(options.clone())) + + let allowed_origins = parse_origins(&options.allowed_origins)?; + + Ok(Cors { + allowed_origins, + allowed_methods: options.allowed_methods.clone(), + allowed_headers: options.allowed_headers.clone(), + allow_credentials: options.allow_credentials, + expose_headers: options.expose_headers.clone(), + max_age: options.max_age, + send_wildcard: options.send_wildcard, + fairing_route_base: options.fairing_route_base.clone(), + fairing_route_rank: options.fairing_route_rank, + }) } /// Manually respond to a request with CORS checks and headers using an Owned `Cors`. @@ -917,7 +921,7 @@ impl Cors { /// You can get this struct by using `Cors::validate_request` in an ad-hoc manner. #[derive(Eq, PartialEq, Debug)] pub(crate) struct Response { - allow_origin: Option>, + allow_origin: Option>, allow_methods: HashSet, allow_headers: HeaderFieldNamesSet, allow_credentials: bool, @@ -941,7 +945,7 @@ impl Response { } /// Consumes the `Response` and return an altered response with origin and `vary_origin` set - fn origin(mut self, origin: &Url, vary_origin: bool) -> Self { + fn origin(mut self, origin: &url::Origin, vary_origin: bool) -> Self { self.allow_origin = Some(AllOrSome::Some(origin.clone())); self.vary_origin = vary_origin; self @@ -1016,11 +1020,9 @@ impl Response { Some(ref origin) => origin, }; - // Origin should be ASCII serialized - // c.f. https://html.spec.whatwg.org/multipage/origin.html#ascii-serialisation-of-an-origin let origin = match *origin { AllOrSome::All => "*".to_string(), - AllOrSome::Some(ref origin) => origin.origin().ascii_serialization(), + AllOrSome::Some(ref origin) => origin.ascii_serialization(), }; let _ = response.set_raw_header("Access-Control-Allow-Origin", origin); @@ -1249,11 +1251,29 @@ enum ValidationResult { None, /// Successful preflight request Preflight { - origin: Origin, + origin: url::Origin, headers: Option, }, /// Successful actual request - Request { origin: Origin }, + Request { origin: url::Origin }, +} + +/// Convert a str to Origin +fn to_origin>(origin: S) -> Result { + // What to do about Opaque origins? + Ok(url::Url::parse(origin.as_ref())?.origin()) +} + +/// Parse and process allowed origins +fn parse_origins(origins: &AllowedOrigins) -> Result>, Error> { + match origins { + AllOrSome::All => Ok(AllOrSome::All), + AllOrSome::Some(ref origins) => { + let parsed: Result, Error> = + origins.iter().map(to_origin).collect(); + Ok(AllOrSome::Some(parsed?)) + } + } } /// Validates a request for CORS and returns a CORS Response @@ -1288,11 +1308,16 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result { actual_request_validate(options, &origin)?; - Ok(ValidationResult::Request { origin }) + Ok(ValidationResult::Request { + origin: origin.deref().clone(), + }) } } } @@ -1301,8 +1326,8 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result>, + origin: &url::Origin, + allowed_origins: &AllOrSome>, ) -> Result<(), Error> { match *allowed_origins { // Always matching is acceptable since the list of origins can be unbounded. @@ -1310,14 +1335,14 @@ fn validate_origin( AllOrSome::Some(ref allowed_origins) => allowed_origins .get(origin) .and_then(|_| Some(())) - .ok_or_else(|| Error::OriginNotAllowed(origin.to_string())), + .ok_or_else(|| Error::OriginNotAllowed(origin.clone())), } } /// Validate allowed methods fn validate_allowed_method( method: &AccessControlRequestMethod, - allowed_methods: &HashSet, + allowed_methods: &AllowedMethods, ) -> Result<(), Error> { let &AccessControlRequestMethod(ref request_method) = method; if !allowed_methods.iter().any(|m| m == request_method) { @@ -1331,7 +1356,7 @@ fn validate_allowed_method( /// Validate allowed headers fn validate_allowed_headers( headers: &AccessControlRequestHeaders, - allowed_headers: &AllOrSome>, + allowed_headers: &AllowedHeaders, ) -> Result<(), Error> { let &AccessControlRequestHeaders(ref headers) = headers; @@ -1380,12 +1405,10 @@ fn request_headers(request: &Request<'_>) -> Result, headers: &Option, ) -> Result<(), Error> { - options.validate()?; // Fast-forward check for #7 - // Note: All header parse failures are dealt with in the `FromRequest` trait implementation // 2. If the value of the Origin header is not a case-sensitive match for any of the values @@ -1430,7 +1453,7 @@ fn preflight_validate( /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). fn preflight_response( options: &Cors, - origin: &Origin, + origin: &url::Origin, headers: Option<&AccessControlRequestHeaders>, ) -> Response { let response = Response::new(); @@ -1501,9 +1524,7 @@ fn preflight_response( /// This implementation references the /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). -fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> { - options.validate()?; - +fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), Error> { // Note: All header parse failures are dealt with in the `FromRequest` trait implementation // 2. If the value of the Origin header is not a case-sensitive match for any of the values @@ -1520,7 +1541,7 @@ fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> /// This implementation references the /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch) -fn actual_request_response(options: &Cors, origin: &Origin) -> Response { +fn actual_request_response(options: &Cors, origin: &url::Origin) -> Response { let response = Response::new(); // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, @@ -1620,8 +1641,7 @@ mod tests { use crate::http::Method; fn make_cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); CorsOptions { allowed_origins, @@ -1653,6 +1673,39 @@ mod tests { Client::new(rocket).expect("valid rocket instance") } + // `to_origin` tests + + #[test] + fn origin_is_parsed_properly() { + let url = "https://foo.bar.xyz"; + let parsed = not_err!(Origin::from_str(url)); + assert_eq!(parsed.ascii_serialization(), url); + } + + #[test] + fn origin_parsing_strips_paths() { + // this should never really be sent by a compliant user agent + let url = "https://foo.bar.xyz/path/somewhere"; + let parsed = not_err!(Origin::from_str(url)); + let expected = "https://foo.bar.xyz"; + assert_eq!(parsed.ascii_serialization(), expected); + } + + #[test] + #[should_panic(expected = "BadOrigin")] + fn origin_parsing_disallows_invalid_origins() { + let url = "invalid_url"; + let _ = Origin::from_str(url).unwrap(); + } + + #[test] + fn origin_parses_opaque_origins() { + let url = "blob://foobar"; + let parsed = not_err!(Origin::from_str(url)); + + assert!(!parsed.is_tuple()); + } + // CORS options test #[test] @@ -1681,7 +1734,7 @@ mod tests { #[test] fn validate_origin_allows_all_origins() { let url = "https://www.example.com"; - let origin = Origin::from_str(url).unwrap(); + let origin = not_err!(to_origin(&url)); let allowed_origins = AllOrSome::All; not_err!(validate_origin(&origin, &allowed_origins)); @@ -1690,20 +1743,40 @@ mod tests { #[test] fn validate_origin_allows_origin() { let url = "https://www.example.com"; - let origin = Origin::from_str(url).unwrap(); - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); - assert!(failed_origins.is_empty()); + let origin = not_err!(to_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ + "https://www.example.com" + ]))); not_err!(validate_origin(&origin, &allowed_origins)); } + #[test] + fn validate_origin_handles_punycode_properly() { + // Test a variety of scenarios where the Origin and settings are in punycode, or not + let cases = vec![ + ("https://аpple.com", "https://аpple.com"), + ("https://аpple.com", "https://xn--pple-43d.com"), + ("https://xn--pple-43d.com", "https://аpple.com"), + ("https://xn--pple-43d.com", "https://xn--pple-43d.com"), + ]; + + for (url, allowed_origin) in cases { + let origin = not_err!(to_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[allowed_origin]))); + + not_err!(validate_origin(&origin, &allowed_origins)); + } + } + #[test] #[should_panic(expected = "OriginNotAllowed")] fn validate_origin_rejects_invalid_origin() { let url = "https://www.acme.com"; - let origin = Origin::from_str(url).unwrap(); - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); - assert!(failed_origins.is_empty()); + let origin = not_err!(to_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ + "https://www.example.com" + ]))); validate_origin(&origin, &allowed_origins).unwrap(); } @@ -1711,10 +1784,7 @@ mod tests { #[test] fn response_sets_allow_origin_without_vary_correctly() { let response = Response::new(); - let response = response.origin( - &FromStr::from_str("https://www.example.com").unwrap(), - false, - ); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); // Build response and check built response header let expected_header = vec!["https://www.example.com"]; @@ -1731,8 +1801,7 @@ mod tests { #[test] fn response_sets_allow_origin_with_vary_correctly() { let response = Response::new(); - let response = - response.origin(&FromStr::from_str("https://www.example.com").unwrap(), true); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), true); // Build response and check built response header let expected_header = vec!["https://www.example.com"]; @@ -1762,9 +1831,10 @@ mod tests { #[test] fn response_sets_allow_origin_with_ascii_serialization() { let response = Response::new(); - let response = response.origin(&FromStr::from_str("https://аpple.com").unwrap(), false); + let response = response.origin(&to_origin("https://аpple.com").unwrap(), false); // Build response and check built response header + // This is "punycode" let expected_header = vec!["https://xn--pple-43d.com"]; let response = response.response(response::Response::new()); let actual_header: Vec<_> = response @@ -1778,10 +1848,7 @@ mod tests { fn response_sets_exposed_headers_correctly() { let headers = vec!["Bar", "Baz", "Foo"]; let response = Response::new(); - let response = response.origin( - &FromStr::from_str("https://www.example.com").unwrap(), - false, - ); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.exposed_headers(&headers); // Build response and check built response header @@ -1803,10 +1870,7 @@ mod tests { #[test] fn response_sets_max_age_correctly() { let response = Response::new(); - let response = response.origin( - &FromStr::from_str("https://www.example.com").unwrap(), - false, - ); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.max_age(Some(42)); @@ -1820,10 +1884,7 @@ mod tests { #[test] fn response_does_not_set_max_age_when_none() { let response = Response::new(); - let response = response.origin( - &FromStr::from_str("https://www.example.com").unwrap(), - false, - ); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.max_age(None); @@ -1936,10 +1997,7 @@ mod tests { .finalize(); let response = Response::new(); - let response = response.origin( - &FromStr::from_str("https://www.example.com").unwrap(), - false, - ); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.response(original); // Check CORS header let expected_header = vec!["https://www.example.com"]; @@ -2015,7 +2073,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Preflight { - origin: FromStr::from_str("https://www.acme.com").unwrap(), + origin: to_origin("https://www.acme.com").unwrap(), // Checks that only a subset of allowed headers are returned // -- i.e. whatever is requested for headers: Some(FromStr::from_str("Authorization").unwrap()), @@ -2050,7 +2108,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Preflight { - origin: FromStr::from_str("https://www.example.com").unwrap(), + origin: to_origin("https://www.example.com").unwrap(), headers: Some(FromStr::from_str("Authorization").unwrap()), }; @@ -2168,7 +2226,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Request { - origin: FromStr::from_str("https://www.acme.com").unwrap(), + origin: to_origin("https://www.acme.com").unwrap(), }; assert_eq!(expected_result, result); @@ -2187,7 +2245,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Request { - origin: FromStr::from_str("https://www.example.com").unwrap(), + origin: to_origin("https://www.example.com").unwrap(), }; assert_eq!(expected_result, result); @@ -2243,7 +2301,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false) + .origin(&to_origin("https://www.acme.com").unwrap(), false) .headers(&["Authorization"]) .methods(&options.allowed_methods) .credentials(options.allow_credentials) @@ -2283,7 +2341,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true) + .origin(&to_origin("https://www.acme.com").unwrap(), true) .headers(&["Authorization"]) .methods(&options.allowed_methods) .credentials(options.allow_credentials) @@ -2344,7 +2402,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false) + .origin(&to_origin("https://www.acme.com").unwrap(), false) .credentials(options.allow_credentials) .exposed_headers(&["Content-Type", "X-Custom"]); @@ -2367,7 +2425,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true) + .origin(&to_origin("https://www.acme.com").unwrap(), true) .credentials(options.allow_credentials) .exposed_headers(&["Content-Type", "X-Custom"]); diff --git a/tests/fairing.rs b/tests/fairing.rs index 5e46436..509ec9b 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -22,8 +22,7 @@ fn panicking_route() { } fn make_cors() -> Cors { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); CorsOptions { allowed_origins: allowed_origins, diff --git a/tests/guard.rs b/tests/guard.rs index cee87d3..eb30da4 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -60,8 +60,7 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo } fn make_cors() -> cors::Cors { - let (allowed_origins, failed_origins) = cors::AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = cors::AllowedOrigins::some(&["https://www.acme.com"]); cors::CorsOptions { allowed_origins: allowed_origins, diff --git a/tests/headers.rs b/tests/headers.rs index 825c3f5..c6b63f8 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -55,7 +55,7 @@ fn request_headers_round_trip_smoke_test() { .body() .and_then(|body| body.into_string()) .expect("Non-empty body"); - let expected_body = r#"https://foo.bar.xyz/ + let expected_body = r#"https://foo.bar.xyz GET X-Ping, accept-language"#; assert_eq!(expected_body, body_str); diff --git a/tests/manual.rs b/tests/manual.rs index cdbb463..27cb58d 100644 --- a/tests/manual.rs +++ b/tests/manual.rs @@ -66,8 +66,7 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp } fn make_cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); CorsOptions { allowed_origins: allowed_origins, @@ -79,8 +78,7 @@ fn make_cors_options() -> CorsOptions { } fn make_different_cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.example.com"]); CorsOptions { allowed_origins: allowed_origins, diff --git a/tests/mix.rs b/tests/mix.rs index 4f46663..87a969b 100644 --- a/tests/mix.rs +++ b/tests/mix.rs @@ -40,8 +40,7 @@ fn ping_options<'r>() -> impl Responder<'r> { /// Returns the "application wide" Cors struct fn cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions {