From d7e5153e2724112a9315d6a9d669cf10a50bb6bb Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Wed, 19 Dec 2018 11:08:30 +0800 Subject: [PATCH] Refactor Origin validation --- Cargo.toml | 3 +- examples/fairing.rs | 3 +- examples/guard.rs | 3 +- examples/json.rs | 3 +- examples/manual.rs | 3 +- examples/mix.rs | 3 +- src/fairing.rs | 5 +- src/headers.rs | 70 +++++++++--------- src/lib.rs | 174 +++++++++++++++++++++++--------------------- tests/fairing.rs | 3 +- tests/guard.rs | 3 +- tests/headers.rs | 2 +- tests/manual.rs | 6 +- tests/mix.rs | 3 +- 14 files changed, 138 insertions(+), 146 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a86bb2e..2ae7377 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ travis-ci = { repository = "lawliet89/rocket_cors" } default = ["serialization"] # Serialization and deserialization support for settings -serialization = ["serde", "serde_derive", "unicase_serde", "url_serde"] +serialization = ["serde", "serde_derive", "unicase_serde"] [dependencies] rocket = "0.4.0" @@ -30,7 +30,6 @@ url = "1.7.2" serde = { version = "1.0", optional = true } serde_derive = { version = "1.0", optional = true } unicase_serde = { version = "0.1.0", optional = true } -url_serde = { version = "0.2.0", optional = true } [dev-dependencies] hyper = "0.10" diff --git a/examples/fairing.rs b/examples/fairing.rs index c82a89e..393ce4f 100644 --- a/examples/fairing.rs +++ b/examples/fairing.rs @@ -12,8 +12,7 @@ fn cors<'a>() -> &'a str { } fn main() -> Result<(), Error> { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); // You can also deserialize this let cors = rocket_cors::CorsOptions { diff --git a/examples/guard.rs b/examples/guard.rs index 20d8705..2e507e6 100644 --- a/examples/guard.rs +++ b/examples/guard.rs @@ -36,8 +36,7 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> { } fn main() -> Result<(), Error> { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); // You can also deserialize this let cors = rocket_cors::CorsOptions { diff --git a/examples/json.rs b/examples/json.rs index e62be45..c79000f 100644 --- a/examples/json.rs +++ b/examples/json.rs @@ -13,8 +13,7 @@ fn main() { // The default demonstrates the "All" serialization of several of the settings let default: CorsOptions = Default::default(); - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let options = cors::CorsOptions { allowed_origins: allowed_origins, diff --git a/examples/manual.rs b/examples/manual.rs index 3bac6df..35b4274 100644 --- a/examples/manual.rs +++ b/examples/manual.rs @@ -59,8 +59,7 @@ fn owned_options<'r>() -> impl Responder<'r> { } fn cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions { diff --git a/examples/mix.rs b/examples/mix.rs index 217c61d..3a243e4 100644 --- a/examples/mix.rs +++ b/examples/mix.rs @@ -36,8 +36,7 @@ fn ping_options<'r>() -> impl Responder<'r> { /// Returns the "application wide" Cors struct fn cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions { diff --git a/src/fairing.rs b/src/fairing.rs index c94e32a..79be9c3 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -53,7 +53,7 @@ fn on_response_wrapper( // Not a CORS request return Ok(()); } - Some(origin) => origin, + Some(origin) => crate::to_origin(origin)?, }; let result = request.local_cache(|| unreachable!("This should not be executed so late")); @@ -140,8 +140,7 @@ mod tests { const CORS_ROOT: &'static str = "/my_cors"; fn make_cors_options() -> Cors { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); CorsOptions { allowed_origins, diff --git a/src/headers.rs b/src/headers.rs index 7685dfd..625c80a 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -11,12 +11,9 @@ use rocket::{self, Outcome}; #[cfg(feature = "serialization")] use serde_derive::{Deserialize, Serialize}; use unicase::UniCase; -use url; #[cfg(feature = "serialization")] use unicase_serde; -#[cfg(feature = "serialization")] -use url_serde; /// A case insensitive header name #[derive(Eq, PartialEq, Clone, Debug, Hash)] @@ -62,54 +59,55 @@ impl FromStr for HeaderFieldName { /// A set of case insensitive header names pub type HeaderFieldNamesSet = HashSet; -/// A wrapped `url::Url` to allow for deserialization +/// The `Origin` request header used in CORS +/// +/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) +/// to ensure that `Origin` is passed in correctly. #[derive(Eq, PartialEq, Clone, Hash, Debug)] #[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] -pub struct Url(#[cfg_attr(feature = "serialization", serde(with = "url_serde"))] url::Url); +pub struct Origin(pub String); -impl fmt::Display for Url { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) +impl FromStr for Origin { + type Err = !; + + fn from_str(input: &str) -> Result { + Ok(Origin(input.to_string())) } } -impl Deref for Url { - type Target = url::Url; +impl Deref for Origin { + type Target = str; fn deref(&self) -> &Self::Target { &self.0 } } -impl FromStr for Url { - type Err = url::ParseError; - - fn from_str(input: &str) -> Result { - let url = url::Url::from_str(input)?; - Ok(Url(url)) +impl AsRef for Origin { + fn as_ref(&self) -> &str { + self } } -impl<'a, 'r> FromRequest<'a, 'r> for Url { +impl fmt::Display for Origin { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.as_ref().fmt(f) + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for Origin { type Error = crate::Error; fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome { match request.headers().get_one("Origin") { - Some(origin) => match Self::from_str(origin) { - Ok(origin) => Outcome::Success(origin), - Err(e) => Outcome::Failure((Status::BadRequest, crate::Error::BadOrigin(e))), - }, + Some(origin) => { + let Ok(origin) = Self::from_str(origin); + Outcome::Success(origin) + } None => Outcome::Forward(()), } } } - -/// The `Origin` request header used in CORS -/// -/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) -/// to ensure that `Origin` is passed in correctly. -pub type Origin = Url; - /// The `Access-Control-Request-Method` request header /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) @@ -201,17 +199,17 @@ mod tests { #[test] fn origin_header_conversion() { let url = "https://foo.bar.xyz"; - let parsed = not_err!(Origin::from_str(url)); - let expected = not_err!(Url::from_str(url)); - assert_eq!(parsed, expected); + let Ok(parsed) = Origin::from_str(url); + assert_eq!(parsed.as_ref(), url); let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used - let parsed = not_err!(Origin::from_str(url)); - let expected = not_err!(Url::from_str(url)); - assert_eq!(parsed, expected); + let Ok(parsed) = Origin::from_str(url); + assert_eq!(parsed.as_ref(), url); + // Validation is not done now let url = "invalid_url"; - let _ = is_err!(Origin::from_str(url)); + let Ok(parsed) = Origin::from_str(url); + assert_eq!(parsed.as_ref(), url); } #[test] @@ -225,7 +223,7 @@ mod tests { let outcome: request::Outcome = FromRequest::from_request(request.inner()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); - assert_eq!("https://www.example.com/", parsed_header.as_str()); + assert_eq!("https://www.example.com", parsed_header.as_ref()); } #[test] diff --git a/src/lib.rs b/src/lib.rs index 8bc9935..081b13c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -267,6 +267,8 @@ See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/ intra_doc_link_resolution_failure )] #![doc(test(attr(allow(unused_variables), deny(warnings))))] +#![feature(never_type)] +#![feature(exhaustive_patterns)] #[cfg(test)] #[macro_use] @@ -276,7 +278,7 @@ mod fairing; pub mod headers; use std::borrow::Cow; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::error; use std::fmt; use std::marker::PhantomData; @@ -293,7 +295,7 @@ use serde_derive::{Deserialize, Serialize}; use crate::headers::{ AccessControlRequestHeaders, AccessControlRequestMethod, HeaderFieldName, HeaderFieldNamesSet, - Origin, Url, + Origin, }; /// Errors during operations @@ -316,7 +318,7 @@ pub enum Error { /// The request header `Access-Control-Request-Headers` is required but is missing. MissingRequestHeaders, /// Origin is not allowed to make this request - OriginNotAllowed(String), + OriginNotAllowed(url::Origin), /// Requested method is not allowed MethodNotAllowed(String), /// One or more headers requested are not allowed @@ -365,7 +367,7 @@ impl fmt::Display for Error { "The request header `Access-Control-Request-Headers` \ is required but is missing") } - Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", &origin), + Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin.ascii_serialization()), Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method), Error::HeadersNotAllowed => write!(f, "Headers are not allowed"), Error::CredentialsWithWildcardOrigin => { write!(f, @@ -398,6 +400,12 @@ impl<'r> response::Responder<'r> for Error { } } +impl From for Error { + fn from(error: url::ParseError) -> Self { + Error::BadOrigin(error) + } +} + /// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// /// `Default` is implemented for this enum and is `All`. @@ -523,31 +531,23 @@ mod method_serde { /// use rocket_cors::AllowedOrigins; /// /// let all_origins = AllowedOrigins::all(); -/// let (some_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); -/// assert!(failed_origins.is_empty()); +/// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]); /// ``` pub type AllowedOrigins = AllOrSome>; impl AllowedOrigins { /// Allows some origins /// - /// Returns a tuple where the first element is the struct `AllowedOrigins`, - /// and the second element - /// is a map of strings which failed to parse into URLs and their associated parse errors. - pub fn some(urls: &[&str]) -> (Self, HashMap) { - let (ok_set, error_map): (Vec<_>, Vec<_>) = urls - .iter() - .map(|s| (s.to_string(), Url::from_str(s))) - .partition(|&(_, ref r)| r.is_ok()); - - let error_map = error_map - .into_iter() - .map(|(s, r)| (s.to_string(), r.unwrap_err())) - .collect(); - - let ok_set = ok_set.into_iter().map(|(_, r)| r.unwrap()).collect(); - - (AllOrSome::Some(ok_set), error_map) + /// Validation is not performed at this stage, but at a later stage. + pub fn some(urls: &[&str]) -> Self { + AllOrSome::Some( + urls.iter() + .map(|s| { + let Ok(s) = FromStr::from_str(s); + s + }) + .collect(), + ) } /// Allows all origins @@ -646,7 +646,7 @@ impl AllowedHeaders { /// { /// "allowed_origins": { /// "Some": [ -/// "https://www.acme.com/" +/// "https://www.acme.com" /// ] /// }, /// "allowed_methods": [ @@ -819,9 +819,7 @@ impl CorsOptions { 0 } - /// Validates if any of the settings are disallowed or incorrect - /// - /// This is run during initial Fairing attachment + /// Validates if any of the settings are disallowed, incorrect, or illegal pub fn validate(&self) -> Result<(), Error> { if self.allowed_origins.is_all() && self.send_wildcard && self.allow_credentials { Err(Error::CredentialsWithWildcardOrigin)?; @@ -844,7 +842,7 @@ impl CorsOptions { /// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`]. #[derive(Clone, Debug, Eq, PartialEq)] pub struct Cors { - pub(crate) allowed_origins: AllowedOrigins, + pub(crate) allowed_origins: AllOrSome>, pub(crate) allowed_methods: AllowedMethods, pub(crate) allowed_headers: AllOrSome>, pub(crate) allow_credentials: bool, @@ -859,8 +857,11 @@ impl Cors { /// Create a `Cors` struct from a [`CorsOptions`] pub fn from_options(options: &CorsOptions) -> Result { options.validate()?; + + let allowed_origins = parse_origins(&options.allowed_origins)?; + Ok(Cors { - allowed_origins: options.allowed_origins.clone(), + allowed_origins, allowed_methods: options.allowed_methods.clone(), allowed_headers: options.allowed_headers.clone(), allow_credentials: options.allow_credentials, @@ -929,7 +930,7 @@ impl Cors { /// You can get this struct by using `Cors::validate_request` in an ad-hoc manner. #[derive(Eq, PartialEq, Debug)] pub(crate) struct Response { - allow_origin: Option>, + allow_origin: Option>, allow_methods: HashSet, allow_headers: HeaderFieldNamesSet, allow_credentials: bool, @@ -953,7 +954,7 @@ impl Response { } /// Consumes the `Response` and return an altered response with origin and `vary_origin` set - fn origin(mut self, origin: &Url, vary_origin: bool) -> Self { + fn origin(mut self, origin: &url::Origin, vary_origin: bool) -> Self { self.allow_origin = Some(AllOrSome::Some(origin.clone())); self.vary_origin = vary_origin; self @@ -1028,11 +1029,9 @@ impl Response { Some(ref origin) => origin, }; - // Origin should be ASCII serialized - // c.f. https://html.spec.whatwg.org/multipage/origin.html#ascii-serialisation-of-an-origin let origin = match *origin { AllOrSome::All => "*".to_string(), - AllOrSome::Some(ref origin) => origin.origin().ascii_serialization(), + AllOrSome::Some(ref origin) => origin.ascii_serialization(), }; let _ = response.set_raw_header("Access-Control-Allow-Origin", origin); @@ -1261,11 +1260,29 @@ enum ValidationResult { None, /// Successful preflight request Preflight { - origin: Origin, + origin: url::Origin, headers: Option, }, /// Successful actual request - Request { origin: Origin }, + Request { origin: url::Origin }, +} + +/// Convert a str to Origin +fn to_origin>(origin: S) -> Result { + // What to do about Opaque origins? + Ok(url::Url::parse(origin.as_ref())?.origin()) +} + +/// Parse and process allowed origins +fn parse_origins(origins: &AllowedOrigins) -> Result>, Error> { + match origins { + AllOrSome::All => Ok(AllOrSome::All), + AllOrSome::Some(ref origins) => { + let parsed: Result, Error> = + origins.iter().map(to_origin).collect(); + Ok(AllOrSome::Some(parsed?)) + } + } } /// Validates a request for CORS and returns a CORS Response @@ -1291,7 +1308,7 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result origin, + Some(origin) => to_origin(origin)?, }; // Check if the request verb is an OPTION or something else @@ -1313,8 +1330,8 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result>, ) -> Result<(), Error> { match *allowed_origins { // Always matching is acceptable since the list of origins can be unbounded. @@ -1322,7 +1339,7 @@ fn validate_origin( AllOrSome::Some(ref allowed_origins) => allowed_origins .get(origin) .and_then(|_| Some(())) - .ok_or_else(|| Error::OriginNotAllowed(origin.to_string())), + .ok_or_else(|| Error::OriginNotAllowed(origin.clone())), } } @@ -1392,7 +1409,7 @@ fn request_headers(request: &Request<'_>) -> Result, headers: &Option, ) -> Result<(), Error> { @@ -1440,7 +1457,7 @@ fn preflight_validate( /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). fn preflight_response( options: &Cors, - origin: &Origin, + origin: &url::Origin, headers: Option<&AccessControlRequestHeaders>, ) -> Response { let response = Response::new(); @@ -1511,7 +1528,7 @@ fn preflight_response( /// This implementation references the /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). -fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> { +fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), Error> { // Note: All header parse failures are dealt with in the `FromRequest` trait implementation // 2. If the value of the Origin header is not a case-sensitive match for any of the values @@ -1528,7 +1545,7 @@ fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> /// This implementation references the /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch) -fn actual_request_response(options: &Cors, origin: &Origin) -> Response { +fn actual_request_response(options: &Cors, origin: &url::Origin) -> Response { let response = Response::new(); // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, @@ -1628,8 +1645,7 @@ mod tests { use crate::http::Method; fn make_cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); CorsOptions { allowed_origins, @@ -1689,7 +1705,8 @@ mod tests { #[test] fn validate_origin_allows_all_origins() { let url = "https://www.example.com"; - let origin = Origin::from_str(url).unwrap(); + let Ok(origin) = Origin::from_str(url); + let origin = not_err!(to_origin(&origin)); let allowed_origins = AllOrSome::All; not_err!(validate_origin(&origin, &allowed_origins)); @@ -1698,9 +1715,11 @@ mod tests { #[test] fn validate_origin_allows_origin() { let url = "https://www.example.com"; - let origin = Origin::from_str(url).unwrap(); - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); - assert!(failed_origins.is_empty()); + let Ok(origin) = Origin::from_str(url); + let origin = not_err!(to_origin(&origin)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ + "https://www.example.com" + ]))); not_err!(validate_origin(&origin, &allowed_origins)); } @@ -1709,9 +1728,11 @@ mod tests { #[should_panic(expected = "OriginNotAllowed")] fn validate_origin_rejects_invalid_origin() { let url = "https://www.acme.com"; - let origin = Origin::from_str(url).unwrap(); - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); - assert!(failed_origins.is_empty()); + let Ok(origin) = Origin::from_str(url); + let origin = not_err!(to_origin(&origin)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ + "https://www.example.com" + ]))); validate_origin(&origin, &allowed_origins).unwrap(); } @@ -1719,10 +1740,7 @@ mod tests { #[test] fn response_sets_allow_origin_without_vary_correctly() { let response = Response::new(); - let response = response.origin( - &FromStr::from_str("https://www.example.com").unwrap(), - false, - ); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); // Build response and check built response header let expected_header = vec!["https://www.example.com"]; @@ -1739,8 +1757,7 @@ mod tests { #[test] fn response_sets_allow_origin_with_vary_correctly() { let response = Response::new(); - let response = - response.origin(&FromStr::from_str("https://www.example.com").unwrap(), true); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), true); // Build response and check built response header let expected_header = vec!["https://www.example.com"]; @@ -1770,9 +1787,10 @@ mod tests { #[test] fn response_sets_allow_origin_with_ascii_serialization() { let response = Response::new(); - let response = response.origin(&FromStr::from_str("https://аpple.com").unwrap(), false); + let response = response.origin(&to_origin("https://аpple.com").unwrap(), false); // Build response and check built response header + // This is "punycode" let expected_header = vec!["https://xn--pple-43d.com"]; let response = response.response(response::Response::new()); let actual_header: Vec<_> = response @@ -1786,10 +1804,7 @@ mod tests { fn response_sets_exposed_headers_correctly() { let headers = vec!["Bar", "Baz", "Foo"]; let response = Response::new(); - let response = response.origin( - &FromStr::from_str("https://www.example.com").unwrap(), - false, - ); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.exposed_headers(&headers); // Build response and check built response header @@ -1811,10 +1826,7 @@ mod tests { #[test] fn response_sets_max_age_correctly() { let response = Response::new(); - let response = response.origin( - &FromStr::from_str("https://www.example.com").unwrap(), - false, - ); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.max_age(Some(42)); @@ -1828,10 +1840,7 @@ mod tests { #[test] fn response_does_not_set_max_age_when_none() { let response = Response::new(); - let response = response.origin( - &FromStr::from_str("https://www.example.com").unwrap(), - false, - ); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.max_age(None); @@ -1944,10 +1953,7 @@ mod tests { .finalize(); let response = Response::new(); - let response = response.origin( - &FromStr::from_str("https://www.example.com").unwrap(), - false, - ); + let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.response(original); // Check CORS header let expected_header = vec!["https://www.example.com"]; @@ -2023,7 +2029,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Preflight { - origin: FromStr::from_str("https://www.acme.com").unwrap(), + origin: to_origin("https://www.acme.com").unwrap(), // Checks that only a subset of allowed headers are returned // -- i.e. whatever is requested for headers: Some(FromStr::from_str("Authorization").unwrap()), @@ -2058,7 +2064,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Preflight { - origin: FromStr::from_str("https://www.example.com").unwrap(), + origin: to_origin("https://www.example.com").unwrap(), headers: Some(FromStr::from_str("Authorization").unwrap()), }; @@ -2176,7 +2182,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Request { - origin: FromStr::from_str("https://www.acme.com").unwrap(), + origin: to_origin("https://www.acme.com").unwrap(), }; assert_eq!(expected_result, result); @@ -2195,7 +2201,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Request { - origin: FromStr::from_str("https://www.example.com").unwrap(), + origin: to_origin("https://www.example.com").unwrap(), }; assert_eq!(expected_result, result); @@ -2251,7 +2257,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false) + .origin(&to_origin("https://www.acme.com").unwrap(), false) .headers(&["Authorization"]) .methods(&options.allowed_methods) .credentials(options.allow_credentials) @@ -2291,7 +2297,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true) + .origin(&to_origin("https://www.acme.com").unwrap(), true) .headers(&["Authorization"]) .methods(&options.allowed_methods) .credentials(options.allow_credentials) @@ -2352,7 +2358,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false) + .origin(&to_origin("https://www.acme.com").unwrap(), false) .credentials(options.allow_credentials) .exposed_headers(&["Content-Type", "X-Custom"]); @@ -2375,7 +2381,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true) + .origin(&to_origin("https://www.acme.com").unwrap(), true) .credentials(options.allow_credentials) .exposed_headers(&["Content-Type", "X-Custom"]); diff --git a/tests/fairing.rs b/tests/fairing.rs index 5e46436..509ec9b 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -22,8 +22,7 @@ fn panicking_route() { } fn make_cors() -> Cors { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); CorsOptions { allowed_origins: allowed_origins, diff --git a/tests/guard.rs b/tests/guard.rs index cee87d3..eb30da4 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -60,8 +60,7 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo } fn make_cors() -> cors::Cors { - let (allowed_origins, failed_origins) = cors::AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = cors::AllowedOrigins::some(&["https://www.acme.com"]); cors::CorsOptions { allowed_origins: allowed_origins, diff --git a/tests/headers.rs b/tests/headers.rs index 825c3f5..c6b63f8 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -55,7 +55,7 @@ fn request_headers_round_trip_smoke_test() { .body() .and_then(|body| body.into_string()) .expect("Non-empty body"); - let expected_body = r#"https://foo.bar.xyz/ + let expected_body = r#"https://foo.bar.xyz GET X-Ping, accept-language"#; assert_eq!(expected_body, body_str); diff --git a/tests/manual.rs b/tests/manual.rs index cdbb463..27cb58d 100644 --- a/tests/manual.rs +++ b/tests/manual.rs @@ -66,8 +66,7 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp } fn make_cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); CorsOptions { allowed_origins: allowed_origins, @@ -79,8 +78,7 @@ fn make_cors_options() -> CorsOptions { } fn make_different_cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.example.com"]); CorsOptions { allowed_origins: allowed_origins, diff --git a/tests/mix.rs b/tests/mix.rs index 4f46663..87a969b 100644 --- a/tests/mix.rs +++ b/tests/mix.rs @@ -40,8 +40,7 @@ fn ping_options<'r>() -> impl Responder<'r> { /// Returns the "application wide" Cors struct fn cors_options() -> CorsOptions { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); + let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions {