Create `CorsOptions` (#57)

* Create `CorsOptions`

- `Cors` will be, in the future, an "optimised" or "compiled" version of
`CorsOptions`
- For now, `Cors` simply clones `CorsOptions` and `Deref`s to
`CorsOptions`.

* Update examples

* Remove usage of `Self` struct constructors
This commit is contained in:
Yong Wen Chua 2018-12-19 01:29:26 +01:00 committed by GitHub
parent aa15af333f
commit c86bb44529
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 179 additions and 464 deletions

View File

@ -4,28 +4,31 @@ use rocket_cors;
use rocket::http::Method; use rocket::http::Method;
use rocket::{get, routes}; use rocket::{get, routes};
use rocket_cors::{AllowedHeaders, AllowedOrigins}; use rocket_cors::{AllowedHeaders, AllowedOrigins, Error};
#[get("/")] #[get("/")]
fn cors<'a>() -> &'a str { fn cors<'a>() -> &'a str {
"Hello CORS" "Hello CORS"
} }
fn main() { fn main() -> Result<(), Error> {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
let options = rocket_cors::Cors { let cors = rocket_cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,
..Default::default() ..Default::default()
}; }
.to_cors()?;
rocket::ignite() rocket::ignite()
.mount("/", routes![cors]) .mount("/", routes![cors])
.attach(options) .attach(cors)
.launch(); .launch();
Ok(())
} }

View File

@ -7,7 +7,7 @@ use std::io::Cursor;
use rocket::http::Method; use rocket::http::Method;
use rocket::Response; use rocket::Response;
use rocket::{get, options, routes}; use rocket::{get, options, routes};
use rocket_cors::{AllowedHeaders, AllowedOrigins, Guard, Responder}; use rocket_cors::{AllowedHeaders, AllowedOrigins, Error, Guard, Responder};
/// Using a `Responder` -- the usual way you would use this /// Using a `Responder` -- the usual way you would use this
#[get("/")] #[get("/")]
@ -35,18 +35,19 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> {
cors.responder("Manual OPTIONS preflight handling") cors.responder("Manual OPTIONS preflight handling")
} }
fn main() { fn main() -> Result<(), Error> {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
let options = rocket_cors::Cors { let cors = rocket_cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,
..Default::default() ..Default::default()
}; }
.to_cors()?;
rocket::ignite() rocket::ignite()
.mount("/", routes![responder, response]) .mount("/", routes![responder, response])
@ -54,6 +55,8 @@ fn main() {
.mount("/", rocket_cors::catch_all_options_routes()) .mount("/", rocket_cors::catch_all_options_routes())
// You can also manually mount an OPTIONS route that will be used instead // You can also manually mount an OPTIONS route that will be used instead
.mount("/", routes![manual, manual_options]) .mount("/", routes![manual, manual_options])
.manage(options) .manage(cors)
.launch(); .launch();
Ok(())
} }

View File

