diff --git a/examples/fairing.rs b/examples/fairing.rs index 81ebf51..e666270 100644 --- a/examples/fairing.rs +++ b/examples/fairing.rs @@ -4,7 +4,7 @@ extern crate rocket; extern crate rocket_cors; use rocket::http::Method; -use rocket_cors::AllOrSome; +use rocket_cors::{AllowedOrigins, AllowedHeaders}; #[get("/")] fn cors<'a>() -> &'a str { @@ -12,19 +12,14 @@ fn cors<'a>() -> &'a str { } fn main() { - let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); // You can also deserialize this let options = rocket_cors::Cors { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - allowed_headers: AllOrSome::Some( - ["Authorization", "Accept"] - .into_iter() - .map(|s| s.to_string().into()) - .collect(), - ), + allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, ..Default::default() }; diff --git a/examples/guard.rs b/examples/guard.rs index 4a27455..3b7c7b3 100644 --- a/examples/guard.rs +++ b/examples/guard.rs @@ -7,7 +7,7 @@ use std::io::Cursor; use rocket::Response; use rocket::http::Method; -use rocket_cors::{Guard, AllOrSome, Responder}; +use rocket_cors::{Guard, AllowedOrigins, AllowedHeaders, Responder}; /// Using a `Responder` -- the usual way you would use this #[get("/")] @@ -39,19 +39,14 @@ fn response_options(cors: Guard) -> Response { } fn main() { - let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); // You can also deserialize this let options = rocket_cors::Cors { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - allowed_headers: AllOrSome::Some( - ["Authorization", "Accept"] - .into_iter() - .map(|s| s.to_string().into()) - .collect(), - ), + allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, ..Default::default() }; diff --git a/examples/json.rs b/examples/json.rs new file mode 100644 index 0000000..b4086cb --- /dev/null +++ b/examples/json.rs @@ -0,0 +1,38 @@ +//! This example is to demonstrate the JSON serialization and deserialization of the Cors settings +extern crate rocket; +extern crate rocket_cors as cors; +extern crate serde_json; + +use rocket::http::Method; +use cors::{Cors, AllowedOrigins, AllowedHeaders}; + +fn main() { + // The default demonstrates the "All" serialization of several of the settings + let default: Cors = Default::default(); + + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); + assert!(failed_origins.is_empty()); + + let options = cors::Cors { + allowed_origins: allowed_origins, + allowed_methods: vec![Method::Get, Method::Post, Method::Delete] + .into_iter() + .map(From::from) + .collect(), + allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), + allow_credentials: true, + expose_headers: ["Content-Type", "X-Custom"] + .iter() + .map(ToString::to_string) + .collect(), + max_age: Some(42), + send_wildcard: false, + fairing_route_base: "/mycors".to_string(), + }; + + println!("Default settings"); + println!("{}", serde_json::to_string_pretty(&default).unwrap()); + + println!("Defined settings"); + println!("{}", serde_json::to_string_pretty(&options).unwrap()); +} diff --git a/src/fairing.rs b/src/fairing.rs index 4510581..7c961b1 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -161,24 +161,18 @@ mod tests { use rocket::http::{Method, Status}; use rocket::local::Client; - use {Cors, AllOrSome}; + use {Cors, AllOrSome, AllowedOrigins, AllowedHeaders}; const CORS_ROOT: &'static str = "/my_cors"; fn make_cors_options() -> Cors { - let (allowed_origins, failed_origins) = - AllOrSome::new_from_str_list(&["https://www.acme.com"]); + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); Cors { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - allowed_headers: AllOrSome::Some( - ["Authorization"] - .into_iter() - .map(|s| s.to_string().into()) - .collect(), - ), + allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, fairing_route_base: CORS_ROOT.to_string(), diff --git a/src/lib.rs b/src/lib.rs index 2df1698..9d86b88 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,7 +96,7 @@ //! extern crate rocket_cors; //! //! use rocket::http::Method; -//! use rocket_cors::AllOrSome; +//! use rocket_cors::{AllowedOrigins, AllowedHeaders}; //! //! #[get("/")] //! fn cors<'a>() -> &'a str { @@ -104,19 +104,14 @@ //! } //! //! fn main() { -//! let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); +//! let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); //! assert!(failed_origins.is_empty()); //! //! // You can also deserialize this //! let options = rocket_cors::Cors { //! allowed_origins: allowed_origins, //! allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), -//! allowed_headers: AllOrSome::Some( -//! ["Authorization", "Accept"] -//! .into_iter() -//! .map(|s| s.to_string().into()) -//! .collect(), -//! ), +//! allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), //! allow_credentials: true, //! ..Default::default() //! }; @@ -164,7 +159,7 @@ //! //! use rocket::Response; //! use rocket::http::Method; -//! use rocket_cors::{Guard, AllOrSome, Responder}; +//! use rocket_cors::{Guard, AllowedOrigins, AllowedHeaders, Responder}; //! //! /// Using a `Responder` -- the usual way you would use this //! #[get("/")] @@ -196,19 +191,14 @@ //! } //! //! fn main() { -//! let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); +//! let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); //! assert!(failed_origins.is_empty()); //! //! // You can also deserialize this //! let options = rocket_cors::Cors { //! allowed_origins: allowed_origins, //! allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), -//! allowed_headers: AllOrSome::Some( -//! ["Authorization", "Accept"] -//! .into_iter() -//! .map(|s| s.to_string().into()) -//! .collect(), -//! ), +//! allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), //! allow_credentials: true, //! ..Default::default() //! }; @@ -442,8 +432,10 @@ impl<'r> response::Responder<'r> for Error { /// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// /// `Default` is implemented for this enum and is `All`. +/// +/// This enum is serialized and deserialized +/// ["Externally tagged"](https://serde.rs/enum-representations.html) #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -#[serde(untagged)] pub enum AllOrSome { /// Everything is allowed. Usually equivalent to the "*" value. All, @@ -473,6 +465,7 @@ impl AllOrSome { } impl AllOrSome> { + #[deprecated(since = "0.1.3", note = "please use `AllowedOrigins::Some` instead")] /// New `AllOrSome` from a list of URL strings. /// Returns a tuple where the first element is the struct `AllOrSome`, /// and the second element @@ -565,6 +558,85 @@ impl<'de> Deserialize<'de> for Method { } } +/// A list of allowed origins. Either Some origins are allowed, or all origins are allowed. +/// +/// # Examples +/// ```rust +/// use rocket_cors::AllowedOrigins; +/// +/// let all_origins = AllowedOrigins::all(); +/// let (some_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); +/// assert!(failed_origins.is_empty()); +/// ``` +pub type AllowedOrigins = AllOrSome>; + +impl AllowedOrigins { + /// Allows some origins + /// + /// Returns a tuple where the first element is the struct `AllowedOrigins`, + /// and the second element + /// is a map of strings which failed to parse into URLs and their associated parse errors. + pub fn some(urls: &[&str]) -> (Self, HashMap) { + let (ok_set, error_map): (Vec<_>, Vec<_>) = urls.iter() + .map(|s| (s.to_string(), Url::from_str(s))) + .partition(|&(_, ref r)| r.is_ok()); + + let error_map = error_map + .into_iter() + .map(|(s, r)| (s.to_string(), r.unwrap_err())) + .collect(); + + let ok_set = ok_set.into_iter().map(|(_, r)| r.unwrap()).collect(); + + (AllOrSome::Some(ok_set), error_map) + } + + /// Allows all origins + pub fn all() -> Self { + AllOrSome::All + } +} + +/// A list of allowed methods +/// +/// The [list](https://api.rocket.rs/rocket/http/enum.Method.html) +/// of methods is whatever is supported by Rocket. +/// +/// # Example +/// ```rust +/// use std::str::FromStr; +/// use rocket_cors::AllowedMethods; +/// +/// let allowed_methods: AllowedMethods = ["Get", "Post", "Delete"] +/// .iter() +/// .map(|s| FromStr::from_str(s).unwrap()) +/// .collect(); +/// ``` +pub type AllowedMethods = HashSet; + +/// A list of allowed headers +/// +/// # Examples +/// ```rust +/// use rocket_cors::AllowedHeaders; +/// +/// let all_headers = AllowedHeaders::all(); +/// let some_headers = AllowedHeaders::some(&["Authorization", "Accept"]); +/// ``` +pub type AllowedHeaders = AllOrSome>; + +impl AllowedHeaders { + /// Allow some headers + pub fn some(headers: &[&str]) -> Self { + AllOrSome::Some(headers.iter().map(|s| s.to_string().into()).collect()) + } + + /// Allows all headers + pub fn all() -> Self { + AllOrSome::All + } +} + /// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS /// /// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. See the @@ -575,6 +647,70 @@ impl<'de> Deserialize<'de> for Method { /// /// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this /// struct. The default for each field is described in the docuementation for the field. +/// +/// # Examples +/// +/// You can run an example from the repository to demonstrate the JSON serialization with +/// `cargo run --example json`. +/// +/// ## Pure default +/// ```rust +/// let default = rocket_cors::Cors::default(); +/// ``` +/// +/// ## JSON Examples +/// ### Default +/// +/// ```json +/// { +/// "allowed_origins": "All", +/// "allowed_methods": [ +/// "POST", +/// "PATCH", +/// "PUT", +/// "DELETE", +/// "HEAD", +/// "OPTIONS", +/// "GET" +/// ], +/// "allowed_headers": "All", +/// "allow_credentials": false, +/// "expose_headers": [], +/// "max_age": null, +/// "send_wildcard": false, +/// "fairing_route_base": "/cors" +/// } +/// ``` +/// ### Defined +/// ```json +/// { +/// "allowed_origins": { +/// "Some": [ +/// "https://www.acme.com/" +/// ] +/// }, +/// "allowed_methods": [ +/// "POST", +/// "DELETE", +/// "GET" +/// ], +/// "allowed_headers": { +/// "Some": [ +/// "Accept", +/// "Authorization" +/// ] +/// }, +/// "allow_credentials": true, +/// "expose_headers": [ +/// "Content-Type", +/// "X-Custom" +/// ], +/// "max_age": 42, +/// "send_wildcard": false, +/// "fairing_route_base": "/mycors" +/// } +/// +/// ``` #[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)] pub struct Cors { /// Origins that are allowed to make requests. @@ -594,7 +730,7 @@ pub struct Cors { /// /// ``` #[serde(default)] - pub allowed_origins: AllOrSome>, + pub allowed_origins: AllowedOrigins, /// The list of methods which the allowed origins are allowed to access for /// non-simple requests. /// @@ -603,7 +739,7 @@ pub struct Cors { /// /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` #[serde(default = "Cors::default_allowed_methods")] - pub allowed_methods: HashSet, + pub allowed_methods: AllowedMethods, /// The list of header field names which can be used when this resource is accessed by allowed /// origins. /// @@ -1348,8 +1484,7 @@ mod tests { use http::Method; fn make_cors_options() -> Cors { - let (allowed_origins, failed_origins) = - AllOrSome::new_from_str_list(&["https://www.acme.com"]); + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); Cors { @@ -1358,12 +1493,7 @@ mod tests { .into_iter() .map(From::from) .collect(), - allowed_headers: AllOrSome::Some( - ["Authorization", "Accept"] - .into_iter() - .map(|s| s.to_string().into()) - .collect(), - ), + allowed_headers: AllowedHeaders::some(&[&"Authorization", "Accept"]), allow_credentials: true, expose_headers: ["Content-Type", "X-Custom"] .into_iter() @@ -1424,8 +1554,7 @@ mod tests { fn response_allows_origin() { let url = "https://www.example.com"; let origin = Origin::from_str(url).unwrap(); - let (allowed_origins, failed_origins) = - AllOrSome::new_from_str_list(&["https://www.example.com"]); + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); assert!(failed_origins.is_empty()); not_err!(validate_origin(&origin, &allowed_origins)); @@ -1436,8 +1565,7 @@ mod tests { fn response_rejects_invalid_origin() { let url = "https://www.acme.com"; let origin = Origin::from_str(url).unwrap(); - let (allowed_origins, failed_origins) = - AllOrSome::new_from_str_list(&["https://www.example.com"]); + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); assert!(failed_origins.is_empty()); validate_origin(&origin, &allowed_origins).unwrap(); diff --git a/tests/fairing.rs b/tests/fairing.rs index b2607c4..acdd3fe 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -24,18 +24,13 @@ fn panicking_route() { } fn make_cors_options() -> Cors { - let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); Cors { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - allowed_headers: AllOrSome::Some( - ["Authorization", "Accept"] - .into_iter() - .map(|s| s.to_string().into()) - .collect(), - ), + allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, ..Default::default() } diff --git a/tests/guard.rs b/tests/guard.rs index 45f4da0..d7377ab 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -60,19 +60,13 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo } fn make_cors_options() -> cors::Cors { - let (allowed_origins, failed_origins) = - cors::AllOrSome::new_from_str_list(&["https://www.acme.com"]); + let (allowed_origins, failed_origins) = cors::AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); cors::Cors { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - allowed_headers: cors::AllOrSome::Some( - ["Authorization", "Accept"] - .into_iter() - .map(|s| s.to_string().into()) - .collect(), - ), + allowed_headers: cors::AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, ..Default::default() }