diff --git a/examples/fairing.rs b/examples/fairing.rs index c640078..c82a89e 100644 --- a/examples/fairing.rs +++ b/examples/fairing.rs @@ -4,28 +4,31 @@ use rocket_cors; use rocket::http::Method; use rocket::{get, routes}; -use rocket_cors::{AllowedHeaders, AllowedOrigins}; +use rocket_cors::{AllowedHeaders, AllowedOrigins, Error}; #[get("/")] fn cors<'a>() -> &'a str { "Hello CORS" } -fn main() { +fn main() -> Result<(), Error> { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); // You can also deserialize this - let options = rocket_cors::Cors { + let cors = rocket_cors::CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, ..Default::default() - }; + } + .to_cors()?; rocket::ignite() .mount("/", routes![cors]) - .attach(options) + .attach(cors) .launch(); + + Ok(()) } diff --git a/examples/guard.rs b/examples/guard.rs index 3971e64..20d8705 100644 --- a/examples/guard.rs +++ b/examples/guard.rs @@ -7,7 +7,7 @@ use std::io::Cursor; use rocket::http::Method; use rocket::Response; use rocket::{get, options, routes}; -use rocket_cors::{AllowedHeaders, AllowedOrigins, Guard, Responder}; +use rocket_cors::{AllowedHeaders, AllowedOrigins, Error, Guard, Responder}; /// Using a `Responder` -- the usual way you would use this #[get("/")] @@ -35,18 +35,19 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> { cors.responder("Manual OPTIONS preflight handling") } -fn main() { +fn main() -> Result<(), Error> { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); // You can also deserialize this - let options = rocket_cors::Cors { + let cors = rocket_cors::CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, ..Default::default() - }; + } + .to_cors()?; rocket::ignite() .mount("/", routes![responder, response]) @@ -54,6 +55,8 @@ fn main() { .mount("/", rocket_cors::catch_all_options_routes()) // You can also manually mount an OPTIONS route that will be used instead .mount("/", routes![manual, manual_options]) - .manage(options) + .manage(cors) .launch(); + + Ok(()) } diff --git a/examples/json.rs b/examples/json.rs index 08131ca..e62be45 100644 --- a/examples/json.rs +++ b/examples/json.rs @@ -6,17 +6,17 @@ use rocket_cors as cors; use serde_json; -use crate::cors::{AllowedHeaders, AllowedOrigins, Cors}; +use crate::cors::{AllowedHeaders, AllowedOrigins, CorsOptions}; use rocket::http::Method; fn main() { // The default demonstrates the "All" serialization of several of the settings - let default: Cors = Default::default(); + let default: CorsOptions = Default::default(); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); - let options = cors::Cors { + let options = cors::CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get, Method::Post, Method::Delete] .into_iter() diff --git a/examples/manual.rs b/examples/manual.rs index 5b4bc05..3bac6df 100644 --- a/examples/manual.rs +++ b/examples/manual.rs @@ -7,9 +7,13 @@ use std::io::Cursor; use rocket::http::Method; use rocket::response::Responder; use rocket::{get, options, routes, Response, State}; -use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors}; +use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions}; /// Using a borrowed Cors +/// +/// You might want to borrow the `Cors` struct from Rocket's state, for example. Unless you have +/// special handling, you might want to use the Guard method instead which has less hassle. +/// /// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime /// `'r` and so does `Responder`! #[get("/")] @@ -34,9 +38,13 @@ fn response(options: State<'_, Cors>) -> impl Responder<'_> { /// Create and use an ad-hoc Cors /// Note that the `'r` lifetime is needed because the compiler cannot elide anything. +/// +/// This is the most likely scenario when you want to have manual CORS validation. You can use this +/// when the settings you want to use for a route is not the same as the rest of the application +/// (which you might have put in Rocket's state). #[get("/owned")] fn owned<'r>() -> impl Responder<'r> { - let options = cors_options(); + let options = cors_options().to_cors()?; options.respond_owned(|guard| guard.responder("Hello CORS")) } @@ -46,16 +54,16 @@ fn owned<'r>() -> impl Responder<'r> { /// Note that the `'r` lifetime is needed because the compiler cannot elide anything. #[options("/owned")] fn owned_options<'r>() -> impl Responder<'r> { - let options = cors_options(); + let options = cors_options().to_cors()?; options.respond_owned(|guard| guard.responder(())) } -fn cors_options() -> Cors { +fn cors_options() -> CorsOptions { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); // You can also deserialize this - rocket_cors::Cors { + rocket_cors::CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), @@ -68,6 +76,6 @@ fn main() { rocket::ignite() .mount("/", routes![borrowed, response, owned, owned_options,]) .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes - .manage(cors_options()) + .manage(cors_options().to_cors().expect("To not fail")) .launch(); } diff --git a/examples/mix.rs b/examples/mix.rs index 0ba1815..217c61d 100644 --- a/examples/mix.rs +++ b/examples/mix.rs @@ -10,7 +10,7 @@ use rocket_cors; use rocket::http::Method; use rocket::response::Responder; use rocket::{get, options, routes}; -use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors, Guard}; +use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard}; /// The "usual" app route #[get("/")] @@ -21,8 +21,8 @@ fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, &str> { /// The special "ping" route #[get("/ping")] fn ping<'r>() -> impl Responder<'r> { - let options = cors_options_all(); - options.respond_owned(|guard| guard.responder("Pong!")) + let cors = cors_options_all().to_cors()?; + cors.respond_owned(|guard| guard.responder("Pong!")) } /// You need to define an OPTIONS route for preflight checks if you want to use `Cors` struct @@ -30,17 +30,17 @@ fn ping<'r>() -> impl Responder<'r> { /// These routes can just return the unit type `()` #[options("/ping")] fn ping_options<'r>() -> impl Responder<'r> { - let options = cors_options_all(); - options.respond_owned(|guard| guard.responder(())) + let cors = cors_options_all().to_cors()?; + cors.respond_owned(|guard| guard.responder(())) } /// Returns the "application wide" Cors struct -fn cors_options() -> Cors { +fn cors_options() -> CorsOptions { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); // You can also deserialize this - rocket_cors::Cors { + rocket_cors::CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), @@ -53,7 +53,7 @@ fn cors_options() -> Cors { /// /// Note: In your real application, you might want to use something like `lazy_static` to generate /// a `&'static` reference to this instead of creating a new struct on every request. -fn cors_options_all() -> Cors { +fn cors_options_all() -> CorsOptions { // You can also deserialize this Default::default() } @@ -62,6 +62,6 @@ fn main() { rocket::ignite() .mount("/", routes![app, ping, ping_options,]) .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes - .manage(cors_options()) + .manage(cors_options().to_cors().expect("To not fail")) .launch(); } diff --git a/src/fairing.rs b/src/fairing.rs index 18a6834..991d988 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -100,16 +100,10 @@ impl rocket::fairing::Fairing for Cors { } fn on_attach(&self, rocket: rocket::Rocket) -> Result { - match self.validate() { - Ok(()) => Ok(rocket.mount( - &self.fairing_route_base, - vec![fairing_route(self.fairing_route_rank)], - )), - Err(e) => { - error_!("Error attaching CORS fairing: {}", e); - Err(rocket) - } - } + Ok(rocket.mount( + &self.fairing_route_base, + vec![fairing_route(self.fairing_route_rank)], + )) } fn on_request(&self, request: &mut Request<'_>, _: &rocket::Data) { @@ -141,7 +135,7 @@ mod tests { use rocket::local::Client; use rocket::Rocket; - use crate::{AllOrSome, AllowedHeaders, AllowedOrigins, Cors}; + use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions}; const CORS_ROOT: &'static str = "/my_cors"; @@ -149,7 +143,7 @@ mod tests { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); - Cors { + CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), @@ -158,6 +152,8 @@ mod tests { ..Default::default() } + .to_cors() + .expect("Not to fail") } fn rocket(fairing: Cors) -> Rocket { @@ -191,15 +187,5 @@ mod tests { assert!(error_route.is_some()); } - #[test] - #[should_panic(expected = "launch fairing failure")] - fn options_are_validated_on_attach() { - let mut options = make_cors_options(); - options.allowed_origins = AllOrSome::All; - options.send_wildcard = true; - - let _ = rocket(options).launch(); - } - // Rest of the things can only be tested in integration tests } diff --git a/src/lib.rs b/src/lib.rs index 85f7abe..bff0e7b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,8 +42,8 @@ rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", branch = "mast ## Features By default, a `serialization` feature is enabled in this crate that allows you to (de)serialize -the `Cors` struct that is described below. If you would like to disable this, simply change -your `Cargo.toml` to: +the [`CorsOptions`] struct that is described below. If you would like to disable this, simply +change your `Cargo.toml` to: ```toml rocket_cors = { version = "0.4.0", default-features = false } @@ -51,8 +51,9 @@ rocket_cors = { version = "0.4.0", default-features = false } ## Usage -Before you can add CORS responses to your application, you need to create a `Cors` struct that -will hold the settings. +Before you can add CORS responses to your application, you need to create a [`CorsOptions`] +struct that will hold the settings. Then, you need to create a [`Cors`] struct using +[`CorsOptions::to_cors`] which will validate and optimise the settings for Rocket to use. Each of the examples can be run off the repository via `cargo run --example xxx` where `xxx` is @@ -60,14 +61,19 @@ Each of the examples can be run off the repository via `cargo run --example xxx` - `guard` - `manual` -### `Cors` Struct +### `CorsOptions` Struct -The [`Cors` struct](Cors) contains the settings for CORS requests to be validated +The [`CorsOptiopns`] struct contains the settings for CORS requests to be validated and for responses to be generated. Defaults are defined for every field in the struct, and -are documented on the [`Cors` struct](Cors) page. You can also deserialize +are documented on the [`CorsOptiopns`] page. You can also deserialize the struct from some format like JSON, YAML or TOML when the default `serialization` feature is enabled. +### `Cors` Struct + +The [`Cors`] struct is what will be used with Rocket. After creating or deserializing a +[`CorsOptions`] struct, use [`CorsOptions::to_cors`] to create a [`Cors`] struct. + ### Three modes of operation You can add CORS to your routes via one of three ways, in descending order of ease and in @@ -100,43 +106,11 @@ routes for your application, and the checks are done transparently. However, you can only have one set of settings that must apply to all routes. You cannot opt any route out of CORS checks. -To use this, simply create a [`Cors` struct](Cors) and then +To use this, simply create a [`Cors`] from [`CorsOptions::to_cors`] and then [`attach`](https://api.rocket.rs/rocket/struct.Rocket.html#method.attach) it to Rocket. -```rust,no_run -#![feature(proc_macro_hygiene, decl_macro)] -extern crate rocket; -extern crate rocket_cors; +Refer to the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/fairing.rs). -use rocket::{get, routes}; -use rocket::http::Method; -use rocket_cors::{AllowedOrigins, AllowedHeaders}; - -#[get("/")] -fn cors<'a>() -> &'a str { - "Hello CORS" -} - -fn main() { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); - - // You can also deserialize this - let options = rocket_cors::Cors { - allowed_origins: allowed_origins, - allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), - allow_credentials: true, - ..Default::default() - }; - - rocket::ignite() - .mount("/", routes![cors]) - .attach(options) - .launch(); -} - -``` #### Injected Route The fairing implementation will inject a route during attachment to Rocket. This route is used @@ -154,7 +128,7 @@ The only way to do this is to hijack the request and route it to our own injecte handle errors. Rocket does not allow Fairings to stop the processing of a route. You can configure the behaviour of the injected route through a couple of fields in the -[`Cors` struct](Cors). +[`CorsOptions`]. ### Request Guard @@ -165,7 +139,7 @@ requests. The `OPTIONS` routes are used for CORS preflight checks. You will have to do the following: -- Create a [`Cors` struct](Cors) and during Rocket's ignite, add the struct to +- Create a [`Cors`] from [`CorsOptions`] and during Rocket's ignite, add the struct to Rocket's [managed state](https://rocket.rs/guide/state/#managed-state). - For all the routes that you want to enforce CORS on, you can mount either some [catch all route](catch_all_options_routes) or define your own route for the OPTIONS @@ -178,69 +152,7 @@ error handling in case of errors. - In your routes, to add CORS headers to your responses, use the appropriate functions on the [`Guard`](Guard) for a `Response` or a `Responder`. -```rust,no_run -#![feature(proc_macro_hygiene, decl_macro)] -extern crate rocket; -extern crate rocket_cors; - -use std::io::Cursor; - -use rocket::{Response, get, options, routes}; -use rocket::http::Method; -use rocket_cors::{Guard, AllowedOrigins, AllowedHeaders, Responder}; - -/// Using a `Responder` -- the usual way you would use this -#[get("/")] -fn responder(cors: Guard) -> Responder<&str> { - cors.responder("Hello CORS!") -} - -/// Using a `Response` instead of a `Responder`. You generally won't have to do this. -#[get("/response")] -fn response(cors: Guard) -> Response { - let mut response = Response::new(); - response.set_sized_body(Cursor::new("Hello CORS!")); - cors.response(response) -} - -/// Manually mount an OPTIONS route for your own handling -#[options("/manual")] -fn manual_options(cors: Guard) -> Responder<&str> { - cors.responder("Manual OPTIONS preflight handling") -} - -/// Manually mount an OPTIONS route for your own handling -#[get("/manual")] -fn manual(cors: Guard) -> Responder<&str> { - cors.responder("Manual OPTIONS preflight handling") -} - -fn main() { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); - - // You can also deserialize this - let options = rocket_cors::Cors { - allowed_origins: allowed_origins, - allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), - allow_credentials: true, - ..Default::default() - }; - - rocket::ignite() - .mount( - "/", - routes![responder, response], - ) - // Mount the routes to catch all the OPTIONS pre-flight requests - .mount("/", rocket_cors::catch_all_options_routes()) - // You can also manually mount an OPTIONS route that will be used instead - .mount("/", routes![manual, manual_options]) - .manage(options) - .launch(); -} -``` +Refer to the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/guard.rs). ## Truly Manual @@ -261,17 +173,17 @@ has been validated. If validation fails, the closure will not be run. You should that has any side effects or with an appreciable computation cost inside this handler. ### Steps to perform: -- You will first need to have a `Cors` struct ready. This struct can be borrowed with a lifetime +- You will first need to have a [`Cors`] struct ready. This struct can be borrowed with a lifetime at least as long as `'r` which is the lifetime of a Rocket request. `'static` works too. In this case, you might as well use the `Guard` method above and place the `Cors` struct in Rocket's [state](https://rocket.rs/guide/state/). -Alternatively, you can create a `Cors` struct directly in the route. +Alternatively, you can create a [`Cors`] struct directly in the route. - Your routes _might_ need to have a `'r` lifetime and return `impl Responder<'r>`. See below. -- Using the `Cors` struct, use either the -[`respond_owned`](Cors#method.respond_owned) or -[`respond_borrowed`](Cors#method.respond_borrowed) function and pass in a handler +- Using the [`Cors`] struct, use either the +[`Cors::respond_owned`] or +[`Cors::respond_borrowed`] function and pass in a handler that will be executed once CORS validation is successful. -- Your handler will be passed a [`Guard`](Guard) which you will have to use to +- Your handler will be passed a [`Guard`] which you will have to use to add CORS headers into your own response. - You will have to manually define your own `OPTIONS` routes. @@ -291,127 +203,7 @@ required. You can see examples when the lifetime annotation is required (or not) in `examples/manual.rs`. -### Owned example -This is the most likely scenario when you want to have manual CORS validation. You can use this -when the settings you want to use for a route is not the same as the rest of the application -(which you might have put in Rocket's state). - -```rust,no_run -#![feature(proc_macro_hygiene, decl_macro)] -extern crate rocket; -extern crate rocket_cors; - -use rocket::{get, options, routes}; -use rocket::http::Method; -use rocket::response::Responder; -use rocket_cors::{Cors, AllowedOrigins, AllowedHeaders}; - -/// Create and use an ad-hoc Cors -#[get("/owned")] -fn owned<'r>() -> impl Responder<'r> { - let options = cors_options(); - options.respond_owned(|guard| guard.responder("Hello CORS")) -} - -/// You need to define an OPTIONS route for preflight checks. -/// These routes can just return the unit type `()` -#[options("/owned")] -fn owned_options<'r>() -> impl Responder<'r> { - let options = cors_options(); - options.respond_owned(|guard| guard.responder(())) -} - -fn cors_options() -> Cors { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); - - // You can also deserialize this - rocket_cors::Cors { - allowed_origins: allowed_origins, - allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), - allow_credentials: true, - ..Default::default() - } -} - -fn main() { - rocket::ignite() - .mount( - "/", - routes![ - owned, - owned_options, - ], - ) - .manage(cors_options()) - .launch(); -} -``` - -### Borrowed Example -You might want to borrow the `Cors` struct from Rocket's state, for example. Unless you have -special handling, you might want to use the Guard method instead which has less hassle. - -```rust,no_run -#![feature(proc_macro_hygiene, decl_macro)] -extern crate rocket; -extern crate rocket_cors; - -use std::io::Cursor; - -use rocket::{State, Response, get, routes}; -use rocket::http::Method; -use rocket::response::Responder; -use rocket_cors::{Cors, AllowedOrigins, AllowedHeaders}; - -/// Using a borrowed Cors -#[get("/")] -fn borrowed(options: State) -> impl Responder { - options.inner().respond_borrowed( - |guard| guard.responder("Hello CORS"), - ) -} - -/// Using a `Response` instead of a `Responder`. You generally won't have to do this. -#[get("/response")] -fn response(options: State) -> impl Responder { - let mut response = Response::new(); - response.set_sized_body(Cursor::new("Hello CORS!")); - - options.inner().respond_borrowed( - move |guard| guard.response(response), - ) -} - -fn cors_options() -> Cors { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); - - // You can also deserialize this - rocket_cors::Cors { - allowed_origins: allowed_origins, - allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), - allow_credentials: true, - ..Default::default() - } -} - -fn main() { - rocket::ignite() - .mount( - "/", - routes![ - borrowed, - response, - ], - ) - .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes - .manage(cors_options()) - .launch(); -} -``` +See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/manual.rs). ## Mixing Guard and Manual @@ -419,83 +211,7 @@ You can mix `Guard` and `Truly Manual` modes together for your application. For application might restrict the Origins that can access it, except for one `ping` route that allows all access. -You can run the example code below with `cargo run --example mix`. - -```rust,no_run -#![feature(proc_macro_hygiene, decl_macro)] -extern crate rocket; -extern crate rocket_cors; - -use rocket::{get, options, routes}; -use rocket::http::Method; -use rocket::response::Responder; -use rocket_cors::{Cors, Guard, AllowedOrigins, AllowedHeaders}; - -/// The "usual" app route -#[get("/")] -fn app(cors: Guard) -> rocket_cors::Responder<&str> { - cors.responder("Hello CORS!") -} - -/// The special "ping" route -#[get("/ping")] -fn ping<'r>() -> impl Responder<'r> { - let options = cors_options_all(); - options.respond_owned(|guard| guard.responder("Pong!")) -} - -/// You need to define an OPTIONS route for preflight checks if you want to use `Cors` struct -/// that is not in Rocket's managed state. -/// These routes can just return the unit type `()` -#[options("/ping")] -fn ping_options<'r>() -> impl Responder<'r> { - let options = cors_options_all(); - options.respond_owned(|guard| guard.responder(())) -} - -/// Returns the "application wide" Cors struct -fn cors_options() -> Cors { - let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); - - // You can also deserialize this - rocket_cors::Cors { - allowed_origins: allowed_origins, - allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), - allow_credentials: true, - ..Default::default() - } -} - -/// A special struct that allows all origins -/// -/// Note: In your real application, you might want to use something like `lazy_static` to -/// generate a `&'static` reference to this instead of creating a new struct on every request. -fn cors_options_all() -> Cors { - // You can also deserialize this - rocket_cors::Cors { - allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), - ..Default::default() - } -} - -fn main() { - rocket::ignite() - .mount( - "/", - routes![ - app, - ping, - ping_options, - ], - ) - .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes - .manage(cors_options()) - .launch(); -} - -``` +See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/guard.rs). ## Reference - [Fetch CORS Specification](https://fetch.spec.whatwg.org/#cors-protocol) @@ -666,7 +382,6 @@ impl fmt::Display for Error { } } - impl error::Error for Error { fn cause(&self) -> Option<&dyn error::Error> { match *self { @@ -881,10 +596,7 @@ impl AllowedHeaders { } } -/// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS -/// -/// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. See the -/// documentation at the [crate root](index.html) for usage information. +/// Configuration options for CORS request handling. /// /// You create a new copy of this struct by defining the configurations in the fields below. /// This struct can also be deserialized by serde with the `serialization` feature which is @@ -893,6 +605,8 @@ impl AllowedHeaders { /// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this /// struct. The default for each field is described in the docuementation for the field. /// +/// Before you can use this with Rocket, you will need to call the [`CorsOptions::to_cors`] method. +/// /// # Examples /// /// You can run an example from the repository to demonstrate the JSON serialization with @@ -900,7 +614,7 @@ impl AllowedHeaders { /// /// ## Pure default /// ```rust -/// let default = rocket_cors::Cors::default(); +/// let default = rocket_cors::CorsOptions::default(); /// ``` /// /// ## JSON Examples @@ -959,7 +673,7 @@ impl AllowedHeaders { /// ``` #[derive(Eq, PartialEq, Clone, Debug)] #[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] -pub struct Cors { +pub struct CorsOptions { /// Origins that are allowed to make requests. /// Will be verified against the `Origin` request header. /// @@ -986,7 +700,7 @@ pub struct Cors { /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` #[cfg_attr( feature = "serialization", - serde(default = "Cors::default_allowed_methods") + serde(default = "CorsOptions::default_allowed_methods") )] pub allowed_methods: AllowedMethods, /// The list of header field names which can be used when this resource is accessed by allowed @@ -1048,7 +762,7 @@ pub struct Cors { /// Defaults to "/cors" #[cfg_attr( feature = "serialization", - serde(default = "Cors::default_fairing_route_base") + serde(default = "CorsOptions::default_fairing_route_base") )] pub fairing_route_base: String, /// When used as Fairing, Cors will need to redirect failed CORS checks to a custom route @@ -1058,12 +772,12 @@ pub struct Cors { /// Defaults to 0 #[cfg_attr( feature = "serialization", - serde(default = "Cors::default_fairing_route_rank") + serde(default = "CorsOptions::default_fairing_route_rank") )] pub fairing_route_rank: isize, } -impl Default for Cors { +impl Default for CorsOptions { fn default() -> Self { Self { allowed_origins: Default::default(), @@ -1079,7 +793,7 @@ impl Default for Cors { } } -impl Cors { +impl CorsOptions { fn default_allowed_methods() -> HashSet { use rocket::http::Method; @@ -1116,6 +830,36 @@ impl Cors { Ok(()) } + /// Creates a [`Cors`] struct that can be used to respond to requests or as a Rocket Fairing + pub fn to_cors(&self) -> Result { + Cors::from_options(self) + } +} + +/// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS +/// +/// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. See the +/// documentation at the [crate root](index.html) for usage information. +/// +/// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`]. +#[derive(Clone, Debug)] +pub struct Cors(CorsOptions); + +impl Deref for Cors { + type Target = CorsOptions; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Cors { + /// Create a `Cors` struct from a [`CorsOptions`] + pub fn from_options(options: &CorsOptions) -> Result { + options.validate()?; + Ok(Cors(options.clone())) + } + /// Manually respond to a request with CORS checks and headers using an Owned `Cors`. /// /// Use this variant when your `Cors` struct will not live at least as long as the whole `'r` @@ -1131,7 +875,6 @@ impl Cors { F: FnOnce(Guard<'r>) -> R + 'r, R: response::Responder<'r>, { - self.validate()?; Ok(ManualResponder::new(Cow::Owned(self), handler)) } @@ -1155,7 +898,6 @@ impl Cors { F: FnOnce(Guard<'r>) -> R + 'r, R: response::Responder<'r>, { - self.validate()?; Ok(ManualResponder::new(Cow::Borrowed(self), handler)) } } @@ -1877,11 +1619,11 @@ mod tests { use super::*; use crate::http::Method; - fn make_cors_options() -> Cors { + fn make_cors_options() -> CorsOptions { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); - Cors { + CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![http::Method::Get] .into_iter() @@ -1897,7 +1639,7 @@ mod tests { } } - fn make_invalid_options() -> Cors { + fn make_invalid_options() -> CorsOptions { let mut cors = make_cors_options(); cors.allow_credentials = true; cors.allowed_origins = AllOrSome::All; @@ -1930,8 +1672,8 @@ mod tests { #[cfg(feature = "serialization")] #[test] fn cors_default_deserialization_is_correct() { - let deserialized: Cors = serde_json::from_str("{}").expect("To not fail"); - assert_eq!(deserialized, Cors::default()); + let deserialized: CorsOptions = serde_json::from_str("{}").expect("To not fail"); + assert_eq!(deserialized, CorsOptions::default()); } // The following tests check validation @@ -2251,7 +1993,7 @@ mod tests { #[test] fn preflight_validated_correctly() { - let options = make_cors_options(); + let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); let origin_header = @@ -2271,7 +2013,7 @@ mod tests { .header(method_header) .header(request_headers); - let result = validate(&options, request.inner()).expect("to not fail"); + let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Preflight { origin: FromStr::from_str("https://www.acme.com").unwrap(), // Checks that only a subset of allowed headers are returned @@ -2282,36 +2024,11 @@ mod tests { assert_eq!(expected_result, result); } - #[test] - #[should_panic(expected = "CredentialsWithWildcardOrigin")] - fn preflight_validation_errors_on_invalid_options() { - let options = make_invalid_options(); - let client = make_client(); - - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); - - let request = client - .options("/") - .header(origin_header) - .header(method_header) - .header(request_headers); - - let _ = validate(&options, request.inner()).unwrap(); - } - #[test] fn preflight_validation_allows_all_origin() { let mut options = make_cors_options(); options.allowed_origins = AllOrSome::All; + let cors = options.to_cors().expect("To not fail"); let client = make_client(); let origin_header = @@ -2331,7 +2048,7 @@ mod tests { .header(method_header) .header(request_headers); - let result = validate(&options, request.inner()).expect("to not fail"); + let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Preflight { origin: FromStr::from_str("https://www.example.com").unwrap(), headers: Some(FromStr::from_str("Authorization").unwrap()), @@ -2343,7 +2060,7 @@ mod tests { #[test] #[should_panic(expected = "OriginNotAllowed")] fn preflight_validation_errors_on_invalid_origin() { - let options = make_cors_options(); + let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); let origin_header = @@ -2363,13 +2080,13 @@ mod tests { .header(method_header) .header(request_headers); - let _ = validate(&options, request.inner()).unwrap(); + let _ = validate(&cors, request.inner()).unwrap(); } #[test] #[should_panic(expected = "MissingRequestMethod")] fn preflight_validation_errors_on_missing_request_method() { - let options = make_cors_options(); + let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); let origin_header = @@ -2385,13 +2102,13 @@ mod tests { .header(origin_header) .header(request_headers); - let _ = validate(&options, request.inner()).unwrap(); + let _ = validate(&cors, request.inner()).unwrap(); } #[test] #[should_panic(expected = "MethodNotAllowed")] fn preflight_validation_errors_on_disallowed_method() { - let options = make_cors_options(); + let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); let origin_header = @@ -2411,13 +2128,13 @@ mod tests { .header(method_header) .header(request_headers); - let _ = validate(&options, request.inner()).unwrap(); + let _ = validate(&cors, request.inner()).unwrap(); } #[test] #[should_panic(expected = "HeadersNotAllowed")] fn preflight_validation_errors_on_disallowed_headers() { - let options = make_cors_options(); + let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); let origin_header = @@ -2437,19 +2154,19 @@ mod tests { .header(method_header) .header(request_headers); - let _ = validate(&options, request.inner()).unwrap(); + let _ = validate(&cors, request.inner()).unwrap(); } #[test] fn actual_request_validated_correctly() { - let options = make_cors_options(); + let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); let origin_header = Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); let request = client.get("/").header(origin_header); - let result = validate(&options, request.inner()).expect("to not fail"); + let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Request { origin: FromStr::from_str("https://www.acme.com").unwrap(), }; @@ -2457,30 +2174,18 @@ mod tests { assert_eq!(expected_result, result); } - #[test] - #[should_panic(expected = "CredentialsWithWildcardOrigin")] - fn actual_request_validation_errors_on_invalid_options() { - let options = make_invalid_options(); - let client = make_client(); - - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let request = client.get("/").header(origin_header); - - let _ = validate(&options, request.inner()).unwrap(); - } - #[test] fn actual_request_validation_allows_all_origin() { let mut options = make_cors_options(); options.allowed_origins = AllOrSome::All; + let cors = options.to_cors().expect("To not fail"); let client = make_client(); let origin_header = Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); let request = client.get("/").header(origin_header); - let result = validate(&options, request.inner()).expect("to not fail"); + let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Request { origin: FromStr::from_str("https://www.example.com").unwrap(), }; @@ -2491,23 +2196,23 @@ mod tests { #[test] #[should_panic(expected = "OriginNotAllowed")] fn actual_request_validation_errors_on_incorrect_origin() { - let options = make_cors_options(); + let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); let origin_header = Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); let request = client.get("/").header(origin_header); - let _ = validate(&options, request.inner()).unwrap(); + let _ = validate(&cors, request.inner()).unwrap(); } #[test] fn non_cors_request_return_empty_response() { - let options = make_cors_options(); + let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); let request = client.options("/"); - let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new(); assert_eq!(expected_response, response); } @@ -2515,6 +2220,7 @@ mod tests { #[test] fn preflight_validated_and_built_correctly() { let options = make_cors_options(); + let cors = options.to_cors().expect("To not fail"); let client = make_client(); let origin_header = @@ -2534,7 +2240,7 @@ mod tests { .header(method_header) .header(request_headers); - let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false) @@ -2553,6 +2259,7 @@ mod tests { let mut options = make_cors_options(); options.allowed_origins = AllOrSome::All; options.send_wildcard = false; + let cors = options.to_cors().expect("To not fail"); let client = make_client(); @@ -2573,7 +2280,7 @@ mod tests { .header(method_header) .header(request_headers); - let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true) @@ -2592,6 +2299,7 @@ mod tests { options.allowed_origins = AllOrSome::All; options.send_wildcard = true; options.allow_credentials = false; + let cors = options.to_cors().expect("To not fail"); let client = make_client(); @@ -2612,7 +2320,7 @@ mod tests { .header(method_header) .header(request_headers); - let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() .any() @@ -2627,13 +2335,14 @@ mod tests { #[test] fn actual_request_validated_and_built_correctly() { let options = make_cors_options(); + let cors = options.to_cors().expect("To not fail"); let client = make_client(); let origin_header = Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); let request = client.get("/").header(origin_header); - let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false) .credentials(options.allow_credentials) @@ -2648,6 +2357,7 @@ mod tests { options.allowed_origins = AllOrSome::All; options.send_wildcard = false; options.allow_credentials = false; + let cors = options.to_cors().expect("To not fail"); let client = make_client(); @@ -2655,7 +2365,7 @@ mod tests { Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); let request = client.get("/").header(origin_header); - let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true) .credentials(options.allow_credentials) @@ -2670,6 +2380,7 @@ mod tests { options.allowed_origins = AllOrSome::All; options.send_wildcard = true; options.allow_credentials = false; + let cors = options.to_cors().expect("To not fail"); let client = make_client(); @@ -2677,7 +2388,7 @@ mod tests { Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); let request = client.get("/").header(origin_header); - let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() .any() .credentials(options.allow_credentials) diff --git a/tests/fairing.rs b/tests/fairing.rs index 7d45232..5e46436 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -21,23 +21,25 @@ fn panicking_route() { panic!("This route will panic"); } -fn make_cors_options() -> Cors { +fn make_cors() -> Cors { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); - Cors { + CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, ..Default::default() } + .to_cors() + .expect("To not fail") } fn rocket() -> rocket::Rocket { rocket::ignite() .mount("/", routes![cors, panicking_route]) - .attach(make_cors_options()) + .attach(make_cors()) } #[test] diff --git a/tests/guard.rs b/tests/guard.rs index 13e7a42..cee87d3 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -59,17 +59,19 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo cors.responder("hmm") } -fn make_cors_options() -> cors::Cors { +fn make_cors() -> cors::Cors { let (allowed_origins, failed_origins) = cors::AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); - cors::Cors { + cors::CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: cors::AllowedHeaders::some(&["Authorization", "Accept"]), allow_credentials: true, ..Default::default() } + .to_cors() + .expect("To not fail") } fn make_rocket() -> rocket::Rocket { @@ -81,7 +83,7 @@ fn make_rocket() -> rocket::Rocket { ) .mount("/", cors::catch_all_options_routes()) // mount the catch all routes .mount("/", routes![cors_manual, cors_manual_options]) // manual OPTIOONS routes - .manage(make_cors_options()) + .manage(make_cors()) .manage(SomeState) } diff --git a/tests/manual.rs b/tests/manual.rs index 9adcf36..cdbb463 100644 --- a/tests/manual.rs +++ b/tests/manual.rs @@ -31,7 +31,7 @@ fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> { /// Respond with an owned option instead #[options("/owned")] fn owned_options<'r>() -> impl Responder<'r> { - let borrow = make_different_cors_options(); + let borrow = make_different_cors_options().to_cors()?; borrow.respond_owned(|guard| guard.responder("Manual CORS Preflight")) } @@ -39,7 +39,7 @@ fn owned_options<'r>() -> impl Responder<'r> { /// Respond with an owned option instead #[get("/owned")] fn owned<'r>() -> impl Responder<'r> { - let borrow = make_different_cors_options(); + let borrow = make_different_cors_options().to_cors()?; borrow.respond_owned(|guard| guard.responder("Hello CORS Owned")) } @@ -65,11 +65,11 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp }) } -fn make_cors_options() -> Cors { +fn make_cors_options() -> CorsOptions { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); - Cors { + CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), @@ -78,11 +78,11 @@ fn make_cors_options() -> Cors { } } -fn make_different_cors_options() -> Cors { +fn make_different_cors_options() -> CorsOptions { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); assert!(failed_origins.is_empty()); - Cors { + CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), @@ -96,7 +96,7 @@ fn rocket() -> rocket::Rocket { .mount("/", routes![cors, panicking_route]) .mount("/", routes![owned, owned_options]) .mount("/", catch_all_options_routes()) // mount the catch all routes - .manage(make_cors_options()) + .manage(make_cors_options().to_cors().expect("Not to fail")) } #[test] diff --git a/tests/mix.rs b/tests/mix.rs index cbe600a..4f46663 100644 --- a/tests/mix.rs +++ b/tests/mix.rs @@ -14,7 +14,7 @@ use rocket::http::{Header, Method, Status}; use rocket::local::Client; use rocket::response::Responder; -use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors, Guard}; +use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard}; /// The "usual" app route #[get("/")] @@ -25,8 +25,8 @@ fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, &str> { /// The special "ping" route #[get("/ping")] fn ping<'r>() -> impl Responder<'r> { - let options = cors_options_all(); - options.respond_owned(|guard| guard.responder("Pong!")) + let cors = cors_options_all().to_cors()?; + cors.respond_owned(|guard| guard.responder("Pong!")) } /// You need to define an OPTIONS route for preflight checks if you want to use `Cors` struct @@ -34,17 +34,17 @@ fn ping<'r>() -> impl Responder<'r> { /// These routes can just return the unit type `()` #[options("/ping")] fn ping_options<'r>() -> impl Responder<'r> { - let options = cors_options_all(); - options.respond_owned(|guard| guard.responder(())) + let cors = cors_options_all().to_cors()?; + cors.respond_owned(|guard| guard.responder(())) } /// Returns the "application wide" Cors struct -fn cors_options() -> Cors { +fn cors_options() -> CorsOptions { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); // You can also deserialize this - rocket_cors::Cors { + rocket_cors::CorsOptions { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), @@ -57,7 +57,7 @@ fn cors_options() -> Cors { /// /// Note: In your real application, you might want to use something like `lazy_static` to generate /// a `&'static` reference to this instead of creating a new struct on every request. -fn cors_options_all() -> Cors { +fn cors_options_all() -> CorsOptions { // You can also deserialize this Default::default() } @@ -66,7 +66,7 @@ fn rocket() -> rocket::Rocket { rocket::ignite() .mount("/", routes![app, ping, ping_options,]) .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes - .manage(cors_options()) + .manage(cors_options().to_cors().expect("Not to fail")) } #[test]