diff --git a/Cargo.toml b/Cargo.toml index 7a38136..0bb6cd3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,8 @@ log = "0.3" rocket = "0.3" serde = "1.0" serde_derive = "1.0" -unicase="1.4" +unicase = "2.0" +unicase_serde = "0.1.0" url = "1.5.1" url_serde = "0.2.0" @@ -31,3 +32,4 @@ version_check = "0.1" hyper = "0.10" rocket_codegen = "0.3" serde_json = "1.0" +serde_test = "1.0" diff --git a/src/headers.rs b/src/headers.rs index 0ac7d1e..b85e730 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -6,14 +6,56 @@ use std::ops::Deref; use std::str::FromStr; use rocket::{self, Outcome}; -use rocket::http::{Method, Status}; +use rocket::http::Status; use rocket::request::{self, FromRequest}; use unicase::UniCase; +use unicase_serde; use url; use url_serde; -pub(crate) type HeaderFieldName = UniCase; -pub(crate) type HeaderFieldNamesSet = HashSet; +/// A case insensitive header name +#[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug, Hash)] +pub struct HeaderFieldName( + #[serde(with = "unicase_serde::unicase")] + UniCase +); + +impl Deref for HeaderFieldName { + type Target = String; + + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +impl fmt::Display for HeaderFieldName { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl<'a> From<&'a str> for HeaderFieldName { + fn from(s: &'a str) -> Self { + HeaderFieldName(From::from(s)) + } +} + +impl<'a> From for HeaderFieldName { + fn from(s: String) -> Self { + HeaderFieldName(From::from(s)) + } +} + +impl FromStr for HeaderFieldName { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + Ok(HeaderFieldName(FromStr::from_str(s)?)) + } +} + +/// A set of case insensitive header names +pub type HeaderFieldNamesSet = HashSet; /// A wrapped `url::Url` to allow for deserialization #[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)] @@ -72,13 +114,13 @@ pub type Origin = Url; /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// to ensure that the header is passed in correctly. #[derive(Debug)] -pub struct AccessControlRequestMethod(pub Method); +pub struct AccessControlRequestMethod(pub ::Method); impl FromStr for AccessControlRequestMethod { type Err = rocket::Error; fn from_str(method: &str) -> Result { - Ok(AccessControlRequestMethod(Method::from_str(method)?)) + Ok(AccessControlRequestMethod(::Method::from_str(method)?)) } } @@ -117,7 +159,7 @@ impl FromStr for AccessControlRequestHeaders { let set: HeaderFieldNamesSet = headers .split(',') - .map(|header| UniCase(header.trim().to_string())) + .map(|header| From::from(header.trim().to_string())) .collect(); Ok(AccessControlRequestHeaders(set)) } @@ -149,7 +191,6 @@ mod tests { use hyper; use rocket; use rocket::local::Client; - use rocket::http::Method; use super::*; @@ -194,11 +235,17 @@ mod tests { fn request_method_conversion() { let method = "POST"; let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); - assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post)); + assert_matches!( + parsed_method, + AccessControlRequestMethod(::Method(rocket::http::Method::Post)) + ); let method = "options"; let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); - assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options)); + assert_matches!( + parsed_method, + AccessControlRequestMethod(::Method(rocket::http::Method::Options)) + ); let method = "INVALID"; let _ = is_err!(AccessControlRequestMethod::from_str(method)); diff --git a/src/lib.rs b/src/lib.rs index d359bf5..3fc236e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,15 +97,20 @@ extern crate log; #[macro_use] extern crate rocket; -// extern crate serde; +extern crate serde; #[macro_use] extern crate serde_derive; extern crate unicase; +extern crate unicase_serde; extern crate url; extern crate url_serde; #[cfg(test)] extern crate hyper; +#[cfg(test)] +extern crate serde_test; +#[cfg(test)] +extern crate serde_json; #[cfg(test)] #[macro_use] @@ -124,10 +129,11 @@ use std::ops::Deref; use std::str::FromStr; use rocket::{Outcome, State}; +use rocket::http::{self, Status}; use rocket::fairing; -use rocket::http::{Method, Status}; use rocket::request::{Request, FromRequest}; use rocket::response; +use serde::{Serialize, Deserialize}; use headers::{HeaderFieldName, HeaderFieldNamesSet, Origin, AccessControlRequestHeaders, AccessControlRequestMethod}; @@ -287,6 +293,78 @@ impl AllOrSome> { } } +/// A wrapper type around `rocket::http::Method` to support serialization and deserialization +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct Method(http::Method); + +impl FromStr for Method { + type Err = rocket::Error; + + fn from_str(s: &str) -> Result { + let method = http::Method::from_str(s)?; + Ok(Method(method)) + } +} + +impl Deref for Method { + type Target = http::Method; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for Method { + fn from(method: http::Method) -> Self { + Method(method) + } +} + +impl fmt::Display for Method { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl Serialize for Method { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for Method { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::{self, Visitor}; + + struct MethodVisitor; + impl<'de> Visitor<'de> for MethodVisitor { + type Value = Method; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string containing a HTTP Verb") + } + + fn visit_str(self, s: &str) -> Result + where + E: de::Error, + { + match Self::Value::from_str(s) { + Ok(value) => Ok(value), + Err(e) => Err(de::Error::custom(format!("{:?}", e))), + } + } + } + + deserializer.deserialize_string(MethodVisitor) + } +} + /// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS /// /// This struct can be as Fairing or in an ad-hoc manner to generate CORS response. @@ -296,7 +374,7 @@ impl AllOrSome> { /// /// [`Default`](https://doc.rust-lang.org/std/default/trait.Default.html) is implemented for this /// struct. The default for each field is described in the docuementation for the field. -#[derive(Clone, Debug)] +#[derive(Eq, PartialEq, Serialize, Deserialize, Clone, Debug)] pub struct Cors { /// Origins that are allowed to make requests. /// Will be verified against the `Origin` request header. @@ -314,7 +392,7 @@ pub struct Cors { /// Defaults to `All`. /// /// ``` - // #[serde(default)] + #[serde(default)] pub allowed_origins: AllOrSome>, /// The list of methods which the allowed origins are allowed to access for /// non-simple requests. @@ -323,7 +401,7 @@ pub struct Cors { /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` - // #[serde(default = "Cors::default_allowed_methods")] + #[serde(default = "Cors::default_allowed_methods")] pub allowed_methods: HashSet, /// The list of header field names which can be used when this resource is accessed by allowed /// origins. @@ -335,7 +413,7 @@ pub struct Cors { /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// /// Defaults to `All`. - // #[serde(default)] + #[serde(default)] pub allowed_headers: AllOrSome>, /// Allows users to make authenticated requests. /// If true, injects the `Access-Control-Allow-Credentials` header in responses. @@ -346,7 +424,7 @@ pub struct Cors { /// in an `Error::CredentialsWithWildcardOrigin` error during Rocket launch or runtime. /// /// Defaults to `false`. - // #[serde(default)] + #[serde(default)] pub allow_credentials: bool, /// The list of headers which are safe to expose to the API of a CORS API specification. /// This corresponds to the `Access-Control-Expose-Headers` responde header. @@ -355,13 +433,13 @@ pub struct Cors { /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// /// This defaults to an empty set. - // #[serde(default)] + #[serde(default)] pub expose_headers: HashSet, /// The maximum time for which this CORS request maybe cached. This value is set as the /// `Access-Control-Max-Age` header. /// /// This defaults to `None` (unset). - // #[serde(default)] + #[serde(default)] pub max_age: Option, /// If true, and the `allowed_origins` parameter is `All`, a wildcard /// `Access-Control-Allow-Origin` response header is sent, rather than the request’s @@ -375,14 +453,14 @@ pub struct Cors { /// in an `Error::CredentialsWithWildcardOrigin` error during Rocket launch or runtime. /// /// Defaults to `false`. - // #[serde(default)] + #[serde(default)] pub send_wildcard: bool, /// When used as Fairing, Cors will need to redirect failed CORS checks to a custom route to /// be mounted by the fairing. Specify the base the route so that it doesn't clash with any /// of your existing routes. /// /// Defaults to "/cors" - // #[serde(default = "Cors::default_fairing_route_base")] + #[serde(default = "Cors::default_fairing_route_base")] pub fairing_route_base: String, } @@ -403,6 +481,8 @@ impl Default for Cors { impl Cors { fn default_allowed_methods() -> HashSet { + use rocket::http::Method; + vec![ Method::Get, Method::Head, @@ -412,6 +492,7 @@ impl Cors { Method::Patch, Method::Delete, ].into_iter() + .map(From::from) .collect() } @@ -442,12 +523,12 @@ impl Cors { /// Create a new `Route` for Fairing handling fn fairing_route(&self) -> rocket::Route { - rocket::Route::new(Method::Get, "/", fairing_error_route) + rocket::Route::new(http::Method::Get, "/", fairing_error_route) } /// Modifies a `Request` to route to Fairing error handler fn route_to_fairing_error_handler(&self, status: u16, request: &mut Request) { - request.set_method(Method::Get); + request.set_method(http::Method::Get); request.set_uri(format!("{}/{}", self.fairing_route_base, status)); } } @@ -502,7 +583,7 @@ impl fairing::Fairing for Cors { // // TODO: Is there anyway we can make this smarter? Only modify status codes for // requests where an actual route exist? - if request.method() == Method::Options && request.route().is_none() { + if request.method() == http::Method::Options && request.route().is_none() { response.set_status(Status::NoContent); let _ = response.take_body(); } @@ -817,7 +898,7 @@ fn build_cors_response(options: &Cors, request: &Request) -> Result { + http::Method::Options => { let method = request_method(request)?; let headers = request_headers(request)?; preflight(options, origin, method, headers) @@ -1074,7 +1155,8 @@ fn actual_request(options: &Cors, origin: Origin) -> Result { #[allow(unmounted_route)] mod tests { use std::str::FromStr; - use rocket::http::Method; + use serde_json; + use http::Method; use super::*; fn make_cors_options() -> Cors { @@ -1084,7 +1166,10 @@ mod tests { Cors { allowed_origins: allowed_origins, - allowed_methods: [Method::Get].iter().cloned().collect(), + allowed_methods: vec![http::Method::Get] + .into_iter() + .map(From::from) + .collect(), allowed_headers: AllOrSome::Some( ["Authorization"] .into_iter() @@ -1096,6 +1181,8 @@ mod tests { } } + // CORS options test + #[test] fn cors_is_validated() { assert!(make_cors_options().validate().is_ok()) @@ -1112,6 +1199,13 @@ mod tests { cors.validate().unwrap(); } + /// Check that the the default deserialization matches the one returned by `Default::default` + #[test] + fn cors_default_deserialization_is_correct() { + let deserialized: Cors = serde_json::from_str("{}").expect("To not fail"); + assert_eq!(deserialized, Cors::default()); + } + // The following tests check validation #[test] @@ -1205,6 +1299,7 @@ mod tests { fn allowed_methods_validated_correctly() { let allowed_methods = vec![Method::Get, Method::Head, Method::Post] .into_iter() + .map(From::from) .collect(); let method = "GET"; @@ -1220,6 +1315,7 @@ mod tests { fn allowed_methods_errors_on_disallowed_method() { let allowed_methods = vec![Method::Get, Method::Head, Method::Post] .into_iter() + .map(From::from) .collect(); let method = "DELETE"; @@ -1325,6 +1421,31 @@ mod tests { } + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct MethodTest { + method: ::Method, + } + + #[test] + fn method_serde_roundtrip() { + use serde_test::{Token, assert_tokens}; + + let test = MethodTest { method: From::from(http::Method::Get) }; + + assert_tokens( + &test, + &[ + Token::Struct { + name: "MethodTest", + len: 1, + }, + Token::Str("method"), + Token::Str("GET"), + Token::StructEnd, + ], + ); + } + // TODO: Preflight tests // TODO: Actual requests tests diff --git a/tests/ad_hoc.rs b/tests/ad_hoc.rs index 47aeccf..22cdfe1 100644 --- a/tests/ad_hoc.rs +++ b/tests/ad_hoc.rs @@ -66,7 +66,7 @@ fn make_cors_options() -> cors::Cors { cors::Cors { allowed_origins: allowed_origins, - allowed_methods: [Method::Get].iter().cloned().collect(), + allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: cors::AllOrSome::Some( ["Authorization"] .into_iter() diff --git a/tests/fairings.rs b/tests/fairings.rs index 641f4ef..e223bef 100644 --- a/tests/fairings.rs +++ b/tests/fairings.rs @@ -29,7 +29,7 @@ fn make_cors_options() -> Cors { Cors { allowed_origins: allowed_origins, - allowed_methods: [Method::Get].iter().cloned().collect(), + allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllOrSome::Some( ["Authorization"] .into_iter() @@ -42,9 +42,9 @@ fn make_cors_options() -> Cors { } fn rocket() -> rocket::Rocket { - rocket::ignite().mount("/", routes![cors, panicking_route]).attach( - make_cors_options(), - ) + rocket::ignite() + .mount("/", routes![cors, panicking_route]) + .attach(make_cors_options()) } #[test]