diff --git a/Cargo.toml b/Cargo.toml index 7a38136..0cdf73b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,3 +31,4 @@ version_check = "0.1" hyper = "0.10" rocket_codegen = "0.3" serde_json = "1.0" +serde_test = "1.0" diff --git a/src/headers.rs b/src/headers.rs index 0ac7d1e..1cd71b0 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -6,7 +6,7 @@ use std::ops::Deref; use std::str::FromStr; use rocket::{self, Outcome}; -use rocket::http::{Method, Status}; +use rocket::http::Status; use rocket::request::{self, FromRequest}; use unicase::UniCase; use url; @@ -72,13 +72,13 @@ pub type Origin = Url; /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// to ensure that the header is passed in correctly. #[derive(Debug)] -pub struct AccessControlRequestMethod(pub Method); +pub struct AccessControlRequestMethod(pub ::Method); impl FromStr for AccessControlRequestMethod { type Err = rocket::Error; fn from_str(method: &str) -> Result { - Ok(AccessControlRequestMethod(Method::from_str(method)?)) + Ok(AccessControlRequestMethod(::Method::from_str(method)?)) } } @@ -149,7 +149,6 @@ mod tests { use hyper; use rocket; use rocket::local::Client; - use rocket::http::Method; use super::*; @@ -194,11 +193,17 @@ mod tests { fn request_method_conversion() { let method = "POST"; let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); - assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post)); + assert_matches!( + parsed_method, + AccessControlRequestMethod(::Method(rocket::http::Method::Post)) + ); let method = "options"; let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); - assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options)); + assert_matches!( + parsed_method, + AccessControlRequestMethod(::Method(rocket::http::Method::Options)) + ); let method = "INVALID"; let _ = is_err!(AccessControlRequestMethod::from_str(method)); diff --git a/src/lib.rs b/src/lib.rs index 3e48528..7361929 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,7 +97,7 @@ extern crate log; #[macro_use] extern crate rocket; -// extern crate serde; +extern crate serde; #[macro_use] extern crate serde_derive; extern crate unicase; @@ -106,6 +106,8 @@ extern crate url_serde; #[cfg(test)] extern crate hyper; +#[cfg(test)] +extern crate serde_test; #[cfg(test)] #[macro_use] @@ -124,9 +126,10 @@ use std::ops::Deref; use std::str::FromStr; use rocket::{Outcome, State}; -use rocket::http::{Method, Status}; +use rocket::http::{self, Status}; use rocket::request::{Request, FromRequest}; use rocket::response; +use serde::{Serialize, Deserialize}; use headers::{HeaderFieldName, HeaderFieldNamesSet, Origin, AccessControlRequestHeaders, AccessControlRequestMethod}; @@ -258,6 +261,78 @@ impl AllOrSome> { } } +/// A newtype wrapper around `rocket::http::Method` to allow for serde deserialization +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct Method(http::Method); + +impl FromStr for Method { + type Err = rocket::Error; + + fn from_str(s: &str) -> Result { + let method = http::Method::from_str(s)?; + Ok(Method(method)) + } +} + +impl Deref for Method { + type Target = http::Method; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for Method { + fn from(method: http::Method) -> Self { + Method(method) + } +} + +impl fmt::Display for Method { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl Serialize for Method { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for Method { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::{self, Visitor}; + + struct MethodVisitor; + impl<'de> Visitor<'de> for MethodVisitor { + type Value = Method; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string containing a HTTP Verb") + } + + fn visit_str(self, s: &str) -> Result + where + E: de::Error, + { + match Self::Value::from_str(s) { + Ok(value) => Ok(value), + Err(e) => Err(de::Error::custom(format!("{:?}", e))), + } + } + } + + deserializer.deserialize_string(MethodVisitor) + } +} + /// Responder generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS /// /// This struct can be used as Fairing for Rocket, or as an ad-hoc responder for any CORS requests. @@ -389,6 +464,8 @@ impl Cors { } fn default_allowed_methods() -> HashSet { + use rocket::http::Method; + vec![ Method::Get, Method::Head, @@ -398,6 +475,7 @@ impl Cors { Method::Patch, Method::Delete, ].into_iter() + .map(From::from) .collect() } } @@ -468,7 +546,7 @@ impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> { // Check if the request verb is an OPTION or something else let cors_response = match request.method() { - Method::Options => { + http::Method::Options => { let method = Self::request_method(request)?; let headers = Self::request_headers(request)?; Self::preflight(&self.options, origin, method, headers) @@ -905,7 +983,7 @@ impl Response { #[allow(unmounted_route)] mod tests { use std::str::FromStr; - use rocket::http::Method; + use http::Method; use super::*; // The following tests check `Response`'s validation @@ -1139,6 +1217,7 @@ mod tests { Method::Head, Method::Post, ].into_iter() + .map(From::from) .collect(); let method = "GET"; @@ -1178,6 +1257,7 @@ mod tests { Method::Head, Method::Post, ].into_iter() + .map(From::from) .collect(); let method = "DELETE"; @@ -1293,6 +1373,28 @@ mod tests { } + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct MethodTest { + method: ::Method, + } + + #[test] + fn method_serde_roundtrip() { + use serde_test::{Token, assert_tokens}; + + let test = MethodTest { method: From::from(http::Method::Get) }; + + assert_tokens( + &test, + &[ + Token::Struct { name: "MethodTest", len: 1 }, + Token::Str("method"), + Token::Str("GET"), + Token::StructEnd, + ], + ); + } + // The following tests check that preflight checks are done properly // fn make_cors_options() -> Cors { diff --git a/tests/routes.rs b/tests/routes.rs index dc421f9..505dad8 100644 --- a/tests/routes.rs +++ b/tests/routes.rs @@ -30,7 +30,7 @@ fn make_cors_options() -> Cors { Cors { allowed_origins: allowed_origins, - allowed_methods: [Method::Get].iter().cloned().collect(), + allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllOrSome::Some( ["Authorization"] .into_iter() @@ -48,7 +48,7 @@ fn smoke_test() { assert!(failed_origins.is_empty()); let cors_options = rocket_cors::Cors { allowed_origins: allowed_origins, - allowed_methods: [Method::Get].iter().cloned().collect(), + allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllOrSome::Some( ["Authorization"] .iter()