Add validate to Cors

This commit is contained in:
Yong Wen Chua 2017-07-16 14:24:00 +08:00
parent 6f1a24e12d
commit ce4eaf84ff
2 changed files with 62 additions and 1 deletions

View File

@ -237,6 +237,21 @@ impl<T> Default for AllOrSome<T> {
}
}
impl<T> AllOrSome<T> {
/// Returns whether this is an `All` variant
pub fn is_all(&self) -> bool {
match *self {
AllOrSome::All => true,
AllOrSome::Some(_) => false,
}
}
/// Returns whether this is a `Some` variant
pub fn is_some(&self) -> bool {
!self.is_all()
}
}
impl AllOrSome<HashSet<Url>> {
/// New `AllOrSome` from a list of URL strings.
/// Returns a tuple where the first element is the struct `AllOrSome`,
@ -400,6 +415,17 @@ impl Cors {
].into_iter()
.collect()
}
/// Validates if any of the settings are disallowed or incorrect
///
/// This is run during initial Fairing attachment
pub fn validate(&self) -> Result<(), Error> {
if self.allowed_origins.is_all() && self.send_wildcard && self.allow_credentials {
Err(Error::CredentialsWithWildcardOrigin)?;
}
Ok(())
}
}
/// A CORS [Responder](https://rocket.rs/guide/responses/#responder)
@ -908,6 +934,41 @@ mod tests {
use rocket::http::Method;
use super::*;
fn make_cors_options() -> Cors {
let (allowed_origins, failed_origins) =
AllOrSome::new_from_str_list(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
Cors {
allowed_origins: allowed_origins,
allowed_methods: [Method::Get].iter().cloned().collect(),
allowed_headers: AllOrSome::Some(
["Authorization"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true,
..Default::default()
}
}
#[test]
fn cors_is_validated() {
assert!(make_cors_options().validate().is_ok())
}
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn cors_validates_illegal_allow_credentials() {
let mut cors = make_cors_options();
cors.allow_credentials = true;
cors.allowed_origins = AllOrSome::All;
cors.send_wildcard = true;
cors.validate().unwrap();
}
// The following tests check `Response`'s validation
#[test]

View File

@ -1,4 +1,4 @@
//! This crate tests using rocket_cors using the "classic" per-route handling
//! This crate tests using rocket_cors using the "classic" ad-hoc per-route handling
#![feature(plugin, custom_derive)]
#![plugin(rocket_codegen)]