From 92d7775b938844ab66b4bfaf37fd6ad864f87214 Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Mon, 24 Jul 2017 13:11:10 +0800 Subject: [PATCH] "Truly manual" API (#22) * Experimental "truly manual" API * Add API for general usage * Add documentation and example Change Fn to FnOnce to allow for moving and consuming --- examples/fairing.rs | 2 +- examples/guard.rs | 6 +- examples/manual.rs | 94 +++++++++++ src/lib.rs | 320 ++++++++++++++++++++++++++++++++++--- tests/fairing.rs | 2 +- tests/guard.rs | 2 +- tests/headers.rs | 2 +- tests/manual.rs | 372 ++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 769 insertions(+), 31 deletions(-) create mode 100644 examples/manual.rs create mode 100644 tests/manual.rs diff --git a/examples/fairing.rs b/examples/fairing.rs index e666270..eca79b1 100644 --- a/examples/fairing.rs +++ b/examples/fairing.rs @@ -1,4 +1,4 @@ -#![feature(plugin, custom_derive)] +#![feature(plugin)] #![plugin(rocket_codegen)] extern crate rocket; extern crate rocket_cors; diff --git a/examples/guard.rs b/examples/guard.rs index 3b7c7b3..e014e8b 100644 --- a/examples/guard.rs +++ b/examples/guard.rs @@ -1,4 +1,4 @@ -#![feature(plugin, custom_derive)] +#![feature(plugin)] #![plugin(rocket_codegen)] extern crate rocket; extern crate rocket_cors; @@ -23,7 +23,7 @@ fn responder_options(cors: Guard) -> Responder<()> { } /// Using a `Response` instead of a `Responder`. You generally won't have to do this. -#[get("/responder")] +#[get("/response")] fn response(cors: Guard) -> Response { let mut response = Response::new(); response.set_sized_body(Cursor::new("Hello CORS!")); @@ -32,7 +32,7 @@ fn response(cors: Guard) -> Response { /// You need to define an OPTIONS route for preflight checks. /// These routes can just return the unit type `()` -#[options("/responder")] +#[options("/response")] fn response_options(cors: Guard) -> Response { let response = Response::new(); cors.response(response) diff --git a/examples/manual.rs b/examples/manual.rs new file mode 100644 index 0000000..67a3e65 --- /dev/null +++ b/examples/manual.rs @@ -0,0 +1,94 @@ +#![feature(plugin, conservative_impl_trait)] +#![plugin(rocket_codegen)] +extern crate rocket; +extern crate rocket_cors; + +use std::io::Cursor; + +use rocket::{State, Response}; +use rocket::http::Method; +use rocket::response::Responder; +use rocket_cors::{Cors, AllowedOrigins, AllowedHeaders}; + +/// Using a borrowed Cors +#[get("/")] +fn borrowed<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed( + |guard| guard.responder("Hello CORS"), + ) +} + +/// You need to define an OPTIONS route for preflight checks. +/// These routes can just return the unit type `()` +#[options("/")] +fn borrowed_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed( + |guard| guard.responder(()), + ) +} + +/// Using a `Response` instead of a `Responder`. You generally won't have to do this. +#[get("/response")] +fn response<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + let mut response = Response::new(); + response.set_sized_body(Cursor::new("Hello CORS!")); + + options.inner().respond_borrowed( + move |guard| guard.response(response), + ) +} + +/// You need to define an OPTIONS route for preflight checks. +/// These routes can just return the unit type `()` +#[options("/response")] +fn response_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed( + move |guard| guard.response(Response::new()), + ) +} + +/// Create and use an ad-hoc Cors +#[get("/owned")] +fn owned<'r>() -> impl Responder<'r> { + let options = cors_options(); + options.respond_owned(|guard| guard.responder("Hello CORS")) +} + +/// You need to define an OPTIONS route for preflight checks. +/// These routes can just return the unit type `()` +#[options("/owned")] +fn owned_options<'r>() -> impl Responder<'r> { + let options = cors_options(); + options.respond_owned(|guard| guard.responder(())) +} + +fn cors_options() -> Cors { + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); + assert!(failed_origins.is_empty()); + + // You can also deserialize this + rocket_cors::Cors { + allowed_origins: allowed_origins, + allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), + allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), + allow_credentials: true, + ..Default::default() + } +} + +fn main() { + rocket::ignite() + .mount( + "/", + routes![ + borrowed, + borrowed_options, + response, + response_options, + owned, + owned_options, + ], + ) + .manage(cors_options()) + .launch(); +} diff --git a/src/lib.rs b/src/lib.rs index f3113fc..9704943 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,7 +61,7 @@ //! //! - Fairing (should only used exclusively) //! - Request Guard -//! - Truly Manual (not supported yet, [#13](https://github.com/lawliet89/rocket_cors/issues/13)) +//! - Truly Manual //! //! Unfortunately, you cannot mix and match Fairing with any other of the methods, due to the //! limitation of Rocket's fairing API. That is, the checks for Fairing will always happen first, @@ -213,7 +213,189 @@ //! //! ## Truly Manual //! -//! This is not supported yet. See [#13](https://github.com/lawliet89/rocket_cors/issues/13). +//! This mode is the most difficult to use but offers the most amount of flexibility. +//! You might have to understand how the library works internally to know how to use this mode. +//! In exchange, you can selectively choose which routes to offer CORS protection to, and you +//! can mix and match CORS settings for the routes. You can combine usage of this mode with +//! "guard" to offer a mix of ease of use and flexibility. +//! +//! You really do not need to use this unless you have a truly ad-hoc need to respond to CORS +//! differently in a route. For example, you have a `ping` endpoint that allows all origins but +//! the rest of your routes do not. +//! +//! ### Handler +//! +//! This mode requires that you pass in a closure that will be lazily evaluated once a CORS request +//! has been validated. If validation fails, the closure will not be run. You should put any code +//! that has any side effects or with an appreciable computation cost inside this handler. +//! +//! ### Steps to perform: +//! - Your crate will need to enable the +//! [`conservative_impl_trait`](https://github.com/rust-lang/rfcs/blob/master/text/1522-conservative-impl-trait.md) +//! feature. You can use `#![feature(conservative_impl_trait)]` at your crate root. +//! Otherwise, the return type of your routes will be unspecifiable. +//! - You will first need to have a `Cors` struct ready. This struct can be borrowed with a lifetime +//! at least as long as `'r` which is the lifetime of a Rocket request. `'static` works too. +//! In this case, you might as well use the `Guard` method above and place the `Cors` struct in +//! Rocket's [state](https://rocket.rs/guide/state/). +//! Alternatively, you can create a `Cors` struct directly in the route. +//! - Your routes will need to have a `'r` lifetime and return `impl Responder<'r>`. +//! - Using the `Cors` struct, use either the +//! [`respond_owned`](struct.Cors.html#method.respond_owned) or +//! [`respond_borrowed`](struct.Cors.html#method.respond_borrowed) function and pass in a handler +//! that will be executed once CORS validation is successful. +//! - Your handler will be passed a [`Guard`](struct.Guard.html) which you will have to use to +//! add CORS headers into your own response. +//! - You will have to manually define your own `OPTIONS` routes. +//! +//! ### Notes about route lifetime +//! It is unfortunate that you have to manually specify the `'r` lifetimes in your routes. +//! Leaving out the lifetime will result in a +//! [compiler panic](https://github.com/rust-lang/rust/issues/43380). Even if the panic is fixed, +//! it is not known if we can exclude the lifetime because lifetimes are _elided_ in Rust, +//! not inferred. +//! +//! ### Owned example +//! This is the most likely scenario when you want to have manual CORS validation. You can use this +//! when the settings you want to use for a route is not the same as the rest of the application +//! (which you might have put in Rocket's state). +//! +//! ```rust,no_run +//! #![feature(plugin, conservative_impl_trait)] +//! #![plugin(rocket_codegen)] +//! extern crate rocket; +//! extern crate rocket_cors; +//! +//! use rocket::http::Method; +//! use rocket::response::Responder; +//! use rocket_cors::{Cors, AllowedOrigins, AllowedHeaders}; +//! +//! /// Create and use an ad-hoc Cors +//! #[get("/owned")] +//! fn owned<'r>() -> impl Responder<'r> { +//! let options = cors_options(); +//! options.respond_owned(|guard| guard.responder("Hello CORS")) +//! } +//! +//! /// You need to define an OPTIONS route for preflight checks. +//! /// These routes can just return the unit type `()` +//! #[options("/owned")] +//! fn owned_options<'r>() -> impl Responder<'r> { +//! let options = cors_options(); +//! options.respond_owned(|guard| guard.responder(())) +//! } +//! +//! fn cors_options() -> Cors { +//! let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); +//! assert!(failed_origins.is_empty()); +//! +//! // You can also deserialize this +//! rocket_cors::Cors { +//! allowed_origins: allowed_origins, +//! allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), +//! allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), +//! allow_credentials: true, +//! ..Default::default() +//! } +//! } +//! +//! fn main() { +//! rocket::ignite() +//! .mount( +//! "/", +//! routes![ +//! owned, +//! owned_options, +//! ], +//! ) +//! .manage(cors_options()) +//! .launch(); +//! } +//! ``` +//! +//! ### Borrowed Example +//! You might want to borrow the `Cors` struct from Rocket's state, for example. Unless you have +//! special handling, you might want to use the Guard method instead which has less hassle. +//! +//! ```rust,no_run +//! #![feature(plugin, conservative_impl_trait)] +//! #![plugin(rocket_codegen)] +//! extern crate rocket; +//! extern crate rocket_cors; +//! +//! use std::io::Cursor; +//! +//! use rocket::{State, Response}; +//! use rocket::http::Method; +//! use rocket::response::Responder; +//! use rocket_cors::{Cors, AllowedOrigins, AllowedHeaders}; +//! +//! /// Using a borrowed Cors +//! #[get("/")] +//! fn borrowed<'r>(options: State<'r, Cors>) -> impl Responder<'r> { +//! options.inner().respond_borrowed( +//! |guard| guard.responder("Hello CORS"), +//! ) +//! } +//! +//! /// You need to define an OPTIONS route for preflight checks. +//! /// These routes can just return the unit type `()` +//! #[options("/")] +//! fn borrowed_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { +//! options.inner().respond_borrowed( +//! |guard| guard.responder(()), +//! ) +//! } +//! +//! /// Using a `Response` instead of a `Responder`. You generally won't have to do this. +//! #[get("/response")] +//! fn response<'r>(options: State<'r, Cors>) -> impl Responder<'r> { +//! let mut response = Response::new(); +//! response.set_sized_body(Cursor::new("Hello CORS!")); +//! +//! options.inner().respond_borrowed( +//! move |guard| guard.response(response), +//! ) +//! } +//! +//! /// You need to define an OPTIONS route for preflight checks. +//! /// These routes can just return the unit type `()` +//! #[options("/response")] +//! fn response_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { +//! options.inner().respond_borrowed( +//! move |guard| guard.response(Response::new()), +//! ) +//! } +//! +//! fn cors_options() -> Cors { +//! let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); +//! assert!(failed_origins.is_empty()); +//! +//! // You can also deserialize this +//! rocket_cors::Cors { +//! allowed_origins: allowed_origins, +//! allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), +//! allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), +//! allow_credentials: true, +//! ..Default::default() +//! } +//! } +//! +//! fn main() { +//! rocket::ignite() +//! .mount( +//! "/", +//! routes![ +//! borrowed, +//! borrowed_options, +//! response, +//! response_options, +//! ], +//! ) +//! .manage(cors_options()) +//! .launch(); +//! } +//! ``` #![allow( legacy_directory_ownership, @@ -266,7 +448,7 @@ while_true, )] -#![cfg_attr(test, feature(plugin, custom_derive))] +#![cfg_attr(test, feature(plugin))] #![cfg_attr(test, plugin(rocket_codegen))] #![doc(test(attr(allow(unused_variables), deny(warnings))))] @@ -296,6 +478,7 @@ mod fairing; pub mod headers; +use std::borrow::Cow; use std::collections::{HashSet, HashMap}; use std::error; use std::fmt; @@ -826,19 +1009,6 @@ impl Cors { "/cors".to_string() } - /// Validate a request and then return a CORS Response - /// - /// You will usually not have to use this function but simply place a r - /// equest guard in the route argument for the `Guard` type. - /// - /// This is useful if you want an even more ad-hoc based approach to respond to - /// CORS by using a `Cors` that is not in Rocket's managed state. - #[doc(hidden)] // Need to figure out a way to do this - pub fn validate_request<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result { - let response = validate_and_build(self, request)?; - Ok(response) - } - /// Validates if any of the settings are disallowed or incorrect /// /// This is run during initial Fairing attachment @@ -849,8 +1019,52 @@ impl Cors { Ok(()) } + + /// Manually respond to a request with CORS checks and headers using an Owned `Cors`. + /// + /// Use this variant when your `Cors` struct will not live at least as long as the whole `'r` + /// lifetime of the request. + /// + /// After the CORS checks are done, the passed in handler closure will be run to generate a + /// final response. You will have to merge your response with the `Guard` that you have been + /// passed in to include the CORS headers. + /// + /// See the documentation at the [crate root](index.html) for usage information. + pub fn respond_owned<'r, F, R>(self, handler: F) -> Result, Error> + where + F: FnOnce(Guard<'r>) -> R + 'r, + R: response::Responder<'r>, + { + self.validate()?; + Ok(ManualResponder::new(Cow::Owned(self), handler)) + } + + /// Manually respond to a request with CORS checks and headers using a borrowed `Cors`. + /// + /// Use this variant when your `Cors` struct will live at least as long as the whole `'r` + /// lifetime of the request. If you are getting your `Cors` from Rocket's state, you will have + /// to use the [`inner` function](https://api.rocket.rs/rocket/struct.State.html#method.inner) + /// to get a longer borrowed lifetime. + /// + /// After the CORS checks are done, the passed in handler closure will be run to generate a + /// final response. You will have to merge your response with the `Guard` that you have been + /// passed in to include the CORS headers. + /// + /// See the documentation at the [crate root](index.html) for usage information. + pub fn respond_borrowed<'r, F, R>( + &'r self, + handler: F, + ) -> Result, Error> + where + F: FnOnce(Guard<'r>) -> R + 'r, + R: response::Responder<'r>, + { + self.validate()?; + Ok(ManualResponder::new(Cow::Borrowed(self), handler)) + } } + /// A CORS Response which provides the following CORS headers: /// /// - `Access-Control-Allow-Origin` @@ -1135,6 +1349,56 @@ impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R } } +/// A Manual Responder used in the "truly manual" mode of operation. +/// +/// See the documentation at the [crate root](index.html) for usage information. +pub struct ManualResponder<'r, F, R> { + options: Cow<'r, Cors>, + handler: F, + marker: PhantomData, +} + +impl<'r, F, R> ManualResponder<'r, F, R> +where + F: FnOnce(Guard<'r>) -> R + 'r, + R: response::Responder<'r>, +{ + /// Create a new manual responder by passing in either a borrowed or owned `Cors` option. + /// + /// A borrowed `Cors` option must live for the entirety of the `'r` lifetime which is the + /// lifetime of the entire Rocket request. + fn new(options: Cow<'r, Cors>, handler: F) -> Self { + let marker = PhantomData; + Self { + options, + handler, + marker, + } + } + + fn build_guard(&self, request: &Request) -> Result, Error> { + let response = Response::validate_and_build(&self.options, request)?; + Ok(Guard::new(response)) + } +} + +impl<'r, F, R> response::Responder<'r> for ManualResponder<'r, F, R> +where + F: FnOnce(Guard<'r>) -> R + 'r, + R: response::Responder<'r>, +{ + fn respond_to(self, request: &Request) -> response::Result<'r> { + let guard = match self.build_guard(request) { + Ok(guard) => guard, + Err(err) => { + error_!("CORS error: {}", err); + return Err(err.status()); + } + }; + (self.handler)(guard).respond_to(request) + } +} + /// Result of CORS validation. /// /// The variants hold enough information to build a response to the validation result @@ -1564,8 +1828,10 @@ mod tests { fn response_sets_exposed_headers_correctly() { let headers = vec!["Bar", "Baz", "Foo"]; let response = Response::new(); - let response = - response.origin(&FromStr::from_str("https://www.example.com").unwrap(), false); + let response = response.origin( + &FromStr::from_str("https://www.example.com").unwrap(), + false, + ); let response = response.exposed_headers(&headers); // Build response and check built response header @@ -1587,8 +1853,10 @@ mod tests { #[test] fn response_sets_max_age_correctly() { let response = Response::new(); - let response = - response.origin(&FromStr::from_str("https://www.example.com").unwrap(), false); + let response = response.origin( + &FromStr::from_str("https://www.example.com").unwrap(), + false, + ); let response = response.max_age(Some(42)); @@ -1602,8 +1870,10 @@ mod tests { #[test] fn response_does_not_set_max_age_when_none() { let response = Response::new(); - let response = - response.origin(&FromStr::from_str("https://www.example.com").unwrap(), false); + let response = response.origin( + &FromStr::from_str("https://www.example.com").unwrap(), + false, + ); let response = response.max_age(None); @@ -1717,8 +1987,10 @@ mod tests { .finalize(); let response = Response::new(); - let response = - response.origin(&FromStr::from_str("https://www.example.com").unwrap(), false); + let response = response.origin( + &FromStr::from_str("https://www.example.com").unwrap(), + false, + ); let response = response.response(original); // Check CORS header let expected_header = vec!["https://www.example.com"]; diff --git a/tests/fairing.rs b/tests/fairing.rs index c54f3d9..9d0c365 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -1,6 +1,6 @@ //! This crate tests using rocket_cors using Fairings -#![feature(plugin, custom_derive)] +#![feature(plugin)] #![plugin(rocket_codegen)] extern crate hyper; extern crate rocket; diff --git a/tests/guard.rs b/tests/guard.rs index ec17d8d..15c9901 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -1,6 +1,6 @@ //! This crate tests using rocket_cors using the per-route handling with request guard -#![feature(plugin, custom_derive)] +#![feature(plugin)] #![plugin(rocket_codegen)] extern crate hyper; extern crate rocket; diff --git a/tests/headers.rs b/tests/headers.rs index 544d823..573197a 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -1,5 +1,5 @@ //! This crate tests that all the request headers are parsed correctly in the round trip -#![feature(plugin, custom_derive)] +#![feature(plugin)] #![plugin(rocket_codegen)] extern crate hyper; extern crate rocket; diff --git a/tests/manual.rs b/tests/manual.rs new file mode 100644 index 0000000..cb21444 --- /dev/null +++ b/tests/manual.rs @@ -0,0 +1,372 @@ +//! This crate tests using rocket_cors using manual mode + +#![feature(plugin, conservative_impl_trait)] +#![plugin(rocket_codegen)] +extern crate hyper; +extern crate rocket; +extern crate rocket_cors; + +use std::str::FromStr; + +use rocket::State; +use rocket::http::Method; +use rocket::http::{Header, Status}; +use rocket::local::Client; +use rocket::response::Responder; +use rocket_cors::*; + +/// Using a borrowed `Cors` +#[options("/")] +fn cors_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed( + |guard| guard.responder(()), + ) +} + +/// Using a borrowed `Cors` +#[get("/")] +fn cors<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed( + |guard| guard.responder("Hello CORS"), + ) +} + +#[options("/panic")] +fn panicking_route_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed( + |guard| guard.responder(()), + ) +} + +#[get("/panic")] +fn panicking_route<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed(|_| -> () { + panic!("This route will panic"); + }) +} + +// The following routes tests that the routes can be compiled with manual CORS + +/// `Responder` with String +#[allow(unmounted_route)] +#[get("/")] +fn responder_string<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed(|guard| { + guard.responder("Hello CORS".to_string()) + }) +} + +struct TestState; +/// Borrow something else from Rocket with lifetime `'r` +#[allow(unmounted_route)] +#[get("/")] +fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> impl Responder<'r> { + let borrow = test_state.inner(); + options.inner().respond_borrowed(move |guard| { + let _ = borrow; + guard.responder("Hello CORS".to_string()) + }) +} + +/// Respond with an owned option instead +#[allow(unmounted_route)] +#[get("/")] +fn owned<'r>() -> impl Responder<'r> { + let borrow = make_cors_options(); + + borrow.respond_owned(|guard| guard.responder("Hello CORS")) +} + +fn make_cors_options() -> Cors { + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); + assert!(failed_origins.is_empty()); + + Cors { + allowed_origins: allowed_origins, + allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), + allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), + allow_credentials: true, + ..Default::default() + } +} + +fn rocket() -> rocket::Rocket { + rocket::ignite() + .mount( + "/", + routes![cors, cors_options, panicking_route, panicking_route_options], + ) + .manage(make_cors_options()) + .attach(make_cors_options()) +} + +#[test] +fn smoke_test() { + let client = Client::new(rocket()).unwrap(); + + // `Options` pre-flight checks + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert!(response.status().class().is_success()); + + // "Actual" request + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let mut response = req.dispatch(); + assert!(response.status().class().is_success()); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com", origin_header); +} + +#[test] +fn cors_options_check() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert!(response.status().class().is_success()); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com", origin_header); +} + +#[test] +fn cors_get_check() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let mut response = req.dispatch(); + println!("{:?}", response); + assert!(response.status().class().is_success()); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com", origin_header); +} + +/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) +#[test] +fn cors_get_no_origin() { + let client = Client::new(rocket()).unwrap(); + + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(authorization); + + let mut response = req.dispatch(); + assert!(response.status().class().is_success()); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); +} + +#[test] +fn cors_options_bad_origin() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +} + +#[test] +fn cors_options_missing_origin() { + let client = Client::new(rocket()).unwrap(); + + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client.options("/").header(method_header).header( + request_headers, + ); + + let response = req.dispatch(); + assert!(response.status().class().is_success()); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); +} + +#[test] +fn cors_options_bad_request_method() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Post, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); +} + +#[test] +fn cors_options_bad_request_header() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = + hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); +} + +#[test] +fn cors_get_bad_origin() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); +} + +/// This test ensures that on a failing CORS request, the route (along with its side effects) +/// should never be executed. +/// The route used will panic if executed +#[test] +fn routes_failing_checks_are_not_executed() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/panic") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); +}