diff --git a/Cargo.toml b/Cargo.toml index 21928f9..d7810ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,8 +15,10 @@ categories = ["web-programming"] travis-ci = { repository = "lawliet89/rocket_cors" } [dependencies] +base64 = "0.6.0" log = "0.3" rocket = "0.3" +rmp-serde = "0.13.4" serde = "1.0" serde_derive = "1.0" unicase = "2.0" diff --git a/src/fairing.rs b/src/fairing.rs new file mode 100644 index 0000000..212c888 --- /dev/null +++ b/src/fairing.rs @@ -0,0 +1,156 @@ +//! Fairing implementation +use base64; +use rocket::{self, Request, Outcome}; +use rocket::http::{self, Status, Header}; +use rmps; + +use {Cors, Response, Error, build_cors_response}; + +static HEADER_NAME: &'static str = "ROCKET-CORS"; + +/// Type of the Request header the `on_request` fairing handler will inject into requests +/// for `on_response` to deal with +pub(crate) type CorsInjectedHeader = Result; + +/// Route for Fairing error handling +pub(crate) fn fairing_error_route<'r>( + request: &'r Request, + _: rocket::Data, +) -> rocket::handler::Outcome<'r> { + let status = request.get_param::(0).unwrap_or_else(|e| { + error_!("Fairing Error Handling Route error: {:?}", e); + 500 + }); + let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError); + Outcome::Failure(status) +} + +/// Create a new `Route` for Fairing handling +fn fairing_route() -> rocket::Route { + rocket::Route::new(http::Method::Get, "/", fairing_error_route) +} + +/// Modifies a `Request` to route to Fairing error handler +fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Request) { + request.set_method(http::Method::Get); + request.set_uri(format!("{}/{}", options.fairing_route_base, status)); +} + +/// Inject `CorsInjectedHeader` into the request header +fn inject_request_header( + response: &CorsInjectedHeader, + request: &mut Request, +) -> Result<(), Error> { + let serialized = rmps::to_vec(response).map_err(Error::RmpSerializationError)?; + let base64 = base64::encode_config(&serialized, base64::URL_SAFE); + request.replace_header(Header::new(HEADER_NAME, base64)); + Ok(()) +} + +/// Extract the injected `CorsInjectedHeader` +fn extract_request_header(request: &Request) -> Result, Error> { + let header = match request.headers().get_one(HEADER_NAME) { + Some(header) => header, + None => return Ok(None), + }; + + let bytes = base64::decode_config(header, base64::URL_SAFE).map_err( + Error::Base64DecodeError, + )?; + let deserialized: CorsInjectedHeader = rmps::from_slice(&bytes).map_err( + Error::RmpDeserializationError, + )?; + Ok(Some(deserialized)) +} + +impl rocket::fairing::Fairing for Cors { + fn info(&self) -> rocket::fairing::Info { + rocket::fairing::Info { + name: "CORS", + kind: rocket::fairing::Kind::Attach | rocket::fairing::Kind::Request | + rocket::fairing::Kind::Response, + } + } + + fn on_attach(&self, rocket: rocket::Rocket) -> Result { + match self.validate() { + Ok(()) => { + Ok(rocket.mount(&self.fairing_route_base, vec![fairing_route()])) + } + Err(e) => { + error_!("Error attaching CORS fairing: {}", e); + Err(rocket) + } + } + } + + fn on_request(&self, request: &mut Request, _: &rocket::Data) { + // Build and merge CORS response + // Type annotation is for sanity check + let cors_response = build_cors_response(self, request); + if let Err(ref err) = cors_response { + error_!("CORS Error: {}", err); + let status = err.status(); + route_to_fairing_error_handler(self, status.code, request); + } + + let cors_response = cors_response.map_err(|e| e.to_string()); + + if let Err(err) = inject_request_header(&cors_response, request) { + // Internal server error -- probably a bug + error_!( + "Fairing had an error injecting headers: {}\nThis might be a bug. Please report.", + err + ); + let status = err.status(); + route_to_fairing_error_handler(self, status.code, request); + } + } + + fn on_response(&self, request: &Request, response: &mut rocket::Response) { + let header = match extract_request_header(request) { + Err(err) => { + // We have a bug + error_!( + "Fairing had an error extracting headers: {}\nThis might be a bug. \ + Please report.", + err + ); + + // Let's respond with an internal server error + response.set_status(Status::InternalServerError); + let _ = response.take_body(); + return; + } + Ok(header) => header, + }; + + let header = match header { + None => { + // This is not a CORS request + return; + } + Some(header) => header, + }; + + match header { + Err(_) => { + // We have dealt with this already + } + Ok(cors_response) => { + cors_response.merge(response); + + // If this was an OPTIONS request and no route can be found, we should turn this + // into a HTTP 204 with no content body. + // This allows the user to not have to specify an OPTIONS route for everything. + // + // TODO: Is there anyway we can make this smarter? Only modify status codes for + // requests where an actual route exist? + if request.method() == http::Method::Options && request.route().is_none() { + response.set_status(Status::NoContent); + let _ = response.take_body(); + } + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 3fc236e..a4309ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,8 +93,10 @@ #![cfg_attr(test, plugin(rocket_codegen))] #![doc(test(attr(allow(unused_variables), deny(warnings))))] +extern crate base64; #[macro_use] extern crate log; +extern crate rmp_serde as rmps; #[macro_use] extern crate rocket; extern crate serde; @@ -115,6 +117,7 @@ extern crate serde_json; #[cfg(test)] #[macro_use] mod test_macros; +mod fairing; pub mod headers; @@ -130,7 +133,6 @@ use std::str::FromStr; use rocket::{Outcome, State}; use rocket::http::{self, Status}; -use rocket::fairing; use rocket::request::{Request, FromRequest}; use rocket::response; use serde::{Serialize, Deserialize}; @@ -171,6 +173,12 @@ pub enum Error { /// /// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state. MissingCorsInRocketState, + /// An internal Base64 Decoding Error. This is likely a bug. + Base64DecodeError(base64::DecodeError), + /// An internal error serializing Rust MessagePack. This is likely a bug. + RmpSerializationError(rmps::encode::Error), + /// An internal error deserializing Rust MessagePack. This is likely a bug. + RmpDeserializationError(rmps::decode::Error), } impl Error { @@ -179,7 +187,8 @@ impl Error { Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed | Error::HeadersNotAllowed => Status::Forbidden, Error::CredentialsWithWildcardOrigin | - Error::MissingCorsInRocketState => Status::InternalServerError, + Error::MissingCorsInRocketState | + Error::Base64DecodeError(_) => Status::InternalServerError, _ => Status::BadRequest, } } @@ -211,12 +220,24 @@ impl error::Error for Error { Error::MissingCorsInRocketState => { "A CORS Request Guard was used, but no CORS Options was available in Rocket's state" } + Error::Base64DecodeError(_) => { + "An internal Base64 Decoding Error. This is likely a bug." + } + Error::RmpSerializationError(_) => { + "An internal error serializing Rust MessagePack. This is likely a bug." + } + Error::RmpDeserializationError(_) => { + "An internal error deserializing Rust MessagePack. This is likely a bug." + } } } fn cause(&self) -> Option<&error::Error> { match *self { Error::BadOrigin(ref e) => Some(e), + Error::Base64DecodeError(ref e) => Some(e), + Error::RmpSerializationError(ref e) => Some(e), + Error::RmpDeserializationError(ref e) => Some(e), _ => Some(self), } } @@ -227,6 +248,9 @@ impl fmt::Display for Error { match *self { Error::BadOrigin(ref e) => fmt::Display::fmt(e, f), Error::BadRequestMethod(ref e) => fmt::Debug::fmt(e, f), + Error::Base64DecodeError(ref e) => fmt::Display::fmt(e, f), + Error::RmpSerializationError(ref e) => fmt::Display::fmt(e, f), + Error::RmpDeserializationError(ref e) => fmt::Display::fmt(e, f), _ => write!(f, "{}", error::Error::description(self)), } } @@ -520,88 +544,6 @@ impl Cors { Ok(()) } - - /// Create a new `Route` for Fairing handling - fn fairing_route(&self) -> rocket::Route { - rocket::Route::new(http::Method::Get, "/", fairing_error_route) - } - - /// Modifies a `Request` to route to Fairing error handler - fn route_to_fairing_error_handler(&self, status: u16, request: &mut Request) { - request.set_method(http::Method::Get); - request.set_uri(format!("{}/{}", self.fairing_route_base, status)); - } -} - -impl fairing::Fairing for Cors { - fn info(&self) -> fairing::Info { - fairing::Info { - name: "CORS", - kind: fairing::Kind::Attach | fairing::Kind::Request | fairing::Kind::Response, - } - } - - fn on_attach(&self, rocket: rocket::Rocket) -> Result { - match self.validate() { - Ok(()) => { - Ok(rocket.mount(&self.fairing_route_base, vec![self.fairing_route()])) - } - Err(e) => { - error_!("Error attaching CORS fairing: {}", e); - Err(rocket) - } - } - } - - fn on_request(&self, request: &mut Request, _: &rocket::Data) { - // Build and merge CORS response - match build_cors_response(self, request) { - Err(err) => { - error_!("CORS Error: {}", err); - let status = err.status(); - self.route_to_fairing_error_handler(status.code, request); - } - Ok(cors_response) => { - // TODO: How to pass response downstream? - let _ = cors_response; - } - }; - } - - fn on_response(&self, request: &Request, response: &mut rocket::Response) { - // Build and merge CORS response - match build_cors_response(self, request) { - Err(_) => { - // We have dealt with this already - } - Ok(cors_response) => { - cors_response.merge(response); - - // If this was an OPTIONS request and no route can be found, we should turn this - // into a HTTP 204 with no content body. - // This allows the user to not have to specify an OPTIONS route for everything. - // - // TODO: Is there anyway we can make this smarter? Only modify status codes for - // requests where an actual route exist? - if request.method() == http::Method::Options && request.route().is_none() { - response.set_status(Status::NoContent); - let _ = response.take_body(); - } - } - }; - - - } -} - -/// Route for Fairing error handling -fn fairing_error_route<'r>(request: &'r Request, _: rocket::Data) -> rocket::handler::Outcome<'r> { - let status = request.get_param::(0).unwrap_or_else(|e| { - error_!("Fairing Error Handling Route error: {:?}", e); - 500 - }); - let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError); - Outcome::Failure(status) } /// A CORS Response which provides the following CORS headers: @@ -613,7 +555,7 @@ fn fairing_error_route<'r>(request: &'r Request, _: rocket::Data) -> rocket::han /// - `Access-Control-Allow-Methods` /// - `Access-Control-Allow-Headers` /// - `Vary` -#[derive(Eq, PartialEq, Debug)] +#[derive(Serialize, Deserialize, Eq, PartialEq, Debug)] struct Response { allow_origin: Option>, allow_methods: HashSet, diff --git a/tests/fairings.rs b/tests/fairing.rs similarity index 100% rename from tests/fairings.rs rename to tests/fairing.rs