diff --git a/Cargo.toml b/Cargo.toml index fbf3f81..ab4c072 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rocket_cors" -version = "0.1.1" +version = "0.1.2" license = "Apache-2.0" authors = ["Yong Wen Chua "] build = "build.rs" diff --git a/README.md b/README.md index 0bbd35c..9feab5c 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ might work, but it's not guaranteed. Add the following to Cargo.toml: ```toml -rocket_cors = "0.1.1" +rocket_cors = "0.1.2" ``` To use the latest `master` branch, for example: diff --git a/examples/guard.rs b/examples/guard.rs index e8696b6..4a27455 100644 --- a/examples/guard.rs +++ b/examples/guard.rs @@ -57,7 +57,10 @@ fn main() { }; rocket::ignite() - .mount("/", routes![responder, responder_options, response, response_options]) + .mount( + "/", + routes![responder, responder_options, response, response_options], + ) .manage(options) .launch(); } diff --git a/src/fairing.rs b/src/fairing.rs index 66703d5..4510581 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -1,9 +1,39 @@ //! Fairing implementation use rocket::{self, Request, Outcome}; -use rocket::http::{self, Status}; +use rocket::http::{self, Status, Header}; use {Cors, Error, validate, preflight_response, actual_request_response, origin, request_headers}; +/// An injected header to quickly give the result of CORS +static CORS_HEADER: &str = "ROCKET-CORS"; +enum InjectedHeader { + Success, + Failure, +} + +impl InjectedHeader { + fn to_str(&self) -> &'static str { + match *self { + InjectedHeader::Success => "Success", + InjectedHeader::Failure => "Failure", + } + } + + fn from_str(s: &str) -> Result { + match s { + "Success" => Ok(InjectedHeader::Success), + "Failure" => Ok(InjectedHeader::Failure), + other => { + error_!( + "Unknown injected header encountered: {}\nThis is probably a bug.", + other + ); + Err(Error::UnknownInjectedHeader) + } + } + } +} + /// Route for Fairing error handling pub(crate) fn fairing_error_route<'r>( request: &'r Request, @@ -28,6 +58,11 @@ fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Req request.set_uri(format!("{}/{}", options.fairing_route_base, status)); } +/// Inject a header into the Request with result +fn inject_request_header(header: InjectedHeader, request: &mut Request) { + request.replace_header(Header::new(CORS_HEADER, header.to_str())); +} + fn on_response_wrapper( options: &Cors, request: &Request, @@ -41,6 +76,17 @@ fn on_response_wrapper( Some(origin) => origin, }; + // Get validation result from injected header + let injected_header = request.headers().get_one(CORS_HEADER).ok_or_else(|| { + Error::MissingInjectedHeader + })?; + let result = InjectedHeader::from_str(injected_header)?; + + if let InjectedHeader::Failure = result { + // Nothing else for us to do + return Ok(()); + } + let cors_response = if request.method() == http::Method::Options { let headers = request_headers(request)?; preflight_response(options, origin, headers) @@ -87,13 +133,17 @@ impl rocket::fairing::Fairing for Cors { } fn on_request(&self, request: &mut Request, _: &rocket::Data) { - // Build and merge CORS response - let cors_response = validate(self, request); - if let Err(ref err) = cors_response { - error_!("CORS Error: {}", err); - let status = err.status(); - route_to_fairing_error_handler(self, status.code, request); - } + let injected_header = match validate(self, request) { + Ok(_) => InjectedHeader::Success, + Err(err) => { + error_!("CORS Error: {}", err); + let status = err.status(); + route_to_fairing_error_handler(self, status.code, request); + InjectedHeader::Failure + } + }; + + inject_request_header(injected_header, request); } fn on_response(&self, request: &Request, response: &mut rocket::Response) { diff --git a/src/lib.rs b/src/lib.rs index 809e631..2df1698 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ //! Add the following to Cargo.toml: //! //! ```toml -//! rocket_cors = "0.1.1" +//! rocket_cors = "0.1.2" //! ``` //! //! To use the latest `master` branch, for example: @@ -355,6 +355,12 @@ pub enum Error { /// /// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state. MissingCorsInRocketState, + /// The `on_response` handler of Fairing could not find the injected header from the Request. + /// Either some other fairing has removed it, or this is a bug. + MissingInjectedHeader, + /// The `on_response` handler of Fairing found an unknown injected header value from the + /// Request. Either some other fairing has modified it, or this is a bug. + UnknownInjectedHeader, } impl Error { @@ -363,7 +369,9 @@ impl Error { Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed | Error::HeadersNotAllowed => Status::Forbidden, Error::CredentialsWithWildcardOrigin | - Error::MissingCorsInRocketState => Status::InternalServerError, + Error::MissingCorsInRocketState | + Error::MissingInjectedHeader | + Error::UnknownInjectedHeader => Status::InternalServerError, _ => Status::BadRequest, } } @@ -395,6 +403,14 @@ impl error::Error for Error { Error::MissingCorsInRocketState => { "A CORS Request Guard was used, but no CORS Options was available in Rocket's state" } + Error::MissingInjectedHeader => { + "The `on_response` handler of Fairing could not find the injected header from the \ + Request. Either some other fairing has removed it, or this is a bug." + } + Error::UnknownInjectedHeader => { + "The `on_response` handler of Fairing found an unknown injected header value from \ + the Request. Either some other fairing has modified it, or this is a bug." + } } } diff --git a/tests/fairing.rs b/tests/fairing.rs index 9d579f0..b2607c4 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -83,6 +83,11 @@ fn smoke_test() { let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com/", origin_header); } #[test] @@ -107,6 +112,12 @@ fn cors_options_check() { let response = req.dispatch(); assert!(response.status().class().is_success()); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com/", origin_header); } #[test] @@ -124,6 +135,12 @@ fn cors_get_check() { assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com/", origin_header); } /// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) @@ -182,6 +199,13 @@ fn cors_options_missing_origin() { let response = req.dispatch(); assert_eq!(response.status(), Status::NotFound); + + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } #[test] @@ -206,6 +230,12 @@ fn cors_options_bad_request_method() { let response = req.dispatch(); assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } #[test] @@ -229,6 +259,12 @@ fn cors_options_bad_request_header() { let response = req.dispatch(); assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } #[test] @@ -243,6 +279,12 @@ fn cors_get_bad_origin() { let response = req.dispatch(); assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } /// This test ensures that on a failing CORS request, the route (along with its side effects) @@ -270,4 +312,10 @@ fn routes_failing_checks_are_not_executed() { let response = req.dispatch(); assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } diff --git a/tests/guard.rs b/tests/guard.rs index 091f09f..45f4da0 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -121,6 +121,11 @@ fn smoke_test() { let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com/", origin_header); } #[test] @@ -146,6 +151,12 @@ fn cors_options_check() { let response = req.dispatch(); assert!(response.status().class().is_success()); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com/", origin_header); } #[test] @@ -164,6 +175,12 @@ fn cors_get_check() { assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); + + let origin_header = response + .headers() + .get_one("Access-Control-Allow-Origin") + .expect("to exist"); + assert_eq!("https://www.acme.com/", origin_header); } /// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) @@ -179,6 +196,12 @@ fn cors_get_no_origin() { assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } #[test] @@ -204,6 +227,12 @@ fn cors_options_bad_origin() { let response = req.dispatch(); assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } #[test] @@ -224,6 +253,12 @@ fn cors_options_missing_origin() { let response = req.dispatch(); assert!(response.status().class().is_success()); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } #[test] @@ -249,6 +284,12 @@ fn cors_options_bad_request_method() { let response = req.dispatch(); assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } #[test] @@ -273,6 +314,12 @@ fn cors_options_bad_request_header() { let response = req.dispatch(); assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } #[test] @@ -288,6 +335,12 @@ fn cors_get_bad_origin() { let response = req.dispatch(); assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); } /// This test ensures that on a failing CORS request, the route (along with its side effects) @@ -306,4 +359,10 @@ fn routes_failing_checks_are_not_executed() { let response = req.dispatch(); assert_eq!(response.status(), Status::Forbidden); + assert!( + response + .headers() + .get_one("Access-Control-Allow-Origin") + .is_none() + ); }