diff --git a/src/lib.rs b/src/lib.rs index d98845f..3229ca7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -161,6 +161,22 @@ pub enum Error { /// /// This is a misconfiguration. Check the docuemntation for `Cors`. CredentialsWithWildcardOrigin, + /// A CORS Request Guard was used, but no CORS Options was available in Rocket's state + /// + /// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state. + MissingCorsInRocketState, +} + +impl Error { + fn status(&self) -> Status { + match *self { + Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed | + Error::HeadersNotAllowed => Status::Forbidden, + Error::CredentialsWithWildcardOrigin | + Error::MissingCorsInRocketState => Status::InternalServerError, + _ => Status::BadRequest, + } + } } impl error::Error for Error { @@ -186,6 +202,9 @@ impl error::Error for Error { "Credentials are allowed, but the Origin is set to \"*\". \ This is not allowed by W3C" } + Error::MissingCorsInRocketState => { + "A CORS Request Guard was used, but no CORS Options was available in Rocket's state" + } } } @@ -211,19 +230,14 @@ impl fmt::Display for Error { impl<'r> response::Responder<'r> for Error { fn respond_to(self, _: &Request) -> Result, Status> { error_!("CORS Error: {:?}", self); - Err(match self { - Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed | - Error::HeadersNotAllowed => Status::Forbidden, - Error::CredentialsWithWildcardOrigin => Status::InternalServerError, - _ => Status::BadRequest, - }) + Err(self.status()) } } /// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// /// `Default` is implemented for this enum and is `All`. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(untagged)] pub enum AllOrSome { /// Everything is allowed. Usually equivalent to the "*" value. @@ -274,9 +288,9 @@ impl AllOrSome> { } } -/// Responder generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS +/// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS /// -/// This struct can be used as Fairing for Rocket, or as an ad-hoc responder for any CORS requests. +/// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. /// /// You create a new copy of this struct by defining the configurations in the fields below. /// This struct can also be deserialized by serde. @@ -381,15 +395,6 @@ impl Default for Cors { } impl Cors { - /// Wrap any `Rocket::Response` and respond with CORS headers. - /// This is only used for ad-hoc route CORS response - fn respond<'a, 'r: 'a, R: response::Responder<'r>>( - &'a self, - responder: R, - ) -> Responder<'a, 'r, R> { - Responder::new(responder, self) - } - fn default_allowed_methods() -> HashSet { vec![ Method::Get, @@ -403,6 +408,17 @@ impl Cors { .collect() } + /// Build a CORS `Response` to an incoming request. + /// + /// The `Response` should be merged with an + /// existing `Rocket::Response` or `rocket::response::Responder`. + /// + /// This is only used for ad-hoc route CORS response + pub fn build<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result { + build_cors_response(self, request) + } + + /// Validates if any of the settings are disallowed or incorrect /// /// This is run during initial Fairing attachment @@ -473,61 +489,6 @@ impl fairing::Fairing for Cors { } } -/// A CORS [Responder](https://rocket.rs/guide/responses/#responder) -/// which will inspect the incoming requests and respond accordingly. -/// -/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set, -/// this responder will leave the response untouched. -/// This allows for chaining of several CORS responders. -/// -/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any -/// existing headers defined: -/// -/// - `Access-Control-Allow-Origin` -/// - `Access-Control-Expose-Headers` -/// - `Access-Control-Max-Age` -/// - `Access-Control-Allow-Credentials` -/// - `Access-Control-Allow-Methods` -/// - `Access-Control-Allow-Headers` -/// - `Vary` -#[derive(Debug)] -pub struct Responder<'a, 'r: 'a, R> { - responder: R, - options: &'a Cors, - marker: PhantomData>, -} - -impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> { - fn new(responder: R, options: &'a Cors) -> Self { - Self { - responder, - options, - marker: PhantomData, - } - } - - /// Respond to a request - fn respond(self, request: &Request) -> response::Result<'r> { - let mut response = self.responder.respond_to(request)?; // handle status errors? - - match build_cors_response(self.options, request) { - Ok(cors_response) => { - cors_response.merge(&mut response); - Ok(response) - }, - Err(e) => response::Responder::respond_to(e, request), - } - - } -} - -impl<'a, 'r: 'a, R: response::Responder<'r>> response::Responder<'r> for Responder<'a, 'r, R> { - fn respond_to(self, request: &Request) -> response::Result<'r> { - self.respond(request) - } -} - - /// A CORS Response which provides the following CORS headers: /// /// - `Access-Control-Allow-Origin` @@ -537,8 +498,8 @@ impl<'a, 'r: 'a, R: response::Responder<'r>> response::Responder<'r> for Respond /// - `Access-Control-Allow-Methods` /// - `Access-Control-Allow-Headers` /// - `Vary` -#[derive(Debug)] -struct Response { +#[derive(Eq, PartialEq, Debug)] +pub struct Response { allow_origin: Option>, allow_methods: HashSet, allow_headers: HeaderFieldNamesSet, @@ -549,7 +510,7 @@ struct Response { } impl Response { - /// Consumes the responder and return an empty `Response` + /// Create an empty `Response` fn new() -> Self { Self { allow_origin: None, @@ -609,11 +570,17 @@ impl Response { self } - /// Builds a `rocket::Response` from this struct based off some base `rocket::Response` + /// Consumes the `Response` and return a `Responder` that wraps a + /// provided `rocket:response::Responder` with CORS headers + pub fn responder<'r, R: response::Responder<'r>>(self, responder: R) -> Responder<'r, R> { + Responder::new(responder, self) + } + + /// Merge a `rocket::Response` with this CORS response. This is usually used in the final step + /// of a route to return a value for the route. /// /// This will overwrite any existing CORS headers - #[cfg(test)] - fn build<'r>(&self, base: response::Response<'r>) -> response::Response<'r> { + pub fn respond<'r>(&self, base: response::Response<'r>) -> response::Response<'r> { let mut response = response::Response::build_from(base).finalize(); self.merge(&mut response); response @@ -691,27 +658,79 @@ impl Response { response.remove_header("Vary"); } } + + /// Validate and create a new CORS Response from a request and settings + pub fn build_cors_response<'a, 'r>( + options: &'a Cors, + request: &'a Request<'r>, + ) -> Result { + build_cors_response(options, request) + } } -/// Ad-hoc per route CORS response to requests +impl<'a, 'r> FromRequest<'a, 'r> for Response { + type Error = Error; + + fn from_request(request: &'a Request<'r>) -> rocket::request::Outcome { + let options = match request.guard::>() { + Outcome::Success(options) => options, + _ => { + let error = Error::MissingCorsInRocketState; + return Outcome::Failure((error.status(), error)); + } + }; + + match Self::build_cors_response(&options, request) { + Ok(response) => Outcome::Success(response), + Err(error) => Outcome::Failure((error.status(), error)), + } + } +} + +/// A [`Responder`](https://rocket.rs/guide/responses/#responder) which will simply wraps another +/// `Responder` with CORS headers. /// -/// Note: If you use this, the lifetime parameter `'r` of your `rocket:::response::Responder<'r>` -/// CANNOT be `'static`. This is because the code generated by Rocket will implicitly try to -/// to restrain the `Request` object passed to the route to `&'static Request`, and it is not -/// possible to have such a reference. -/// See [this PR on Rocket](https://github.com/SergioBenitez/Rocket/pull/345). -pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>( - options: State<'a, Cors>, +/// The following CORS headers will be overwritten: +/// +/// - `Access-Control-Allow-Origin` +/// - `Access-Control-Expose-Headers` +/// - `Access-Control-Max-Age` +/// - `Access-Control-Allow-Credentials` +/// - `Access-Control-Allow-Methods` +/// - `Access-Control-Allow-Headers` +/// - `Vary` +#[derive(Debug)] +pub struct Responder<'r, R> { responder: R, -) -> Responder<'a, 'r, R> { - options.inner().respond(responder) + cors_response: Response, + marker: PhantomData>, +} + +impl<'r, R: response::Responder<'r>> Responder<'r, R> { + fn new(responder: R, cors_response: Response) -> Self { + Self { + responder, + cors_response, + marker: PhantomData, + } + } + + /// Respond to a request + fn respond(self, request: &Request) -> response::Result<'r> { + let mut response = self.responder.respond_to(request)?; // handle status errors? + self.cors_response.merge(&mut response); + Ok(response) + } +} + +impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R> { + fn respond_to(self, request: &Request) -> response::Result<'r> { + self.respond(request) + } } /// Build a CORS response and merge with an existing `rocket::Response` for the request -fn build_cors_response( - options: &Cors, - request: &Request, -) -> Result { +fn build_cors_response(options: &Cors, request: &Request) -> Result { // Existing CORS response? // if has_allow_origin(response) { // return Ok(()); @@ -1072,7 +1091,7 @@ mod tests { let response = response.exposed_headers(&headers); // Build response and check built response header - let response = response.build(response::Response::new()); + let response = response.respond(response::Response::new()); let actual_header: Vec<_> = response .headers() .get("Access-Control-Expose-Headers") @@ -1096,7 +1115,7 @@ mod tests { // Build response and check built response header let expected_header = vec!["42"]; - let response = response.build(response::Response::new()); + let response = response.respond(response::Response::new()); let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect(); assert_eq!(expected_header, actual_header); } @@ -1109,7 +1128,7 @@ mod tests { let response = response.max_age(None); // Build response and check built response header - let response = response.build(response::Response::new()); + let response = response.respond(response::Response::new()); assert!( response .headers() @@ -1198,7 +1217,7 @@ mod tests { #[test] fn response_does_not_build_if_origin_is_not_set() { let response = Response::new(); - let response = response.build(response::Response::new()); + let response = response.respond(response::Response::new()); let headers: Vec<_> = response.headers().iter().collect(); assert_eq!(headers.len(), 0); @@ -1217,7 +1236,7 @@ mod tests { let response = Response::new(); let response = response.origin("https://www.example.com", false); - let response = response.build(original); + let response = response.respond(original); // Check CORS header let expected_header = vec!["https://www.example.com"]; let actual_header: Vec<_> = response diff --git a/tests/routes.rs b/tests/ad_hoc.rs similarity index 92% rename from tests/routes.rs rename to tests/ad_hoc.rs index 997382d..dae5f93 100644 --- a/tests/routes.rs +++ b/tests/ad_hoc.rs @@ -4,34 +4,33 @@ #![plugin(rocket_codegen)] extern crate hyper; extern crate rocket; -extern crate rocket_cors; +extern crate rocket_cors as cors; use std::str::FromStr; -use rocket::State; use rocket::http::Method; use rocket::http::{Header, Status}; use rocket::local::Client; -use rocket_cors::*; #[options("/")] -fn cors_options(options: State) -> Responder<&str> { - rocket_cors::respond(options, "") +fn cors_options<'a>(cors: cors::Response) -> cors::Responder<'a, &'a str> { + cors.responder("") } #[get("/")] -fn cors(options: State) -> Responder<&str> { - rocket_cors::respond(options, "Hello CORS") +fn cors<'a>(cors: cors::Response) -> cors::Responder<'a, &'a str> { + cors.responder("Hello CORS") } -fn make_cors_options() -> Cors { - let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); +fn make_cors_options() -> cors::Cors { + let (allowed_origins, failed_origins) = + cors::AllOrSome::new_from_str_list(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); - Cors { + cors::Cors { allowed_origins: allowed_origins, allowed_methods: [Method::Get].iter().cloned().collect(), - allowed_headers: AllOrSome::Some( + allowed_headers: cors::AllOrSome::Some( ["Authorization"] .into_iter() .map(|s| s.to_string().into()) @@ -44,12 +43,13 @@ fn make_cors_options() -> Cors { #[test] fn smoke_test() { - let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); + let (allowed_origins, failed_origins) = + cors::AllOrSome::new_from_str_list(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); - let cors_options = rocket_cors::Cors { + let cors_options = cors::Cors { allowed_origins: allowed_origins, allowed_methods: [Method::Get].iter().cloned().collect(), - allowed_headers: AllOrSome::Some( + allowed_headers: cors::AllOrSome::Some( ["Authorization"] .iter() .map(|s| s.to_string().into())