@ -6,17 +6,17 @@
use rocket_cors as cors; use rocket_cors as cors;
use serde_json; use serde_json;
use crate::cors::{AllowedHeaders, AllowedOrigins, Cors}; use crate::cors::{AllowedHeaders, AllowedOrigins, CorsOptions};
use rocket::http::Method; use rocket::http::Method;
fn main() { fn main() {
// The default demonstrates the "All" serialization of several of the settings // The default demonstrates the "All" serialization of several of the settings
let default: Cors = Default::default(); let default: CorsOptions = Default::default();
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
let options = cors::Cors { let options = cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get, Method::Post, Method::Delete] allowed_methods: vec![Method::Get, Method::Post, Method::Delete]
.into_iter() .into_iter()

View File

@ -7,9 +7,13 @@ use std::io::Cursor;
use rocket::http::Method; use rocket::http::Method;
use rocket::response::Responder; use rocket::response::Responder;
use rocket::{get, options, routes, Response, State}; use rocket::{get, options, routes, Response, State};
use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors}; use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions};
/// Using a borrowed Cors /// Using a borrowed Cors
///
/// You might want to borrow the `Cors` struct from Rocket's state, for example. Unless you have
/// special handling, you might want to use the Guard method instead which has less hassle.
///
/// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime /// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime
/// `'r` and so does `Responder`! /// `'r` and so does `Responder`!
#[get("/")] #[get("/")]
@ -34,9 +38,13 @@ fn response(options: State<'_, Cors>) -> impl Responder<'_> {
/// Create and use an ad-hoc Cors /// Create and use an ad-hoc Cors
/// Note that the `'r` lifetime is needed because the compiler cannot elide anything. /// Note that the `'r` lifetime is needed because the compiler cannot elide anything.
///
/// This is the most likely scenario when you want to have manual CORS validation. You can use this
/// when the settings you want to use for a route is not the same as the rest of the application
/// (which you might have put in Rocket's state).
#[get("/owned")] #[get("/owned")]
fn owned<'r>() -> impl Responder<'r> { fn owned<'r>() -> impl Responder<'r> {
let options = cors_options(); let options = cors_options().to_cors()?;
options.respond_owned(|guard| guard.responder("Hello CORS")) options.respond_owned(|guard| guard.responder("Hello CORS"))
} }
@ -46,16 +54,16 @@ fn owned<'r>() -> impl Responder<'r> {
/// Note that the `'r` lifetime is needed because the compiler cannot elide anything. /// Note that the `'r` lifetime is needed because the compiler cannot elide anything.
#[options("/owned")] #[options("/owned")]
fn owned_options<'r>() -> impl Responder<'r> { fn owned_options<'r>() -> impl Responder<'r> {
let options = cors_options(); let options = cors_options().to_cors()?;
options.respond_owned(|guard| guard.responder(())) options.respond_owned(|guard| guard.responder(()))
} }
fn cors_options() -> Cors { fn cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
rocket_cors::Cors { rocket_cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
@ -68,6 +76,6 @@ fn main() {
rocket::ignite() rocket::ignite()
.mount("/", routes![borrowed, response, owned, owned_options,]) .mount("/", routes![borrowed, response, owned, owned_options,])
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
.manage(cors_options()) .manage(cors_options().to_cors().expect("To not fail"))
.launch(); .launch();
} }

View File

@ -10,7 +10,7 @@ use rocket_cors;
use rocket::http::Method; use rocket::http::Method;
use rocket::response::Responder; use rocket::response::Responder;
use rocket::{get, options, routes}; use rocket::{get, options, routes};
use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors, Guard}; use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard};
/// The "usual" app route /// The "usual" app route
#[get("/")] #[get("/")]
@ -21,8 +21,8 @@ fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, &str> {
/// The special "ping" route /// The special "ping" route
#[get("/ping")] #[get("/ping")]
fn ping<'r>() -> impl Responder<'r> { fn ping<'r>() -> impl Responder<'r> {
let options = cors_options_all(); let cors = cors_options_all().to_cors()?;
options.respond_owned(|guard| guard.responder("Pong!")) cors.respond_owned(|guard| guard.responder("Pong!"))
} }
/// You need to define an OPTIONS route for preflight checks if you want to use `Cors` struct /// You need to define an OPTIONS route for preflight checks if you want to use `Cors` struct
@ -30,17 +30,17 @@ fn ping<'r>() -> impl Responder<'r> {
/// These routes can just return the unit type `()` /// These routes can just return the unit type `()`
#[options("/ping")] #[options("/ping")]
fn ping_options<'r>() -> impl Responder<'r> { fn ping_options<'r>() -> impl Responder<'r> {
let options = cors_options_all(); let cors = cors_options_all().to_cors()?;
options.respond_owned(|guard| guard.responder(())) cors.respond_owned(|guard| guard.responder(()))
} }
/// Returns the "application wide" Cors struct /// Returns the "application wide" Cors struct
fn cors_options() -> Cors { fn cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
rocket_cors::Cors { rocket_cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
@ -53,7 +53,7 @@ fn cors_options() -> Cors {
/// ///
/// Note: In your real application, you might want to use something like `lazy_static` to generate /// Note: In your real application, you might want to use something like `lazy_static` to generate
/// a `&'static` reference to this instead of creating a new struct on every request. /// a `&'static` reference to this instead of creating a new struct on every request.
fn cors_options_all() -> Cors { fn cors_options_all() -> CorsOptions {
// You can also deserialize this // You can also deserialize this
Default::default() Default::default()
} }
@ -62,6 +62,6 @@ fn main() {
rocket::ignite() rocket::ignite()
.mount("/", routes![app, ping, ping_options,]) .mount("/", routes![app, ping, ping_options,])
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
.manage(cors_options()) .manage(cors_options().to_cors().expect("To not fail"))
.launch(); .launch();
} }

View File

@ -100,16 +100,10 @@ impl rocket::fairing::Fairing for Cors {
} }
fn on_attach(&self, rocket: rocket::Rocket) -> Result<rocket::Rocket, rocket::Rocket> { fn on_attach(&self, rocket: rocket::Rocket) -> Result<rocket::Rocket, rocket::Rocket> {
match self.validate() { Ok(rocket.mount(
Ok(()) => Ok(rocket.mount( &self.fairing_route_base,
&self.fairing_route_base, vec![fairing_route(self.fairing_route_rank)],
vec![fairing_route(self.fairing_route_rank)], ))
)),
Err(e) => {
error_!("Error attaching CORS fairing: {}", e);
Err(rocket)
}
}
} }
fn on_request(&self, request: &mut Request<'_>, _: &rocket::Data) { fn on_request(&self, request: &mut Request<'_>, _: &rocket::Data) {
@ -141,7 +135,7 @@ mod tests {
use rocket::local::Client; use rocket::local::Client;
use rocket::Rocket; use rocket::Rocket;
use crate::{AllOrSome, AllowedHeaders, AllowedOrigins, Cors}; use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions};
const CORS_ROOT: &'static str = "/my_cors"; const CORS_ROOT: &'static str = "/my_cors";
@ -149,7 +143,7 @@ mod tests {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
Cors { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
@ -158,6 +152,8 @@ mod tests {
..Default::default() ..Default::default()
} }
.to_cors()
.expect("Not to fail")
} }
fn rocket(fairing: Cors) -> Rocket { fn rocket(fairing: Cors) -> Rocket {
@ -191,15 +187,5 @@ mod tests {
assert!(error_route.is_some()); assert!(error_route.is_some());
} }
#[test]
#[should_panic(expected = "launch fairing failure")]
fn options_are_validated_on_attach() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
options.send_wildcard = true;
let _ = rocket(options).launch();
}
// Rest of the things can only be tested in integration tests // Rest of the things can only be tested in integration tests
} }

View File

@ -42,8 +42,8 @@ rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", branch = "mast
## Features ## Features
By default, a `serialization` feature is enabled in this crate that allows you to (de)serialize By default, a `serialization` feature is enabled in this crate that allows you to (de)serialize
the `Cors` struct that is described below. If you would like to disable this, simply change the [`CorsOptions`] struct that is described below. If you would like to disable this, simply
your `Cargo.toml` to: change your `Cargo.toml` to:
```toml ```toml
rocket_cors = { version = "0.4.0", default-features = false } rocket_cors = { version = "0.4.0", default-features = false }
@ -51,8 +51,9 @@ rocket_cors = { version = "0.4.0", default-features = false }
## Usage ## Usage
Before you can add CORS responses to your application, you need to create a `Cors` struct that Before you can add CORS responses to your application, you need to create a [`CorsOptions`]
will hold the settings. struct that will hold the settings. Then, you need to create a [`Cors`] struct using
[`CorsOptions::to_cors`] which will validate and optimise the settings for Rocket to use.
Each of the examples can be run off the repository via `cargo run --example xxx` where `xxx` is Each of the examples can be run off the repository via `cargo run --example xxx` where `xxx` is
@ -60,14 +61,19 @@ Each of the examples can be run off the repository via `cargo run --example xxx`
- `guard` - `guard`
- `manual` - `manual`
### `Cors` Struct ### `CorsOptions` Struct
The [`Cors` struct](Cors) contains the settings for CORS requests to be validated The [`CorsOptiopns`] struct contains the settings for CORS requests to be validated
and for responses to be generated. Defaults are defined for every field in the struct, and and for responses to be generated. Defaults are defined for every field in the struct, and
are documented on the [`Cors` struct](Cors) page. You can also deserialize are documented on the [`CorsOptiopns`] page. You can also deserialize
the struct from some format like JSON, YAML or TOML when the default `serialization` feature the struct from some format like JSON, YAML or TOML when the default `serialization` feature
is enabled. is enabled.
### `Cors` Struct
The [`Cors`] struct is what will be used with Rocket. After creating or deserializing a
[`CorsOptions`] struct, use [`CorsOptions::to_cors`] to create a [`Cors`] struct.
### Three modes of operation ### Three modes of operation
You can add CORS to your routes via one of three ways, in descending order of ease and in You can add CORS to your routes via one of three ways, in descending order of ease and in
@ -100,43 +106,11 @@ routes for your application, and the checks are done transparently.
However, you can only have one set of settings that must apply to all routes. You cannot opt However, you can only have one set of settings that must apply to all routes. You cannot opt
any route out of CORS checks. any route out of CORS checks.
To use this, simply create a [`Cors` struct](Cors) and then To use this, simply create a [`Cors`] from [`CorsOptions::to_cors`] and then
[`attach`](https://api.rocket.rs/rocket/struct.Rocket.html#method.attach) it to Rocket. [`attach`](https://api.rocket.rs/rocket/struct.Rocket.html#method.attach) it to Rocket.
```rust,no_run Refer to the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/fairing.rs).
#![feature(proc_macro_hygiene, decl_macro)]
extern crate rocket;
extern crate rocket_cors;
use rocket::{get, routes};
use rocket::http::Method;
use rocket_cors::{AllowedOrigins, AllowedHeaders};
#[get("/")]
fn cors<'a>() -> &'a str {
"Hello CORS"
}
fn main() {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
// You can also deserialize this
let options = rocket_cors::Cors {
allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true,
..Default::default()
};
rocket::ignite()
.mount("/", routes![cors])
.attach(options)
.launch();
}
```
#### Injected Route #### Injected Route
The fairing implementation will inject a route during attachment to Rocket. This route is used The fairing implementation will inject a route during attachment to Rocket. This route is used
@ -154,7 +128,7 @@ The only way to do this is to hijack the request and route it to our own injecte
handle errors. Rocket does not allow Fairings to stop the processing of a route. handle errors. Rocket does not allow Fairings to stop the processing of a route.
You can configure the behaviour of the injected route through a couple of fields in the You can configure the behaviour of the injected route through a couple of fields in the
[`Cors` struct](Cors). [`CorsOptions`].
### Request Guard ### Request Guard
@ -165,7 +139,7 @@ requests. The `OPTIONS` routes are used for CORS preflight checks.
You will have to do the following: You will have to do the following:
- Create a [`Cors` struct](Cors) and during Rocket's ignite, add the struct to - Create a [`Cors`] from [`CorsOptions`] and during Rocket's ignite, add the struct to
Rocket's [managed state](https://rocket.rs/guide/state/#managed-state). Rocket's [managed state](https://rocket.rs/guide/state/#managed-state).
- For all the routes that you want to enforce CORS on, you can mount either some - For all the routes that you want to enforce CORS on, you can mount either some
[catch all route](catch_all_options_routes) or define your own route for the OPTIONS [catch all route](catch_all_options_routes) or define your own route for the OPTIONS
@ -178,69 +152,7 @@ error handling in case of errors.
- In your routes, to add CORS headers to your responses, use the appropriate functions on the - In your routes, to add CORS headers to your responses, use the appropriate functions on the
[`Guard`](Guard) for a `Response` or a `Responder`. [`Guard`](Guard) for a `Response` or a `Responder`.
```rust,no_run Refer to the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/guard.rs).
#![feature(proc_macro_hygiene, decl_macro)]
extern crate rocket;
extern crate rocket_cors;
use std::io::Cursor;
use rocket::{Response, get, options, routes};
use rocket::http::Method;
use rocket_cors::{Guard, AllowedOrigins, AllowedHeaders, Responder};
/// Using a `Responder` -- the usual way you would use this
#[get("/")]
fn responder(cors: Guard) -> Responder<&str> {
cors.responder("Hello CORS!")
}
/// Using a `Response` instead of a `Responder`. You generally won't have to do this.
#[get("/response")]
fn response(cors: Guard) -> Response {
let mut response = Response::new();
response.set_sized_body(Cursor::new("Hello CORS!"));
cors.response(response)
}
/// Manually mount an OPTIONS route for your own handling
#[options("/manual")]
fn manual_options(cors: Guard) -> Responder<&str> {
cors.responder("Manual OPTIONS preflight handling")
}
/// Manually mount an OPTIONS route for your own handling
#[get("/manual")]
fn manual(cors: Guard) -> Responder<&str> {
cors.responder("Manual OPTIONS preflight handling")
}
fn main() {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
// You can also deserialize this
let options = rocket_cors::Cors {
allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true,
..Default::default()
};
rocket::ignite()
.mount(
"/",
routes![responder, response],
)
// Mount the routes to catch all the OPTIONS pre-flight requests
.mount("/", rocket_cors::catch_all_options_routes())
// You can also manually mount an OPTIONS route that will be used instead
.mount("/", routes![manual, manual_options])
.manage(options)
.launch();
}
```
## Truly Manual ## Truly Manual
@ -261,17 +173,17 @@ has been validated. If validation fails, the closure will not be run. You should
that has any side effects or with an appreciable computation cost inside this handler. that has any side effects or with an appreciable computation cost inside this handler.
### Steps to perform: ### Steps to perform:
- You will first need to have a `Cors` struct ready. This struct can be borrowed with a lifetime - You will first need to have a [`Cors`] struct ready. This struct can be borrowed with a lifetime
at least as long as `'r` which is the lifetime of a Rocket request. `'static` works too. at least as long as `'r` which is the lifetime of a Rocket request. `'static` works too.
In this case, you might as well use the `Guard` method above and place the `Cors` struct in In this case, you might as well use the `Guard` method above and place the `Cors` struct in
Rocket's [state](https://rocket.rs/guide/state/). Rocket's [state](https://rocket.rs/guide/state/).
Alternatively, you can create a `Cors` struct directly in the route. Alternatively, you can create a [`Cors`] struct directly in the route.
- Your routes _might_ need to have a `'r` lifetime and return `impl Responder<'r>`. See below. - Your routes _might_ need to have a `'r` lifetime and return `impl Responder<'r>`. See below.
- Using the `Cors` struct, use either the - Using the [`Cors`] struct, use either the
[`respond_owned`](Cors#method.respond_owned) or [`Cors::respond_owned`] or
[`respond_borrowed`](Cors#method.respond_borrowed) function and pass in a handler [`Cors::respond_borrowed`] function and pass in a handler
that will be executed once CORS validation is successful. that will be executed once CORS validation is successful.
- Your handler will be passed a [`Guard`](Guard) which you will have to use to - Your handler will be passed a [`Guard`] which you will have to use to
add CORS headers into your own response. add CORS headers into your own response.
- You will have to manually define your own `OPTIONS` routes. - You will have to manually define your own `OPTIONS` routes.
@ -291,127 +203,7 @@ required.
You can see examples when the lifetime annotation is required (or not) in `examples/manual.rs`. You can see examples when the lifetime annotation is required (or not) in `examples/manual.rs`.
### Owned example See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/manual.rs).
This is the most likely scenario when you want to have manual CORS validation. You can use this
when the settings you want to use for a route is not the same as the rest of the application
(which you might have put in Rocket's state).
```rust,no_run
#![feature(proc_macro_hygiene, decl_macro)]
extern crate rocket;
extern crate rocket_cors;
use rocket::{get, options, routes};
use rocket::http::Method;
use rocket::response::Responder;
use rocket_cors::{Cors, AllowedOrigins, AllowedHeaders};
/// Create and use an ad-hoc Cors
#[get("/owned")]
fn owned<'r>() -> impl Responder<'r> {
let options = cors_options();
options.respond_owned(|guard| guard.responder("Hello CORS"))
}
/// You need to define an OPTIONS route for preflight checks.
/// These routes can just return the unit type `()`
#[options("/owned")]
fn owned_options<'r>() -> impl Responder<'r> {
let options = cors_options();
options.respond_owned(|guard| guard.responder(()))
}
fn cors_options() -> Cors {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
// You can also deserialize this
rocket_cors::Cors {
allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true,
..Default::default()
}
}
fn main() {
rocket::ignite()
.mount(
"/",
routes![
owned,
owned_options,
],
)
.manage(cors_options())
.launch();
}
```
### Borrowed Example
You might want to borrow the `Cors` struct from Rocket's state, for example. Unless you have
special handling, you might want to use the Guard method instead which has less hassle.
```rust,no_run
#![feature(proc_macro_hygiene, decl_macro)]
extern crate rocket;
extern crate rocket_cors;
use std::io::Cursor;
use rocket::{State, Response, get, routes};
use rocket::http::Method;
use rocket::response::Responder;
use rocket_cors::{Cors, AllowedOrigins, AllowedHeaders};
/// Using a borrowed Cors
#[get("/")]
fn borrowed(options: State<Cors>) -> impl Responder {
options.inner().respond_borrowed(
|guard| guard.responder("Hello CORS"),
)
}
/// Using a `Response` instead of a `Responder`. You generally won't have to do this.
#[get("/response")]
fn response(options: State<Cors>) -> impl Responder {
let mut response = Response::new();
response.set_sized_body(Cursor::new("Hello CORS!"));
options.inner().respond_borrowed(
move |guard| guard.response(response),
)
}
fn cors_options() -> Cors {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
// You can also deserialize this
rocket_cors::Cors {
allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true,
..Default::default()
}
}
fn main() {
rocket::ignite()
.mount(
"/",
routes![
borrowed,
response,
],
)
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
.manage(cors_options())
.launch();
}
```
## Mixing Guard and Manual ## Mixing Guard and Manual
@ -419,83 +211,7 @@ You can mix `Guard` and `Truly Manual` modes together for your application. For
application might restrict the Origins that can access it, except for one `ping` route that application might restrict the Origins that can access it, except for one `ping` route that
allows all access. allows all access.
You can run the example code below with `cargo run --example mix`. See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/guard.rs).
```rust,no_run
#![feature(proc_macro_hygiene, decl_macro)]
extern crate rocket;
extern crate rocket_cors;
use rocket::{get, options, routes};
use rocket::http::Method;
use rocket::response::Responder;
use rocket_cors::{Cors, Guard, AllowedOrigins, AllowedHeaders};
/// The "usual" app route
#[get("/")]
fn app(cors: Guard) -> rocket_cors::Responder<&str> {
cors.responder("Hello CORS!")
}
/// The special "ping" route
#[get("/ping")]
fn ping<'r>() -> impl Responder<'r> {
let options = cors_options_all();
options.respond_owned(|guard| guard.responder("Pong!"))
}
/// You need to define an OPTIONS route for preflight checks if you want to use `Cors` struct
/// that is not in Rocket's managed state.
/// These routes can just return the unit type `()`
#[options("/ping")]
fn ping_options<'r>() -> impl Responder<'r> {
let options = cors_options_all();
options.respond_owned(|guard| guard.responder(()))
}
/// Returns the "application wide" Cors struct
fn cors_options() -> Cors {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
// You can also deserialize this
rocket_cors::Cors {
allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true,
..Default::default()
}
}
/// A special struct that allows all origins
///
/// Note: In your real application, you might want to use something like `lazy_static` to
/// generate a `&'static` reference to this instead of creating a new struct on every request.
fn cors_options_all() -> Cors {
// You can also deserialize this
rocket_cors::Cors {
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
..Default::default()
}
}
fn main() {
rocket::ignite()
.mount(
"/",
routes![
app,
ping,
ping_options,
],
)
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
.manage(cors_options())
.launch();
}
```
## Reference ## Reference
- [Fetch CORS Specification](https://fetch.spec.whatwg.org/#cors-protocol) - [Fetch CORS Specification](https://fetch.spec.whatwg.org/#cors-protocol)
@ -666,7 +382,6 @@ impl fmt::Display for Error {
} }
} }
impl error::Error for Error { impl error::Error for Error {
fn cause(&self) -> Option<&dyn error::Error> { fn cause(&self) -> Option<&dyn error::Error> {
match *self { match *self {
@ -881,10 +596,7 @@ impl AllowedHeaders {
} }
} }
/// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS /// Configuration options for CORS request handling.
///
/// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. See the
/// documentation at the [crate root](index.html) for usage information.
/// ///
/// You create a new copy of this struct by defining the configurations in the fields below. /// You create a new copy of this struct by defining the configurations in the fields below.
/// This struct can also be deserialized by serde with the `serialization` feature which is /// This struct can also be deserialized by serde with the `serialization` feature which is
@ -893,6 +605,8 @@ impl AllowedHeaders {
/// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this /// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this
/// struct. The default for each field is described in the docuementation for the field. /// struct. The default for each field is described in the docuementation for the field.
/// ///
/// Before you can use this with Rocket, you will need to call the [`CorsOptions::to_cors`] method.
///
/// # Examples /// # Examples
/// ///
/// You can run an example from the repository to demonstrate the JSON serialization with /// You can run an example from the repository to demonstrate the JSON serialization with
@ -900,7 +614,7 @@ impl AllowedHeaders {
/// ///
/// ## Pure default /// ## Pure default
/// ```rust /// ```rust
/// let default = rocket_cors::Cors::default(); /// let default = rocket_cors::CorsOptions::default();
/// ``` /// ```
/// ///
/// ## JSON Examples /// ## JSON Examples
@ -959,7 +673,7 @@ impl AllowedHeaders {
/// ``` /// ```
#[derive(Eq, PartialEq, Clone, Debug)] #[derive(Eq, PartialEq, Clone, Debug)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
pub struct Cors { pub struct CorsOptions {
/// Origins that are allowed to make requests. /// Origins that are allowed to make requests.
/// Will be verified against the `Origin` request header. /// Will be verified against the `Origin` request header.
/// ///
@ -986,7 +700,7 @@ pub struct Cors {
/// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
#[cfg_attr( #[cfg_attr(
feature = "serialization", feature = "serialization",
serde(default = "Cors::default_allowed_methods") serde(default = "CorsOptions::default_allowed_methods")
)] )]
pub allowed_methods: AllowedMethods, pub allowed_methods: AllowedMethods,
/// The list of header field names which can be used when this resource is accessed by allowed /// The list of header field names which can be used when this resource is accessed by allowed
@ -1048,7 +762,7 @@ pub struct Cors {
/// Defaults to "/cors" /// Defaults to "/cors"
#[cfg_attr( #[cfg_attr(
feature = "serialization", feature = "serialization",
serde(default = "Cors::default_fairing_route_base") serde(default = "CorsOptions::default_fairing_route_base")
)] )]
pub fairing_route_base: String, pub fairing_route_base: String,
/// When used as Fairing, Cors will need to redirect failed CORS checks to a custom route /// When used as Fairing, Cors will need to redirect failed CORS checks to a custom route
@ -1058,12 +772,12 @@ pub struct Cors {
/// Defaults to 0 /// Defaults to 0
#[cfg_attr( #[cfg_attr(
feature = "serialization", feature = "serialization",
serde(default = "Cors::default_fairing_route_rank") serde(default = "CorsOptions::default_fairing_route_rank")
)] )]
pub fairing_route_rank: isize, pub fairing_route_rank: isize,
} }
impl Default for Cors { impl Default for CorsOptions {
fn default() -> Self { fn default() -> Self {
Self { Self {
allowed_origins: Default::default(), allowed_origins: Default::default(),
@ -1079,7 +793,7 @@ impl Default for Cors {
} }
} }
impl Cors { impl CorsOptions {
fn default_allowed_methods() -> HashSet<Method> { fn default_allowed_methods() -> HashSet<Method> {
use rocket::http::Method; use rocket::http::Method;
@ -1116,6 +830,36 @@ impl Cors {
Ok(()) Ok(())
} }
/// Creates a [`Cors`] struct that can be used to respond to requests or as a Rocket Fairing
pub fn to_cors(&self) -> Result<Cors, Error> {
Cors::from_options(self)
}
}
/// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS
///
/// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. See the
/// documentation at the [crate root](index.html) for usage information.
///
/// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`].
#[derive(Clone, Debug)]
pub struct Cors(CorsOptions);
impl Deref for Cors {
type Target = CorsOptions;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Cors {
/// Create a `Cors` struct from a [`CorsOptions`]
pub fn from_options(options: &CorsOptions) -> Result<Self, Error> {
options.validate()?;
Ok(Cors(options.clone()))
}
/// Manually respond to a request with CORS checks and headers using an Owned `Cors`. /// Manually respond to a request with CORS checks and headers using an Owned `Cors`.
/// ///
/// Use this variant when your `Cors` struct will not live at least as long as the whole `'r` /// Use this variant when your `Cors` struct will not live at least as long as the whole `'r`
@ -1131,7 +875,6 @@ impl Cors {
F: FnOnce(Guard<'r>) -> R + 'r, F: FnOnce(Guard<'r>) -> R + 'r,
R: response::Responder<'r>, R: response::Responder<'r>,
{ {
self.validate()?;
Ok(ManualResponder::new(Cow::Owned(self), handler)) Ok(ManualResponder::new(Cow::Owned(self), handler))
} }
@ -1155,7 +898,6 @@ impl Cors {
F: FnOnce(Guard<'r>) -> R + 'r, F: FnOnce(Guard<'r>) -> R + 'r,
R: response::Responder<'r>, R: response::Responder<'r>,
{ {
self.validate()?;
Ok(ManualResponder::new(Cow::Borrowed(self), handler)) Ok(ManualResponder::new(Cow::Borrowed(self), handler))
} }
} }
@ -1877,11 +1619,11 @@ mod tests {
use super::*; use super::*;
use crate::http::Method; use crate::http::Method;
fn make_cors_options() -> Cors { fn make_cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
Cors { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![http::Method::Get] allowed_methods: vec![http::Method::Get]
.into_iter() .into_iter()
@ -1897,7 +1639,7 @@ mod tests {
} }
} }
fn make_invalid_options() -> Cors { fn make_invalid_options() -> CorsOptions {
let mut cors = make_cors_options(); let mut cors = make_cors_options();
cors.allow_credentials = true; cors.allow_credentials = true;
cors.allowed_origins = AllOrSome::All; cors.allowed_origins = AllOrSome::All;
@ -1930,8 +1672,8 @@ mod tests {
#[cfg(feature = "serialization")] #[cfg(feature = "serialization")]
#[test] #[test]
fn cors_default_deserialization_is_correct() { fn cors_default_deserialization_is_correct() {
let deserialized: Cors = serde_json::from_str("{}").expect("To not fail"); let deserialized: CorsOptions = serde_json::from_str("{}").expect("To not fail");
assert_eq!(deserialized, Cors::default()); assert_eq!(deserialized, CorsOptions::default());
} }
// The following tests check validation // The following tests check validation
@ -2251,7 +1993,7 @@ mod tests {
#[test] #[test]
fn preflight_validated_correctly() { fn preflight_validated_correctly() {
let options = make_cors_options(); let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
@ -2271,7 +2013,7 @@ mod tests {
.header(method_header) .header(method_header)
.header(request_headers); .header(request_headers);
let result = validate(&options, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight { let expected_result = ValidationResult::Preflight {
origin: FromStr::from_str("https://www.acme.com").unwrap(), origin: FromStr::from_str("https://www.acme.com").unwrap(),
// Checks that only a subset of allowed headers are returned // Checks that only a subset of allowed headers are returned
@ -2282,36 +2024,11 @@ mod tests {
assert_eq!(expected_result, result); assert_eq!(expected_result, result);
} }
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn preflight_validation_errors_on_invalid_options() {
let options = make_invalid_options();
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let _ = validate(&options, request.inner()).unwrap();
}
#[test] #[test]
fn preflight_validation_allows_all_origin() { fn preflight_validation_allows_all_origin() {
let mut options = make_cors_options(); let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All; options.allowed_origins = AllOrSome::All;
let cors = options.to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
@ -2331,7 +2048,7 @@ mod tests {
.header(method_header) .header(method_header)
.header(request_headers); .header(request_headers);
let result = validate(&options, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight { let expected_result = ValidationResult::Preflight {
origin: FromStr::from_str("https://www.example.com").unwrap(), origin: FromStr::from_str("https://www.example.com").unwrap(),
headers: Some(FromStr::from_str("Authorization").unwrap()), headers: Some(FromStr::from_str("Authorization").unwrap()),
@ -2343,7 +2060,7 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn preflight_validation_errors_on_invalid_origin() { fn preflight_validation_errors_on_invalid_origin() {
let options = make_cors_options(); let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
@ -2363,13 +2080,13 @@ mod tests {
.header(method_header) .header(method_header)
.header(request_headers); .header(request_headers);
let _ = validate(&options, request.inner()).unwrap(); let _ = validate(&cors, request.inner()).unwrap();
} }
#[test] #[test]
#[should_panic(expected = "MissingRequestMethod")] #[should_panic(expected = "MissingRequestMethod")]
fn preflight_validation_errors_on_missing_request_method() { fn preflight_validation_errors_on_missing_request_method() {
let options = make_cors_options(); let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
@ -2385,13 +2102,13 @@ mod tests {
.header(origin_header) .header(origin_header)
.header(request_headers); .header(request_headers);
let _ = validate(&options, request.inner()).unwrap(); let _ = validate(&cors, request.inner()).unwrap();
} }
#[test] #[test]
#[should_panic(expected = "MethodNotAllowed")] #[should_panic(expected = "MethodNotAllowed")]
fn preflight_validation_errors_on_disallowed_method() { fn preflight_validation_errors_on_disallowed_method() {
let options = make_cors_options(); let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
@ -2411,13 +2128,13 @@ mod tests {
.header(method_header) .header(method_header)
.header(request_headers); .header(request_headers);
let _ = validate(&options, request.inner()).unwrap(); let _ = validate(&cors, request.inner()).unwrap();
} }
#[test] #[test]
#[should_panic(expected = "HeadersNotAllowed")] #[should_panic(expected = "HeadersNotAllowed")]
fn preflight_validation_errors_on_disallowed_headers() { fn preflight_validation_errors_on_disallowed_headers() {
let options = make_cors_options(); let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
@ -2437,19 +2154,19 @@ mod tests {
.header(method_header) .header(method_header)
.header(request_headers); .header(request_headers);
let _ = validate(&options, request.inner()).unwrap(); let _ = validate(&cors, request.inner()).unwrap();
} }
#[test] #[test]
fn actual_request_validated_correctly() { fn actual_request_validated_correctly() {
let options = make_cors_options(); let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let request = client.get("/").header(origin_header); let request = client.get("/").header(origin_header);
let result = validate(&options, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request { let expected_result = ValidationResult::Request {
origin: FromStr::from_str("https://www.acme.com").unwrap(), origin: FromStr::from_str("https://www.acme.com").unwrap(),
}; };
@ -2457,30 +2174,18 @@ mod tests {
assert_eq!(expected_result, result); assert_eq!(expected_result, result);
} }
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn actual_request_validation_errors_on_invalid_options() {
let options = make_invalid_options();
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let request = client.get("/").header(origin_header);
let _ = validate(&options, request.inner()).unwrap();
}
#[test] #[test]
fn actual_request_validation_allows_all_origin() { fn actual_request_validation_allows_all_origin() {
let mut options = make_cors_options(); let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All; options.allowed_origins = AllOrSome::All;
let cors = options.to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
let request = client.get("/").header(origin_header); let request = client.get("/").header(origin_header);
let result = validate(&options, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request { let expected_result = ValidationResult::Request {
origin: FromStr::from_str("https://www.example.com").unwrap(), origin: FromStr::from_str("https://www.example.com").unwrap(),
}; };
@ -2491,23 +2196,23 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn actual_request_validation_errors_on_incorrect_origin() { fn actual_request_validation_errors_on_incorrect_origin() {
let options = make_cors_options(); let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
let request = client.get("/").header(origin_header); let request = client.get("/").header(origin_header);
let _ = validate(&options, request.inner()).unwrap(); let _ = validate(&cors, request.inner()).unwrap();
} }
#[test] #[test]
fn non_cors_request_return_empty_response() { fn non_cors_request_return_empty_response() {
let options = make_cors_options(); let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let request = client.options("/"); let request = client.options("/");
let response = validate_and_build(&options, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new(); let expected_response = Response::new();
assert_eq!(expected_response, response); assert_eq!(expected_response, response);
} }
@ -2515,6 +2220,7 @@ mod tests {
#[test] #[test]
fn preflight_validated_and_built_correctly() { fn preflight_validated_and_built_correctly() {
let options = make_cors_options(); let options = make_cors_options();
let cors = options.to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
@ -2534,7 +2240,7 @@ mod tests {
.header(method_header) .header(method_header)
.header(request_headers); .header(request_headers);
let response = validate_and_build(&options, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false) .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false)
@ -2553,6 +2259,7 @@ mod tests {
let mut options = make_cors_options(); let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All; options.allowed_origins = AllOrSome::All;
options.send_wildcard = false; options.send_wildcard = false;
let cors = options.to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
@ -2573,7 +2280,7 @@ mod tests {
.header(method_header) .header(method_header)
.header(request_headers); .header(request_headers);
let response = validate_and_build(&options, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true) .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true)
@ -2592,6 +2299,7 @@ mod tests {
options.allowed_origins = AllOrSome::All; options.allowed_origins = AllOrSome::All;
options.send_wildcard = true; options.send_wildcard = true;
options.allow_credentials = false; options.allow_credentials = false;
let cors = options.to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
@ -2612,7 +2320,7 @@ mod tests {
.header(method_header) .header(method_header)
.header(request_headers); .header(request_headers);
let response = validate_and_build(&options, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.any() .any()
@ -2627,13 +2335,14 @@ mod tests {
#[test] #[test]
fn actual_request_validated_and_built_correctly() { fn actual_request_validated_and_built_correctly() {
let options = make_cors_options(); let options = make_cors_options();
let cors = options.to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
let origin_header = let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let request = client.get("/").header(origin_header); let request = client.get("/").header(origin_header);
let response = validate_and_build(&options, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false) .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), false)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
@ -2648,6 +2357,7 @@ mod tests {
options.allowed_origins = AllOrSome::All; options.allowed_origins = AllOrSome::All;
options.send_wildcard = false; options.send_wildcard = false;
options.allow_credentials = false; options.allow_credentials = false;
let cors = options.to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
@ -2655,7 +2365,7 @@ mod tests {
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let request = client.get("/").header(origin_header); let request = client.get("/").header(origin_header);
let response = validate_and_build(&options, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true) .origin(&FromStr::from_str("https://www.acme.com/").unwrap(), true)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
@ -2670,6 +2380,7 @@ mod tests {
options.allowed_origins = AllOrSome::All; options.allowed_origins = AllOrSome::All;
options.send_wildcard = true; options.send_wildcard = true;
options.allow_credentials = false; options.allow_credentials = false;
let cors = options.to_cors().expect("To not fail");
let client = make_client(); let client = make_client();
@ -2677,7 +2388,7 @@ mod tests {
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let request = client.get("/").header(origin_header); let request = client.get("/").header(origin_header);
let response = validate_and_build(&options, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.any() .any()
.credentials(options.allow_credentials) .credentials(options.allow_credentials)

View File

@ -21,23 +21,25 @@ fn panicking_route() {
panic!("This route will panic"); panic!("This route will panic");
} }
fn make_cors_options() -> Cors { fn make_cors() -> Cors {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
Cors { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,
..Default::default() ..Default::default()
} }
.to_cors()
.expect("To not fail")
} }
fn rocket() -> rocket::Rocket { fn rocket() -> rocket::Rocket {
rocket::ignite() rocket::ignite()
.mount("/", routes![cors, panicking_route]) .mount("/", routes![cors, panicking_route])
.attach(make_cors_options()) .attach(make_cors())
} }
#[test] #[test]

View File

@ -59,17 +59,19 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo
cors.responder("hmm") cors.responder("hmm")
} }
fn make_cors_options() -> cors::Cors { fn make_cors() -> cors::Cors {
let (allowed_origins, failed_origins) = cors::AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = cors::AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
cors::Cors { cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: cors::AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: cors::AllowedHeaders::some(&["Authorization", "Accept"]),
allow_credentials: true, allow_credentials: true,
..Default::default() ..Default::default()
} }
.to_cors()
.expect("To not fail")
} }
fn make_rocket() -> rocket::Rocket { fn make_rocket() -> rocket::Rocket {
@ -81,7 +83,7 @@ fn make_rocket() -> rocket::Rocket {
) )
.mount("/", cors::catch_all_options_routes()) // mount the catch all routes .mount("/", cors::catch_all_options_routes()) // mount the catch all routes
.mount("/", routes![cors_manual, cors_manual_options]) // manual OPTIOONS routes .mount("/", routes![cors_manual, cors_manual_options]) // manual OPTIOONS routes
.manage(make_cors_options()) .manage(make_cors())
.manage(SomeState) .manage(SomeState)
} }

View File

@ -31,7 +31,7 @@ fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> {
/// Respond with an owned option instead /// Respond with an owned option instead
#[options("/owned")] #[options("/owned")]
fn owned_options<'r>() -> impl Responder<'r> { fn owned_options<'r>() -> impl Responder<'r> {
let borrow = make_different_cors_options(); let borrow = make_different_cors_options().to_cors()?;
borrow.respond_owned(|guard| guard.responder("Manual CORS Preflight")) borrow.respond_owned(|guard| guard.responder("Manual CORS Preflight"))
} }
@ -39,7 +39,7 @@ fn owned_options<'r>() -> impl Responder<'r> {
/// Respond with an owned option instead /// Respond with an owned option instead
#[get("/owned")] #[get("/owned")]
fn owned<'r>() -> impl Responder<'r> { fn owned<'r>() -> impl Responder<'r> {
let borrow = make_different_cors_options(); let borrow = make_different_cors_options().to_cors()?;
borrow.respond_owned(|guard| guard.responder("Hello CORS Owned")) borrow.respond_owned(|guard| guard.responder("Hello CORS Owned"))
} }
@ -65,11 +65,11 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp
}) })
} }
fn make_cors_options() -> Cors { fn make_cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
Cors { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
@ -78,11 +78,11 @@ fn make_cors_options() -> Cors {
} }
} }
fn make_different_cors_options() -> Cors { fn make_different_cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
Cors { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
@ -96,7 +96,7 @@ fn rocket() -> rocket::Rocket {
.mount("/", routes![cors, panicking_route]) .mount("/", routes![cors, panicking_route])
.mount("/", routes![owned, owned_options]) .mount("/", routes![owned, owned_options])
.mount("/", catch_all_options_routes()) // mount the catch all routes .mount("/", catch_all_options_routes()) // mount the catch all routes
.manage(make_cors_options()) .manage(make_cors_options().to_cors().expect("Not to fail"))
} }
#[test] #[test]

View File

@ -14,7 +14,7 @@ use rocket::http::{Header, Method, Status};
use rocket::local::Client; use rocket::local::Client;
use rocket::response::Responder; use rocket::response::Responder;
use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors, Guard}; use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard};
/// The "usual" app route /// The "usual" app route
#[get("/")] #[get("/")]
@ -25,8 +25,8 @@ fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, &str> {
/// The special "ping" route /// The special "ping" route
#[get("/ping")] #[get("/ping")]
fn ping<'r>() -> impl Responder<'r> { fn ping<'r>() -> impl Responder<'r> {
let options = cors_options_all(); let cors = cors_options_all().to_cors()?;
options.respond_owned(|guard| guard.responder("Pong!")) cors.respond_owned(|guard| guard.responder("Pong!"))
} }
/// You need to define an OPTIONS route for preflight checks if you want to use `Cors` struct /// You need to define an OPTIONS route for preflight checks if you want to use `Cors` struct
@ -34,17 +34,17 @@ fn ping<'r>() -> impl Responder<'r> {
/// These routes can just return the unit type `()` /// These routes can just return the unit type `()`
#[options("/ping")] #[options("/ping")]
fn ping_options<'r>() -> impl Responder<'r> { fn ping_options<'r>() -> impl Responder<'r> {
let options = cors_options_all(); let cors = cors_options_all().to_cors()?;
options.respond_owned(|guard| guard.responder(())) cors.respond_owned(|guard| guard.responder(()))
} }
/// Returns the "application wide" Cors struct /// Returns the "application wide" Cors struct
fn cors_options() -> Cors { fn cors_options() -> CorsOptions {
let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]);
assert!(failed_origins.is_empty()); assert!(failed_origins.is_empty());
// You can also deserialize this // You can also deserialize this
rocket_cors::Cors { rocket_cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]),
@ -57,7 +57,7 @@ fn cors_options() -> Cors {
/// ///
/// Note: In your real application, you might want to use something like `lazy_static` to generate /// Note: In your real application, you might want to use something like `lazy_static` to generate
/// a `&'static` reference to this instead of creating a new struct on every request. /// a `&'static` reference to this instead of creating a new struct on every request.
fn cors_options_all() -> Cors { fn cors_options_all() -> CorsOptions {
// You can also deserialize this // You can also deserialize this
Default::default() Default::default()
} }
@ -66,7 +66,7 @@ fn rocket() -> rocket::Rocket {
rocket::ignite() rocket::ignite()
.mount("/", routes![app, ping, ping_options,]) .mount("/", routes![app, ping, ping_options,])
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
.manage(cors_options()) .manage(cors_options().to_cors().expect("Not to fail"))
} }
#[test] #[test]