Add Convenience typedefs and functions (#19)

* Add JSON documentation and convenience AllowedOrigin

* Add `AllowedHeaders`

* Add AllowedHeaders

* Fix tests
This commit is contained in:
Yong Wen Chua 2017-07-19 12:25:56 +08:00 committed by GitHub
parent fcd83e8fb5
commit 0a94dfe22a
7 changed files with 210 additions and 71 deletions

View File

@ -4,7 +4,7 @@ extern crate rocket;
extern crate rocket_cors; extern crate rocket_cors;
use rocket::http::Method; use rocket::http::Method;
use rocket_cors::AllOrSome; use rocket_cors::{AllowedOrigins, AllowedHeaders};
#[get("/")] #[get("/")]
fn cors<'a>() -> &'a str { fn cors<'a>() -> &'a str {
@ -12,19 +12,14 @@ fn cors<'a>() -> &'a str {
} }
fn main() { fn main() {
let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
let options = rocket_cors::Cors { let options = rocket_cors::Cors {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllOrSome::Some( allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
["Authorization", "Accept"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true, allow_credentials: true,
..Default::default() ..Default::default()
}; };

View File

@ -7,7 +7,7 @@ use std::io::Cursor;
use rocket::Response; use rocket::Response;
use rocket::http::Method; use rocket::http::Method;
use rocket_cors::{Guard, AllOrSome, Responder}; use rocket_cors::{Guard, AllowedOrigins, AllowedHeaders, Responder};
/// Using a `Responder` -- the usual way you would use this /// Using a `Responder` -- the usual way you would use this
#[get("/")] #[get("/")]
@ -39,19 +39,14 @@ fn response_options(cors: Guard) -> Response {
} }
fn main() { fn main() {
let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
let options = rocket_cors::Cors { let options = rocket_cors::Cors {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllOrSome::Some( allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
["Authorization", "Accept"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true, allow_credentials: true,
..Default::default() ..Default::default()
}; };

38
examples/json.rs Normal file
View File

@ -0,0 +1,38 @@
//! This example is to demonstrate the JSON serialization and deserialization of the Cors settings
extern crate rocket;
extern crate rocket_cors as cors;
extern crate serde_json;
use rocket::http::Method;
use cors::{Cors, AllowedOrigins, AllowedHeaders};
fn main() {
// The default demonstrates the "All" serialization of several of the settings
let default: Cors = Default::default();
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
let options = cors::Cors {
allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get, Method::Post, Method::Delete]
.into_iter()
.map(From::from)
.collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true,
expose_headers: ["Content-Type", "X-Custom"]
.iter()
.map(ToString::to_string)
.collect(),
max_age: Some(42),
send_wildcard: false,
fairing_route_base: "/mycors".to_string(),
};
println!("Default settings");
println!("{}", serde_json::to_string_pretty(&default).unwrap());
println!("Defined settings");
println!("{}", serde_json::to_string_pretty(&options).unwrap());
}

View File

@ -161,24 +161,18 @@ mod tests {
use rocket::http::{Method, Status}; use rocket::http::{Method, Status};
use rocket::local::Client; use rocket::local::Client;
use {Cors, AllOrSome}; use {Cors, AllOrSome, AllowedOrigins, AllowedHeaders};
const CORS_ROOT: &'static str = "/my_cors"; const CORS_ROOT: &'static str = "/my_cors";
fn make_cors_options() -> Cors { fn make_cors_options() -> Cors {
let (allowed_origins, failed_origins) = let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
AllOrSome::new_from_str_list(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
Cors { Cors {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllOrSome::Some( allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
["Authorization"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true, allow_credentials: true,
fairing_route_base: CORS_ROOT.to_string(), fairing_route_base: CORS_ROOT.to_string(),

View File

@ -96,7 +96,7 @@
//! extern crate rocket_cors; //! extern crate rocket_cors;
//! //!
//! use rocket::http::Method; //! use rocket::http::Method;
//! use rocket_cors::AllOrSome; //! use rocket_cors::{AllowedOrigins, AllowedHeaders};
//! //!
//! #[get("/")] //! #[get("/")]
//! fn cors<'a>() -> &'a str { //! fn cors<'a>() -> &'a str {
@ -104,19 +104,14 @@
//! } //! }
//! //!
//! fn main() { //! fn main() {
//! let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); //! let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
//! assert!(failed_origins.is_empty()); //! assert!(failed_origins.is_empty());
//! //!
//! // You can also deserialize this //! // You can also deserialize this
//! let options = rocket_cors::Cors { //! let options = rocket_cors::Cors {
//! allowed_origins: allowed_origins, //! allowed_origins: allowed_origins,
//! allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), //! allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
//! allowed_headers: AllOrSome::Some( //! allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
//! ["Authorization", "Accept"]
//! .into_iter()
//! .map(|s| s.to_string().into())
//! .collect(),
//! ),
//! allow_credentials: true, //! allow_credentials: true,
//! ..Default::default() //! ..Default::default()
//! }; //! };
@ -164,7 +159,7 @@
//! //!
//! use rocket::Response; //! use rocket::Response;
//! use rocket::http::Method; //! use rocket::http::Method;
//! use rocket_cors::{Guard, AllOrSome, Responder}; //! use rocket_cors::{Guard, AllowedOrigins, AllowedHeaders, Responder};
//! //!
//! /// Using a `Responder` -- the usual way you would use this //! /// Using a `Responder` -- the usual way you would use this
//! #[get("/")] //! #[get("/")]
@ -196,19 +191,14 @@
//! } //! }
//! //!
//! fn main() { //! fn main() {
//! let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); //! let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
//! assert!(failed_origins.is_empty()); //! assert!(failed_origins.is_empty());
//! //!
//! // You can also deserialize this //! // You can also deserialize this
//! let options = rocket_cors::Cors { //! let options = rocket_cors::Cors {
//! allowed_origins: allowed_origins, //! allowed_origins: allowed_origins,
//! allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), //! allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
//! allowed_headers: AllOrSome::Some( //! allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
//! ["Authorization", "Accept"]
//! .into_iter()
//! .map(|s| s.to_string().into())
//! .collect(),
//! ),
//! allow_credentials: true, //! allow_credentials: true,
//! ..Default::default() //! ..Default::default()
//! }; //! };
@ -442,8 +432,10 @@ impl<'r> response::Responder<'r> for Error {
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
/// ///
/// `Default` is implemented for this enum and is `All`. /// `Default` is implemented for this enum and is `All`.
///
/// This enum is serialized and deserialized
/// ["Externally tagged"](https://serde.rs/enum-representations.html)
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(untagged)]
pub enum AllOrSome<T> { pub enum AllOrSome<T> {
/// Everything is allowed. Usually equivalent to the "*" value. /// Everything is allowed. Usually equivalent to the "*" value.
All, All,
@ -473,6 +465,7 @@ impl<T> AllOrSome<T> {
} }
impl AllOrSome<HashSet<Url>> { impl AllOrSome<HashSet<Url>> {
#[deprecated(since = "0.1.3", note = "please use `AllowedOrigins::Some` instead")]
/// New `AllOrSome` from a list of URL strings. /// New `AllOrSome` from a list of URL strings.
/// Returns a tuple where the first element is the struct `AllOrSome`, /// Returns a tuple where the first element is the struct `AllOrSome`,
/// and the second element /// and the second element
@ -565,6 +558,85 @@ impl<'de> Deserialize<'de> for Method {
} }
} }
/// A list of allowed origins. Either Some origins are allowed, or all origins are allowed.
///
/// # Examples
/// ```rust
/// use rocket_cors::AllowedOrigins;
///
/// let all_origins = AllowedOrigins::all();
/// let (some_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
/// assert!(failed_origins.is_empty());
/// ```
pub type AllowedOrigins = AllOrSome<HashSet<Url>>;
impl AllowedOrigins {
/// Allows some origins
///
/// Returns a tuple where the first element is the struct `AllowedOrigins`,
/// and the second element
/// is a map of strings which failed to parse into URLs and their associated parse errors.
pub fn some(urls: &[&str]) -> (Self, HashMap<String, url::ParseError>) {
let (ok_set, error_map): (Vec<_>, Vec<_>) = urls.iter()
.map(|s| (s.to_string(), Url::from_str(s)))
.partition(|&(_, ref r)| r.is_ok());
let error_map = error_map
.into_iter()
.map(|(s, r)| (s.to_string(), r.unwrap_err()))
.collect();
let ok_set = ok_set.into_iter().map(|(_, r)| r.unwrap()).collect();
(AllOrSome::Some(ok_set), error_map)
}
/// Allows all origins
pub fn all() -> Self {
AllOrSome::All
}
}
/// A list of allowed methods
///
/// The [list](https://api.rocket.rs/rocket/http/enum.Method.html)
/// of methods is whatever is supported by Rocket.
///
/// # Example
/// ```rust
/// use std::str::FromStr;
/// use rocket_cors::AllowedMethods;
///
/// let allowed_methods: AllowedMethods = ["Get", "Post", "Delete"]
/// .iter()
/// .map(|s| FromStr::from_str(s).unwrap())
/// .collect();
/// ```
pub type AllowedMethods = HashSet<Method>;
/// A list of allowed headers
///
/// # Examples
/// ```rust
/// use rocket_cors::AllowedHeaders;
///
/// let all_headers = AllowedHeaders::all();
/// let some_headers = AllowedHeaders::some(&["Authorization", "Accept"]);
/// ```
pub type AllowedHeaders = AllOrSome<HashSet<HeaderFieldName>>;
impl AllowedHeaders {
/// Allow some headers
pub fn some(headers: &[&str]) -> Self {
AllOrSome::Some(headers.iter().map(|s| s.to_string().into()).collect())
}
/// Allows all headers
pub fn all() -> Self {
AllOrSome::All
}
}
/// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS /// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS
/// ///
/// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. See the /// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. See the
@ -575,6 +647,70 @@ impl<'de> Deserialize<'de> for Method {
/// ///
/// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this /// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this
/// struct. The default for each field is described in the docuementation for the field. /// struct. The default for each field is described in the docuementation for the field.
///
/// # Examples
///
/// You can run an example from the repository to demonstrate the JSON serialization with
/// `cargo run --example json`.
///
/// ## Pure default
/// ```rust
/// let default = rocket_cors::Cors::default();
/// ```
///
/// ## JSON Examples
/// ### Default
///
/// ```json
/// {
/// "allowed_origins": "All",
/// "allowed_methods": [
/// "POST",
/// "PATCH",
/// "PUT",
/// "DELETE",
/// "HEAD",
/// "OPTIONS",
/// "GET"
/// ],
/// "allowed_headers": "All",
/// "allow_credentials": false,
/// "expose_headers": [],
/// "max_age": null,
/// "send_wildcard": false,
/// "fairing_route_base": "/cors"
/// }
/// ```
/// ### Defined
/// ```json
/// {
/// "allowed_origins": {
/// "Some": [
/// "https://www.acme.com/"
/// ]
/// },
/// "allowed_methods": [
/// "POST",
/// "DELETE",
/// "GET"
/// ],
/// "allowed_headers": {
/// "Some": [
/// "Accept",
/// "Authorization"
/// ]
/// },
/// "allow_credentials": true,
/// "expose_headers": [
/// "Content-Type",
/// "X-Custom"
/// ],
/// "max_age": 42,
/// "send_wildcard": false,
/// "fairing_route_base": "/mycors"
/// }
///
/// ```
#[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)] #[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)]
pub struct Cors { pub struct Cors {
/// Origins that are allowed to make requests. /// Origins that are allowed to make requests.
@ -594,7 +730,7 @@ pub struct Cors {
/// ///
/// ``` /// ```
#[serde(default)] #[serde(default)]
pub allowed_origins: AllOrSome<HashSet<Url>>, pub allowed_origins: AllowedOrigins,
/// The list of methods which the allowed origins are allowed to access for /// The list of methods which the allowed origins are allowed to access for
/// non-simple requests. /// non-simple requests.
/// ///
@ -603,7 +739,7 @@ pub struct Cors {
/// ///
/// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
#[serde(default = "Cors::default_allowed_methods")] #[serde(default = "Cors::default_allowed_methods")]
pub allowed_methods: HashSet<Method>, pub allowed_methods: AllowedMethods,
/// The list of header field names which can be used when this resource is accessed by allowed /// The list of header field names which can be used when this resource is accessed by allowed
/// origins. /// origins.
/// ///
@ -1348,8 +1484,7 @@ mod tests {
use http::Method; use http::Method;
fn make_cors_options() -> Cors { fn make_cors_options() -> Cors {
let (allowed_origins, failed_origins) = let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
AllOrSome::new_from_str_list(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
Cors { Cors {
@ -1358,12 +1493,7 @@ mod tests {
.into_iter() .into_iter()
.map(From::from) .map(From::from)
.collect(), .collect(),
allowed_headers: AllOrSome::Some( allowed_headers: AllowedHeaders::some(&[&"Authorization", "Accept"]),
["Authorization", "Accept"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true, allow_credentials: true,
expose_headers: ["Content-Type", "X-Custom"] expose_headers: ["Content-Type", "X-Custom"]
.into_iter() .into_iter()
@ -1424,8 +1554,7 @@ mod tests {
fn response_allows_origin() { fn response_allows_origin() {
let url = "https://www.example.com"; let url = "https://www.example.com";
let origin = Origin::from_str(url).unwrap(); let origin = Origin::from_str(url).unwrap();
let (allowed_origins, failed_origins) = let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]);
AllOrSome::new_from_str_list(&["https://www.example.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
not_err!(validate_origin(&origin, &allowed_origins)); not_err!(validate_origin(&origin, &allowed_origins));
@ -1436,8 +1565,7 @@ mod tests {
fn response_rejects_invalid_origin() { fn response_rejects_invalid_origin() {
let url = "https://www.acme.com"; let url = "https://www.acme.com";
let origin = Origin::from_str(url).unwrap(); let origin = Origin::from_str(url).unwrap();
let (allowed_origins, failed_origins) = let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]);
AllOrSome::new_from_str_list(&["https://www.example.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
validate_origin(&origin, &allowed_origins).unwrap(); validate_origin(&origin, &allowed_origins).unwrap();

View File

@ -24,18 +24,13 @@ fn panicking_route() {
} }
fn make_cors_options() -> Cors { fn make_cors_options() -> Cors {
let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
Cors { Cors {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllOrSome::Some( allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
["Authorization", "Accept"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true, allow_credentials: true,
..Default::default() ..Default::default()
} }

View File

@ -60,19 +60,13 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo
} }
fn make_cors_options() -> cors::Cors { fn make_cors_options() -> cors::Cors {
let (allowed_origins, failed_origins) = let (allowed_origins, failed_origins) = cors::AllowedOrigins::some(&["https://www.acme.com"]);
cors::AllOrSome::new_from_str_list(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
cors::Cors { cors::Cors {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: cors::AllOrSome::Some( allowed_headers: cors::AllowedHeaders::some(&["Authorization", "Accept"]),
["Authorization", "Accept"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true, allow_credentials: true,
..Default::default() ..Default::default()
} }