From 6f1a24e12dec25228e687f9aa327959bc4727b51 Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Sat, 15 Jul 2017 11:18:37 +0800 Subject: [PATCH] Move some code to headers module --- src/headers.rs | 253 +++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 243 +++------------------------------------------ tests/headers.rs | 2 +- 3 files changed, 266 insertions(+), 232 deletions(-) create mode 100644 src/headers.rs diff --git a/src/headers.rs b/src/headers.rs new file mode 100644 index 0000000..0ac7d1e --- /dev/null +++ b/src/headers.rs @@ -0,0 +1,253 @@ +//! CORS specific Request Headers + +use std::collections::HashSet; +use std::fmt; +use std::ops::Deref; +use std::str::FromStr; + +use rocket::{self, Outcome}; +use rocket::http::{Method, Status}; +use rocket::request::{self, FromRequest}; +use unicase::UniCase; +use url; +use url_serde; + +pub(crate) type HeaderFieldName = UniCase; +pub(crate) type HeaderFieldNamesSet = HashSet; + +/// A wrapped `url::Url` to allow for deserialization +#[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)] +pub struct Url( + #[serde(with = "url_serde")] + url::Url +); + +impl fmt::Display for Url { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Deref for Url { + type Target = url::Url; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl FromStr for Url { + type Err = url::ParseError; + + fn from_str(input: &str) -> Result { + let url = url::Url::from_str(input)?; + Ok(Url(url)) + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for Url { + type Error = ::Error; + + fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome { + match request.headers().get_one("Origin") { + Some(origin) => { + match Self::from_str(origin) { + Ok(origin) => Outcome::Success(origin), + Err(e) => Outcome::Failure((Status::BadRequest, ::Error::BadOrigin(e))), + } + } + None => Outcome::Forward(()), + } + } +} + +/// The `Origin` request header used in CORS +/// +/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) +/// to ensure that `Origin` is passed in correctly. +pub type Origin = Url; + +/// The `Access-Control-Request-Method` request header +/// +/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) +/// to ensure that the header is passed in correctly. +#[derive(Debug)] +pub struct AccessControlRequestMethod(pub Method); + +impl FromStr for AccessControlRequestMethod { + type Err = rocket::Error; + + fn from_str(method: &str) -> Result { + Ok(AccessControlRequestMethod(Method::from_str(method)?)) + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod { + type Error = ::Error; + + fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome { + match request.headers().get_one("Access-Control-Request-Method") { + Some(request_method) => { + match Self::from_str(request_method) { + Ok(request_method) => Outcome::Success(request_method), + Err(e) => Outcome::Failure((Status::BadRequest, ::Error::BadRequestMethod(e))), + } + } + None => Outcome::Forward(()), + } + } +} + +/// The `Access-Control-Request-Headers` request header +/// +/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) +/// to ensure that the header is passed in correctly. +#[derive(Debug)] +pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet); + +/// Will never fail +impl FromStr for AccessControlRequestHeaders { + type Err = (); + + /// Will never fail + fn from_str(headers: &str) -> Result { + if headers.trim().is_empty() { + return Ok(AccessControlRequestHeaders(HashSet::new())); + } + + let set: HeaderFieldNamesSet = headers + .split(',') + .map(|header| UniCase(header.trim().to_string())) + .collect(); + Ok(AccessControlRequestHeaders(set)) + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders { + type Error = ::Error; + + fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome { + match request.headers().get_one("Access-Control-Request-Headers") { + Some(request_headers) => { + match Self::from_str(request_headers) { + Ok(request_headers) => Outcome::Success(request_headers), + Err(()) => { + unreachable!("`AccessControlRequestHeaders::from_str` should never fail") + } + } + } + None => Outcome::Forward(()), + } + } +} + +#[cfg(test)] +#[allow(unmounted_route)] +mod tests { + use std::str::FromStr; + + use hyper; + use rocket; + use rocket::local::Client; + use rocket::http::Method; + + use super::*; + + /// Make a client with no routes for unit testing + fn make_client() -> Client { + let rocket = rocket::ignite(); + Client::new(rocket).expect("valid rocket instance") + } + + // The following tests check that CORS Request headers are parsed correctly + + #[test] + fn origin_header_conversion() { + let url = "https://foo.bar.xyz"; + let parsed = not_err!(Origin::from_str(url)); + let expected = not_err!(Url::from_str(url)); + assert_eq!(parsed, expected); + + let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used + let parsed = not_err!(Origin::from_str(url)); + let expected = not_err!(Url::from_str(url)); + assert_eq!(parsed, expected); + + let url = "invalid_url"; + let _ = is_err!(Origin::from_str(url)); + } + + #[test] + fn origin_header_parsing() { + let client = make_client(); + let mut request = client.get("/"); + + let origin = hyper::header::Origin::new("https", "www.example.com", None); + request.add_header(origin); + + let outcome: request::Outcome = FromRequest::from_request(request.inner()); + let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); + assert_eq!("https://www.example.com/", parsed_header.as_str()); + } + + #[test] + fn request_method_conversion() { + let method = "POST"; + let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); + assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post)); + + let method = "options"; + let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); + assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options)); + + let method = "INVALID"; + let _ = is_err!(AccessControlRequestMethod::from_str(method)); + } + + #[test] + fn request_method_parsing() { + let client = make_client(); + let mut request = client.get("/"); + let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get); + request.add_header(method); + let outcome: request::Outcome = + FromRequest::from_request(request.inner()); + + let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); + let AccessControlRequestMethod(parsed_method) = parsed_header; + assert_eq!("GET", parsed_method.as_str()); + } + + #[test] + fn request_headers_conversion() { + let headers = ["foo", "bar", "baz"]; + let parsed_headers = not_err!(AccessControlRequestHeaders::from_str(&headers.join(", "))); + let expected_headers: HeaderFieldNamesSet = + headers.iter().map(|s| s.to_string().into()).collect(); + let AccessControlRequestHeaders(actual_headers) = parsed_headers; + assert_eq!(actual_headers, expected_headers); + } + + #[test] + fn request_headers_parsing() { + let client = make_client(); + let mut request = client.get("/"); + let headers = hyper::header::AccessControlRequestHeaders(vec![ + FromStr::from_str("accept-language").unwrap(), + FromStr::from_str("date").unwrap(), + ]); + request.add_header(headers); + let outcome: request::Outcome = + FromRequest::from_request(request.inner()); + + let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); + let AccessControlRequestHeaders(parsed_headers) = parsed_header; + let mut parsed_headers: Vec = + parsed_headers.iter().map(|s| s.to_string()).collect(); + parsed_headers.sort(); + assert_eq!( + vec!["accept-language".to_string(), "date".to_string()], + parsed_headers + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 15523c3..3e48528 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -107,6 +107,15 @@ extern crate url_serde; #[cfg(test)] extern crate hyper; +#[cfg(test)] +#[macro_use] +mod test_macros; + +pub mod headers; + +// Public exports +pub use headers::Url; + use std::collections::{HashSet, HashMap}; use std::error; use std::fmt; @@ -116,13 +125,11 @@ use std::str::FromStr; use rocket::{Outcome, State}; use rocket::http::{Method, Status}; -use rocket::request::{self, Request, FromRequest}; +use rocket::request::{Request, FromRequest}; use rocket::response; -use unicase::UniCase; -#[cfg(test)] -#[macro_use] -mod test_macros; +use headers::{HeaderFieldName, HeaderFieldNamesSet, Origin, AccessControlRequestHeaders, + AccessControlRequestMethod}; /// Errors during operations /// @@ -212,129 +219,6 @@ impl<'r> response::Responder<'r> for Error { } } -/// A wrapped `url::Url` to allow for deserialization -#[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)] -pub struct Url( - #[serde(with = "url_serde")] - url::Url -); - -impl fmt::Display for Url { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.0.fmt(f) - } -} - -impl Deref for Url { - type Target = url::Url; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl FromStr for Url { - type Err = url::ParseError; - - fn from_str(input: &str) -> Result { - let url = url::Url::from_str(input)?; - Ok(Url(url)) - } -} - -impl<'a, 'r> FromRequest<'a, 'r> for Url { - type Error = Error; - - fn from_request(request: &'a Request<'r>) -> request::Outcome { - match request.headers().get_one("Origin") { - Some(origin) => { - match Self::from_str(origin) { - Ok(origin) => Outcome::Success(origin), - Err(e) => Outcome::Failure((Status::BadRequest, Error::BadOrigin(e))), - } - } - None => Outcome::Forward(()), - } - } -} - -/// The `Origin` request header used in CORS -/// -/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) -/// to ensure that Origins are passed in correctly. -pub type Origin = Url; - -/// The `Access-Control-Request-Method` request header -#[derive(Debug)] -pub struct AccessControlRequestMethod(pub Method); - -impl FromStr for AccessControlRequestMethod { - type Err = rocket::Error; - - fn from_str(method: &str) -> Result { - Ok(AccessControlRequestMethod(Method::from_str(method)?)) - } -} - -impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod { - type Error = Error; - - fn from_request(request: &'a Request<'r>) -> request::Outcome { - match request.headers().get_one("Access-Control-Request-Method") { - Some(request_method) => { - match Self::from_str(request_method) { - Ok(request_method) => Outcome::Success(request_method), - Err(e) => Outcome::Failure((Status::BadRequest, Error::BadRequestMethod(e))), - } - } - None => Outcome::Forward(()), - } - } -} - -type HeaderFieldName = UniCase; -type HeaderFieldNamesSet = HashSet; - -/// The `Access-Control-Request-Headers` request header -#[derive(Debug)] -pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet); - -/// Will never fail -impl FromStr for AccessControlRequestHeaders { - type Err = (); - - /// Will never fail - fn from_str(headers: &str) -> Result { - if headers.trim().is_empty() { - return Ok(AccessControlRequestHeaders(HashSet::new())); - } - - let set: HeaderFieldNamesSet = headers - .split(',') - .map(|header| UniCase(header.trim().to_string())) - .collect(); - Ok(AccessControlRequestHeaders(set)) - } -} - -impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders { - type Error = Error; - - fn from_request(request: &'a Request<'r>) -> request::Outcome { - match request.headers().get_one("Access-Control-Request-Headers") { - Some(request_headers) => { - match Self::from_str(request_headers) { - Ok(request_headers) => Outcome::Success(request_headers), - Err(()) => { - unreachable!("`AccessControlRequestHeaders::from_str` should never fail") - } - } - } - None => Outcome::Forward(()), - } - } -} - /// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// /// `Default` is implemented for this enum and is `All`. @@ -1021,112 +905,9 @@ impl Response { #[allow(unmounted_route)] mod tests { use std::str::FromStr; - - use hyper; - use rocket; - use rocket::local::Client; use rocket::http::Method; - use super::*; - /// Make a client with no routes for unit testing - fn make_client() -> Client { - let rocket = rocket::ignite(); - Client::new(rocket).expect("valid rocket instance") - } - - // The following tests check that CORS Request headers are parsed correctly - - #[test] - fn origin_header_conversion() { - let url = "https://foo.bar.xyz"; - let parsed = not_err!(Origin::from_str(url)); - let expected = not_err!(Url::from_str(url)); - assert_eq!(parsed, expected); - - let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used - let parsed = not_err!(Origin::from_str(url)); - let expected = not_err!(Url::from_str(url)); - assert_eq!(parsed, expected); - - let url = "invalid_url"; - let _ = is_err!(Origin::from_str(url)); - } - - #[test] - fn origin_header_parsing() { - let client = make_client(); - let mut request = client.get("/"); - - let origin = hyper::header::Origin::new("https", "www.example.com", None); - request.add_header(origin); - - let outcome: request::Outcome = FromRequest::from_request(request.inner()); - let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); - assert_eq!("https://www.example.com/", parsed_header.as_str()); - } - - #[test] - fn request_method_conversion() { - let method = "POST"; - let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); - assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post)); - - let method = "options"; - let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); - assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options)); - - let method = "INVALID"; - let _ = is_err!(AccessControlRequestMethod::from_str(method)); - } - - #[test] - fn request_method_parsing() { - let client = make_client(); - let mut request = client.get("/"); - let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get); - request.add_header(method); - let outcome: request::Outcome = - FromRequest::from_request(request.inner()); - - let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); - let AccessControlRequestMethod(parsed_method) = parsed_header; - assert_eq!("GET", parsed_method.as_str()); - } - - #[test] - fn request_headers_conversion() { - let headers = ["foo", "bar", "baz"]; - let parsed_headers = not_err!(AccessControlRequestHeaders::from_str(&headers.join(", "))); - let expected_headers: HeaderFieldNamesSet = - headers.iter().map(|s| s.to_string().into()).collect(); - let AccessControlRequestHeaders(actual_headers) = parsed_headers; - assert_eq!(actual_headers, expected_headers); - } - - #[test] - fn request_headers_parsing() { - let client = make_client(); - let mut request = client.get("/"); - let headers = hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("accept-language").unwrap(), - FromStr::from_str("date").unwrap(), - ]); - request.add_header(headers); - let outcome: request::Outcome = - FromRequest::from_request(request.inner()); - - let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); - let AccessControlRequestHeaders(parsed_headers) = parsed_header; - let mut parsed_headers: Vec = - parsed_headers.iter().map(|s| s.to_string()).collect(); - parsed_headers.sort(); - assert_eq!( - vec!["accept-language".to_string(), "date".to_string()], - parsed_headers - ); - } - // The following tests check `Response`'s validation #[test] diff --git a/tests/headers.rs b/tests/headers.rs index a81e167..9be98da 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -10,7 +10,7 @@ use std::str::FromStr; use rocket::local::Client; use rocket::http::{Header, Status}; -use rocket_cors::*; +use rocket_cors::headers::*; #[get("/request_headers")] fn request_headers(