diff --git a/Cargo.toml b/Cargo.toml index d7810ff..21928f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,10 +15,8 @@ categories = ["web-programming"] travis-ci = { repository = "lawliet89/rocket_cors" } [dependencies] -base64 = "0.6.0" log = "0.3" rocket = "0.3" -rmp-serde = "0.13.4" serde = "1.0" serde_derive = "1.0" unicase = "2.0" diff --git a/src/fairing.rs b/src/fairing.rs index 212c888..5e85d87 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -1,16 +1,8 @@ //! Fairing implementation -use base64; use rocket::{self, Request, Outcome}; -use rocket::http::{self, Status, Header}; -use rmps; +use rocket::http::{self, Status}; -use {Cors, Response, Error, build_cors_response}; - -static HEADER_NAME: &'static str = "ROCKET-CORS"; - -/// Type of the Request header the `on_request` fairing handler will inject into requests -/// for `on_response` to deal with -pub(crate) type CorsInjectedHeader = Result; +use {Cors, build_cors_response}; /// Route for Fairing error handling pub(crate) fn fairing_error_route<'r>( @@ -36,33 +28,6 @@ fn route_to_fairing_error_handler(options: &Cors, status: u16, request: &mut Req request.set_uri(format!("{}/{}", options.fairing_route_base, status)); } -/// Inject `CorsInjectedHeader` into the request header -fn inject_request_header( - response: &CorsInjectedHeader, - request: &mut Request, -) -> Result<(), Error> { - let serialized = rmps::to_vec(response).map_err(Error::RmpSerializationError)?; - let base64 = base64::encode_config(&serialized, base64::URL_SAFE); - request.replace_header(Header::new(HEADER_NAME, base64)); - Ok(()) -} - -/// Extract the injected `CorsInjectedHeader` -fn extract_request_header(request: &Request) -> Result, Error> { - let header = match request.headers().get_one(HEADER_NAME) { - Some(header) => header, - None => return Ok(None), - }; - - let bytes = base64::decode_config(header, base64::URL_SAFE).map_err( - Error::Base64DecodeError, - )?; - let deserialized: CorsInjectedHeader = rmps::from_slice(&bytes).map_err( - Error::RmpDeserializationError, - )?; - Ok(Some(deserialized)) -} - impl rocket::fairing::Fairing for Cors { fn info(&self) -> rocket::fairing::Info { rocket::fairing::Info { @@ -93,47 +58,11 @@ impl rocket::fairing::Fairing for Cors { let status = err.status(); route_to_fairing_error_handler(self, status.code, request); } - - let cors_response = cors_response.map_err(|e| e.to_string()); - - if let Err(err) = inject_request_header(&cors_response, request) { - // Internal server error -- probably a bug - error_!( - "Fairing had an error injecting headers: {}\nThis might be a bug. Please report.", - err - ); - let status = err.status(); - route_to_fairing_error_handler(self, status.code, request); - } } fn on_response(&self, request: &Request, response: &mut rocket::Response) { - let header = match extract_request_header(request) { - Err(err) => { - // We have a bug - error_!( - "Fairing had an error extracting headers: {}\nThis might be a bug. \ - Please report.", - err - ); - - // Let's respond with an internal server error - response.set_status(Status::InternalServerError); - let _ = response.take_body(); - return; - } - Ok(header) => header, - }; - - let header = match header { - None => { - // This is not a CORS request - return; - } - Some(header) => header, - }; - - match header { + // Rebuild the response + match build_cors_response(self, request) { Err(_) => { // We have dealt with this already } diff --git a/src/lib.rs b/src/lib.rs index a4309ef..ec1e799 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,10 +93,8 @@ #![cfg_attr(test, plugin(rocket_codegen))] #![doc(test(attr(allow(unused_variables), deny(warnings))))] -extern crate base64; #[macro_use] extern crate log; -extern crate rmp_serde as rmps; #[macro_use] extern crate rocket; extern crate serde; @@ -173,12 +171,6 @@ pub enum Error { /// /// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state. MissingCorsInRocketState, - /// An internal Base64 Decoding Error. This is likely a bug. - Base64DecodeError(base64::DecodeError), - /// An internal error serializing Rust MessagePack. This is likely a bug. - RmpSerializationError(rmps::encode::Error), - /// An internal error deserializing Rust MessagePack. This is likely a bug. - RmpDeserializationError(rmps::decode::Error), } impl Error { @@ -187,8 +179,7 @@ impl Error { Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed | Error::HeadersNotAllowed => Status::Forbidden, Error::CredentialsWithWildcardOrigin | - Error::MissingCorsInRocketState | - Error::Base64DecodeError(_) => Status::InternalServerError, + Error::MissingCorsInRocketState => Status::InternalServerError, _ => Status::BadRequest, } } @@ -220,24 +211,12 @@ impl error::Error for Error { Error::MissingCorsInRocketState => { "A CORS Request Guard was used, but no CORS Options was available in Rocket's state" } - Error::Base64DecodeError(_) => { - "An internal Base64 Decoding Error. This is likely a bug." - } - Error::RmpSerializationError(_) => { - "An internal error serializing Rust MessagePack. This is likely a bug." - } - Error::RmpDeserializationError(_) => { - "An internal error deserializing Rust MessagePack. This is likely a bug." - } } } fn cause(&self) -> Option<&error::Error> { match *self { Error::BadOrigin(ref e) => Some(e), - Error::Base64DecodeError(ref e) => Some(e), - Error::RmpSerializationError(ref e) => Some(e), - Error::RmpDeserializationError(ref e) => Some(e), _ => Some(self), } } @@ -248,9 +227,6 @@ impl fmt::Display for Error { match *self { Error::BadOrigin(ref e) => fmt::Display::fmt(e, f), Error::BadRequestMethod(ref e) => fmt::Debug::fmt(e, f), - Error::Base64DecodeError(ref e) => fmt::Display::fmt(e, f), - Error::RmpSerializationError(ref e) => fmt::Display::fmt(e, f), - Error::RmpDeserializationError(ref e) => fmt::Display::fmt(e, f), _ => write!(f, "{}", error::Error::description(self)), } } @@ -398,7 +374,7 @@ impl<'de> Deserialize<'de> for Method { /// /// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this /// struct. The default for each field is described in the docuementation for the field. -#[derive(Eq, PartialEq, Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)] pub struct Cors { /// Origins that are allowed to make requests. /// Will be verified against the `Origin` request header. @@ -555,7 +531,7 @@ impl Cors { /// - `Access-Control-Allow-Methods` /// - `Access-Control-Allow-Headers` /// - `Vary` -#[derive(Serialize, Deserialize, Eq, PartialEq, Debug)] +#[derive(Eq, PartialEq, Debug)] struct Response { allow_origin: Option>, allow_methods: HashSet,