diff --git a/src/lib.rs b/src/lib.rs index 3e48528..8ec8530 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -237,6 +237,21 @@ impl Default for AllOrSome { } } +impl AllOrSome { + /// Returns whether this is an `All` variant + pub fn is_all(&self) -> bool { + match *self { + AllOrSome::All => true, + AllOrSome::Some(_) => false, + } + } + + /// Returns whether this is a `Some` variant + pub fn is_some(&self) -> bool { + !self.is_all() + } +} + impl AllOrSome> { /// New `AllOrSome` from a list of URL strings. /// Returns a tuple where the first element is the struct `AllOrSome`, @@ -400,6 +415,17 @@ impl Cors { ].into_iter() .collect() } + + /// Validates if any of the settings are disallowed or incorrect + /// + /// This is run during initial Fairing attachment + pub fn validate(&self) -> Result<(), Error> { + if self.allowed_origins.is_all() && self.send_wildcard && self.allow_credentials { + Err(Error::CredentialsWithWildcardOrigin)?; + } + + Ok(()) + } } /// A CORS [Responder](https://rocket.rs/guide/responses/#responder) @@ -908,6 +934,41 @@ mod tests { use rocket::http::Method; use super::*; + fn make_cors_options() -> Cors { + let (allowed_origins, failed_origins) = + AllOrSome::new_from_str_list(&["https://www.acme.com"]); + assert!(failed_origins.is_empty()); + + Cors { + allowed_origins: allowed_origins, + allowed_methods: [Method::Get].iter().cloned().collect(), + allowed_headers: AllOrSome::Some( + ["Authorization"] + .into_iter() + .map(|s| s.to_string().into()) + .collect(), + ), + allow_credentials: true, + ..Default::default() + } + } + + #[test] + fn cors_is_validated() { + assert!(make_cors_options().validate().is_ok()) + } + + #[test] + #[should_panic(expected = "CredentialsWithWildcardOrigin")] + fn cors_validates_illegal_allow_credentials() { + let mut cors = make_cors_options(); + cors.allow_credentials = true; + cors.allowed_origins = AllOrSome::All; + cors.send_wildcard = true; + + cors.validate().unwrap(); + } + // The following tests check `Response`'s validation #[test] diff --git a/tests/routes.rs b/tests/routes.rs index dc421f9..57c4bea 100644 --- a/tests/routes.rs +++ b/tests/routes.rs @@ -1,4 +1,4 @@ -//! This crate tests using rocket_cors using the "classic" per-route handling +//! This crate tests using rocket_cors using the "classic" ad-hoc per-route handling #![feature(plugin, custom_derive)] #![plugin(rocket_codegen)]