Fix an issue where Fairing on_response will inject CORS headers into failed CORS requests

This commit is contained in:
Yong Wen Chua 2017-07-19 09:51:31 +08:00
parent 539157e0f0
commit fcd83e8fb5
7 changed files with 189 additions and 13 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "rocket_cors" name = "rocket_cors"
version = "0.1.1" version = "0.1.2"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Yong Wen Chua <me@yongwen.xyz>"] authors = ["Yong Wen Chua <me@yongwen.xyz>"]
build = "build.rs" build = "build.rs"

View File

@ -29,7 +29,7 @@ might work, but it's not guaranteed.
Add the following to Cargo.toml: Add the following to Cargo.toml:
```toml ```toml
rocket_cors = "0.1.1" rocket_cors = "0.1.2"
``` ```
To use the latest `master` branch, for example: To use the latest `master` branch, for example:

View File

@ -57,7 +57,10 @@ fn main() {
}; };
rocket::ignite() rocket::ignite()
.mount("/", routes![responder, responder_options, response, response_options]) .mount(
"/",
routes![responder, responder_options, response, response_options],
)
.manage(options) .manage(options)
.launch(); .launch();
} }

View File

@ -1,9 +1,39 @@
//! Fairing implementation //! Fairing implementation
use rocket::{self, Request, Outcome}; use rocket::{self, Request, Outcome};
use rocket::http::{self, Status}; use rocket::http::{self, Status, Header};
use {Cors, Error, validate, preflight_response, actual_request_response, origin, request_headers}; use {Cors, Error, validate, preflight_response, actual_request_response, origin, request_headers};
/// An injected header to quickly give the result of CORS
static CORS_HEADER: &str = "ROCKET-CORS";
enum InjectedHeader {
Success,
Failure,
}
impl InjectedHeader {
fn to_str(&self) -> &'static str {
match *self {
InjectedHeader::Success => "Success",
InjectedHeader::Failure => "Failure",
}
}
fn from_str(s: &str) -> Result<Self, Error> {
match s {
"Success" => Ok(InjectedHeader::Success),
"Failure" => Ok(InjectedHeader::Failure),
other => {
error_!(
"Unknown injected header encountered: {}\nThis is probably a bug.",
other
);
Err(Error::UnknownInjectedHeader)
}
}
}
}
/// Route for Fairing error handling /// Route for Fairing error handling
pub(crate) fn fairing_error_route<'r>( pub(crate) fn fairing_error_route<'r>(
request: &'r Request, request: &'r Request,
@ -28,6 +58,11 @@ fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Req
request.set_uri(format!("{}/{}", options.fairing_route_base, status)); request.set_uri(format!("{}/{}", options.fairing_route_base, status));
} }
/// Inject a header into the Request with result
fn inject_request_header(header: InjectedHeader, request: &mut Request) {
request.replace_header(Header::new(CORS_HEADER, header.to_str()));
}
fn on_response_wrapper( fn on_response_wrapper(
options: &Cors, options: &Cors,
request: &Request, request: &Request,
@ -41,6 +76,17 @@ fn on_response_wrapper(
Some(origin) => origin, Some(origin) => origin,
}; };
// Get validation result from injected header
let injected_header = request.headers().get_one(CORS_HEADER).ok_or_else(|| {
Error::MissingInjectedHeader
})?;
let result = InjectedHeader::from_str(injected_header)?;
if let InjectedHeader::Failure = result {
// Nothing else for us to do
return Ok(());
}
let cors_response = if request.method() == http::Method::Options { let cors_response = if request.method() == http::Method::Options {
let headers = request_headers(request)?; let headers = request_headers(request)?;
preflight_response(options, origin, headers) preflight_response(options, origin, headers)
@ -87,13 +133,17 @@ impl rocket::fairing::Fairing for Cors {
} }
fn on_request(&self, request: &mut Request, _: &rocket::Data) { fn on_request(&self, request: &mut Request, _: &rocket::Data) {
// Build and merge CORS response let injected_header = match validate(self, request) {
let cors_response = validate(self, request); Ok(_) => InjectedHeader::Success,
if let Err(ref err) = cors_response { Err(err) => {
error_!("CORS Error: {}", err); error_!("CORS Error: {}", err);
let status = err.status(); let status = err.status();
route_to_fairing_error_handler(self, status.code, request); route_to_fairing_error_handler(self, status.code, request);
} InjectedHeader::Failure
}
};
inject_request_header(injected_header, request);
} }
fn on_response(&self, request: &Request, response: &mut rocket::Response) { fn on_response(&self, request: &Request, response: &mut rocket::Response) {

View File

@ -27,7 +27,7 @@
//! Add the following to Cargo.toml: //! Add the following to Cargo.toml:
//! //!
//! ```toml //! ```toml
//! rocket_cors = "0.1.1" //! rocket_cors = "0.1.2"
//! ``` //! ```
//! //!
//! To use the latest `master` branch, for example: //! To use the latest `master` branch, for example:
@ -355,6 +355,12 @@ pub enum Error {
/// ///
/// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state. /// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state.
MissingCorsInRocketState, MissingCorsInRocketState,
/// The `on_response` handler of Fairing could not find the injected header from the Request.
/// Either some other fairing has removed it, or this is a bug.
MissingInjectedHeader,
/// The `on_response` handler of Fairing found an unknown injected header value from the
/// Request. Either some other fairing has modified it, or this is a bug.
UnknownInjectedHeader,
} }
impl Error { impl Error {
@ -363,7 +369,9 @@ impl Error {
Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed | Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed |
Error::HeadersNotAllowed => Status::Forbidden, Error::HeadersNotAllowed => Status::Forbidden,
Error::CredentialsWithWildcardOrigin | Error::CredentialsWithWildcardOrigin |
Error::MissingCorsInRocketState => Status::InternalServerError, Error::MissingCorsInRocketState |
Error::MissingInjectedHeader |
Error::UnknownInjectedHeader => Status::InternalServerError,
_ => Status::BadRequest, _ => Status::BadRequest,
} }
} }
@ -395,6 +403,14 @@ impl error::Error for Error {
Error::MissingCorsInRocketState => { Error::MissingCorsInRocketState => {
"A CORS Request Guard was used, but no CORS Options was available in Rocket's state" "A CORS Request Guard was used, but no CORS Options was available in Rocket's state"
} }
Error::MissingInjectedHeader => {
"The `on_response` handler of Fairing could not find the injected header from the \
Request. Either some other fairing has removed it, or this is a bug."
}
Error::UnknownInjectedHeader => {
"The `on_response` handler of Fairing found an unknown injected header value from \
the Request. Either some other fairing has modified it, or this is a bug."
}
} }
} }

