diff --git a/Cargo.toml b/Cargo.toml index 2ae7377..f45ac48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rocket_cors" -version = "0.4.0" +version = "0.5.0" license = "MIT/Apache-2.0" authors = ["Yong Wen Chua "] description = "Cross-origin resource sharing (CORS) for Rocket.rs applications" @@ -21,6 +21,7 @@ default = ["serialization"] serialization = ["serde", "serde_derive", "unicase_serde"] [dependencies] +regex = "1.1" rocket = "0.4.0" log = "0.3" unicase = "2.0" diff --git a/README.md b/README.md index d63cdf3..5d2c2d0 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ # rocket_cors [![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors) -[![Dependency Status](https://dependencyci.com/github/lawliet89/rocket_cors/badge)](https://dependencyci.com/github/lawliet89/rocket_cors) [![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors) [![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors) @@ -31,7 +30,7 @@ work, but they are subject to the minimum that Rocket sets. Add the following to Cargo.toml: ```toml -rocket_cors = "0.4.0" +rocket_cors = "0.5.0" ``` To use the latest `master` branch, for example: diff --git a/examples/fairing.rs b/examples/fairing.rs index 393ce4f..814aa3c 100644 --- a/examples/fairing.rs +++ b/examples/fairing.rs @@ -12,11 +12,11 @@ fn cors<'a>() -> &'a str { } fn main() -> Result<(), Error> { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this let cors = rocket_cors::CorsOptions { - allowed_origins: allowed_origins, + allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, diff --git a/examples/guard.rs b/examples/guard.rs index 2e507e6..b8faa39 100644 --- a/examples/guard.rs +++ b/examples/guard.rs @@ -36,11 +36,11 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> { } fn main() -> Result<(), Error> { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this let cors = rocket_cors::CorsOptions { - allowed_origins: allowed_origins, + allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, diff --git a/examples/json.rs b/examples/json.rs index c79000f..5cee758 100644 --- a/examples/json.rs +++ b/examples/json.rs @@ -13,10 +13,10 @@ fn main() { // The default demonstrates the "All" serialization of several of the settings let default: CorsOptions = Default::default(); - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); let options = cors::CorsOptions { - allowed_origins: allowed_origins, + allowed_origins, allowed_methods: vec![Method::Get, Method::Post, Method::Delete] .into_iter() .map(From::from) diff --git a/examples/manual.rs b/examples/manual.rs index 35b4274..db83c25 100644 --- a/examples/manual.rs +++ b/examples/manual.rs @@ -59,11 +59,11 @@ fn owned_options<'r>() -> impl Responder<'r> { } fn cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions { - allowed_origins: allowed_origins, + allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, diff --git a/examples/mix.rs b/examples/mix.rs index 3a243e4..3f8d350 100644 --- a/examples/mix.rs +++ b/examples/mix.rs @@ -36,11 +36,11 @@ fn ping_options<'r>() -> impl Responder<'r> { /// Returns the "application wide" Cors struct fn cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions { - allowed_origins: allowed_origins, + allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, diff --git a/src/fairing.rs b/src/fairing.rs index 7f497dd..3d96adc 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -63,6 +63,7 @@ fn on_response_wrapper( return Ok(()); } + let origin = origin.to_string(); let cors_response = if request.method() == http::Method::Options { let headers = request_headers(request)?; preflight_response(options, &origin, headers.as_ref()) @@ -137,10 +138,10 @@ mod tests { use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions}; - const CORS_ROOT: &'static str = "/my_cors"; + const CORS_ROOT: &str = "/my_cors"; fn make_cors_options() -> Cors { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); CorsOptions { allowed_origins, diff --git a/src/headers.rs b/src/headers.rs index 6e819e9..6e0a099 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -63,28 +63,51 @@ pub type HeaderFieldNamesSet = HashSet; /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// to ensure that `Origin` is passed in correctly. +/// +/// Reference: [Mozilla](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) #[derive(Eq, PartialEq, Clone, Hash, Debug)] -pub struct Origin(pub url::Origin); +pub enum Origin { + /// A `null` Origin + Null, + /// A well-formed origin that was parsed by [`url::Url::origin`] + Parsed(url::Origin), +} + +impl Origin { + /// Perform an + /// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin) + /// of this origin. + pub fn ascii_serialization(&self) -> String { + self.to_string() + } + + /// Returns whether the origin was parsed as non-opaque + pub fn is_tuple(&self) -> bool { + match self { + Origin::Null => false, + Origin::Parsed(ref parsed) => parsed.is_tuple(), + } + } +} impl FromStr for Origin { type Err = crate::Error; fn from_str(input: &str) -> Result { - Ok(Origin(crate::to_origin(input)?)) - } -} - -impl Deref for Origin { - type Target = url::Origin; - - fn deref(&self) -> &Self::Target { - &self.0 + if input.to_lowercase() == "null" { + Ok(Origin::Null) + } else { + Ok(Origin::Parsed(crate::to_origin(input)?)) + } } } impl fmt::Display for Origin { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.ascii_serialization()) + match self { + Origin::Null => write!(f, "null"), + Origin::Parsed(ref parsed) => write!(f, "{}", parsed.ascii_serialization()), + } } } @@ -195,6 +218,10 @@ mod tests { let parsed = not_err!(Origin::from_str(url)); assert_eq!(parsed.ascii_serialization(), url); + let url = "https://foo.bar.xyz:1234"; + let parsed = not_err!(Origin::from_str(url)); + assert_eq!(parsed.ascii_serialization(), url); + // this should never really be sent by a compliant user agent let url = "https://foo.bar.xyz/path/somewhere"; let parsed = not_err!(Origin::from_str(url)); @@ -239,7 +266,7 @@ mod tests { ); let method = "INVALID"; - let _ = is_err!(AccessControlRequestMethod::from_str(method)); + is_err!(AccessControlRequestMethod::from_str(method)); } #[test] @@ -281,7 +308,7 @@ mod tests { let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); let AccessControlRequestHeaders(parsed_headers) = parsed_header; let mut parsed_headers: Vec = - parsed_headers.iter().map(|s| s.to_string()).collect(); + parsed_headers.iter().map(ToString::to_string).collect(); parsed_headers.sort(); assert_eq!( vec!["accept-language".to_string(), "date".to_string()], diff --git a/src/lib.rs b/src/lib.rs index 0378016..cc3fb17 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,5 @@ /*! [![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors) -[![Dependency Status](https://dependencyci.com/github/lawliet89/rocket_cors/badge)](https://dependencyci.com/github/lawliet89/rocket_cors) [![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors) [![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors) @@ -30,7 +29,7 @@ might work, but they are subject to the minimum that Rocket sets. Add the following to Cargo.toml: ```toml -rocket_cors = "0.4.0" +rocket_cors = "0.5.0" ``` To use the latest `master` branch, for example: @@ -46,7 +45,7 @@ the [`CorsOptions`] struct that is described below. If you would like to disable change your `Cargo.toml` to: ```toml -rocket_cors = { version = "0.4.0", default-features = false } +rocket_cors = { version = "0.5.0", default-features = false } ``` ## Usage @@ -63,9 +62,9 @@ Each of the examples can be run off the repository via `cargo run --example xxx` ### `CorsOptions` Struct -The [`CorsOptiopns`] struct contains the settings for CORS requests to be validated +The [`CorsOptions`] struct contains the settings for CORS requests to be validated and for responses to be generated. Defaults are defined for every field in the struct, and -are documented on the [`CorsOptiopns`] page. You can also deserialize +are documented on the [`CorsOptions`] page. You can also deserialize the struct from some format like JSON, YAML or TOML when the default `serialization` feature is enabled. @@ -284,6 +283,7 @@ use std::ops::Deref; use std::str::FromStr; use ::log::{error, info, log}; +use regex::RegexSet; use rocket::http::{self, Status}; use rocket::request::{FromRequest, Request}; use rocket::response; @@ -309,6 +309,8 @@ pub enum Error { MissingOrigin, /// The HTTP request header `Origin` could not be parsed correctly. BadOrigin(url::ParseError), + /// The configured Allowed Origin is opaque and cannot be parsed. + OpaqueAllowedOrigin(String), /// The request header `Access-Control-Request-Method` is required but is missing MissingRequestMethod, /// The request header `Access-Control-Request-Method` has an invalid value @@ -316,9 +318,11 @@ pub enum Error { /// The request header `Access-Control-Request-Headers` is required but is missing. MissingRequestHeaders, /// Origin is not allowed to make this request - OriginNotAllowed(url::Origin), + OriginNotAllowed(String), /// Requested method is not allowed MethodNotAllowed(String), + /// A regular expression compilation error + RegexError(regex::Error), /// One or more headers requested are not allowed HeadersNotAllowed, /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C @@ -365,7 +369,7 @@ impl fmt::Display for Error { "The request header `Access-Control-Request-Headers` \ is required but is missing") } - Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin.ascii_serialization()), + Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin), Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method), Error::HeadersNotAllowed => write!(f, "Headers are not allowed"), Error::CredentialsWithWildcardOrigin => { write!(f, @@ -377,7 +381,9 @@ impl fmt::Display for Error { } Error::MissingInjectedHeader => write!(f, "The `on_response` handler of Fairing could not find the injected header from the \ - Request. Either some other fairing has removed it, or this is a bug.") + Request. Either some other fairing has removed it, or this is a bug."), + Error::OpaqueAllowedOrigin(ref origin) => write!(f, "The configured Origin '{}' not have a parsable Origin. Use a regex instead.", origin), + Error::RegexError(ref e) => write!(f, "{}", e), } } } @@ -404,6 +410,12 @@ impl From for Error { } } +impl From for Error { + fn from(error: regex::Error) -> Self { + Error::RegexError(error) + } +} + /// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// /// `Default` is implemented for this enum and is `All`. @@ -529,16 +541,47 @@ mod method_serde { /// use rocket_cors::AllowedOrigins; /// /// let all_origins = AllowedOrigins::all(); -/// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]); +/// let some_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); +/// let null_origins = AllowedOrigins::some_null(); /// ``` -pub type AllowedOrigins = AllOrSome>; +pub type AllowedOrigins = AllOrSome; impl AllowedOrigins { /// Allows some origins /// /// Validation is not performed at this stage, but at a later stage. - pub fn some(urls: &[&str]) -> Self { - AllOrSome::Some(urls.iter().map(|s| s.to_string()).collect()) + pub fn some, S2: AsRef>(exact: &[S1], regex: &[S2]) -> Self { + AllOrSome::Some(Origins { + exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()), + regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()), + ..Default::default() + }) + } + + /// Allows some _exact_ origins + /// + /// Validation is not performed at this stage, but at a later stage. + pub fn some_exact>(exact: &[S]) -> Self { + AllOrSome::Some(Origins { + exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()), + ..Default::default() + }) + } + + /// Allow some __regex__ origins + pub fn some_regex>(regex: &[S]) -> Self { + AllOrSome::Some(Origins { + regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()), + ..Default::default() + }) + } + + /// Allow some `null` origins + pub fn some_null() -> Self { + AllOrSome::Some(Origins { + allow_null: true, + ..Default::default() + }) } /// Allows all origins @@ -547,6 +590,105 @@ impl AllowedOrigins { } } +/// Origins that are allowed to make CORS requests. +/// +/// An origin is defined according to the defined +/// [syntax](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin). +/// +/// Origins can be specified as an exact match or using regex. +/// +/// These Origins are specified as logical `ORs`. That is, if any of the origins match, the entire +/// request is considered to be valid. +#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serialization", serde(default))] +pub struct Origins { + /// Whether null origins are accepted + #[cfg_attr(feature = "serialization", serde(default))] + pub allow_null: bool, + /// Origins that must be matched exactly as provided. + /// + /// These __must__ be valid URL strings that will be parsed and validated when + /// creating [`Cors`]. + #[cfg_attr(feature = "serialization", serde(default))] + pub exact: Option>, + /// Origins that will be matched via __any__ regex in this list. + /// + /// These __must__ be valid Regex that will be parsed and validated when creating [`Cors`]. + /// + /// The regex will be matched according to the + /// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin) + /// of the incoming Origin. + /// + /// For more information on the syntax of Regex in Rust, see the + /// [documentation](https://docs.rs/regex). + #[cfg_attr(feature = "serialization", serde(default))] + pub regex: Option>, +} + +/// Parsed set of configured allowed origins +#[derive(Clone, Debug)] +pub(crate) struct ParsedAllowedOrigins { + pub allow_null: bool, + pub exact: HashSet, + pub regex: Option, +} + +impl ParsedAllowedOrigins { + fn parse(origins: &Origins) -> Result { + let exact: Result, Error> = match &origins.exact { + Some(exact) => exact.iter().map(|url| to_origin(url.as_str())).collect(), + None => Ok(Default::default()), + }; + let exact = exact?; + + // Let's check if any of them is Opaque + exact.iter().try_for_each(|url| { + if !url.is_tuple() { + Err(Error::OpaqueAllowedOrigin(url.ascii_serialization())) + } else { + Ok(()) + } + })?; + + let regex = match &origins.regex { + None => None, + Some(ref regex) => Some(RegexSet::new(regex)?), + }; + + Ok(Self { + allow_null: origins.allow_null, + exact, + regex, + }) + } + + fn verify(&self, origin: &Origin) -> bool { + info_!("Verifying origin: {}", origin); + match origin { + Origin::Null => { + info_!("Origin is null. Allowing? {}", self.allow_null); + self.allow_null + } + Origin::Parsed(ref parsed) => { + // Verify by exact, then regex + if self.exact.get(parsed).is_some() { + info_!("Origin has an exact match"); + return true; + } + if let Some(regex_set) = &self.regex { + let regex_match = regex_set.is_match(&parsed.ascii_serialization()); + info_!("Origin has a regex match? {}", regex_match); + return regex_match; + } + + info!("Origin does not match anything"); + false + } + } + } +} + /// A list of allowed methods /// /// The [list](https://api.rocket.rs/rocket/http/enum.Method.html) @@ -636,9 +778,10 @@ impl AllowedHeaders { /// ```json /// { /// "allowed_origins": { -/// "Some": [ -/// "https://www.acme.com" -/// ] +/// "Some": { +/// "exact": ["https://www.acme.com"], +/// "regex": ["^https://www.example-[A-z0-9]*.com$"] +/// } /// }, /// "allowed_methods": [ /// "POST", @@ -831,9 +974,9 @@ impl CorsOptions { /// documentation at the [crate root](index.html) for usage information. /// /// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`]. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] pub struct Cors { - pub(crate) allowed_origins: AllOrSome>, + pub(crate) allowed_origins: AllOrSome, pub(crate) allowed_methods: AllowedMethods, pub(crate) allowed_headers: AllOrSome>, pub(crate) allow_credentials: bool, @@ -921,7 +1064,7 @@ impl Cors { /// You can get this struct by using `Cors::validate_request` in an ad-hoc manner. #[derive(Eq, PartialEq, Debug)] pub(crate) struct Response { - allow_origin: Option>, + allow_origin: Option>, allow_methods: HashSet, allow_headers: HeaderFieldNamesSet, allow_credentials: bool, @@ -945,8 +1088,8 @@ impl Response { } /// Consumes the `Response` and return an altered response with origin and `vary_origin` set - fn origin(mut self, origin: &url::Origin, vary_origin: bool) -> Self { - self.allow_origin = Some(AllOrSome::Some(origin.clone())); + fn origin(mut self, origin: &str, vary_origin: bool) -> Self { + self.allow_origin = Some(AllOrSome::Some(origin.to_string())); self.vary_origin = vary_origin; self } @@ -1022,7 +1165,7 @@ impl Response { let origin = match *origin { AllOrSome::All => "*".to_string(), - AllOrSome::Some(ref origin) => origin.ascii_serialization(), + AllOrSome::Some(ref origin) => origin.to_string(), }; let _ = response.set_raw_header("Access-Control-Allow-Origin", origin); @@ -1251,27 +1394,25 @@ enum ValidationResult { None, /// Successful preflight request Preflight { - origin: url::Origin, + origin: String, headers: Option, }, /// Successful actual request - Request { origin: url::Origin }, + Request { origin: String }, } -/// Convert a str to Origin +/// Convert a str to a URL Origin fn to_origin>(origin: S) -> Result { - // What to do about Opaque origins? Ok(url::Url::parse(origin.as_ref())?.origin()) } /// Parse and process allowed origins -fn parse_origins(origins: &AllowedOrigins) -> Result>, Error> { +fn parse_origins(origins: &AllowedOrigins) -> Result, Error> { match origins { AllOrSome::All => Ok(AllOrSome::All), - AllOrSome::Some(ref origins) => { - let parsed: Result, Error> = - origins.iter().map(to_origin).collect(); - Ok(AllOrSome::Some(parsed?)) + AllOrSome::Some(origins) => { + let parsed = ParsedAllowedOrigins::parse(origins)?; + Ok(AllOrSome::Some(parsed)) } } } @@ -1309,14 +1450,14 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result { actual_request_validate(options, &origin)?; Ok(ValidationResult::Request { - origin: origin.deref().clone(), + origin: origin.to_string(), }) } } @@ -1326,16 +1467,19 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result>, + origin: &Origin, + allowed_origins: &AllOrSome, ) -> Result<(), Error> { match *allowed_origins { // Always matching is acceptable since the list of origins can be unbounded. AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_origins) => allowed_origins - .get(origin) - .and_then(|_| Some(())) - .ok_or_else(|| Error::OriginNotAllowed(origin.clone())), + AllOrSome::Some(ref allowed_origins) => { + if allowed_origins.verify(origin) { + Ok(()) + } else { + Err(Error::OriginNotAllowed(origin.to_string())) + } + } } } @@ -1405,7 +1549,7 @@ fn request_headers(request: &Request<'_>) -> Result, headers: &Option, ) -> Result<(), Error> { @@ -1453,7 +1597,7 @@ fn preflight_validate( /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). fn preflight_response( options: &Cors, - origin: &url::Origin, + origin: &str, headers: Option<&AccessControlRequestHeaders>, ) -> Response { let response = Response::new(); @@ -1524,7 +1668,7 @@ fn preflight_response( /// This implementation references the /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). -fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), Error> { +fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> { // Note: All header parse failures are dealt with in the `FromRequest` trait implementation // 2. If the value of the Origin header is not a case-sensitive match for any of the values @@ -1541,7 +1685,7 @@ fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), E /// This implementation references the /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch) -fn actual_request_response(options: &Cors, origin: &url::Origin) -> Response { +fn actual_request_response(options: &Cors, origin: &str) -> Response { let response = Response::new(); // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header, @@ -1640,8 +1784,12 @@ mod tests { use super::*; use crate::http::Method; + fn to_parsed_origin>(origin: S) -> Result { + Origin::from_str(origin.as_ref()) + } + fn make_cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); CorsOptions { allowed_origins, @@ -1652,8 +1800,8 @@ mod tests { allowed_headers: AllowedHeaders::some(&[&"Authorization", "Accept"]), allow_credentials: true, expose_headers: ["Content-Type", "X-Custom"] - .into_iter() - .map(|s| s.to_string().into()) + .iter() + .map(|s| s.to_string()) .collect(), ..Default::default() } @@ -1727,6 +1875,64 @@ mod tests { fn cors_default_deserialization_is_correct() { let deserialized: CorsOptions = serde_json::from_str("{}").expect("To not fail"); assert_eq!(deserialized, CorsOptions::default()); + + let expected_json = r#" +{ + "allowed_origins": "All", + "allowed_methods": [ + "POST", + "PATCH", + "PUT", + "DELETE", + "HEAD", + "OPTIONS", + "GET" + ], + "allowed_headers": "All", + "allow_credentials": false, + "expose_headers": [], + "max_age": null, + "send_wildcard": false, + "fairing_route_base": "/cors", + "fairing_route_rank": 0 +} +"#; + let actual: CorsOptions = serde_json::from_str(expected_json).expect("to not fail"); + assert_eq!(actual, CorsOptions::default()); + } + + /// Checks that the example provided can actually be deserialized + #[cfg(feature = "serialization")] + #[test] + fn cors_options_example_can_be_deserialized() { + let json = r#"{ + "allowed_origins": { + "Some": { + "exact": ["https://www.acme.com"], + "regex": ["^https://www.example-[A-z0-9]*.com$"] + } + }, + "allowed_methods": [ + "POST", + "DELETE", + "GET" + ], + "allowed_headers": { + "Some": [ + "Accept", + "Authorization" + ] + }, + "allow_credentials": true, + "expose_headers": [ + "Content-Type", + "X-Custom" + ], + "max_age": 42, + "send_wildcard": false, + "fairing_route_base": "/mycors" +}"#; + let _: CorsOptions = serde_json::from_str(json).expect("to not fail"); } // The following tests check validation @@ -1734,7 +1940,7 @@ mod tests { #[test] fn validate_origin_allows_all_origins() { let url = "https://www.example.com"; - let origin = not_err!(to_origin(&url)); + let origin = not_err!(to_parsed_origin(&url)); let allowed_origins = AllOrSome::All; not_err!(validate_origin(&origin, &allowed_origins)); @@ -1743,8 +1949,8 @@ mod tests { #[test] fn validate_origin_allows_origin() { let url = "https://www.example.com"; - let origin = not_err!(to_origin(&url)); - let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ + let origin = not_err!(to_parsed_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[ "https://www.example.com" ]))); @@ -1762,19 +1968,48 @@ mod tests { ]; for (url, allowed_origin) in cases { - let origin = not_err!(to_origin(&url)); - let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[allowed_origin]))); + let origin = not_err!(to_parsed_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[ + allowed_origin + ]))); not_err!(validate_origin(&origin, &allowed_origins)); } } + #[test] + fn validate_origin_validates_regex() { + let url = "https://www.example-something.com"; + let origin = not_err!(to_parsed_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_regex(&[ + "^https://www.example-[A-z0-9]+.com$" + ]))); + + not_err!(validate_origin(&origin, &allowed_origins)); + } + + #[test] + fn validate_origin_validates_mixed_settings() { + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some( + &["https://www.acme.com"], + &["^https://www.example-[A-z0-9]+.com$"] + ))); + + let url = "https://www.example-something123.com"; + let origin = not_err!(to_parsed_origin(&url)); + not_err!(validate_origin(&origin, &allowed_origins)); + + let url = "https://www.acme.com"; + let origin = not_err!(to_parsed_origin(&url)); + not_err!(validate_origin(&origin, &allowed_origins)); + } + #[test] #[should_panic(expected = "OriginNotAllowed")] fn validate_origin_rejects_invalid_origin() { let url = "https://www.acme.com"; - let origin = not_err!(to_origin(&url)); - let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ + let origin = not_err!(to_parsed_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[ "https://www.example.com" ]))); @@ -1784,7 +2019,7 @@ mod tests { #[test] fn response_sets_allow_origin_without_vary_correctly() { let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); + let response = response.origin("https://www.example.com", false); // Build response and check built response header let expected_header = vec!["https://www.example.com"]; @@ -1801,7 +2036,7 @@ mod tests { #[test] fn response_sets_allow_origin_with_vary_correctly() { let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), true); + let response = response.origin("https://www.example.com", true); // Build response and check built response header let expected_header = vec!["https://www.example.com"]; @@ -1828,27 +2063,11 @@ mod tests { assert_eq!(expected_header, actual_header); } - #[test] - fn response_sets_allow_origin_with_ascii_serialization() { - let response = Response::new(); - let response = response.origin(&to_origin("https://аpple.com").unwrap(), false); - - // Build response and check built response header - // This is "punycode" - let expected_header = vec!["https://xn--pple-43d.com"]; - let response = response.response(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Origin") - .collect(); - assert_eq!(expected_header, actual_header); - } - #[test] fn response_sets_exposed_headers_correctly() { let headers = vec!["Bar", "Baz", "Foo"]; let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); + let response = response.origin("https://www.example.com", false); let response = response.exposed_headers(&headers); // Build response and check built response header @@ -1870,7 +2089,7 @@ mod tests { #[test] fn response_sets_max_age_correctly() { let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); + let response = response.origin("https://www.example.com", false); let response = response.max_age(Some(42)); @@ -1884,7 +2103,7 @@ mod tests { #[test] fn response_does_not_set_max_age_when_none() { let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); + let response = response.origin("https://www.example.com", false); let response = response.max_age(None); @@ -1997,7 +2216,7 @@ mod tests { .finalize(); let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); + let response = response.origin("https://www.example.com", false); let response = response.response(original); // Check CORS header let expected_header = vec!["https://www.example.com"]; @@ -2073,7 +2292,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Preflight { - origin: to_origin("https://www.acme.com").unwrap(), + origin: "https://www.acme.com".to_string(), // Checks that only a subset of allowed headers are returned // -- i.e. whatever is requested for headers: Some(FromStr::from_str("Authorization").unwrap()), @@ -2108,7 +2327,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Preflight { - origin: to_origin("https://www.example.com").unwrap(), + origin: "https://www.example.com".to_string(), headers: Some(FromStr::from_str("Authorization").unwrap()), }; @@ -2226,7 +2445,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Request { - origin: to_origin("https://www.acme.com").unwrap(), + origin: "https://www.acme.com".to_string(), }; assert_eq!(expected_result, result); @@ -2245,7 +2464,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Request { - origin: to_origin("https://www.example.com").unwrap(), + origin: "https://www.example.com".to_string(), }; assert_eq!(expected_result, result); @@ -2301,7 +2520,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&to_origin("https://www.acme.com").unwrap(), false) + .origin("https://www.acme.com", false) .headers(&["Authorization"]) .methods(&options.allowed_methods) .credentials(options.allow_credentials) @@ -2341,7 +2560,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&to_origin("https://www.acme.com").unwrap(), true) + .origin("https://www.acme.com", true) .headers(&["Authorization"]) .methods(&options.allowed_methods) .credentials(options.allow_credentials) @@ -2402,7 +2621,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&to_origin("https://www.acme.com").unwrap(), false) + .origin("https://www.acme.com", false) .credentials(options.allow_credentials) .exposed_headers(&["Content-Type", "X-Custom"]); @@ -2425,7 +2644,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&to_origin("https://www.acme.com").unwrap(), true) + .origin("https://www.acme.com", true) .credentials(options.allow_credentials) .exposed_headers(&["Content-Type", "X-Custom"]); diff --git a/tests/fairing.rs b/tests/fairing.rs index 509ec9b..055e015 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -1,14 +1,14 @@ //! This crate tests using `rocket_cors` using Fairings #![feature(proc_macro_hygiene, decl_macro)] use hyper; -#[macro_use] -extern crate rocket; use std::str::FromStr; use rocket::http::Method; use rocket::http::{Header, Status}; use rocket::local::Client; +use rocket::response::Body; +use rocket::{get, routes}; use rocket_cors::*; #[get("/")] @@ -22,10 +22,10 @@ fn panicking_route() { } fn make_cors() -> Cors { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); CorsOptions { - allowed_origins: allowed_origins, + allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, @@ -73,7 +73,7 @@ fn smoke_test() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS".to_string())); let origin_header = response @@ -124,7 +124,7 @@ fn cors_get_check() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS".to_string())); let origin_header = response @@ -144,7 +144,7 @@ fn cors_get_no_origin() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS".to_string())); } diff --git a/tests/guard.rs b/tests/guard.rs index eb30da4..d929616 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -1,8 +1,6 @@ //! This crate tests using `rocket_cors` using the per-route handling with request guard #![feature(proc_macro_hygiene, decl_macro)] use hyper; -#[macro_use] -extern crate rocket; use rocket_cors as cors; use std::str::FromStr; @@ -10,6 +8,8 @@ use std::str::FromStr; use rocket::http::Method; use rocket::http::{Header, Status}; use rocket::local::Client; +use rocket::response::Body; +use rocket::{get, options, routes}; use rocket::{Response, State}; #[get("/")] @@ -60,10 +60,10 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo } fn make_cors() -> cors::Cors { - let allowed_origins = cors::AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = cors::AllowedOrigins::some_exact(&["https://www.acme.com"]); cors::CorsOptions { - allowed_origins: allowed_origins, + allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: cors::AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, @@ -119,7 +119,7 @@ fn smoke_test() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS".to_string())); let origin_header = response @@ -205,7 +205,7 @@ fn cors_get_check() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS".to_string())); let origin_header = response @@ -226,7 +226,7 @@ fn cors_get_no_origin() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS".to_string())); assert!(response .headers() @@ -408,7 +408,7 @@ fn overridden_options_routes_are_used() { .header(request_headers); let mut response = req.dispatch(); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert!(response.status().class().is_success()); assert_eq!(body_str, Some("Manual CORS Preflight".to_string())); diff --git a/tests/headers.rs b/tests/headers.rs index c6b63f8..603c25c 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -1,14 +1,14 @@ //! This crate tests that all the request headers are parsed correctly in the round trip #![feature(proc_macro_hygiene, decl_macro)] use hyper; -#[macro_use] -extern crate rocket; use std::ops::Deref; use std::str::FromStr; use rocket::http::Header; use rocket::local::Client; +use rocket::response::Body; +use rocket::{get, routes}; use rocket_cors::headers::*; #[get("/request_headers")] @@ -53,7 +53,7 @@ fn request_headers_round_trip_smoke_test() { assert!(response.status().class().is_success()); let body_str = response .body() - .and_then(|body| body.into_string()) + .and_then(Body::into_string) .expect("Non-empty body"); let expected_body = r#"https://foo.bar.xyz GET diff --git a/tests/manual.rs b/tests/manual.rs index 27cb58d..b7f4147 100644 --- a/tests/manual.rs +++ b/tests/manual.rs @@ -1,16 +1,16 @@ //! This crate tests using `rocket_cors` using manual mode #![feature(proc_macro_hygiene, decl_macro)] use hyper; -#[macro_use] -extern crate rocket; use std::str::FromStr; use rocket::http::Method; use rocket::http::{Header, Status}; use rocket::local::Client; +use rocket::response::Body; use rocket::response::Responder; use rocket::State; +use rocket::{get, options, routes}; use rocket_cors::*; /// Using a borrowed `Cors` @@ -23,7 +23,7 @@ fn cors(options: State<'_, Cors>) -> impl Responder<'_> { #[get("/panic")] fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> { - options.inner().respond_borrowed(|_| -> () { + options.inner().respond_borrowed(|_| { panic!("This route will panic"); }) } @@ -66,10 +66,10 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp } fn make_cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); CorsOptions { - allowed_origins: allowed_origins, + allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, @@ -78,10 +78,10 @@ fn make_cors_options() -> CorsOptions { } fn make_different_cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.example.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.example.com"]); CorsOptions { - allowed_origins: allowed_origins, + allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, @@ -129,7 +129,7 @@ fn smoke_test() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS".to_string())); let origin_header = response @@ -180,7 +180,7 @@ fn cors_get_borrowed_check() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS".to_string())); let origin_header = response @@ -200,7 +200,7 @@ fn cors_get_no_origin() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS".to_string())); } @@ -378,7 +378,7 @@ fn cors_options_owned_check() { .header(request_headers); let mut response = req.dispatch(); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert!(response.status().class().is_success()); assert_eq!(body_str, Some("Manual CORS Preflight".to_string())); @@ -404,7 +404,7 @@ fn cors_get_owned_check() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS Owned".to_string())); let origin_header = response diff --git a/tests/mix.rs b/tests/mix.rs index 87a969b..08ffbf0 100644 --- a/tests/mix.rs +++ b/tests/mix.rs @@ -4,15 +4,15 @@ //! `ping` route that you want to allow all Origins to access. #![feature(proc_macro_hygiene, decl_macro)] use hyper; -#[macro_use] -extern crate rocket; use rocket_cors; use std::str::FromStr; use rocket::http::{Header, Method, Status}; use rocket::local::Client; +use rocket::response::Body; use rocket::response::Responder; +use rocket::{get, options, routes}; use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard}; @@ -40,11 +40,11 @@ fn ping_options<'r>() -> impl Responder<'r> { /// Returns the "application wide" Cors struct fn cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions { - allowed_origins: allowed_origins, + allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, @@ -100,7 +100,7 @@ fn smoke_test() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS!".to_string())); let origin_header = response @@ -151,7 +151,7 @@ fn cors_get_check() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS!".to_string())); let origin_header = response @@ -171,7 +171,7 @@ fn cors_get_no_origin() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Hello CORS!".to_string())); } @@ -333,7 +333,7 @@ fn cors_get_ping_check() { let mut response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(|body| body.into_string()); + let body_str = response.body().and_then(Body::into_string); assert_eq!(body_str, Some("Pong!".to_string())); let origin_header = response