From bc16568e8bbf3789a493c2cd374fd9939ead2794 Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Tue, 12 Mar 2019 13:09:15 +0800 Subject: [PATCH] Refactor Origin --- Cargo.toml | 2 +- README.md | 3 +- src/fairing.rs | 1 + src/headers.rs | 32 +++++++---- src/lib.rs | 153 ++++++++++++++++++++++++++++++++++++++----------- 5 files changed, 144 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2ae7377..69c9221 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rocket_cors" -version = "0.4.0" +version = "0.5.0" license = "MIT/Apache-2.0" authors = ["Yong Wen Chua "] description = "Cross-origin resource sharing (CORS) for Rocket.rs applications" diff --git a/README.md b/README.md index d63cdf3..5d2c2d0 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ # rocket_cors [![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors) -[![Dependency Status](https://dependencyci.com/github/lawliet89/rocket_cors/badge)](https://dependencyci.com/github/lawliet89/rocket_cors) [![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors) [![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors) @@ -31,7 +30,7 @@ work, but they are subject to the minimum that Rocket sets. Add the following to Cargo.toml: ```toml -rocket_cors = "0.4.0" +rocket_cors = "0.5.0" ``` To use the latest `master` branch, for example: diff --git a/src/fairing.rs b/src/fairing.rs index 7f497dd..5fe40b2 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -63,6 +63,7 @@ fn on_response_wrapper( return Ok(()); } + let origin = origin.to_string(); let cors_response = if request.method() == http::Method::Options { let headers = request_headers(request)?; preflight_response(options, &origin, headers.as_ref()) diff --git a/src/headers.rs b/src/headers.rs index 6e819e9..501e008 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -63,28 +63,34 @@ pub type HeaderFieldNamesSet = HashSet; /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// to ensure that `Origin` is passed in correctly. +/// +/// Reference: [Mozilla](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) #[derive(Eq, PartialEq, Clone, Hash, Debug)] -pub struct Origin(pub url::Origin); +pub enum Origin { + /// A `null` Origin + Null, + /// A well-formed origin that was parsed by [`url::Url::origin`] + Parsed(url::Origin), +} impl FromStr for Origin { type Err = crate::Error; fn from_str(input: &str) -> Result { - Ok(Origin(crate::to_origin(input)?)) - } -} - -impl Deref for Origin { - type Target = url::Origin; - - fn deref(&self) -> &Self::Target { - &self.0 + if input.to_lowercase() == "null" { + Ok(Origin::Null) + } else { + Ok(Origin::Parsed(crate::to_origin(input)?)) + } } } impl fmt::Display for Origin { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.ascii_serialization()) + match self { + Origin::Null => write!(f, "null"), + Origin::Parsed(ref parsed) => write!(f, "{}", parsed.ascii_serialization()), + } } } @@ -195,6 +201,10 @@ mod tests { let parsed = not_err!(Origin::from_str(url)); assert_eq!(parsed.ascii_serialization(), url); + let url = "https://foo.bar.xyz:1234"; + let parsed = not_err!(Origin::from_str(url)); + assert_eq!(parsed.ascii_serialization(), url); + // this should never really be sent by a compliant user agent let url = "https://foo.bar.xyz/path/somewhere"; let parsed = not_err!(Origin::from_str(url)); diff --git a/src/lib.rs b/src/lib.rs index 0378016..1145096 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,5 @@ /*! [![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors) -[![Dependency Status](https://dependencyci.com/github/lawliet89/rocket_cors/badge)](https://dependencyci.com/github/lawliet89/rocket_cors) [![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors) [![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors) @@ -30,7 +29,7 @@ might work, but they are subject to the minimum that Rocket sets. Add the following to Cargo.toml: ```toml -rocket_cors = "0.4.0" +rocket_cors = "0.5.0" ``` To use the latest `master` branch, for example: @@ -46,7 +45,7 @@ the [`CorsOptions`] struct that is described below. If you would like to disable change your `Cargo.toml` to: ```toml -rocket_cors = { version = "0.4.0", default-features = false } +rocket_cors = { version = "0.5.0", default-features = false } ``` ## Usage @@ -316,7 +315,7 @@ pub enum Error { /// The request header `Access-Control-Request-Headers` is required but is missing. MissingRequestHeaders, /// Origin is not allowed to make this request - OriginNotAllowed(url::Origin), + OriginNotAllowed(String), /// Requested method is not allowed MethodNotAllowed(String), /// One or more headers requested are not allowed @@ -365,7 +364,7 @@ impl fmt::Display for Error { "The request header `Access-Control-Request-Headers` \ is required but is missing") } - Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin.ascii_serialization()), + Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin), Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method), Error::HeadersNotAllowed => write!(f, "Headers are not allowed"), Error::CredentialsWithWildcardOrigin => { write!(f, @@ -529,16 +528,44 @@ mod method_serde { /// use rocket_cors::AllowedOrigins; /// /// let all_origins = AllowedOrigins::all(); -/// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]); +/// let some_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); +/// let null_origins = AllowedOrigins::some_null(); /// ``` -pub type AllowedOrigins = AllOrSome>; +pub type AllowedOrigins = AllOrSome; impl AllowedOrigins { - /// Allows some origins + /// Allows some _exact_ origins /// /// Validation is not performed at this stage, but at a later stage. + #[deprecated(since = "0.5.0", note = "use `some_exact` instead")] pub fn some(urls: &[&str]) -> Self { - AllOrSome::Some(urls.iter().map(|s| s.to_string()).collect()) + Self::some_exact(urls) + } + + /// Allows some _exact_ origins + /// + /// Validation is not performed at this stage, but at a later stage. + pub fn some_exact>(urls: &[S]) -> Self { + AllOrSome::Some(Origins { + exact: Some(urls.iter().map(|s| s.as_ref().to_string()).collect()), + ..Default::default() + }) + } + + /// Allow some __regex__ origins + pub fn some_regex>(regex: &[S]) -> Self { + AllOrSome::Some(Origins { + regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()), + ..Default::default() + }) + } + + /// Allow some `null` origins + pub fn some_null() -> Self { + AllOrSome::Some(Origins { + allow_null: true, + ..Default::default() + }) } /// Allows all origins @@ -547,6 +574,53 @@ impl AllowedOrigins { } } +/// A list of allows origins +/// +/// An origin is defined according +/// [syntax](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) defined here. +/// +/// Origins can be specified as an exact match or via some other supported way according to the +/// fields of the struct. +/// +/// These Origins are specified as logical `ORs`. That is, if any of the origins match, the entire +/// request is considered to be valid. +#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serialization", serde(default))] +pub struct Origins { + /// Whether null origins are accepted + #[cfg_attr(feature = "serialization", serde(default))] + pub allow_null: bool, + /// Origins that must be matched exactly as below. These __must__ be valid URL strings that will + /// be parsed and validated when creating [`Cors`]. + #[cfg_attr(feature = "serialization", serde(default))] + pub exact: Option>, + /// Origins that will be matched via __any__ regex in this list. These __must__ be valid Regex + /// that will be parsed and validated when creating [`Cors`]. + #[cfg_attr(feature = "serialization", serde(default))] + pub regex: Option>, +} + +/// Parsed set of configured allowed origins +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct ParsedAllowedOrigins { + pub allow_null: bool, + pub exact: HashSet, +} + +impl ParsedAllowedOrigins { + fn parse(origins: &Origins) -> Result { + let exact: Result<_, Error> = match &origins.exact { + Some(exact) => exact.iter().map(|url| to_origin(url.as_str())).collect(), + None => Ok(Default::default()), + }; + Ok(Self { + allow_null: origins.allow_null, + exact: exact?, + }) + } +} + /// A list of allowed methods /// /// The [list](https://api.rocket.rs/rocket/http/enum.Method.html) @@ -833,7 +907,7 @@ impl CorsOptions { /// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`]. #[derive(Clone, Debug, Eq, PartialEq)] pub struct Cors { - pub(crate) allowed_origins: AllOrSome>, + pub(crate) allowed_origins: AllOrSome, pub(crate) allowed_methods: AllowedMethods, pub(crate) allowed_headers: AllOrSome>, pub(crate) allow_credentials: bool, @@ -921,7 +995,7 @@ impl Cors { /// You can get this struct by using `Cors::validate_request` in an ad-hoc manner. #[derive(Eq, PartialEq, Debug)] pub(crate) struct Response { - allow_origin: Option>, + allow_origin: Option>, allow_methods: HashSet, allow_headers: HeaderFieldNamesSet, allow_credentials: bool, @@ -945,8 +1019,8 @@ impl Response { } /// Consumes the `Response` and return an altered response with origin and `vary_origin` set - fn origin(mut self, origin: &url::Origin, vary_origin: bool) -> Self { - self.allow_origin = Some(AllOrSome::Some(origin.clone())); + fn origin(mut self, origin: &str, vary_origin: bool) -> Self { + self.allow_origin = Some(AllOrSome::Some(origin.to_string())); self.vary_origin = vary_origin; self } @@ -1022,7 +1096,7 @@ impl Response { let origin = match *origin { AllOrSome::All => "*".to_string(), - AllOrSome::Some(ref origin) => origin.ascii_serialization(), + AllOrSome::Some(ref origin) => origin.to_string(), }; let _ = response.set_raw_header("Access-Control-Allow-Origin", origin); @@ -1251,11 +1325,11 @@ enum ValidationResult { None, /// Successful preflight request Preflight { - origin: url::Origin, + origin: String, headers: Option, }, /// Successful actual request - Request { origin: url::Origin }, + Request { origin: String }, } /// Convert a str to Origin @@ -1265,13 +1339,12 @@ fn to_origin>(origin: S) -> Result { } /// Parse and process allowed origins -fn parse_origins(origins: &AllowedOrigins) -> Result>, Error> { +fn parse_origins(origins: &AllowedOrigins) -> Result, Error> { match origins { AllOrSome::All => Ok(AllOrSome::All), - AllOrSome::Some(ref origins) => { - let parsed: Result, Error> = - origins.iter().map(to_origin).collect(); - Ok(AllOrSome::Some(parsed?)) + AllOrSome::Some(origins) => { + let parsed = ParsedAllowedOrigins::parse(origins)?; + Ok(AllOrSome::Some(parsed)) } } } @@ -1309,14 +1382,14 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result { actual_request_validate(options, &origin)?; Ok(ValidationResult::Request { - origin: origin.deref().clone(), + origin: origin.to_string(), }) } } @@ -1326,16 +1399,30 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result>, + origin: &Origin, + allowed_origins: &AllOrSome, ) -> Result<(), Error> { match *allowed_origins { // Always matching is acceptable since the list of origins can be unbounded. AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_origins) => allowed_origins - .get(origin) - .and_then(|_| Some(())) - .ok_or_else(|| Error::OriginNotAllowed(origin.clone())), + // AllOrSome::Some(ref allowed_origins) => allowed_origins + // .get(origin) + // .and_then(|_| Some(())) + // .ok_or_else(|| Error::OriginNotAllowed(origin.clone())), + AllOrSome::Some(ref allowed_origins) => match origin { + Origin::Null => { + if allowed_origins.allow_null { + Ok(()) + } else { + Err(Error::OriginNotAllowed(origin.to_string())) + } + } + Origin::Parsed(ref parsed) => allowed_origins + .exact + .get(parsed) + .and_then(|_| Some(())) + .ok_or_else(|| Error::OriginNotAllowed(origin.to_string())), + }, } } @@ -1405,7 +1492,7 @@ fn request_headers(request: &Request<'_>) -> Result, headers: &Option, ) -> Result<(), Error> { @@ -1453,7 +1540,7 @@ fn preflight_validate( /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). fn preflight_response( options: &Cors, - origin: &url::Origin, + origin: &str, headers: Option<&AccessControlRequestHeaders>, ) -> Response { let response = Response::new(); @@ -1524,7 +1611,7 @@ fn preflight_response( /// This implementation references the /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch). -fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), Error> { +fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> { // Note: All header parse failures are dealt with in the `FromRequest` trait implementation // 2. If the value of the Origin header is not a case-sensitive match for any of the values @@ -1541,7 +1628,7 @@ fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), E /// This implementation references the /// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) /// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch) -fn actual_request_response(options: &Cors, origin: &url::Origin) -> Response { +fn actual_request_response(options: &Cors, origin: &str) -> Response { let response = Response::new(); // 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,