Support regex
This commit is contained in:
parent
43b92a3b4a
commit
6a70a8b70b
|
@ -21,6 +21,7 @@ default = ["serialization"]
|
|||
serialization = ["serde", "serde_derive", "unicase_serde"]
|
||||
|
||||
[dependencies]
|
||||
regex = "1.1"
|
||||
rocket = "0.4.0"
|
||||
log = "0.3"
|
||||
unicase = "2.0"
|
||||
|
|
103
src/lib.rs
103
src/lib.rs
|
@ -283,6 +283,7 @@ use std::ops::Deref;
|
|||
use std::str::FromStr;
|
||||
|
||||
use ::log::{error, info, log};
|
||||
use regex::RegexSet;
|
||||
use rocket::http::{self, Status};
|
||||
use rocket::request::{FromRequest, Request};
|
||||
use rocket::response;
|
||||
|
@ -320,6 +321,8 @@ pub enum Error {
|
|||
OriginNotAllowed(String),
|
||||
/// Requested method is not allowed
|
||||
MethodNotAllowed(String),
|
||||
/// A regular expression compilation error
|
||||
RegexError(regex::Error),
|
||||
/// One or more headers requested are not allowed
|
||||
HeadersNotAllowed,
|
||||
/// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C
|
||||
|
@ -380,6 +383,7 @@ impl fmt::Display for Error {
|
|||
"The `on_response` handler of Fairing could not find the injected header from the \
|
||||
Request. Either some other fairing has removed it, or this is a bug."),
|
||||
Error::OpaqueAllowedOrigin(ref origin) => write!(f, "The configured Origin '{}' not have a parsable Origin. Use a regex instead.", origin),
|
||||
Error::RegexError(ref e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -406,6 +410,12 @@ impl From<url::ParseError> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<regex::Error> for Error {
|
||||
fn from(error: regex::Error) -> Self {
|
||||
Error::RegexError(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
|
||||
///
|
||||
/// `Default` is implemented for this enum and is `All`.
|
||||
|
@ -537,20 +547,23 @@ mod method_serde {
|
|||
pub type AllowedOrigins = AllOrSome<Origins>;
|
||||
|
||||
impl AllowedOrigins {
|
||||
/// Allows some _exact_ origins
|
||||
/// Allows some origins
|
||||
///
|
||||
/// Validation is not performed at this stage, but at a later stage.
|
||||
#[deprecated(since = "0.5.0", note = "use `some_exact` instead")]
|
||||
pub fn some(urls: &[&str]) -> Self {
|
||||
Self::some_exact(urls)
|
||||
pub fn some<S1: AsRef<str>, S2: AsRef<str>>(exact: &[S1], regex: &[S2]) -> Self {
|
||||
AllOrSome::Some(Origins {
|
||||
exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()),
|
||||
regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Allows some _exact_ origins
|
||||
///
|
||||
/// Validation is not performed at this stage, but at a later stage.
|
||||
pub fn some_exact<S: AsRef<str>>(urls: &[S]) -> Self {
|
||||
pub fn some_exact<S: AsRef<str>>(exact: &[S]) -> Self {
|
||||
AllOrSome::Some(Origins {
|
||||
exact: Some(urls.iter().map(|s| s.as_ref().to_string()).collect()),
|
||||
exact: Some(exact.iter().map(|s| s.as_ref().to_string()).collect()),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
@ -615,10 +628,11 @@ pub struct Origins {
|
|||
}
|
||||
|
||||
/// Parsed set of configured allowed origins
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ParsedAllowedOrigins {
|
||||
pub allow_null: bool,
|
||||
pub exact: HashSet<url::Origin>,
|
||||
pub regex: Option<RegexSet>,
|
||||
}
|
||||
|
||||
impl ParsedAllowedOrigins {
|
||||
|
@ -638,11 +652,42 @@ impl ParsedAllowedOrigins {
|
|||
}
|
||||
})?;
|
||||
|
||||
let regex = match &origins.regex {
|
||||
None => None,
|
||||
Some(ref regex) => Some(RegexSet::new(regex)?),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
allow_null: origins.allow_null,
|
||||
exact,
|
||||
regex,
|
||||
})
|
||||
}
|
||||
|
||||
fn verify(&self, origin: &Origin) -> bool {
|
||||
info_!("Verifying origin: {}", origin);
|
||||
match origin {
|
||||
Origin::Null => {
|
||||
info_!("Origin is null. Allowing? {}", self.allow_null);
|
||||
self.allow_null
|
||||
}
|
||||
Origin::Parsed(ref parsed) => {
|
||||
// Verify by exact, then regex
|
||||
if self.exact.get(parsed).is_some() {
|
||||
info_!("Origin has an exact match");
|
||||
return true;
|
||||
}
|
||||
if let Some(regex_set) = &self.regex {
|
||||
let regex_match = regex_set.is_match(&parsed.ascii_serialization());
|
||||
info_!("Origin has a regex match? {}", regex_match);
|
||||
return regex_match;
|
||||
}
|
||||
|
||||
info!("Origin does not match anything");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A list of allowed methods
|
||||
|
@ -930,7 +975,7 @@ impl CorsOptions {
|
|||
/// documentation at the [crate root](index.html) for usage information.
|
||||
///
|
||||
/// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`].
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Cors {
|
||||
pub(crate) allowed_origins: AllOrSome<ParsedAllowedOrigins>,
|
||||
pub(crate) allowed_methods: AllowedMethods,
|
||||
|
@ -1429,24 +1474,13 @@ fn validate_origin(
|
|||
match *allowed_origins {
|
||||
// Always matching is acceptable since the list of origins can be unbounded.
|
||||
AllOrSome::All => Ok(()),
|
||||
// AllOrSome::Some(ref allowed_origins) => allowed_origins
|
||||
// .get(origin)
|
||||
// .and_then(|_| Some(()))
|
||||
// .ok_or_else(|| Error::OriginNotAllowed(origin.clone())),
|
||||
AllOrSome::Some(ref allowed_origins) => match origin {
|
||||
Origin::Null => {
|
||||
if allowed_origins.allow_null {
|
||||
AllOrSome::Some(ref allowed_origins) => {
|
||||
if allowed_origins.verify(origin) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::OriginNotAllowed(origin.to_string()))
|
||||
}
|
||||
}
|
||||
Origin::Parsed(ref parsed) => allowed_origins
|
||||
.exact
|
||||
.get(parsed)
|
||||
.and_then(|_| Some(()))
|
||||
.ok_or_else(|| Error::OriginNotAllowed(origin.to_string())),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1944,6 +1978,33 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_origin_validates_regex() {
|
||||
let url = "https://www.example-something.com";
|
||||
let origin = not_err!(to_parsed_origin(&url));
|
||||
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_regex(&[
|
||||
"^https://www.example-[A-z0-9]+.com$"
|
||||
])));
|
||||
|
||||
not_err!(validate_origin(&origin, &allowed_origins));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_origin_validates_mixed_settings() {
|
||||
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(
|
||||
&["https://www.acme.com"],
|
||||
&["^https://www.example-[A-z0-9]+.com$"]
|
||||
)));
|
||||
|
||||
let url = "https://www.example-something123.com";
|
||||
let origin = not_err!(to_parsed_origin(&url));
|
||||
not_err!(validate_origin(&origin, &allowed_origins));
|
||||
|
||||
let url = "https://www.acme.com";
|
||||
let origin = not_err!(to_parsed_origin(&url));
|
||||
not_err!(validate_origin(&origin, &allowed_origins));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "OriginNotAllowed")]
|
||||
fn validate_origin_rejects_invalid_origin() {
|
||||
|
|
Loading…
Reference in New Issue