diff --git a/examples/guard.rs b/examples/guard.rs index e014e8b..2edb297 100644 --- a/examples/guard.rs +++ b/examples/guard.rs @@ -15,13 +15,6 @@ fn responder(cors: Guard) -> Responder<&str> { cors.responder("Hello CORS!") } -/// You need to define an OPTIONS route for preflight checks. -/// These routes can just return the unit type `()` -#[options("/")] -fn responder_options(cors: Guard) -> Responder<()> { - cors.responder(()) -} - /// Using a `Response` instead of a `Responder`. You generally won't have to do this. #[get("/response")] fn response(cors: Guard) -> Response { @@ -30,12 +23,16 @@ fn response(cors: Guard) -> Response { cors.response(response) } -/// You need to define an OPTIONS route for preflight checks. -/// These routes can just return the unit type `()` -#[options("/response")] -fn response_options(cors: Guard) -> Response { - let response = Response::new(); - cors.response(response) +/// Manually mount an OPTIONS route for your own handling +#[options("/manual")] +fn manual_options(cors: Guard) -> Responder<&str> { + cors.responder("Manual OPTIONS preflight handling") +} + +/// Manually mount an OPTIONS route for your own handling +#[get("/manual")] +fn manual(cors: Guard) -> Responder<&str> { + cors.responder("Manual OPTIONS preflight handling") } fn main() { @@ -54,8 +51,12 @@ fn main() { rocket::ignite() .mount( "/", - routes![responder, responder_options, response, response_options], + routes![responder, response], ) + // Mount the routes to catch all the OPTIONS pre-flight requests + .mount("/", rocket_cors::catch_all_options_routes()) + // You can also manually mount an OPTIONS route that will be used instead + .mount("/", routes![manual, manual_options]) .manage(options) .launch(); } diff --git a/examples/manual.rs b/examples/manual.rs index 67a3e65..8a8df54 100644 --- a/examples/manual.rs +++ b/examples/manual.rs @@ -18,15 +18,6 @@ fn borrowed<'r>(options: State<'r, Cors>) -> impl Responder<'r> { ) } -/// You need to define an OPTIONS route for preflight checks. -/// These routes can just return the unit type `()` -#[options("/")] -fn borrowed_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { - options.inner().respond_borrowed( - |guard| guard.responder(()), - ) -} - /// Using a `Response` instead of a `Responder`. You generally won't have to do this. #[get("/response")] fn response<'r>(options: State<'r, Cors>) -> impl Responder<'r> { @@ -38,15 +29,6 @@ fn response<'r>(options: State<'r, Cors>) -> impl Responder<'r> { ) } -/// You need to define an OPTIONS route for preflight checks. -/// These routes can just return the unit type `()` -#[options("/response")] -fn response_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { - options.inner().respond_borrowed( - move |guard| guard.response(Response::new()), - ) -} - /// Create and use an ad-hoc Cors #[get("/owned")] fn owned<'r>() -> impl Responder<'r> { @@ -54,7 +36,8 @@ fn owned<'r>() -> impl Responder<'r> { options.respond_owned(|guard| guard.responder("Hello CORS")) } -/// You need to define an OPTIONS route for preflight checks. +/// You need to define an OPTIONS route for preflight checks if you want to use `Cors` struct +/// that is not in Rocket's managed state. /// These routes can just return the unit type `()` #[options("/owned")] fn owned_options<'r>() -> impl Responder<'r> { @@ -82,13 +65,12 @@ fn main() { "/", routes![ borrowed, - borrowed_options, response, - response_options, owned, owned_options, ], ) + .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes .manage(cors_options()) .launch(); } diff --git a/src/lib.rs b/src/lib.rs index 9704943..49f4c6f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,7 +76,7 @@ //! |:---------------------------------------:|:-------:|:-------------:|:------:| //! | Must apply to all routes | ✔ | ✗ | ✗ | //! | Different settings for different routes | ✗ | ✗ | ✔ | -//! | Must define OPTIONS route | ✗ | ✔ | ✔ | +//! | May define custom OPTIONS routes | ✗ | ✔ | ✔ | //! //! ### Fairing //! @@ -128,16 +128,16 @@ //! //! Using request guard requires you to sacrifice the convenience of Fairings for being able to //! opt some routes out of CORS checks and enforcement. _BUT_ you are still restricted to only -//! one set of CORS settings and you will now have to define `OPTIONS` routes for all the routes -//! you want to have CORS checks on. The `OPTIONS` routes are used for CORS preflight checks. +//! one set of CORS settings and you have to mount additional routes to catch and process OPTIONS +//! requests. The `OPTIONS` routes are used for CORS preflight checks. //! //! You will have to do the following: //! //! - Create a [`Cors` struct](struct.Cors.html) and during Rocket's ignite, add the struct to //! Rocket's [managed state](https://rocket.rs/guide/state/#managed-state). -//! - For all the routes that you want to enforce CORS on, you have to define a `OPTIONS` route -//! for the path. You can use [dynamic segments](https://rocket.rs/guide/requests/#dynamic-segments) -//! to reduce the number of routes you have to define. +//! - For all the routes that you want to enforce CORS on, you can mount either some +//! [catch all route](fn.catch_all_options_routes.html) or define your own route for the OPTIONS +//! verb. //! - Then in all the routes you want to enforce CORS on, add a //! [Request Guard](https://rocket.rs/guide/requests/#request-guards) for the //! [`Guard`](struct.Guard.html) struct in the route arguments. You should not wrap this in an @@ -146,11 +146,8 @@ //! - In your routes, to add CORS headers to your responses, use the appropriate functions on the //! [`Guard`](struct.Guard.html) for a `Response` or a `Responder`. //! -//! You can mix this with the "manual" checks, but whichever `Response` is the last merged will -//! overwrite the previous CORS headers. -//! //! ```rust,no_run -//! #![feature(plugin, custom_derive)] +//! #![feature(plugin)] //! #![plugin(rocket_codegen)] //! extern crate rocket; //! extern crate rocket_cors; @@ -167,27 +164,24 @@ //! cors.responder("Hello CORS!") //! } //! -//! /// You need to define an OPTIONS route for preflight checks. -//! /// These routes can just return the unit type `()` -//! #[options("/")] -//! fn responder_options(cors: Guard) -> Responder<()> { -//! cors.responder(()) -//! } -//! //! /// Using a `Response` instead of a `Responder`. You generally won't have to do this. -//! #[get("/responder")] +//! #[get("/response")] //! fn response(cors: Guard) -> Response { //! let mut response = Response::new(); //! response.set_sized_body(Cursor::new("Hello CORS!")); //! cors.response(response) //! } //! -//! /// You need to define an OPTIONS route for preflight checks. -//! /// These routes can just return the unit type `()` -//! #[options("/responder")] -//! fn response_options(cors: Guard) -> Response { -//! let response = Response::new(); -//! cors.response(response) +//! /// Manually mount an OPTIONS route for your own handling +//! #[options("/manual")] +//! fn manual_options(cors: Guard) -> Responder<&str> { +//! cors.responder("Manual OPTIONS preflight handling") +//! } +//! +//! /// Manually mount an OPTIONS route for your own handling +//! #[get("/manual")] +//! fn manual(cors: Guard) -> Responder<&str> { +//! cors.responder("Manual OPTIONS preflight handling") //! } //! //! fn main() { @@ -204,11 +198,17 @@ //! }; //! //! rocket::ignite() -//! .mount("/", routes![responder, responder_options, response, response_options]) +//! .mount( +//! "/", +//! routes![responder, response], +//! ) +//! // Mount the routes to catch all the OPTIONS pre-flight requests +//! .mount("/", rocket_cors::catch_all_options_routes()) +//! // You can also manually mount an OPTIONS route that will be used instead +//! .mount("/", routes![manual, manual_options]) //! .manage(options) //! .launch(); //! } -//! //! ``` //! //! ## Truly Manual @@ -338,15 +338,6 @@ //! ) //! } //! -//! /// You need to define an OPTIONS route for preflight checks. -//! /// These routes can just return the unit type `()` -//! #[options("/")] -//! fn borrowed_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { -//! options.inner().respond_borrowed( -//! |guard| guard.responder(()), -//! ) -//! } -//! //! /// Using a `Response` instead of a `Responder`. You generally won't have to do this. //! #[get("/response")] //! fn response<'r>(options: State<'r, Cors>) -> impl Responder<'r> { @@ -358,15 +349,6 @@ //! ) //! } //! -//! /// You need to define an OPTIONS route for preflight checks. -//! /// These routes can just return the unit type `()` -//! #[options("/response")] -//! fn response_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { -//! options.inner().respond_borrowed( -//! move |guard| guard.response(Response::new()), -//! ) -//! } -//! //! fn cors_options() -> Cors { //! let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); //! assert!(failed_origins.is_empty()); @@ -387,11 +369,10 @@ //! "/", //! routes![ //! borrowed, -//! borrowed_options, //! response, -//! response_options, //! ], //! ) +//! .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes //! .manage(cors_options()) //! .launch(); //! } @@ -1725,6 +1706,51 @@ fn actual_request_response(options: &Cors, origin: Origin) -> Response { response } +/// Returns "catch all" OPTIONS routes that you can mount to catch all OPTIONS request. Only works +/// if you have put a `Cors` struct into Rocket's managed state. +/// +/// This route has very high rank (and therefore low priority) of +/// [max value](https://doc.rust-lang.org/nightly/std/primitive.isize.html#method.max_value) +/// so you can define your own to override this route's behaviour. +/// +/// See the documentation at the [crate root](index.html) for usage information. +pub fn catch_all_options_routes() -> Vec { + vec![ + rocket::Route::ranked( + isize::max_value(), + http::Method::Options, + "/", + catch_all_options_route_handler + ), + rocket::Route::ranked( + isize::max_value(), + http::Method::Options, + "/", + catch_all_options_route_handler + ), + ] +} + +/// Handler for the "catch all options route" +fn catch_all_options_route_handler<'r>( + request: &'r Request, + _: rocket::Data, +) -> rocket::handler::Outcome<'r> { + + let guard: Guard = match request.guard() { + Outcome::Success(guard) => guard, + Outcome::Failure((status, _)) => return rocket::handler::Outcome::failure(status), + Outcome::Forward(()) => unreachable!("Should not be reachable"), + }; + + info_!( + "\"Catch all\" handling of CORS `OPTIONS` preflight for request {}", + request + ); + + rocket::handler::Outcome::from(request, guard.responder(())) +} + #[cfg(test)] mod tests { use std::str::FromStr; diff --git a/tests/guard.rs b/tests/guard.rs index 15c9901..bc0b5b1 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -13,11 +13,6 @@ use rocket::http::Method; use rocket::http::{Header, Status}; use rocket::local::Client; -#[options("/")] -fn cors_options(cors: cors::Guard) -> cors::Responder<&str> { - cors.responder("") -} - #[get("/")] fn cors(cors: cors::Guard) -> cors::Responder<&str> { cors.responder("Hello CORS") @@ -28,33 +23,39 @@ fn panicking_route(_cors: cors::Guard) { panic!("This route will panic"); } -// The following routes tests that the routes can be compiled with ad-hoc CORS Response/Responders +/// Manually specify our own OPTIONS route +#[options("/manual")] +fn cors_manual_options(cors: cors::Guard) -> cors::Responder<&str> { + cors.responder("Manual CORS Preflight") +} + +/// Manually specify our own OPTIONS route +#[get("/manual")] +fn cors_manual(cors: cors::Guard) -> cors::Responder<&str> { + cors.responder("Hello CORS") +} /// Using a `Response` instead of a `Responder` -#[allow(unmounted_route)] -#[get("/")] +#[get("/response")] fn response(cors: cors::Guard) -> Response { cors.response(Response::new()) } /// `Responder` with String -#[allow(unmounted_route)] -#[get("/")] +#[get("/responder/string")] fn responder_string(cors: cors::Guard) -> cors::Responder { cors.responder("Hello CORS".to_string()) } /// `Responder` with 'static () -#[allow(unmounted_route)] -#[get("/")] +#[get("/responder/unit")] fn responder_unit(cors: cors::Guard) -> cors::Responder<()> { cors.responder(()) } struct SomeState; /// Borrow `SomeState` from Rocket -#[allow(unmounted_route)] -#[get("/")] +#[get("/state")] fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Responder<'r, &'r str> { cors.responder("hmm") } @@ -74,8 +75,12 @@ fn make_cors_options() -> cors::Cors { fn make_rocket() -> rocket::Rocket { rocket::ignite() - .mount("/", routes![cors, cors_options, panicking_route]) + .mount("/", routes![cors, panicking_route]) + .mount("/", routes![response, responder_string, responder_unit, state]) + .mount("/", cors::catch_all_options_routes()) // mount the catch all routes + .mount("/", routes![cors_manual, cors_manual_options]) // manual OPTIOONS routes .manage(make_cors_options()) + .manage(SomeState) } #[test] @@ -122,8 +127,9 @@ fn smoke_test() { assert_eq!("https://www.acme.com", origin_header); } +/// Check the "catch all" OPTIONS route works for `/` #[test] -fn cors_options_check() { +fn cors_options_catch_all_check() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); @@ -153,6 +159,39 @@ fn cors_options_check() { assert_eq!("https://www.acme.com", origin_header); } + +/// Check the "catch all" OPTIONS route works for other routes +#[test] +fn cors_options_catch_all_check_other_routes() { + let rocket = make_rocket(); + let client = Client::new(rocket).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/response/unit") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert!(response.status().class().is_success()); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com", origin_header); +} + #[test] fn cors_get_check() { let rocket = make_rocket(); @@ -360,3 +399,38 @@ fn routes_failing_checks_are_not_executed() { .is_none() ); } + +/// This test ensures that manually mounted CORS OPTIONS routes are used even in the presence of +/// a "catch all" route. +#[test] +fn overridden_options_routes_are_used() { + let rocket = make_rocket(); + let client = Client::new(rocket).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/manual") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let mut response = req.dispatch(); + let body_str = response.body().and_then(|body| body.into_string()); + assert!(response.status().class().is_success()); + assert_eq!(body_str, Some("Manual CORS Preflight".to_string())); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com", origin_header); +} diff --git a/tests/manual.rs b/tests/manual.rs index cb21444..bb963b9 100644 --- a/tests/manual.rs +++ b/tests/manual.rs @@ -15,14 +15,6 @@ use rocket::local::Client; use rocket::response::Responder; use rocket_cors::*; -/// Using a borrowed `Cors` -#[options("/")] -fn cors_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { - options.inner().respond_borrowed( - |guard| guard.responder(()), - ) -} - /// Using a borrowed `Cors` #[get("/")] fn cors<'r>(options: State<'r, Cors>) -> impl Responder<'r> { @@ -31,13 +23,6 @@ fn cors<'r>(options: State<'r, Cors>) -> impl Responder<'r> { ) } -#[options("/panic")] -fn panicking_route_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { - options.inner().respond_borrowed( - |guard| guard.responder(()), - ) -} - #[get("/panic")] fn panicking_route<'r>(options: State<'r, Cors>) -> impl Responder<'r> { options.inner().respond_borrowed(|_| -> () { @@ -45,6 +30,22 @@ fn panicking_route<'r>(options: State<'r, Cors>) -> impl Responder<'r> { }) } +/// Respond with an owned option instead +#[options("/owned")] +fn owned_options<'r>() -> impl Responder<'r> { + let borrow = make_different_cors_options(); + + borrow.respond_owned(|guard| guard.responder("Manual CORS Preflight")) +} + +/// Respond with an owned option instead +#[get("/owned")] +fn owned<'r>() -> impl Responder<'r> { + let borrow = make_different_cors_options(); + + borrow.respond_owned(|guard| guard.responder("Hello CORS Owned")) +} + // The following routes tests that the routes can be compiled with manual CORS /// `Responder` with String @@ -68,15 +69,6 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp }) } -/// Respond with an owned option instead -#[allow(unmounted_route)] -#[get("/")] -fn owned<'r>() -> impl Responder<'r> { - let borrow = make_cors_options(); - - borrow.respond_owned(|guard| guard.responder("Hello CORS")) -} - fn make_cors_options() -> Cors { let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); @@ -90,14 +82,25 @@ fn make_cors_options() -> Cors { } } +fn make_different_cors_options() -> Cors { + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); + assert!(failed_origins.is_empty()); + + Cors { + allowed_origins: allowed_origins, + allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), + allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), + allow_credentials: true, + ..Default::default() + } +} + fn rocket() -> rocket::Rocket { rocket::ignite() - .mount( - "/", - routes![cors, cors_options, panicking_route, panicking_route_options], - ) + .mount("/", routes![cors, panicking_route]) + .mount("/", routes![owned, owned_options]) + .mount("/", catch_all_options_routes()) // mount the catch all routes .manage(make_cors_options()) - .attach(make_cors_options()) } #[test] @@ -144,7 +147,7 @@ fn smoke_test() { } #[test] -fn cors_options_check() { +fn cors_options_borrowed_check() { let client = Client::new(rocket()).unwrap(); let origin_header = Header::from( @@ -174,7 +177,7 @@ fn cors_options_check() { } #[test] -fn cors_get_check() { +fn cors_get_borrowed_check() { let client = Client::new(rocket()).unwrap(); let origin_header = Header::from( @@ -370,3 +373,63 @@ fn routes_failing_checks_are_not_executed() { .is_none() ); } + +/// Manual OPTIONS routes are called +#[test] +fn cors_options_owned_check() { + let rocket = rocket(); + let client = Client::new(rocket).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.example.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/owned") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let mut response = req.dispatch(); + let body_str = response.body().and_then(|body| body.into_string()); + assert!(response.status().class().is_success()); + assert_eq!(body_str, Some("Manual CORS Preflight".to_string())); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.example.com", origin_header); +} + +/// Owned manual response works +#[test] +fn cors_get_owned_check() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.example.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/owned").header(origin_header).header( + authorization, + ); + + let mut response = req.dispatch(); + println!("{:?}", response); + assert!(response.status().class().is_success()); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS Owned".to_string())); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.example.com", origin_header); +}