Support regex

This commit is contained in:
Yong Wen Chua 2019-03-12 14:49:29 +08:00
parent 43b92a3b4a
commit 6a70a8b70b
No known key found for this signature in database
GPG Key ID: A70BD30B21497EA9
2 changed files with 87 additions and 25 deletions

View File

@ -21,6 +21,7 @@ default = ["serialization"]
serialization = ["serde", "serde_derive", "unicase_serde"] serialization = ["serde", "serde_derive", "unicase_serde"]
[dependencies] [dependencies]
regex = "1.1"
rocket = "0.4.0" rocket = "0.4.0"
log = "0.3" log = "0.3"
unicase = "2.0" unicase = "2.0"

View File

@ -283,6 +283,7 @@ use std::ops::Deref;
use std::str::FromStr; use std::str::FromStr;
use ::log::{error, info, log}; use ::log::{error, info, log};
use regex::RegexSet;
use rocket::http::{self, Status}; use rocket::http::{self, Status};
use rocket::request::{FromRequest, Request}; use rocket::request::{FromRequest, Request};
use rocket::response; use rocket::response;
@ -320,6 +321,8 @@ pub enum Error {
OriginNotAllowed(String), OriginNotAllowed(String),
/// Requested method is not allowed /// Requested method is not allowed
MethodNotAllowed(String), MethodNotAllowed(String),
/// A regular expression compilation error
RegexError(regex::Error),
/// One or more headers requested are not allowed /// One or more headers requested are not allowed
HeadersNotAllowed, HeadersNotAllowed,
/// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C
@ -380,6 +383,7 @@ impl fmt::Display for Error {
"The `on_response` handler of Fairing could not find the injected header from the \ "The `on_response` handler of Fairing could not find the injected header from the \
Request. Either some other fairing has removed it, or this is a bug."), Request. Either some other fairing has removed it, or this is a bug."),
Error::OpaqueAllowedOrigin(ref origin) => write!(f, "The configured Origin '{}' not have a parsable Origin. Use a regex instead.", origin), Error::OpaqueAllowedOrigin(ref origin) => write!(f, "The configured Origin '{}' not have a parsable Origin. Use a regex instead.", origin),
Error::RegexError(ref e) => write!(f, "{}", e),
} }
} }
} }
@ -406,6 +410,12 @@ impl From<url::ParseError> for Error {
} }
} }
impl From<regex::Error> for Error {
fn from(error: regex::Error) -> Self {
Error::RegexError(error)
}
}
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
/// ///
/// `Default` is implemented for this enum and is `All`. /// `Default` is implemented for this enum and is `All`.
@ -537,20 +547,23 @@ mod method_serde {
pub type AllowedOrigins = AllOrSome<Origins>; pub type AllowedOrigins = AllOrSome<Origins>;
impl AllowedOrigins { impl AllowedOrigins {
/// Allows some _exact_ origins /// Allows some origins
/// ///
/// Validation is not performed at this stage, but at a later stage. /// Validation is not performed at this stage, but at a later stage.
#[deprecated(since = "0.5.0", note = "use `some_exact` instead")] pub fn some<S1: AsRef<str>, S2: AsRef<str>>(exact: &[S1], regex: &[S2]) -> Self {
pub fn some(urls: &[&str]) -> Self { AllOrSome::Some(Origins {
Self::some_exact(urls) exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()),
regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()),
..Default::default()
})
} }
/// Allows some _exact_ origins /// Allows some _exact_ origins
/// ///
/// Validation is not performed at this stage, but at a later stage. /// Validation is not performed at this stage, but at a later stage.
pub fn some_exact<S: AsRef<str>>(urls: &[S]) -> Self { pub fn some_exact<S: AsRef<str>>(exact: &[S]) -> Self {
AllOrSome::Some(Origins { AllOrSome::Some(Origins {
exact: Some(urls.iter().map(|s| s.as_ref().to_string()).collect()), exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()),
..Default::default() ..Default::default()
}) })
} }
@ -615,10 +628,11 @@ pub struct Origins {
} }
/// Parsed set of configured allowed origins /// Parsed set of configured allowed origins
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug)]
pub(crate) struct ParsedAllowedOrigins { pub(crate) struct ParsedAllowedOrigins {
pub allow_null: bool, pub allow_null: bool,
pub exact: HashSet<url::Origin>, pub exact: HashSet<url::Origin>,
pub regex: Option<RegexSet>,
} }
impl ParsedAllowedOrigins { impl ParsedAllowedOrigins {
@ -638,11 +652,42 @@ impl ParsedAllowedOrigins {
} }
})?; })?;
let regex = match &origins.regex {
None => None,
Some(ref regex) => Some(RegexSet::new(regex)?),
};
Ok(Self { Ok(Self {
allow_null: origins.allow_null, allow_null: origins.allow_null,
exact, exact,
regex,
}) })
} }
fn verify(&self, origin: &Origin) -> bool {
info_!("Verifying origin: {}", origin);
match origin {
Origin::Null => {
info_!("Origin is null. Allowing? {}", self.allow_null);
self.allow_null
}
Origin::Parsed(ref parsed) => {
// Verify by exact, then regex
if self.exact.get(parsed).is_some() {
info_!("Origin has an exact match");
return true;
}
if let Some(regex_set) = &self.regex {
let regex_match = regex_set.is_match(&parsed.ascii_serialization());
info_!("Origin has a regex match? {}", regex_match);
return regex_match;
}
info!("Origin does not match anything");
false
}
}
}
} }
/// A list of allowed methods /// A list of allowed methods
@ -930,7 +975,7 @@ impl CorsOptions {
/// documentation at the [crate root](index.html) for usage information. /// documentation at the [crate root](index.html) for usage information.
/// ///
/// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`]. /// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`].
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug)]
pub struct Cors { pub struct Cors {
pub(crate) allowed_origins: AllOrSome<ParsedAllowedOrigins>, pub(crate) allowed_origins: AllOrSome<ParsedAllowedOrigins>,
pub(crate) allowed_methods: AllowedMethods, pub(crate) allowed_methods: AllowedMethods,
@ -1429,24 +1474,13 @@ fn validate_origin(
match *allowed_origins { match *allowed_origins {
// Always matching is acceptable since the list of origins can be unbounded. // Always matching is acceptable since the list of origins can be unbounded.
AllOrSome::All => Ok(()), AllOrSome::All => Ok(()),
// AllOrSome::Some(ref allowed_origins) => allowed_origins AllOrSome::Some(ref allowed_origins) => {
// .get(origin) if allowed_origins.verify(origin) {
// .and_then(|_| Some(())) Ok(())
// .ok_or_else(|| Error::OriginNotAllowed(origin.clone())), } else {
AllOrSome::Some(ref allowed_origins) => match origin { Err(Error::OriginNotAllowed(origin.to_string()))
Origin::Null => {
if allowed_origins.allow_null {
Ok(())
} else {
Err(Error::OriginNotAllowed(origin.to_string()))
}
} }
Origin::Parsed(ref parsed) => allowed_origins }
.exact
.get(parsed)
.and_then(|_| Some(()))
.ok_or_else(|| Error::OriginNotAllowed(origin.to_string())),
},
} }
} }
@ -1944,6 +1978,33 @@ mod tests {
} }
} }
#[test]
fn validate_origin_validates_regex() {
let url = "https://www.example-something.com";
let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_regex(&[
"^https://www.example-[A-z0-9]+.com$"
])));
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test]
fn validate_origin_validates_mixed_settings() {
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(
&["https://www.acme.com"],
&["^https://www.example-[A-z0-9]+.com$"]
)));
let url = "https://www.example-something123.com";
let origin = not_err!(to_parsed_origin(&url));
not_err!(validate_origin(&origin, &allowed_origins));
let url = "https://www.acme.com";
let origin = not_err!(to_parsed_origin(&url));
not_err!(validate_origin(&origin, &allowed_origins));
}
#[test] #[test]
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn validate_origin_rejects_invalid_origin() { fn validate_origin_rejects_invalid_origin() {