From 6a70a8b70b6e63dec0a700f309724425fcc2549c Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Tue, 12 Mar 2019 14:49:29 +0800 Subject: [PATCH] Support regex --- Cargo.toml | 1 + src/lib.rs | 111 +++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 87 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 69c9221..f45ac48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ default = ["serialization"] serialization = ["serde", "serde_derive", "unicase_serde"] [dependencies] +regex = "1.1" rocket = "0.4.0" log = "0.3" unicase = "2.0" diff --git a/src/lib.rs b/src/lib.rs index 7dfb355..57ba833 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -283,6 +283,7 @@ use std::ops::Deref; use std::str::FromStr; use ::log::{error, info, log}; +use regex::RegexSet; use rocket::http::{self, Status}; use rocket::request::{FromRequest, Request}; use rocket::response; @@ -320,6 +321,8 @@ pub enum Error { OriginNotAllowed(String), /// Requested method is not allowed MethodNotAllowed(String), + /// A regular expression compilation error + RegexError(regex::Error), /// One or more headers requested are not allowed HeadersNotAllowed, /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C @@ -380,6 +383,7 @@ impl fmt::Display for Error { "The `on_response` handler of Fairing could not find the injected header from the \ Request. Either some other fairing has removed it, or this is a bug."), Error::OpaqueAllowedOrigin(ref origin) => write!(f, "The configured Origin '{}' not have a parsable Origin. Use a regex instead.", origin), + Error::RegexError(ref e) => write!(f, "{}", e), } } } @@ -406,6 +410,12 @@ impl From for Error { } } +impl From for Error { + fn from(error: regex::Error) -> Self { + Error::RegexError(error) + } +} + /// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// /// `Default` is implemented for this enum and is `All`. @@ -537,20 +547,23 @@ mod method_serde { pub type AllowedOrigins = AllOrSome; impl AllowedOrigins { - /// Allows some _exact_ origins + /// Allows some origins /// /// Validation is not performed at this stage, but at a later stage. - #[deprecated(since = "0.5.0", note = "use `some_exact` instead")] - pub fn some(urls: &[&str]) -> Self { - Self::some_exact(urls) + pub fn some, S2: AsRef>(exact: &[S1], regex: &[S2]) -> Self { + AllOrSome::Some(Origins { + exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()), + regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()), + ..Default::default() + }) } /// Allows some _exact_ origins /// /// Validation is not performed at this stage, but at a later stage. - pub fn some_exact>(urls: &[S]) -> Self { + pub fn some_exact>(exact: &[S]) -> Self { AllOrSome::Some(Origins { - exact: Some(urls.iter().map(|s| s.as_ref().to_string()).collect()), + exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()), ..Default::default() }) } @@ -615,10 +628,11 @@ pub struct Origins { } /// Parsed set of configured allowed origins -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] pub(crate) struct ParsedAllowedOrigins { pub allow_null: bool, pub exact: HashSet, + pub regex: Option, } impl ParsedAllowedOrigins { @@ -638,11 +652,42 @@ impl ParsedAllowedOrigins { } })?; + let regex = match &origins.regex { + None => None, + Some(ref regex) => Some(RegexSet::new(regex)?), + }; + Ok(Self { allow_null: origins.allow_null, exact, + regex, }) } + + fn verify(&self, origin: &Origin) -> bool { + info_!("Verifying origin: {}", origin); + match origin { + Origin::Null => { + info_!("Origin is null. Allowing? {}", self.allow_null); + self.allow_null + } + Origin::Parsed(ref parsed) => { + // Verify by exact, then regex + if self.exact.get(parsed).is_some() { + info_!("Origin has an exact match"); + return true; + } + if let Some(regex_set) = &self.regex { + let regex_match = regex_set.is_match(&parsed.ascii_serialization()); + info_!("Origin has a regex match? {}", regex_match); + return regex_match; + } + + info!("Origin does not match anything"); + false + } + } + } } /// A list of allowed methods @@ -930,7 +975,7 @@ impl CorsOptions { /// documentation at the [crate root](index.html) for usage information. /// /// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`]. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] pub struct Cors { pub(crate) allowed_origins: AllOrSome, pub(crate) allowed_methods: AllowedMethods, @@ -1429,24 +1474,13 @@ fn validate_origin( match *allowed_origins { // Always matching is acceptable since the list of origins can be unbounded. AllOrSome::All => Ok(()), - // AllOrSome::Some(ref allowed_origins) => allowed_origins - // .get(origin) - // .and_then(|_| Some(())) - // .ok_or_else(|| Error::OriginNotAllowed(origin.clone())), - AllOrSome::Some(ref allowed_origins) => match origin { - Origin::Null => { - if allowed_origins.allow_null { - Ok(()) - } else { - Err(Error::OriginNotAllowed(origin.to_string())) - } + AllOrSome::Some(ref allowed_origins) => { + if allowed_origins.verify(origin) { + Ok(()) + } else { + Err(Error::OriginNotAllowed(origin.to_string())) } - Origin::Parsed(ref parsed) => allowed_origins - .exact - .get(parsed) - .and_then(|_| Some(())) - .ok_or_else(|| Error::OriginNotAllowed(origin.to_string())), - }, + } } } @@ -1944,6 +1978,33 @@ mod tests { } } + #[test] + fn validate_origin_validates_regex() { + let url = "https://www.example-something.com"; + let origin = not_err!(to_parsed_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_regex(&[ + "^https://www.example-[A-z0-9]+.com$" + ]))); + + not_err!(validate_origin(&origin, &allowed_origins)); + } + + #[test] + fn validate_origin_validates_mixed_settings() { + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some( + &["https://www.acme.com"], + &["^https://www.example-[A-z0-9]+.com$"] + ))); + + let url = "https://www.example-something123.com"; + let origin = not_err!(to_parsed_origin(&url)); + not_err!(validate_origin(&origin, &allowed_origins)); + + let url = "https://www.acme.com"; + let origin = not_err!(to_parsed_origin(&url)); + not_err!(validate_origin(&origin, &allowed_origins)); + } + #[test] #[should_panic(expected = "OriginNotAllowed")] fn validate_origin_rejects_invalid_origin() {