View File

@ -83,6 +83,11 @@ fn smoke_test() {
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
} }
#[test] #[test]
@ -107,6 +112,12 @@ fn cors_options_check() {
let response = req.dispatch(); let response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
} }
#[test] #[test]
@ -124,6 +135,12 @@ fn cors_get_check() {
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
} }
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) /// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
@ -182,6 +199,13 @@ fn cors_options_missing_origin() {
let response = req.dispatch(); let response = req.dispatch();
assert_eq!(response.status(), Status::NotFound); assert_eq!(response.status(), Status::NotFound);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }
#[test] #[test]
@ -206,6 +230,12 @@ fn cors_options_bad_request_method() {
let response = req.dispatch(); let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden); assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }
#[test] #[test]
@ -229,6 +259,12 @@ fn cors_options_bad_request_header() {
let response = req.dispatch(); let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden); assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }
#[test] #[test]
@ -243,6 +279,12 @@ fn cors_get_bad_origin() {
let response = req.dispatch(); let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden); assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }
/// This test ensures that on a failing CORS request, the route (along with its side effects) /// This test ensures that on a failing CORS request, the route (along with its side effects)
@ -270,4 +312,10 @@ fn routes_failing_checks_are_not_executed() {
let response = req.dispatch(); let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden); assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }

View File

@ -121,6 +121,11 @@ fn smoke_test() {
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
} }
#[test] #[test]
@ -146,6 +151,12 @@ fn cors_options_check() {
let response = req.dispatch(); let response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
} }
#[test] #[test]
@ -164,6 +175,12 @@ fn cors_get_check() {
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
} }
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) /// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
@ -179,6 +196,12 @@ fn cors_get_no_origin() {
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string()); let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string())); assert_eq!(body_str, Some("Hello CORS".to_string()));
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }
#[test] #[test]
@ -204,6 +227,12 @@ fn cors_options_bad_origin() {
let response = req.dispatch(); let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden); assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }
#[test] #[test]
@ -224,6 +253,12 @@ fn cors_options_missing_origin() {
let response = req.dispatch(); let response = req.dispatch();
assert!(response.status().class().is_success()); assert!(response.status().class().is_success());
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }
#[test] #[test]
@ -249,6 +284,12 @@ fn cors_options_bad_request_method() {
let response = req.dispatch(); let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden); assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }
#[test] #[test]
@ -273,6 +314,12 @@ fn cors_options_bad_request_header() {
let response = req.dispatch(); let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden); assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }
#[test] #[test]
@ -288,6 +335,12 @@ fn cors_get_bad_origin() {
let response = req.dispatch(); let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden); assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }
/// This test ensures that on a failing CORS request, the route (along with its side effects) /// This test ensures that on a failing CORS request, the route (along with its side effects)
@ -306,4 +359,10 @@ fn routes_failing_checks_are_not_executed() {
let response = req.dispatch(); let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden); assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
} }