Move some code to headers module
This commit is contained in:
parent
ad25352bdd
commit
6f1a24e12d
|
@ -0,0 +1,253 @@
|
|||
//! CORS specific Request Headers
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::fmt;
|
||||
use std::ops::Deref;
|
||||
use std::str::FromStr;
|
||||
|
||||
use rocket::{self, Outcome};
|
||||
use rocket::http::{Method, Status};
|
||||
use rocket::request::{self, FromRequest};
|
||||
use unicase::UniCase;
|
||||
use url;
|
||||
use url_serde;
|
||||
|
||||
pub(crate) type HeaderFieldName = UniCase<String>;
|
||||
pub(crate) type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
|
||||
|
||||
/// A wrapped `url::Url` to allow for deserialization
|
||||
#[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)]
|
||||
pub struct Url(
|
||||
#[serde(with = "url_serde")]
|
||||
url::Url
|
||||
);
|
||||
|
||||
impl fmt::Display for Url {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Url {
|
||||
type Target = url::Url;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Url {
|
||||
type Err = url::ParseError;
|
||||
|
||||
fn from_str(input: &str) -> Result<Self, Self::Err> {
|
||||
let url = url::Url::from_str(input)?;
|
||||
Ok(Url(url))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for Url {
|
||||
type Error = ::Error;
|
||||
|
||||
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, ::Error> {
|
||||
match request.headers().get_one("Origin") {
|
||||
Some(origin) => {
|
||||
match Self::from_str(origin) {
|
||||
Ok(origin) => Outcome::Success(origin),
|
||||
Err(e) => Outcome::Failure((Status::BadRequest, ::Error::BadOrigin(e))),
|
||||
}
|
||||
}
|
||||
None => Outcome::Forward(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The `Origin` request header used in CORS
|
||||
///
|
||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||
/// to ensure that `Origin` is passed in correctly.
|
||||
pub type Origin = Url;
|
||||
|
||||
/// The `Access-Control-Request-Method` request header
|
||||
///
|
||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||
/// to ensure that the header is passed in correctly.
|
||||
#[derive(Debug)]
|
||||
pub struct AccessControlRequestMethod(pub Method);
|
||||
|
||||
impl FromStr for AccessControlRequestMethod {
|
||||
type Err = rocket::Error;
|
||||
|
||||
fn from_str(method: &str) -> Result<Self, Self::Err> {
|
||||
Ok(AccessControlRequestMethod(Method::from_str(method)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
|
||||
type Error = ::Error;
|
||||
|
||||
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, ::Error> {
|
||||
match request.headers().get_one("Access-Control-Request-Method") {
|
||||
Some(request_method) => {
|
||||
match Self::from_str(request_method) {
|
||||
Ok(request_method) => Outcome::Success(request_method),
|
||||
Err(e) => Outcome::Failure((Status::BadRequest, ::Error::BadRequestMethod(e))),
|
||||
}
|
||||
}
|
||||
None => Outcome::Forward(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The `Access-Control-Request-Headers` request header
|
||||
///
|
||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||
/// to ensure that the header is passed in correctly.
|
||||
#[derive(Debug)]
|
||||
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
|
||||
|
||||
/// Will never fail
|
||||
impl FromStr for AccessControlRequestHeaders {
|
||||
type Err = ();
|
||||
|
||||
/// Will never fail
|
||||
fn from_str(headers: &str) -> Result<Self, Self::Err> {
|
||||
if headers.trim().is_empty() {
|
||||
return Ok(AccessControlRequestHeaders(HashSet::new()));
|
||||
}
|
||||
|
||||
let set: HeaderFieldNamesSet = headers
|
||||
.split(',')
|
||||
.map(|header| UniCase(header.trim().to_string()))
|
||||
.collect();
|
||||
Ok(AccessControlRequestHeaders(set))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
|
||||
type Error = ::Error;
|
||||
|
||||
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, ::Error> {
|
||||
match request.headers().get_one("Access-Control-Request-Headers") {
|
||||
Some(request_headers) => {
|
||||
match Self::from_str(request_headers) {
|
||||
Ok(request_headers) => Outcome::Success(request_headers),
|
||||
Err(()) => {
|
||||
unreachable!("`AccessControlRequestHeaders::from_str` should never fail")
|
||||
}
|
||||
}
|
||||
}
|
||||
None => Outcome::Forward(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(unmounted_route)]
|
||||
mod tests {
|
||||
use std::str::FromStr;
|
||||
|
||||
use hyper;
|
||||
use rocket;
|
||||
use rocket::local::Client;
|
||||
use rocket::http::Method;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Make a client with no routes for unit testing
|
||||
fn make_client() -> Client {
|
||||
let rocket = rocket::ignite();
|
||||
Client::new(rocket).expect("valid rocket instance")
|
||||
}
|
||||
|
||||
// The following tests check that CORS Request headers are parsed correctly
|
||||
|
||||
#[test]
|
||||
fn origin_header_conversion() {
|
||||
let url = "https://foo.bar.xyz";
|
||||
let parsed = not_err!(Origin::from_str(url));
|
||||
let expected = not_err!(Url::from_str(url));
|
||||
assert_eq!(parsed, expected);
|
||||
|
||||
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
|
||||
let parsed = not_err!(Origin::from_str(url));
|
||||
let expected = not_err!(Url::from_str(url));
|
||||
assert_eq!(parsed, expected);
|
||||
|
||||
let url = "invalid_url";
|
||||
let _ = is_err!(Origin::from_str(url));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn origin_header_parsing() {
|
||||
let client = make_client();
|
||||
let mut request = client.get("/");
|
||||
|
||||
let origin = hyper::header::Origin::new("https", "www.example.com", None);
|
||||
request.add_header(origin);
|
||||
|
||||
let outcome: request::Outcome<Origin, ::Error> = FromRequest::from_request(request.inner());
|
||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||
assert_eq!("https://www.example.com/", parsed_header.as_str());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_method_conversion() {
|
||||
let method = "POST";
|
||||
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
|
||||
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post));
|
||||
|
||||
let method = "options";
|
||||
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
|
||||
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options));
|
||||
|
||||
let method = "INVALID";
|
||||
let _ = is_err!(AccessControlRequestMethod::from_str(method));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_method_parsing() {
|
||||
let client = make_client();
|
||||
let mut request = client.get("/");
|
||||
let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get);
|
||||
request.add_header(method);
|
||||
let outcome: request::Outcome<AccessControlRequestMethod, ::Error> =
|
||||
FromRequest::from_request(request.inner());
|
||||
|
||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||
let AccessControlRequestMethod(parsed_method) = parsed_header;
|
||||
assert_eq!("GET", parsed_method.as_str());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_headers_conversion() {
|
||||
let headers = ["foo", "bar", "baz"];
|
||||
let parsed_headers = not_err!(AccessControlRequestHeaders::from_str(&headers.join(", ")));
|
||||
let expected_headers: HeaderFieldNamesSet =
|
||||
headers.iter().map(|s| s.to_string().into()).collect();
|
||||
let AccessControlRequestHeaders(actual_headers) = parsed_headers;
|
||||
assert_eq!(actual_headers, expected_headers);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_headers_parsing() {
|
||||
let client = make_client();
|
||||
let mut request = client.get("/");
|
||||
let headers = hyper::header::AccessControlRequestHeaders(vec![
|
||||
FromStr::from_str("accept-language").unwrap(),
|
||||
FromStr::from_str("date").unwrap(),
|
||||
]);
|
||||
request.add_header(headers);
|
||||
let outcome: request::Outcome<AccessControlRequestHeaders, ::Error> =
|
||||
FromRequest::from_request(request.inner());
|
||||
|
||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||
let AccessControlRequestHeaders(parsed_headers) = parsed_header;
|
||||
let mut parsed_headers: Vec<String> =
|
||||
parsed_headers.iter().map(|s| s.to_string()).collect();
|
||||
parsed_headers.sort();
|
||||
assert_eq!(
|
||||
vec!["accept-language".to_string(), "date".to_string()],
|
||||
parsed_headers
|
||||
);
|
||||
}
|
||||
}
|
243
src/lib.rs
243
src/lib.rs
|
@ -107,6 +107,15 @@ extern crate url_serde;
|
|||
#[cfg(test)]
|
||||
extern crate hyper;
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
mod test_macros;
|
||||
|
||||
pub mod headers;
|
||||
|
||||
// Public exports
|
||||
pub use headers::Url;
|
||||
|
||||
use std::collections::{HashSet, HashMap};
|
||||
use std::error;
|
||||
use std::fmt;
|
||||
|
@ -116,13 +125,11 @@ use std::str::FromStr;
|
|||
|
||||
use rocket::{Outcome, State};
|
||||
use rocket::http::{Method, Status};
|
||||
use rocket::request::{self, Request, FromRequest};
|
||||
use rocket::request::{Request, FromRequest};
|
||||
use rocket::response;
|
||||
use unicase::UniCase;
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
mod test_macros;
|
||||
use headers::{HeaderFieldName, HeaderFieldNamesSet, Origin, AccessControlRequestHeaders,
|
||||
AccessControlRequestMethod};
|
||||
|
||||
/// Errors during operations
|
||||
///
|
||||
|
@ -212,129 +219,6 @@ impl<'r> response::Responder<'r> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
/// A wrapped `url::Url` to allow for deserialization
|
||||
#[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)]
|
||||
pub struct Url(
|
||||
#[serde(with = "url_serde")]
|
||||
url::Url
|
||||
);
|
||||
|
||||
impl fmt::Display for Url {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Url {
|
||||
type Target = url::Url;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Url {
|
||||
type Err = url::ParseError;
|
||||
|
||||
fn from_str(input: &str) -> Result<Self, Self::Err> {
|
||||
let url = url::Url::from_str(input)?;
|
||||
Ok(Url(url))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for Url {
|
||||
type Error = Error;
|
||||
|
||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Error> {
|
||||
match request.headers().get_one("Origin") {
|
||||
Some(origin) => {
|
||||
match Self::from_str(origin) {
|
||||
Ok(origin) => Outcome::Success(origin),
|
||||
Err(e) => Outcome::Failure((Status::BadRequest, Error::BadOrigin(e))),
|
||||
}
|
||||
}
|
||||
None => Outcome::Forward(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The `Origin` request header used in CORS
|
||||
///
|
||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||
/// to ensure that Origins are passed in correctly.
|
||||
pub type Origin = Url;
|
||||
|
||||
/// The `Access-Control-Request-Method` request header
|
||||
#[derive(Debug)]
|
||||
pub struct AccessControlRequestMethod(pub Method);
|
||||
|
||||
impl FromStr for AccessControlRequestMethod {
|
||||
type Err = rocket::Error;
|
||||
|
||||
fn from_str(method: &str) -> Result<Self, Self::Err> {
|
||||
Ok(AccessControlRequestMethod(Method::from_str(method)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
|
||||
type Error = Error;
|
||||
|
||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Error> {
|
||||
match request.headers().get_one("Access-Control-Request-Method") {
|
||||
Some(request_method) => {
|
||||
match Self::from_str(request_method) {
|
||||
Ok(request_method) => Outcome::Success(request_method),
|
||||
Err(e) => Outcome::Failure((Status::BadRequest, Error::BadRequestMethod(e))),
|
||||
}
|
||||
}
|
||||
None => Outcome::Forward(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type HeaderFieldName = UniCase<String>;
|
||||
type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
|
||||
|
||||
/// The `Access-Control-Request-Headers` request header
|
||||
#[derive(Debug)]
|
||||
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
|
||||
|
||||
/// Will never fail
|
||||
impl FromStr for AccessControlRequestHeaders {
|
||||
type Err = ();
|
||||
|
||||
/// Will never fail
|
||||
fn from_str(headers: &str) -> Result<Self, Self::Err> {
|
||||
if headers.trim().is_empty() {
|
||||
return Ok(AccessControlRequestHeaders(HashSet::new()));
|
||||
}
|
||||
|
||||
let set: HeaderFieldNamesSet = headers
|
||||
.split(',')
|
||||
.map(|header| UniCase(header.trim().to_string()))
|
||||
.collect();
|
||||
Ok(AccessControlRequestHeaders(set))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
|
||||
type Error = Error;
|
||||
|
||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Error> {
|
||||
match request.headers().get_one("Access-Control-Request-Headers") {
|
||||
Some(request_headers) => {
|
||||
match Self::from_str(request_headers) {
|
||||
Ok(request_headers) => Outcome::Success(request_headers),
|
||||
Err(()) => {
|
||||
unreachable!("`AccessControlRequestHeaders::from_str` should never fail")
|
||||
}
|
||||
}
|
||||
}
|
||||
None => Outcome::Forward(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
|
||||
///
|
||||
/// `Default` is implemented for this enum and is `All`.
|
||||
|
@ -1021,112 +905,9 @@ impl Response {
|
|||
#[allow(unmounted_route)]
|
||||
mod tests {
|
||||
use std::str::FromStr;
|
||||
|
||||
use hyper;
|
||||
use rocket;
|
||||
use rocket::local::Client;
|
||||
use rocket::http::Method;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Make a client with no routes for unit testing
|
||||
fn make_client() -> Client {
|
||||
let rocket = rocket::ignite();
|
||||
Client::new(rocket).expect("valid rocket instance")
|
||||
}
|
||||
|
||||
// The following tests check that CORS Request headers are parsed correctly
|
||||
|
||||
#[test]
|
||||
fn origin_header_conversion() {
|
||||
let url = "https://foo.bar.xyz";
|
||||
let parsed = not_err!(Origin::from_str(url));
|
||||
let expected = not_err!(Url::from_str(url));
|
||||
assert_eq!(parsed, expected);
|
||||
|
||||
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
|
||||
let parsed = not_err!(Origin::from_str(url));
|
||||
let expected = not_err!(Url::from_str(url));
|
||||
assert_eq!(parsed, expected);
|
||||
|
||||
let url = "invalid_url";
|
||||
let _ = is_err!(Origin::from_str(url));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn origin_header_parsing() {
|
||||
let client = make_client();
|
||||
let mut request = client.get("/");
|
||||
|
||||
let origin = hyper::header::Origin::new("https", "www.example.com", None);
|
||||
request.add_header(origin);
|
||||
|
||||
let outcome: request::Outcome<Origin, Error> = FromRequest::from_request(request.inner());
|
||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||
assert_eq!("https://www.example.com/", parsed_header.as_str());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_method_conversion() {
|
||||
let method = "POST";
|
||||
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
|
||||
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post));
|
||||
|
||||
let method = "options";
|
||||
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
|
||||
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options));
|
||||
|
||||
let method = "INVALID";
|
||||
let _ = is_err!(AccessControlRequestMethod::from_str(method));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_method_parsing() {
|
||||
let client = make_client();
|
||||
let mut request = client.get("/");
|
||||
let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get);
|
||||
request.add_header(method);
|
||||
let outcome: request::Outcome<AccessControlRequestMethod, Error> =
|
||||
FromRequest::from_request(request.inner());
|
||||
|
||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||
let AccessControlRequestMethod(parsed_method) = parsed_header;
|
||||
assert_eq!("GET", parsed_method.as_str());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_headers_conversion() {
|
||||
let headers = ["foo", "bar", "baz"];
|
||||
let parsed_headers = not_err!(AccessControlRequestHeaders::from_str(&headers.join(", ")));
|
||||
let expected_headers: HeaderFieldNamesSet =
|
||||
headers.iter().map(|s| s.to_string().into()).collect();
|
||||
let AccessControlRequestHeaders(actual_headers) = parsed_headers;
|
||||
assert_eq!(actual_headers, expected_headers);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_headers_parsing() {
|
||||
let client = make_client();
|
||||
let mut request = client.get("/");
|
||||
let headers = hyper::header::AccessControlRequestHeaders(vec![
|
||||
FromStr::from_str("accept-language").unwrap(),
|
||||
FromStr::from_str("date").unwrap(),
|
||||
]);
|
||||
request.add_header(headers);
|
||||
let outcome: request::Outcome<AccessControlRequestHeaders, Error> =
|
||||
FromRequest::from_request(request.inner());
|
||||
|
||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||
let AccessControlRequestHeaders(parsed_headers) = parsed_header;
|
||||
let mut parsed_headers: Vec<String> =
|
||||
parsed_headers.iter().map(|s| s.to_string()).collect();
|
||||
parsed_headers.sort();
|
||||
assert_eq!(
|
||||
vec!["accept-language".to_string(), "date".to_string()],
|
||||
parsed_headers
|
||||
);
|
||||
}
|
||||
|
||||
// The following tests check `Response`'s validation
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -10,7 +10,7 @@ use std::str::FromStr;
|
|||
|
||||
use rocket::local::Client;
|
||||
use rocket::http::{Header, Status};
|
||||
use rocket_cors::*;
|
||||
use rocket_cors::headers::*;
|
||||
|
||||
#[get("/request_headers")]
|
||||
fn request_headers(
|
||||
|
|
Loading…
Reference in New Issue