Add serde pass between fairings

This commit is contained in:
Yong Wen Chua 2017-07-17 17:36:41 +08:00
parent c6403fcffd
commit 56de116595
4 changed files with 185 additions and 85 deletions

View File

@ -15,8 +15,10 @@ categories = ["web-programming"]
travis-ci = { repository = "lawliet89/rocket_cors" }
[dependencies]
base64 = "0.6.0"
log = "0.3"
rocket = "0.3"
rmp-serde = "0.13.4"
serde = "1.0"
serde_derive = "1.0"
unicase = "2.0"

156
src/fairing.rs Normal file
View File

@ -0,0 +1,156 @@
//! Fairing implementation
use base64;
use rocket::{self, Request, Outcome};
use rocket::http::{self, Status, Header};
use rmps;
use {Cors, Response, Error, build_cors_response};
static HEADER_NAME: &'static str = "ROCKET-CORS";
/// Type of the Request header the `on_request` fairing handler will inject into requests
/// for `on_response` to deal with
pub(crate) type CorsInjectedHeader = Result<Response, String>;
/// Route for Fairing error handling
pub(crate) fn fairing_error_route<'r>(
request: &'r Request,
_: rocket::Data,
) -> rocket::handler::Outcome<'r> {
let status = request.get_param::<u16>(0).unwrap_or_else(|e| {
error_!("Fairing Error Handling Route error: {:?}", e);
500
});
let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError);
Outcome::Failure(status)
}
/// Create a new `Route` for Fairing handling
fn fairing_route() -> rocket::Route {
rocket::Route::new(http::Method::Get, "/<status>", fairing_error_route)
}
/// Modifies a `Request` to route to Fairing error handler
fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Request) {
request.set_method(http::Method::Get);
request.set_uri(format!("{}/{}", options.fairing_route_base, status));
}
/// Inject `CorsInjectedHeader` into the request header
fn inject_request_header(
response: &CorsInjectedHeader,
request: &mut Request,
) -> Result<(), Error> {
let serialized = rmps::to_vec(response).map_err(Error::RmpSerializationError)?;
let base64 = base64::encode_config(&serialized, base64::URL_SAFE);
request.replace_header(Header::new(HEADER_NAME, base64));
Ok(())
}
/// Extract the injected `CorsInjectedHeader`
fn extract_request_header(request: &Request) -> Result<Option<CorsInjectedHeader>, Error> {
let header = match request.headers().get_one(HEADER_NAME) {
Some(header) => header,
None => return Ok(None),
};
let bytes = base64::decode_config(header, base64::URL_SAFE).map_err(
Error::Base64DecodeError,
)?;
let deserialized: CorsInjectedHeader = rmps::from_slice(&bytes).map_err(
Error::RmpDeserializationError,
)?;
Ok(Some(deserialized))
}
impl rocket::fairing::Fairing for Cors {
fn info(&self) -> rocket::fairing::Info {
rocket::fairing::Info {
name: "CORS",
kind: rocket::fairing::Kind::Attach | rocket::fairing::Kind::Request |
rocket::fairing::Kind::Response,
}
}
fn on_attach(&self, rocket: rocket::Rocket) -> Result<rocket::Rocket, rocket::Rocket> {
match self.validate() {
Ok(()) => {
Ok(rocket.mount(&self.fairing_route_base, vec![fairing_route()]))
}
Err(e) => {
error_!("Error attaching CORS fairing: {}", e);
Err(rocket)
}
}
}
fn on_request(&self, request: &mut Request, _: &rocket::Data) {
// Build and merge CORS response
// Type annotation is for sanity check
let cors_response = build_cors_response(self, request);
if let Err(ref err) = cors_response {
error_!("CORS Error: {}", err);
let status = err.status();
route_to_fairing_error_handler(self, status.code, request);
}
let cors_response = cors_response.map_err(|e| e.to_string());
if let Err(err) = inject_request_header(&cors_response, request) {
// Internal server error -- probably a bug
error_!(
"Fairing had an error injecting headers: {}\nThis might be a bug. Please report.",
err
);
let status = err.status();
route_to_fairing_error_handler(self, status.code, request);
}
}
fn on_response(&self, request: &Request, response: &mut rocket::Response) {
let header = match extract_request_header(request) {
Err(err) => {
// We have a bug
error_!(
"Fairing had an error extracting headers: {}\nThis might be a bug. \
Please report.",
err
);
// Let's respond with an internal server error
response.set_status(Status::InternalServerError);
let _ = response.take_body();
return;
}
Ok(header) => header,
};
let header = match header {
None => {
// This is not a CORS request
return;
}
Some(header) => header,
};
match header {
Err(_) => {
// We have dealt with this already
}
Ok(cors_response) => {
cors_response.merge(response);
// If this was an OPTIONS request and no route can be found, we should turn this
// into a HTTP 204 with no content body.
// This allows the user to not have to specify an OPTIONS route for everything.
//
// TODO: Is there anyway we can make this smarter? Only modify status codes for
// requests where an actual route exist?
if request.method() == http::Method::Options && request.route().is_none() {
response.set_status(Status::NoContent);
let _ = response.take_body();
}
}
}
}
}

View File

@ -93,8 +93,10 @@
#![cfg_attr(test, plugin(rocket_codegen))]
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
extern crate base64;
#[macro_use]
extern crate log;
extern crate rmp_serde as rmps;
#[macro_use]
extern crate rocket;
extern crate serde;
@ -115,6 +117,7 @@ extern crate serde_json;
#[cfg(test)]
#[macro_use]
mod test_macros;
mod fairing;
pub mod headers;
@ -130,7 +133,6 @@ use std::str::FromStr;
use rocket::{Outcome, State};
use rocket::http::{self, Status};
use rocket::fairing;
use rocket::request::{Request, FromRequest};
use rocket::response;
use serde::{Serialize, Deserialize};
@ -171,6 +173,12 @@ pub enum Error {
///
/// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state.
MissingCorsInRocketState,
/// An internal Base64 Decoding Error. This is likely a bug.
Base64DecodeError(base64::DecodeError),
/// An internal error serializing Rust MessagePack. This is likely a bug.
RmpSerializationError(rmps::encode::Error),
/// An internal error deserializing Rust MessagePack. This is likely a bug.
RmpDeserializationError(rmps::decode::Error),
}
impl Error {
@ -179,7 +187,8 @@ impl Error {
Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed |
Error::HeadersNotAllowed => Status::Forbidden,
Error::CredentialsWithWildcardOrigin |
Error::MissingCorsInRocketState => Status::InternalServerError,
Error::MissingCorsInRocketState |
Error::Base64DecodeError(_) => Status::InternalServerError,
_ => Status::BadRequest,
}
}
@ -211,12 +220,24 @@ impl error::Error for Error {
Error::MissingCorsInRocketState => {
"A CORS Request Guard was used, but no CORS Options was available in Rocket's state"
}
Error::Base64DecodeError(_) => {
"An internal Base64 Decoding Error. This is likely a bug."
}
Error::RmpSerializationError(_) => {
"An internal error serializing Rust MessagePack. This is likely a bug."
}
Error::RmpDeserializationError(_) => {
"An internal error deserializing Rust MessagePack. This is likely a bug."
}
}
}
fn cause(&self) -> Option<&error::Error> {
match *self {
Error::BadOrigin(ref e) => Some(e),
Error::Base64DecodeError(ref e) => Some(e),
Error::RmpSerializationError(ref e) => Some(e),
Error::RmpDeserializationError(ref e) => Some(e),
_ => Some(self),
}
}
@ -227,6 +248,9 @@ impl fmt::Display for Error {
match *self {
Error::BadOrigin(ref e) => fmt::Display::fmt(e, f),
Error::BadRequestMethod(ref e) => fmt::Debug::fmt(e, f),
Error::Base64DecodeError(ref e) => fmt::Display::fmt(e, f),
Error::RmpSerializationError(ref e) => fmt::Display::fmt(e, f),
Error::RmpDeserializationError(ref e) => fmt::Display::fmt(e, f),
_ => write!(f, "{}", error::Error::description(self)),
}
}
@ -520,88 +544,6 @@ impl Cors {
Ok(())
}
/// Create a new `Route` for Fairing handling
fn fairing_route(&self) -> rocket::Route {
rocket::Route::new(http::Method::Get, "/<status>", fairing_error_route)
}
/// Modifies a `Request` to route to Fairing error handler
fn route_to_fairing_error_handler(&self, status: u16, request: &mut Request) {
request.set_method(http::Method::Get);
request.set_uri(format!("{}/{}", self.fairing_route_base, status));
}
}
impl fairing::Fairing for Cors {
fn info(&self) -> fairing::Info {
fairing::Info {
name: "CORS",
kind: fairing::Kind::Attach | fairing::Kind::Request | fairing::Kind::Response,
}
}
fn on_attach(&self, rocket: rocket::Rocket) -> Result<rocket::Rocket, rocket::Rocket> {
match self.validate() {
Ok(()) => {
Ok(rocket.mount(&self.fairing_route_base, vec![self.fairing_route()]))
}
Err(e) => {
error_!("Error attaching CORS fairing: {}", e);
Err(rocket)
}
}
}
fn on_request(&self, request: &mut Request, _: &rocket::Data) {
// Build and merge CORS response
match build_cors_response(self, request) {
Err(err) => {
error_!("CORS Error: {}", err);
let status = err.status();
self.route_to_fairing_error_handler(status.code, request);
}
Ok(cors_response) => {
// TODO: How to pass response downstream?
let _ = cors_response;
}
};
}
fn on_response(&self, request: &Request, response: &mut rocket::Response) {
// Build and merge CORS response
match build_cors_response(self, request) {
Err(_) => {
// We have dealt with this already
}
Ok(cors_response) => {
cors_response.merge(response);
// If this was an OPTIONS request and no route can be found, we should turn this
// into a HTTP 204 with no content body.
// This allows the user to not have to specify an OPTIONS route for everything.
//
// TODO: Is there anyway we can make this smarter? Only modify status codes for
// requests where an actual route exist?
if request.method() == http::Method::Options && request.route().is_none() {
response.set_status(Status::NoContent);
let _ = response.take_body();
}
}
};
}
}
/// Route for Fairing error handling
fn fairing_error_route<'r>(request: &'r Request, _: rocket::Data) -> rocket::handler::Outcome<'r> {
let status = request.get_param::<u16>(0).unwrap_or_else(|e| {
error_!("Fairing Error Handling Route error: {:?}", e);
500
});
let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError);
Outcome::Failure(status)
}
/// A CORS Response which provides the following CORS headers:
@ -613,7 +555,7 @@ fn fairing_error_route<'r>(request: &'r Request, _: rocket::Data) -> rocket::han
/// - `Access-Control-Allow-Methods`
/// - `Access-Control-Allow-Headers`
/// - `Vary`
#[derive(Eq, PartialEq, Debug)]
#[derive(Serialize, Deserialize, Eq, PartialEq, Debug)]
struct Response {
allow_origin: Option<AllOrSome<String>>,
allow_methods: HashSet<Method>,