Fix an issue where Fairing on_response will inject CORS headers into failed CORS requests
This commit is contained in:
parent
539157e0f0
commit
fcd83e8fb5
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "rocket_cors"
|
||||
version = "0.1.1"
|
||||
version = "0.1.2"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Yong Wen Chua <me@yongwen.xyz>"]
|
||||
build = "build.rs"
|
||||
|
|
|
@ -29,7 +29,7 @@ might work, but it's not guaranteed.
|
|||
Add the following to Cargo.toml:
|
||||
|
||||
```toml
|
||||
rocket_cors = "0.1.1"
|
||||
rocket_cors = "0.1.2"
|
||||
```
|
||||
|
||||
To use the latest `master` branch, for example:
|
||||
|
|
|
@ -57,7 +57,10 @@ fn main() {
|
|||
};
|
||||
|
||||
rocket::ignite()
|
||||
.mount("/", routes![responder, responder_options, response, response_options])
|
||||
.mount(
|
||||
"/",
|
||||
routes![responder, responder_options, response, response_options],
|
||||
)
|
||||
.manage(options)
|
||||
.launch();
|
||||
}
|
||||
|
|
|
@ -1,9 +1,39 @@
|
|||
//! Fairing implementation
|
||||
use rocket::{self, Request, Outcome};
|
||||
use rocket::http::{self, Status};
|
||||
use rocket::http::{self, Status, Header};
|
||||
|
||||
use {Cors, Error, validate, preflight_response, actual_request_response, origin, request_headers};
|
||||
|
||||
/// An injected header to quickly give the result of CORS
|
||||
static CORS_HEADER: &str = "ROCKET-CORS";
|
||||
enum InjectedHeader {
|
||||
Success,
|
||||
Failure,
|
||||
}
|
||||
|
||||
impl InjectedHeader {
|
||||
fn to_str(&self) -> &'static str {
|
||||
match *self {
|
||||
InjectedHeader::Success => "Success",
|
||||
InjectedHeader::Failure => "Failure",
|
||||
}
|
||||
}
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Error> {
|
||||
match s {
|
||||
"Success" => Ok(InjectedHeader::Success),
|
||||
"Failure" => Ok(InjectedHeader::Failure),
|
||||
other => {
|
||||
error_!(
|
||||
"Unknown injected header encountered: {}\nThis is probably a bug.",
|
||||
other
|
||||
);
|
||||
Err(Error::UnknownInjectedHeader)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Route for Fairing error handling
|
||||
pub(crate) fn fairing_error_route<'r>(
|
||||
request: &'r Request,
|
||||
|
@ -28,6 +58,11 @@ fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Req
|
|||
request.set_uri(format!("{}/{}", options.fairing_route_base, status));
|
||||
}
|
||||
|
||||
/// Inject a header into the Request with result
|
||||
fn inject_request_header(header: InjectedHeader, request: &mut Request) {
|
||||
request.replace_header(Header::new(CORS_HEADER, header.to_str()));
|
||||
}
|
||||
|
||||
fn on_response_wrapper(
|
||||
options: &Cors,
|
||||
request: &Request,
|
||||
|
@ -41,6 +76,17 @@ fn on_response_wrapper(
|
|||
Some(origin) => origin,
|
||||
};
|
||||
|
||||
// Get validation result from injected header
|
||||
let injected_header = request.headers().get_one(CORS_HEADER).ok_or_else(|| {
|
||||
Error::MissingInjectedHeader
|
||||
})?;
|
||||
let result = InjectedHeader::from_str(injected_header)?;
|
||||
|
||||
if let InjectedHeader::Failure = result {
|
||||
// Nothing else for us to do
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let cors_response = if request.method() == http::Method::Options {
|
||||
let headers = request_headers(request)?;
|
||||
preflight_response(options, origin, headers)
|
||||
|
@ -87,13 +133,17 @@ impl rocket::fairing::Fairing for Cors {
|
|||
}
|
||||
|
||||
fn on_request(&self, request: &mut Request, _: &rocket::Data) {
|
||||
// Build and merge CORS response
|
||||
let cors_response = validate(self, request);
|
||||
if let Err(ref err) = cors_response {
|
||||
error_!("CORS Error: {}", err);
|
||||
let status = err.status();
|
||||
route_to_fairing_error_handler(self, status.code, request);
|
||||
}
|
||||
let injected_header = match validate(self, request) {
|
||||
Ok(_) => InjectedHeader::Success,
|
||||
Err(err) => {
|
||||
error_!("CORS Error: {}", err);
|
||||
let status = err.status();
|
||||
route_to_fairing_error_handler(self, status.code, request);
|
||||
InjectedHeader::Failure
|
||||
}
|
||||
};
|
||||
|
||||
inject_request_header(injected_header, request);
|
||||
}
|
||||
|
||||
fn on_response(&self, request: &Request, response: &mut rocket::Response) {
|
||||
|
|
20
src/lib.rs
20
src/lib.rs
|
@ -27,7 +27,7 @@
|
|||
//! Add the following to Cargo.toml:
|
||||
//!
|
||||
//! ```toml
|
||||
//! rocket_cors = "0.1.1"
|
||||
//! rocket_cors = "0.1.2"
|
||||
//! ```
|
||||
//!
|
||||
//! To use the latest `master` branch, for example:
|
||||
|
@ -355,6 +355,12 @@ pub enum Error {
|
|||
///
|
||||
/// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state.
|
||||
MissingCorsInRocketState,
|
||||
/// The `on_response` handler of Fairing could not find the injected header from the Request.
|
||||
/// Either some other fairing has removed it, or this is a bug.
|
||||
MissingInjectedHeader,
|
||||
/// The `on_response` handler of Fairing found an unknown injected header value from the
|
||||
/// Request. Either some other fairing has modified it, or this is a bug.
|
||||
UnknownInjectedHeader,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
|
@ -363,7 +369,9 @@ impl Error {
|
|||
Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed |
|
||||
Error::HeadersNotAllowed => Status::Forbidden,
|
||||
Error::CredentialsWithWildcardOrigin |
|
||||
Error::MissingCorsInRocketState => Status::InternalServerError,
|
||||
Error::MissingCorsInRocketState |
|
||||
Error::MissingInjectedHeader |
|
||||
Error::UnknownInjectedHeader => Status::InternalServerError,
|
||||
_ => Status::BadRequest,
|
||||
}
|
||||
}
|
||||
|
@ -395,6 +403,14 @@ impl error::Error for Error {
|
|||
Error::MissingCorsInRocketState => {
|
||||
"A CORS Request Guard was used, but no CORS Options was available in Rocket's state"
|
||||
}
|
||||
Error::MissingInjectedHeader => {
|
||||
"The `on_response` handler of Fairing could not find the injected header from the \
|
||||
Request. Either some other fairing has removed it, or this is a bug."
|
||||
}
|
||||
Error::UnknownInjectedHeader => {
|
||||
"The `on_response` handler of Fairing found an unknown injected header value from \
|
||||
the Request. Either some other fairing has modified it, or this is a bug."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -83,6 +83,11 @@ fn smoke_test() {
|
|||
let body_str = response.body().and_then(|body| body.into_string());
|
||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||
|
||||
let origin_header = response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.expect("to exist");
|
||||
assert_eq!("https://www.acme.com/", origin_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -107,6 +112,12 @@ fn cors_options_check() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert!(response.status().class().is_success());
|
||||
|
||||
let origin_header = response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.expect("to exist");
|
||||
assert_eq!("https://www.acme.com/", origin_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -124,6 +135,12 @@ fn cors_get_check() {
|
|||
assert!(response.status().class().is_success());
|
||||
let body_str = response.body().and_then(|body| body.into_string());
|
||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||
|
||||
let origin_header = response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.expect("to exist");
|
||||
assert_eq!("https://www.acme.com/", origin_header);
|
||||
}
|
||||
|
||||
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
|
||||
|
@ -182,6 +199,13 @@ fn cors_options_missing_origin() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert_eq!(response.status(), Status::NotFound);
|
||||
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -206,6 +230,12 @@ fn cors_options_bad_request_method() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert_eq!(response.status(), Status::Forbidden);
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -229,6 +259,12 @@ fn cors_options_bad_request_header() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert_eq!(response.status(), Status::Forbidden);
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -243,6 +279,12 @@ fn cors_get_bad_origin() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert_eq!(response.status(), Status::Forbidden);
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
/// This test ensures that on a failing CORS request, the route (along with its side effects)
|
||||
|
@ -270,4 +312,10 @@ fn routes_failing_checks_are_not_executed() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert_eq!(response.status(), Status::Forbidden);
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
|
|
@ -121,6 +121,11 @@ fn smoke_test() {
|
|||
let body_str = response.body().and_then(|body| body.into_string());
|
||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||
|
||||
let origin_header = response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.expect("to exist");
|
||||
assert_eq!("https://www.acme.com/", origin_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -146,6 +151,12 @@ fn cors_options_check() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert!(response.status().class().is_success());
|
||||
|
||||
let origin_header = response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.expect("to exist");
|
||||
assert_eq!("https://www.acme.com/", origin_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -164,6 +175,12 @@ fn cors_get_check() {
|
|||
assert!(response.status().class().is_success());
|
||||
let body_str = response.body().and_then(|body| body.into_string());
|
||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||
|
||||
let origin_header = response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.expect("to exist");
|
||||
assert_eq!("https://www.acme.com/", origin_header);
|
||||
}
|
||||
|
||||
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
|
||||
|
@ -179,6 +196,12 @@ fn cors_get_no_origin() {
|
|||
assert!(response.status().class().is_success());
|
||||
let body_str = response.body().and_then(|body| body.into_string());
|
||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -204,6 +227,12 @@ fn cors_options_bad_origin() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert_eq!(response.status(), Status::Forbidden);
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -224,6 +253,12 @@ fn cors_options_missing_origin() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert!(response.status().class().is_success());
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -249,6 +284,12 @@ fn cors_options_bad_request_method() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert_eq!(response.status(), Status::Forbidden);
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -273,6 +314,12 @@ fn cors_options_bad_request_header() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert_eq!(response.status(), Status::Forbidden);
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -288,6 +335,12 @@ fn cors_get_bad_origin() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert_eq!(response.status(), Status::Forbidden);
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
/// This test ensures that on a failing CORS request, the route (along with its side effects)
|
||||
|
@ -306,4 +359,10 @@ fn routes_failing_checks_are_not_executed() {
|
|||
|
||||
let response = req.dispatch();
|
||||
assert_eq!(response.status(), Status::Forbidden);
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
.get_one("Access-Control-Allow-Origin")
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue