Fix an issue where Fairing on_response will inject CORS headers into failed CORS requests

This commit is contained in:
Yong Wen Chua 2017-07-19 09:51:31 +08:00
parent 539157e0f0
commit fcd83e8fb5
7 changed files with 189 additions and 13 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "rocket_cors"
version = "0.1.1"
version = "0.1.2"
license = "Apache-2.0"
authors = ["Yong Wen Chua <me@yongwen.xyz>"]
build = "build.rs"

View File

@ -29,7 +29,7 @@ might work, but it's not guaranteed.
Add the following to Cargo.toml:
```toml
rocket_cors = "0.1.1"
rocket_cors = "0.1.2"
```
To use the latest `master` branch, for example:

View File

@ -57,7 +57,10 @@ fn main() {
};
rocket::ignite()
.mount("/", routes![responder, responder_options, response, response_options])
.mount(
"/",
routes![responder, responder_options, response, response_options],
)
.manage(options)
.launch();
}

View File

@ -1,9 +1,39 @@
//! Fairing implementation
use rocket::{self, Request, Outcome};
use rocket::http::{self, Status};
use rocket::http::{self, Status, Header};
use {Cors, Error, validate, preflight_response, actual_request_response, origin, request_headers};
/// An injected header to quickly give the result of CORS
static CORS_HEADER: &str = "ROCKET-CORS";
enum InjectedHeader {
Success,
Failure,
}
impl InjectedHeader {
fn to_str(&self) -> &'static str {
match *self {
InjectedHeader::Success => "Success",
InjectedHeader::Failure => "Failure",
}
}
fn from_str(s: &str) -> Result<Self, Error> {
match s {
"Success" => Ok(InjectedHeader::Success),
"Failure" => Ok(InjectedHeader::Failure),
other => {
error_!(
"Unknown injected header encountered: {}\nThis is probably a bug.",
other
);
Err(Error::UnknownInjectedHeader)
}
}
}
}
/// Route for Fairing error handling
pub(crate) fn fairing_error_route<'r>(
request: &'r Request,
@ -28,6 +58,11 @@ fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Req
request.set_uri(format!("{}/{}", options.fairing_route_base, status));
}
/// Inject a header into the Request with result
fn inject_request_header(header: InjectedHeader, request: &mut Request) {
request.replace_header(Header::new(CORS_HEADER, header.to_str()));
}
fn on_response_wrapper(
options: &Cors,
request: &Request,
@ -41,6 +76,17 @@ fn on_response_wrapper(
Some(origin) => origin,
};
// Get validation result from injected header
let injected_header = request.headers().get_one(CORS_HEADER).ok_or_else(|| {
Error::MissingInjectedHeader
})?;
let result = InjectedHeader::from_str(injected_header)?;
if let InjectedHeader::Failure = result {
// Nothing else for us to do
return Ok(());
}
let cors_response = if request.method() == http::Method::Options {
let headers = request_headers(request)?;
preflight_response(options, origin, headers)
@ -87,13 +133,17 @@ impl rocket::fairing::Fairing for Cors {
}
fn on_request(&self, request: &mut Request, _: &rocket::Data) {
// Build and merge CORS response
let cors_response = validate(self, request);
if let Err(ref err) = cors_response {
error_!("CORS Error: {}", err);
let status = err.status();
route_to_fairing_error_handler(self, status.code, request);
}
let injected_header = match validate(self, request) {
Ok(_) => InjectedHeader::Success,
Err(err) => {
error_!("CORS Error: {}", err);
let status = err.status();
route_to_fairing_error_handler(self, status.code, request);
InjectedHeader::Failure
}
};
inject_request_header(injected_header, request);
}
fn on_response(&self, request: &Request, response: &mut rocket::Response) {

View File

@ -27,7 +27,7 @@
//! Add the following to Cargo.toml:
//!
//! ```toml
//! rocket_cors = "0.1.1"
//! rocket_cors = "0.1.2"
//! ```
//!
//! To use the latest `master` branch, for example:
@ -355,6 +355,12 @@ pub enum Error {
///
/// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state.
MissingCorsInRocketState,
/// The `on_response` handler of Fairing could not find the injected header from the Request.
/// Either some other fairing has removed it, or this is a bug.
MissingInjectedHeader,
/// The `on_response` handler of Fairing found an unknown injected header value from the
/// Request. Either some other fairing has modified it, or this is a bug.
UnknownInjectedHeader,
}
impl Error {
@ -363,7 +369,9 @@ impl Error {
Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed |
Error::HeadersNotAllowed => Status::Forbidden,
Error::CredentialsWithWildcardOrigin |
Error::MissingCorsInRocketState => Status::InternalServerError,
Error::MissingCorsInRocketState |
Error::MissingInjectedHeader |
Error::UnknownInjectedHeader => Status::InternalServerError,
_ => Status::BadRequest,
}
}
@ -395,6 +403,14 @@ impl error::Error for Error {
Error::MissingCorsInRocketState => {
"A CORS Request Guard was used, but no CORS Options was available in Rocket's state"
}
Error::MissingInjectedHeader => {
"The `on_response` handler of Fairing could not find the injected header from the \
Request. Either some other fairing has removed it, or this is a bug."
}
Error::UnknownInjectedHeader => {
"The `on_response` handler of Fairing found an unknown injected header value from \
the Request. Either some other fairing has modified it, or this is a bug."
}
}
}

View File

@ -83,6 +83,11 @@ fn smoke_test() {
let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
}
#[test]
@ -107,6 +112,12 @@ fn cors_options_check() {
let response = req.dispatch();
assert!(response.status().class().is_success());
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
}
#[test]
@ -124,6 +135,12 @@ fn cors_get_check() {
assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
}
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
@ -182,6 +199,13 @@ fn cors_options_missing_origin() {
let response = req.dispatch();
assert_eq!(response.status(), Status::NotFound);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}
#[test]
@ -206,6 +230,12 @@ fn cors_options_bad_request_method() {
let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}
#[test]
@ -229,6 +259,12 @@ fn cors_options_bad_request_header() {
let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}
#[test]
@ -243,6 +279,12 @@ fn cors_get_bad_origin() {
let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}
/// This test ensures that on a failing CORS request, the route (along with its side effects)
@ -270,4 +312,10 @@ fn routes_failing_checks_are_not_executed() {
let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}

View File

@ -121,6 +121,11 @@ fn smoke_test() {
let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
}
#[test]
@ -146,6 +151,12 @@ fn cors_options_check() {
let response = req.dispatch();
assert!(response.status().class().is_success());
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
}
#[test]
@ -164,6 +175,12 @@ fn cors_get_check() {
assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com/", origin_header);
}
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
@ -179,6 +196,12 @@ fn cors_get_no_origin() {
assert!(response.status().class().is_success());
let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello CORS".to_string()));
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}
#[test]
@ -204,6 +227,12 @@ fn cors_options_bad_origin() {
let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}
#[test]
@ -224,6 +253,12 @@ fn cors_options_missing_origin() {
let response = req.dispatch();
assert!(response.status().class().is_success());
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}
#[test]
@ -249,6 +284,12 @@ fn cors_options_bad_request_method() {
let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}
#[test]
@ -273,6 +314,12 @@ fn cors_options_bad_request_header() {
let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}
#[test]
@ -288,6 +335,12 @@ fn cors_get_bad_origin() {
let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}
/// This test ensures that on a failing CORS request, the route (along with its side effects)
@ -306,4 +359,10 @@ fn routes_failing_checks_are_not_executed() {
let response = req.dispatch();
assert_eq!(response.status(), Status::Forbidden);
assert!(
response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none()
);
}