Move some code to headers module

This commit is contained in:
Yong Wen Chua 2017-07-15 11:18:37 +08:00
parent ad25352bdd
commit 6f1a24e12d
3 changed files with 266 additions and 232 deletions

253
src/headers.rs Normal file
View File

@ -0,0 +1,253 @@
//! CORS specific Request Headers
use std::collections::HashSet;
use std::fmt;
use std::ops::Deref;
use std::str::FromStr;
use rocket::{self, Outcome};
use rocket::http::{Method, Status};
use rocket::request::{self, FromRequest};
use unicase::UniCase;
use url;
use url_serde;
pub(crate) type HeaderFieldName = UniCase<String>;
pub(crate) type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
/// A wrapped `url::Url` to allow for deserialization
#[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)]
pub struct Url(
#[serde(with = "url_serde")]
url::Url
);
impl fmt::Display for Url {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}
impl Deref for Url {
type Target = url::Url;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl FromStr for Url {
type Err = url::ParseError;
fn from_str(input: &str) -> Result<Self, Self::Err> {
let url = url::Url::from_str(input)?;
Ok(Url(url))
}
}
impl<'a, 'r> FromRequest<'a, 'r> for Url {
type Error = ::Error;
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, ::Error> {
match request.headers().get_one("Origin") {
Some(origin) => {
match Self::from_str(origin) {
Ok(origin) => Outcome::Success(origin),
Err(e) => Outcome::Failure((Status::BadRequest, ::Error::BadOrigin(e))),
}
}
None => Outcome::Forward(()),
}
}
}
/// The `Origin` request header used in CORS
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that `Origin` is passed in correctly.
pub type Origin = Url;
/// The `Access-Control-Request-Method` request header
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that the header is passed in correctly.
#[derive(Debug)]
pub struct AccessControlRequestMethod(pub Method);
impl FromStr for AccessControlRequestMethod {
type Err = rocket::Error;
fn from_str(method: &str) -> Result<Self, Self::Err> {
Ok(AccessControlRequestMethod(Method::from_str(method)?))
}
}
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
type Error = ::Error;
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, ::Error> {
match request.headers().get_one("Access-Control-Request-Method") {
Some(request_method) => {
match Self::from_str(request_method) {
Ok(request_method) => Outcome::Success(request_method),
Err(e) => Outcome::Failure((Status::BadRequest, ::Error::BadRequestMethod(e))),
}
}
None => Outcome::Forward(()),
}
}
}
/// The `Access-Control-Request-Headers` request header
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that the header is passed in correctly.
#[derive(Debug)]
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
/// Will never fail
impl FromStr for AccessControlRequestHeaders {
type Err = ();
/// Will never fail
fn from_str(headers: &str) -> Result<Self, Self::Err> {
if headers.trim().is_empty() {
return Ok(AccessControlRequestHeaders(HashSet::new()));
}
let set: HeaderFieldNamesSet = headers
.split(',')
.map(|header| UniCase(header.trim().to_string()))
.collect();
Ok(AccessControlRequestHeaders(set))
}
}
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
type Error = ::Error;
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, ::Error> {
match request.headers().get_one("Access-Control-Request-Headers") {
Some(request_headers) => {
match Self::from_str(request_headers) {
Ok(request_headers) => Outcome::Success(request_headers),
Err(()) => {
unreachable!("`AccessControlRequestHeaders::from_str` should never fail")
}
}
}
None => Outcome::Forward(()),
}
}
}
#[cfg(test)]
#[allow(unmounted_route)]
mod tests {
use std::str::FromStr;
use hyper;
use rocket;
use rocket::local::Client;
use rocket::http::Method;
use super::*;
/// Make a client with no routes for unit testing
fn make_client() -> Client {
let rocket = rocket::ignite();
Client::new(rocket).expect("valid rocket instance")
}
// The following tests check that CORS Request headers are parsed correctly
#[test]
fn origin_header_conversion() {
let url = "https://foo.bar.xyz";
let parsed = not_err!(Origin::from_str(url));
let expected = not_err!(Url::from_str(url));
assert_eq!(parsed, expected);
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
let parsed = not_err!(Origin::from_str(url));
let expected = not_err!(Url::from_str(url));
assert_eq!(parsed, expected);
let url = "invalid_url";
let _ = is_err!(Origin::from_str(url));
}
#[test]
fn origin_header_parsing() {
let client = make_client();
let mut request = client.get("/");
let origin = hyper::header::Origin::new("https", "www.example.com", None);
request.add_header(origin);
let outcome: request::Outcome<Origin, ::Error> = FromRequest::from_request(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
assert_eq!("https://www.example.com/", parsed_header.as_str());
}
#[test]
fn request_method_conversion() {
let method = "POST";
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post));
let method = "options";
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options));
let method = "INVALID";
let _ = is_err!(AccessControlRequestMethod::from_str(method));
}
#[test]
fn request_method_parsing() {
let client = make_client();
let mut request = client.get("/");
let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get);
request.add_header(method);
let outcome: request::Outcome<AccessControlRequestMethod, ::Error> =
FromRequest::from_request(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
let AccessControlRequestMethod(parsed_method) = parsed_header;
assert_eq!("GET", parsed_method.as_str());
}
#[test]
fn request_headers_conversion() {
let headers = ["foo", "bar", "baz"];
let parsed_headers = not_err!(AccessControlRequestHeaders::from_str(&headers.join(", ")));
let expected_headers: HeaderFieldNamesSet =
headers.iter().map(|s| s.to_string().into()).collect();
let AccessControlRequestHeaders(actual_headers) = parsed_headers;
assert_eq!(actual_headers, expected_headers);
}
#[test]
fn request_headers_parsing() {
let client = make_client();
let mut request = client.get("/");
let headers = hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("accept-language").unwrap(),
FromStr::from_str("date").unwrap(),
]);
request.add_header(headers);
let outcome: request::Outcome<AccessControlRequestHeaders, ::Error> =
FromRequest::from_request(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
let AccessControlRequestHeaders(parsed_headers) = parsed_header;
let mut parsed_headers: Vec<String> =
parsed_headers.iter().map(|s| s.to_string()).collect();
parsed_headers.sort();
assert_eq!(
vec!["accept-language".to_string(), "date".to_string()],
parsed_headers
);
}
}

