Add Convenience typedefs and functions (#19)
* Add JSON documentation and convenience AllowedOrigin * Add `AllowedHeaders` * Add AllowedHeaders * Fix tests
This commit is contained in:
parent
fcd83e8fb5
commit
0a94dfe22a
|
@ -4,7 +4,7 @@ extern crate rocket;
|
|||
extern crate rocket_cors;
|
||||
|
||||
use rocket::http::Method;
|
||||
use rocket_cors::AllOrSome;
|
||||
use rocket_cors::{AllowedOrigins, AllowedHeaders};
|
||||
|
||||
#[get("/")]
|
||||
fn cors<'a>() -> &'a str {
|
||||
|
@ -12,19 +12,14 @@ fn cors<'a>() -> &'a str {
|
|||
}
|
||||
|
||||
fn main() {
|
||||
let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
|
||||
// You can also deserialize this
|
||||
let options = rocket_cors::Cors {
|
||||
allowed_origins: allowed_origins,
|
||||
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
||||
allowed_headers: AllOrSome::Some(
|
||||
["Authorization", "Accept"]
|
||||
.into_iter()
|
||||
.map(|s| s.to_string().into())
|
||||
.collect(),
|
||||
),
|
||||
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
|
||||
allow_credentials: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
|
|
@ -7,7 +7,7 @@ use std::io::Cursor;
|
|||
|
||||
use rocket::Response;
|
||||
use rocket::http::Method;
|
||||
use rocket_cors::{Guard, AllOrSome, Responder};
|
||||
use rocket_cors::{Guard, AllowedOrigins, AllowedHeaders, Responder};
|
||||
|
||||
/// Using a `Responder` -- the usual way you would use this
|
||||
#[get("/")]
|
||||
|
@ -39,19 +39,14 @@ fn response_options(cors: Guard) -> Response {
|
|||
}
|
||||
|
||||
fn main() {
|
||||
let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
|
||||
// You can also deserialize this
|
||||
let options = rocket_cors::Cors {
|
||||
allowed_origins: allowed_origins,
|
||||
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
||||
allowed_headers: AllOrSome::Some(
|
||||
["Authorization", "Accept"]
|
||||
.into_iter()
|
||||
.map(|s| s.to_string().into())
|
||||
.collect(),
|
||||
),
|
||||
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
|
||||
allow_credentials: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
//! This example is to demonstrate the JSON serialization and deserialization of the Cors settings
|
||||
extern crate rocket;
|
||||
extern crate rocket_cors as cors;
|
||||
extern crate serde_json;
|
||||
|
||||
use rocket::http::Method;
|
||||
use cors::{Cors, AllowedOrigins, AllowedHeaders};
|
||||
|
||||
fn main() {
|
||||
// The default demonstrates the "All" serialization of several of the settings
|
||||
let default: Cors = Default::default();
|
||||
|
||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
|
||||
let options = cors::Cors {
|
||||
allowed_origins: allowed_origins,
|
||||
allowed_methods: vec![Method::Get, Method::Post, Method::Delete]
|
||||
.into_iter()
|
||||
.map(From::from)
|
||||
.collect(),
|
||||
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
|
||||
allow_credentials: true,
|
||||
expose_headers: ["Content-Type", "X-Custom"]
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
.collect(),
|
||||
max_age: Some(42),
|
||||
send_wildcard: false,
|
||||
fairing_route_base: "/mycors".to_string(),
|
||||
};
|
||||
|
||||
println!("Default settings");
|
||||
println!("{}", serde_json::to_string_pretty(&default).unwrap());
|
||||
|
||||
println!("Defined settings");
|
||||
println!("{}", serde_json::to_string_pretty(&options).unwrap());
|
||||
}
|
|
@ -161,24 +161,18 @@ mod tests {
|
|||
use rocket::http::{Method, Status};
|
||||
use rocket::local::Client;
|
||||
|
||||
use {Cors, AllOrSome};
|
||||
use {Cors, AllOrSome, AllowedOrigins, AllowedHeaders};
|
||||
|
||||
const CORS_ROOT: &'static str = "/my_cors";
|
||||
|
||||
fn make_cors_options() -> Cors {
|
||||
let (allowed_origins, failed_origins) =
|
||||
AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
|
||||
Cors {
|
||||
allowed_origins: allowed_origins,
|
||||
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
||||
allowed_headers: AllOrSome::Some(
|
||||
["Authorization"]
|
||||
.into_iter()
|
||||
.map(|s| s.to_string().into())
|
||||
.collect(),
|
||||
),
|
||||
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
|
||||
allow_credentials: true,
|
||||
fairing_route_base: CORS_ROOT.to_string(),
|
||||
|
||||
|
|
190
src/lib.rs
190
src/lib.rs
|
@ -96,7 +96,7 @@
|
|||
//! extern crate rocket_cors;
|
||||
//!
|
||||
//! use rocket::http::Method;
|
||||
//! use rocket_cors::AllOrSome;
|
||||
//! use rocket_cors::{AllowedOrigins, AllowedHeaders};
|
||||
//!
|
||||
//! #[get("/")]
|
||||
//! fn cors<'a>() -> &'a str {
|
||||
|
@ -104,19 +104,14 @@
|
|||
//! }
|
||||
//!
|
||||
//! fn main() {
|
||||
//! let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
//! let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
//! assert!(failed_origins.is_empty());
|
||||
//!
|
||||
//! // You can also deserialize this
|
||||
//! let options = rocket_cors::Cors {
|
||||
//! allowed_origins: allowed_origins,
|
||||
//! allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
||||
//! allowed_headers: AllOrSome::Some(
|
||||
//! ["Authorization", "Accept"]
|
||||
//! .into_iter()
|
||||
//! .map(|s| s.to_string().into())
|
||||
//! .collect(),
|
||||
//! ),
|
||||
//! allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
|
||||
//! allow_credentials: true,
|
||||
//! ..Default::default()
|
||||
//! };
|
||||
|
@ -164,7 +159,7 @@
|
|||
//!
|
||||
//! use rocket::Response;
|
||||
//! use rocket::http::Method;
|
||||
//! use rocket_cors::{Guard, AllOrSome, Responder};
|
||||
//! use rocket_cors::{Guard, AllowedOrigins, AllowedHeaders, Responder};
|
||||
//!
|
||||
//! /// Using a `Responder` -- the usual way you would use this
|
||||
//! #[get("/")]
|
||||
|
@ -196,19 +191,14 @@
|
|||
//! }
|
||||
//!
|
||||
//! fn main() {
|
||||
//! let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
//! let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
//! assert!(failed_origins.is_empty());
|
||||
//!
|
||||
//! // You can also deserialize this
|
||||
//! let options = rocket_cors::Cors {
|
||||
//! allowed_origins: allowed_origins,
|
||||
//! allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
||||
//! allowed_headers: AllOrSome::Some(
|
||||
//! ["Authorization", "Accept"]
|
||||
//! .into_iter()
|
||||
//! .map(|s| s.to_string().into())
|
||||
//! .collect(),
|
||||
//! ),
|
||||
//! allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
|
||||
//! allow_credentials: true,
|
||||
//! ..Default::default()
|
||||
//! };
|
||||
|
@ -442,8 +432,10 @@ impl<'r> response::Responder<'r> for Error {
|
|||
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
|
||||
///
|
||||
/// `Default` is implemented for this enum and is `All`.
|
||||
///
|
||||
/// This enum is serialized and deserialized
|
||||
/// ["Externally tagged"](https://serde.rs/enum-representations.html)
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum AllOrSome<T> {
|
||||
/// Everything is allowed. Usually equivalent to the "*" value.
|
||||
All,
|
||||
|
@ -473,6 +465,7 @@ impl<T> AllOrSome<T> {
|
|||
}
|
||||
|
||||
impl AllOrSome<HashSet<Url>> {
|
||||
#[deprecated(since = "0.1.3", note = "please use `AllowedOrigins::Some` instead")]
|
||||
/// New `AllOrSome` from a list of URL strings.
|
||||
/// Returns a tuple where the first element is the struct `AllOrSome`,
|
||||
/// and the second element
|
||||
|
@ -565,6 +558,85 @@ impl<'de> Deserialize<'de> for Method {
|
|||
}
|
||||
}
|
||||
|
||||
/// A list of allowed origins. Either Some origins are allowed, or all origins are allowed.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```rust
|
||||
/// use rocket_cors::AllowedOrigins;
|
||||
///
|
||||
/// let all_origins = AllowedOrigins::all();
|
||||
/// let (some_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
/// assert!(failed_origins.is_empty());
|
||||
/// ```
|
||||
pub type AllowedOrigins = AllOrSome<HashSet<Url>>;
|
||||
|
||||
impl AllowedOrigins {
|
||||
/// Allows some origins
|
||||
///
|
||||
/// Returns a tuple where the first element is the struct `AllowedOrigins`,
|
||||
/// and the second element
|
||||
/// is a map of strings which failed to parse into URLs and their associated parse errors.
|
||||
pub fn some(urls: &[&str]) -> (Self, HashMap<String, url::ParseError>) {
|
||||
let (ok_set, error_map): (Vec<_>, Vec<_>) = urls.iter()
|
||||
.map(|s| (s.to_string(), Url::from_str(s)))
|
||||
.partition(|&(_, ref r)| r.is_ok());
|
||||
|
||||
let error_map = error_map
|
||||
.into_iter()
|
||||
.map(|(s, r)| (s.to_string(), r.unwrap_err()))
|
||||
.collect();
|
||||
|
||||
let ok_set = ok_set.into_iter().map(|(_, r)| r.unwrap()).collect();
|
||||
|
||||
(AllOrSome::Some(ok_set), error_map)
|
||||
}
|
||||
|
||||
/// Allows all origins
|
||||
pub fn all() -> Self {
|
||||
AllOrSome::All
|
||||
}
|
||||
}
|
||||
|
||||
/// A list of allowed methods
|
||||
///
|
||||
/// The [list](https://api.rocket.rs/rocket/http/enum.Method.html)
|
||||
/// of methods is whatever is supported by Rocket.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use std::str::FromStr;
|
||||
/// use rocket_cors::AllowedMethods;
|
||||
///
|
||||
/// let allowed_methods: AllowedMethods = ["Get", "Post", "Delete"]
|
||||
/// .iter()
|
||||
/// .map(|s| FromStr::from_str(s).unwrap())
|
||||
/// .collect();
|
||||
/// ```
|
||||
pub type AllowedMethods = HashSet<Method>;
|
||||
|
||||
/// A list of allowed headers
|
||||
///
|
||||
/// # Examples
|
||||
/// ```rust
|
||||
/// use rocket_cors::AllowedHeaders;
|
||||
///
|
||||
/// let all_headers = AllowedHeaders::all();
|
||||
/// let some_headers = AllowedHeaders::some(&["Authorization", "Accept"]);
|
||||
/// ```
|
||||
pub type AllowedHeaders = AllOrSome<HashSet<HeaderFieldName>>;
|
||||
|
||||
impl AllowedHeaders {
|
||||
/// Allow some headers
|
||||
pub fn some(headers: &[&str]) -> Self {
|
||||
AllOrSome::Some(headers.iter().map(|s| s.to_string().into()).collect())
|
||||
}
|
||||
|
||||
/// Allows all headers
|
||||
pub fn all() -> Self {
|
||||
AllOrSome::All
|
||||
}
|
||||
}
|
||||
|
||||
/// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS
|
||||
///
|
||||
/// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. See the
|
||||
|
@ -575,6 +647,70 @@ impl<'de> Deserialize<'de> for Method {
|
|||
///
|
||||
/// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this
|
||||
/// struct. The default for each field is described in the docuementation for the field.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// You can run an example from the repository to demonstrate the JSON serialization with
|
||||
/// `cargo run --example json`.
|
||||
///
|
||||
/// ## Pure default
|
||||
/// ```rust
|
||||
/// let default = rocket_cors::Cors::default();
|
||||
/// ```
|
||||
///
|
||||
/// ## JSON Examples
|
||||
/// ### Default
|
||||
///
|
||||
/// ```json
|
||||
/// {
|
||||
/// "allowed_origins": "All",
|
||||
/// "allowed_methods": [
|
||||
/// "POST",
|
||||
/// "PATCH",
|
||||
/// "PUT",
|
||||
/// "DELETE",
|
||||
/// "HEAD",
|
||||
/// "OPTIONS",
|
||||
/// "GET"
|
||||
/// ],
|
||||
/// "allowed_headers": "All",
|
||||
/// "allow_credentials": false,
|
||||
/// "expose_headers": [],
|
||||
/// "max_age": null,
|
||||
/// "send_wildcard": false,
|
||||
/// "fairing_route_base": "/cors"
|
||||
/// }
|
||||
/// ```
|
||||
/// ### Defined
|
||||
/// ```json
|
||||
/// {
|
||||
/// "allowed_origins": {
|
||||
/// "Some": [
|
||||
/// "https://www.acme.com/"
|
||||
/// ]
|
||||
/// },
|
||||
/// "allowed_methods": [
|
||||
/// "POST",
|
||||
/// "DELETE",
|
||||
/// "GET"
|
||||
/// ],
|
||||
/// "allowed_headers": {
|
||||
/// "Some": [
|
||||
/// "Accept",
|
||||
/// "Authorization"
|
||||
/// ]
|
||||
/// },
|
||||
/// "allow_credentials": true,
|
||||
/// "expose_headers": [
|
||||
/// "Content-Type",
|
||||
/// "X-Custom"
|
||||
/// ],
|
||||
/// "max_age": 42,
|
||||
/// "send_wildcard": false,
|
||||
/// "fairing_route_base": "/mycors"
|
||||
/// }
|
||||
///
|
||||
/// ```
|
||||
#[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)]
|
||||
pub struct Cors {
|
||||
/// Origins that are allowed to make requests.
|
||||
|
@ -594,7 +730,7 @@ pub struct Cors {
|
|||
///
|
||||
/// ```
|
||||
#[serde(default)]
|
||||
pub allowed_origins: AllOrSome<HashSet<Url>>,
|
||||
pub allowed_origins: AllowedOrigins,
|
||||
/// The list of methods which the allowed origins are allowed to access for
|
||||
/// non-simple requests.
|
||||
///
|
||||
|
@ -603,7 +739,7 @@ pub struct Cors {
|
|||
///
|
||||
/// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
|
||||
#[serde(default = "Cors::default_allowed_methods")]
|
||||
pub allowed_methods: HashSet<Method>,
|
||||
pub allowed_methods: AllowedMethods,
|
||||
/// The list of header field names which can be used when this resource is accessed by allowed
|
||||
/// origins.
|
||||
///
|
||||
|
@ -1348,8 +1484,7 @@ mod tests {
|
|||
use http::Method;
|
||||
|
||||
fn make_cors_options() -> Cors {
|
||||
let (allowed_origins, failed_origins) =
|
||||
AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
|
||||
Cors {
|
||||
|
@ -1358,12 +1493,7 @@ mod tests {
|
|||
.into_iter()
|
||||
.map(From::from)
|
||||
.collect(),
|
||||
allowed_headers: AllOrSome::Some(
|
||||
["Authorization", "Accept"]
|
||||
.into_iter()
|
||||
.map(|s| s.to_string().into())
|
||||
.collect(),
|
||||
),
|
||||
allowed_headers: AllowedHeaders::some(&[&"Authorization", "Accept"]),
|
||||
allow_credentials: true,
|
||||
expose_headers: ["Content-Type", "X-Custom"]
|
||||
.into_iter()
|
||||
|
@ -1424,8 +1554,7 @@ mod tests {
|
|||
fn response_allows_origin() {
|
||||
let url = "https://www.example.com";
|
||||
let origin = Origin::from_str(url).unwrap();
|
||||
let (allowed_origins, failed_origins) =
|
||||
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
|
||||
not_err!(validate_origin(&origin, &allowed_origins));
|
||||
|
@ -1436,8 +1565,7 @@ mod tests {
|
|||
fn response_rejects_invalid_origin() {
|
||||
let url = "https://www.acme.com";
|
||||
let origin = Origin::from_str(url).unwrap();
|
||||
let (allowed_origins, failed_origins) =
|
||||
AllOrSome::new_from_str_list(&["https://www.example.com"]);
|
||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
|
||||
validate_origin(&origin, &allowed_origins).unwrap();
|
||||
|
|
|
@ -24,18 +24,13 @@ fn panicking_route() {
|
|||
}
|
||||
|
||||
fn make_cors_options() -> Cors {
|
||||
let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
|
||||
Cors {
|
||||
allowed_origins: allowed_origins,
|
||||
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
||||
allowed_headers: AllOrSome::Some(
|
||||
["Authorization", "Accept"]
|
||||
.into_iter()
|
||||
.map(|s| s.to_string().into())
|
||||
.collect(),
|
||||
),
|
||||
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
|
||||
allow_credentials: true,
|
||||
..Default::default()
|
||||
}
|
||||
|
|
|
@ -60,19 +60,13 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo
|
|||
}
|
||||
|
||||
fn make_cors_options() -> cors::Cors {
|
||||
let (allowed_origins, failed_origins) =
|
||||
cors::AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
let (allowed_origins, failed_origins) = cors::AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
|
||||
cors::Cors {
|
||||
allowed_origins: allowed_origins,
|
||||
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
||||
allowed_headers: cors::AllOrSome::Some(
|
||||
["Authorization", "Accept"]
|
||||
.into_iter()
|
||||
.map(|s| s.to_string().into())
|
||||
.collect(),
|
||||
),
|
||||
allowed_headers: cors::AllowedHeaders::some(&["Authorization", "Accept"]),
|
||||
allow_credentials: true,
|
||||
..Default::default()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue