diff --git a/src/lib.rs b/src/lib.rs index f43e024..ecc19e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -266,7 +266,7 @@ while_true, )] -#![cfg_attr(test, feature(plugin, custom_derive))] +#![cfg_attr(test, feature(plugin))] #![cfg_attr(test, plugin(rocket_codegen))] #![doc(test(attr(allow(unused_variables), deny(warnings))))] @@ -837,8 +837,52 @@ impl Cors { Ok(()) } + + /// Manually respond to a request with CORS checks and headers using an Owned `Cors`. + /// + /// Use this variant when your `Cors` struct will not live at least as long as the whole `'r` + /// lifetime of the request. + /// + /// After the CORS checks are done, the passed in handler closure will be run to generate a + /// final response. You will have to merge your response with the `Guard` that you have been + /// passed in to include the CORS headers. + /// + /// See the documentation at the [crate root](index.html) for usage information. + pub fn respond_owned<'r, F, R>(self, handler: F) -> Result, Error> + where + F: Fn(Guard<'r>) -> R + 'r, + R: response::Responder<'r>, + { + self.validate()?; + Ok(ManualResponder::new(Cow::Owned(self), handler)) + } + + /// Manually respond to a request with CORS checks and headers using a borrowed `Cors`. + /// + /// Use this variant when your `Cors` struct will live at least as long as the whole `'r` + /// lifetime of the request. If you are getting your `Cors` from Rocket's state, you will have + /// to use the [`inner` function](https://api.rocket.rs/rocket/struct.State.html#method.inner) + /// to get a longer borrowed lifetime. + /// + /// After the CORS checks are done, the passed in handler closure will be run to generate a + /// final response. You will have to merge your response with the `Guard` that you have been + /// passed in to include the CORS headers. + /// + /// See the documentation at the [crate root](index.html) for usage information. + pub fn respond_borrowed<'r, F, R>( + &'r self, + handler: F, + ) -> Result, Error> + where + F: Fn(Guard<'r>) -> R + 'r, + R: response::Responder<'r>, + { + self.validate()?; + Ok(ManualResponder::new(Cow::Borrowed(self), handler)) + } } + /// A CORS Response which provides the following CORS headers: /// /// - `Access-Control-Allow-Origin` @@ -1123,28 +1167,31 @@ impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R } } -/// The type of closure that will be used to generate a Rocket `Responder` -/// after passing CORS checks. -/// This is used in the "truly manual" mode of CORS handling. -/// -/// See the documentation at the [crate root](index.html) for usage information. -pub type GuardHandler<'r, R> = Fn(Guard<'r>) -> R; - /// A Manual Responder used in the "truly manual" mode of operation. /// /// See the documentation at the [crate root](index.html) for usage information. -pub struct ManualResponder<'r, R> { +pub struct ManualResponder<'r, F, R> { options: Cow<'r, Cors>, - handler: Box>, + handler: F, + marker: PhantomData, } -impl<'r, R: response::Responder<'r>> ManualResponder<'r, R> { +impl<'r, F, R> ManualResponder<'r, F, R> +where + F: Fn(Guard<'r>) -> R + 'r, + R: response::Responder<'r>, +{ /// Create a new manual responder by passing in either a borrowed or owned `Cors` option. /// /// A borrowed `Cors` option must live for the entirety of the `'r` lifetime which is the /// lifetime of the entire Rocket request. - pub fn new(options: Cow<'r, Cors>, handler: Box>) -> Self { - Self { options, handler } + fn new(options: Cow<'r, Cors>, handler: F) -> Self { + let marker = PhantomData; + Self { + options, + handler, + marker, + } } fn build_guard(&self, request: &Request) -> Result, Error> { @@ -1153,7 +1200,11 @@ impl<'r, R: response::Responder<'r>> ManualResponder<'r, R> { } } -impl<'r, R: response::Responder<'r>> response::Responder<'r> for ManualResponder<'r, R> { +impl<'r, F, R> response::Responder<'r> for ManualResponder<'r, F, R> +where + F: Fn(Guard<'r>) -> R + 'r, + R: response::Responder<'r>, +{ fn respond_to(self, request: &Request) -> response::Result<'r> { let guard = match self.build_guard(request) { Ok(guard) => guard, diff --git a/tests/manual.rs b/tests/manual.rs new file mode 100644 index 0000000..bea020b --- /dev/null +++ b/tests/manual.rs @@ -0,0 +1,372 @@ +//! This crate tests using rocket_cors using manual mode + +#![feature(plugin, custom_derive, conservative_impl_trait)] +#![plugin(rocket_codegen)] +extern crate hyper; +extern crate rocket; +extern crate rocket_cors; + +use std::str::FromStr; + +use rocket::State; +use rocket::http::Method; +use rocket::http::{Header, Status}; +use rocket::local::Client; +use rocket::response::Responder; +use rocket_cors::*; + +/// Using a borrowed `Cors` +#[options("/")] +fn cors_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed( + |guard| guard.responder(()), + ) +} + +/// Using a borrowed `Cors` +#[get("/")] +fn cors<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed( + |guard| guard.responder("Hello CORS"), + ) +} + +#[options("/panic")] +fn panicking_route_options<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed( + |guard| guard.responder(()), + ) +} + +#[get("/panic")] +fn panicking_route<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed(|_| -> () { + panic!("This route will panic"); + }) +} + +// The following routes tests that the routes can be compiled with manual CORS + +/// `Responder` with String +#[allow(unmounted_route)] +#[get("/")] +fn responder_string<'r>(options: State<'r, Cors>) -> impl Responder<'r> { + options.inner().respond_borrowed(|guard| { + guard.responder("Hello CORS".to_string()) + }) +} + +struct TestState; +/// Borrow something else from Rocket with lifetime `'r` +#[allow(unmounted_route)] +#[get("/")] +fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> impl Responder<'r> { + let borrow = test_state.inner(); + options.inner().respond_borrowed(move |guard| { + let _ = borrow; + guard.responder("Hello CORS".to_string()) + }) +} + +/// Respond with an owned option instead +#[allow(unmounted_route)] +#[get("/")] +fn owned<'r>() -> impl Responder<'r> { + let borrow = make_cors_options(); + + borrow.respond_owned(|guard| guard.responder("Hello CORS")) +} + +fn make_cors_options() -> Cors { + let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.acme.com"]); + assert!(failed_origins.is_empty()); + + Cors { + allowed_origins: allowed_origins, + allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), + allowed_headers: AllowedHeaders::some(&["Authorization", "Accept"]), + allow_credentials: true, + ..Default::default() + } +} + +fn rocket() -> rocket::Rocket { + rocket::ignite() + .mount( + "/", + routes![cors, cors_options, panicking_route, panicking_route_options], + ) + .manage(make_cors_options()) + .attach(make_cors_options()) +} + +#[test] +fn smoke_test() { + let client = Client::new(rocket()).unwrap(); + + // `Options` pre-flight checks + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert!(response.status().class().is_success()); + + // "Actual" request + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let mut response = req.dispatch(); + assert!(response.status().class().is_success()); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com", origin_header); +} + +#[test] +fn cors_options_check() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert!(response.status().class().is_success()); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com", origin_header); +} + +#[test] +fn cors_get_check() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let mut response = req.dispatch(); + println!("{:?}", response); + assert!(response.status().class().is_success()); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com", origin_header); +} + +/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) +#[test] +fn cors_get_no_origin() { + let client = Client::new(rocket()).unwrap(); + + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(authorization); + + let mut response = req.dispatch(); + assert!(response.status().class().is_success()); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); +} + +#[test] +fn cors_options_bad_origin() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +} + +#[test] +fn cors_options_missing_origin() { + let client = Client::new(rocket()).unwrap(); + + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client.options("/").header(method_header).header( + request_headers, + ); + + let response = req.dispatch(); + assert!(response.status().class().is_success()); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); +} + +#[test] +fn cors_options_bad_request_method() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Post, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); +} + +#[test] +fn cors_options_bad_request_header() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = + hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); +} + +#[test] +fn cors_get_bad_origin() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); +} + +/// This test ensures that on a failing CORS request, the route (along with its side effects) +/// should never be executed. +/// The route used will panic if executed +#[test] +fn routes_failing_checks_are_not_executed() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/panic") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); +}