View File

@ -107,6 +107,15 @@ extern crate url_serde;
#[cfg(test)]
extern crate hyper;
#[cfg(test)]
#[macro_use]
mod test_macros;
pub mod headers;
// Public exports
pub use headers::Url;
use std::collections::{HashSet, HashMap};
use std::error;
use std::fmt;
@ -116,13 +125,11 @@ use std::str::FromStr;
use rocket::{Outcome, State};
use rocket::http::{Method, Status};
use rocket::request::{self, Request, FromRequest};
use rocket::request::{Request, FromRequest};
use rocket::response;
use unicase::UniCase;
#[cfg(test)]
#[macro_use]
mod test_macros;
use headers::{HeaderFieldName, HeaderFieldNamesSet, Origin, AccessControlRequestHeaders,
AccessControlRequestMethod};
/// Errors during operations
///
@ -212,129 +219,6 @@ impl<'r> response::Responder<'r> for Error {
}
}
/// A wrapped `url::Url` to allow for deserialization
#[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)]
pub struct Url(
#[serde(with = "url_serde")]
url::Url
);
impl fmt::Display for Url {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}
impl Deref for Url {
type Target = url::Url;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl FromStr for Url {
type Err = url::ParseError;
fn from_str(input: &str) -> Result<Self, Self::Err> {
let url = url::Url::from_str(input)?;
Ok(Url(url))
}
}
impl<'a, 'r> FromRequest<'a, 'r> for Url {
type Error = Error;
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Error> {
match request.headers().get_one("Origin") {
Some(origin) => {
match Self::from_str(origin) {
Ok(origin) => Outcome::Success(origin),
Err(e) => Outcome::Failure((Status::BadRequest, Error::BadOrigin(e))),
}
}
None => Outcome::Forward(()),
}
}
}
/// The `Origin` request header used in CORS
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that Origins are passed in correctly.
pub type Origin = Url;
/// The `Access-Control-Request-Method` request header
#[derive(Debug)]
pub struct AccessControlRequestMethod(pub Method);
impl FromStr for AccessControlRequestMethod {
type Err = rocket::Error;
fn from_str(method: &str) -> Result<Self, Self::Err> {
Ok(AccessControlRequestMethod(Method::from_str(method)?))
}
}
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
type Error = Error;
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Error> {
match request.headers().get_one("Access-Control-Request-Method") {
Some(request_method) => {
match Self::from_str(request_method) {
Ok(request_method) => Outcome::Success(request_method),
Err(e) => Outcome::Failure((Status::BadRequest, Error::BadRequestMethod(e))),
}
}
None => Outcome::Forward(()),
}
}
}
type HeaderFieldName = UniCase<String>;
type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
/// The `Access-Control-Request-Headers` request header
#[derive(Debug)]
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
/// Will never fail
impl FromStr for AccessControlRequestHeaders {
type Err = ();
/// Will never fail
fn from_str(headers: &str) -> Result<Self, Self::Err> {
if headers.trim().is_empty() {
return Ok(AccessControlRequestHeaders(HashSet::new()));
}
let set: HeaderFieldNamesSet = headers
.split(',')
.map(|header| UniCase(header.trim().to_string()))
.collect();
Ok(AccessControlRequestHeaders(set))
}
}
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
type Error = Error;
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Error> {
match request.headers().get_one("Access-Control-Request-Headers") {
Some(request_headers) => {
match Self::from_str(request_headers) {
Ok(request_headers) => Outcome::Success(request_headers),
Err(()) => {
unreachable!("`AccessControlRequestHeaders::from_str` should never fail")
}
}
}
None => Outcome::Forward(()),
}
}
}
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
///
/// `Default` is implemented for this enum and is `All`.
@ -1021,112 +905,9 @@ impl Response {
#[allow(unmounted_route)]
mod tests {
use std::str::FromStr;
use hyper;
use rocket;
use rocket::local::Client;
use rocket::http::Method;
use super::*;
/// Make a client with no routes for unit testing
fn make_client() -> Client {
let rocket = rocket::ignite();
Client::new(rocket).expect("valid rocket instance")
}
// The following tests check that CORS Request headers are parsed correctly
#[test]
fn origin_header_conversion() {
let url = "https://foo.bar.xyz";
let parsed = not_err!(Origin::from_str(url));
let expected = not_err!(Url::from_str(url));
assert_eq!(parsed, expected);
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
let parsed = not_err!(Origin::from_str(url));
let expected = not_err!(Url::from_str(url));
assert_eq!(parsed, expected);
let url = "invalid_url";
let _ = is_err!(Origin::from_str(url));
}
#[test]
fn origin_header_parsing() {
let client = make_client();
let mut request = client.get("/");
let origin = hyper::header::Origin::new("https", "www.example.com", None);
request.add_header(origin);
let outcome: request::Outcome<Origin, Error> = FromRequest::from_request(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
assert_eq!("https://www.example.com/", parsed_header.as_str());
}
#[test]
fn request_method_conversion() {
let method = "POST";
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post));
let method = "options";
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options));
let method = "INVALID";
let _ = is_err!(AccessControlRequestMethod::from_str(method));
}
#[test]
fn request_method_parsing() {
let client = make_client();
let mut request = client.get("/");
let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get);
request.add_header(method);
let outcome: request::Outcome<AccessControlRequestMethod, Error> =
FromRequest::from_request(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
let AccessControlRequestMethod(parsed_method) = parsed_header;
assert_eq!("GET", parsed_method.as_str());
}
#[test]
fn request_headers_conversion() {
let headers = ["foo", "bar", "baz"];
let parsed_headers = not_err!(AccessControlRequestHeaders::from_str(&headers.join(", ")));
let expected_headers: HeaderFieldNamesSet =
headers.iter().map(|s| s.to_string().into()).collect();
let AccessControlRequestHeaders(actual_headers) = parsed_headers;
assert_eq!(actual_headers, expected_headers);
}
#[test]
fn request_headers_parsing() {
let client = make_client();
let mut request = client.get("/");
let headers = hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("accept-language").unwrap(),
FromStr::from_str("date").unwrap(),
]);
request.add_header(headers);
let outcome: request::Outcome<AccessControlRequestHeaders, Error> =
FromRequest::from_request(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
let AccessControlRequestHeaders(parsed_headers) = parsed_header;
let mut parsed_headers: Vec<String> =
parsed_headers.iter().map(|s| s.to_string()).collect();
parsed_headers.sort();
assert_eq!(
vec!["accept-language".to_string(), "date".to_string()],
parsed_headers
);
}
// The following tests check `Response`'s validation
#[test]

View File

@ -10,7 +10,7 @@ use std::str::FromStr;
use rocket::local::Client;
use rocket::http::{Header, Status};
use rocket_cors::*;
use rocket_cors::headers::*;
#[get("/request_headers")]
fn request_headers(