Move some code to headers module
This commit is contained in:
parent
ad25352bdd
commit
6f1a24e12d
|
@ -0,0 +1,253 @@
|
||||||
|
//! CORS specific Request Headers
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::fmt;
|
||||||
|
use std::ops::Deref;
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use rocket::{self, Outcome};
|
||||||
|
use rocket::http::{Method, Status};
|
||||||
|
use rocket::request::{self, FromRequest};
|
||||||
|
use unicase::UniCase;
|
||||||
|
use url;
|
||||||
|
use url_serde;
|
||||||
|
|
||||||
|
pub(crate) type HeaderFieldName = UniCase<String>;
|
||||||
|
pub(crate) type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
|
||||||
|
|
||||||
|
/// A wrapped `url::Url` to allow for deserialization
|
||||||
|
#[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Url(
|
||||||
|
#[serde(with = "url_serde")]
|
||||||
|
url::Url
|
||||||
|
);
|
||||||
|
|
||||||
|
impl fmt::Display for Url {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
self.0.fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for Url {
|
||||||
|
type Target = url::Url;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for Url {
|
||||||
|
type Err = url::ParseError;
|
||||||
|
|
||||||
|
fn from_str(input: &str) -> Result<Self, Self::Err> {
|
||||||
|
let url = url::Url::from_str(input)?;
|
||||||
|
Ok(Url(url))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'r> FromRequest<'a, 'r> for Url {
|
||||||
|
type Error = ::Error;
|
||||||
|
|
||||||
|
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, ::Error> {
|
||||||
|
match request.headers().get_one("Origin") {
|
||||||
|
Some(origin) => {
|
||||||
|
match Self::from_str(origin) {
|
||||||
|
Ok(origin) => Outcome::Success(origin),
|
||||||
|
Err(e) => Outcome::Failure((Status::BadRequest, ::Error::BadOrigin(e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => Outcome::Forward(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The `Origin` request header used in CORS
|
||||||
|
///
|
||||||
|
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||||
|
/// to ensure that `Origin` is passed in correctly.
|
||||||
|
pub type Origin = Url;
|
||||||
|
|
||||||
|
/// The `Access-Control-Request-Method` request header
|
||||||
|
///
|
||||||
|
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||||
|
/// to ensure that the header is passed in correctly.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AccessControlRequestMethod(pub Method);
|
||||||
|
|
||||||
|
impl FromStr for AccessControlRequestMethod {
|
||||||
|
type Err = rocket::Error;
|
||||||
|
|
||||||
|
fn from_str(method: &str) -> Result<Self, Self::Err> {
|
||||||
|
Ok(AccessControlRequestMethod(Method::from_str(method)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
|
||||||
|
type Error = ::Error;
|
||||||
|
|
||||||
|
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, ::Error> {
|
||||||
|
match request.headers().get_one("Access-Control-Request-Method") {
|
||||||
|
Some(request_method) => {
|
||||||
|
match Self::from_str(request_method) {
|
||||||
|
Ok(request_method) => Outcome::Success(request_method),
|
||||||
|
Err(e) => Outcome::Failure((Status::BadRequest, ::Error::BadRequestMethod(e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => Outcome::Forward(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The `Access-Control-Request-Headers` request header
|
||||||
|
///
|
||||||
|
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||||
|
/// to ensure that the header is passed in correctly.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
|
||||||
|
|
||||||
|
/// Will never fail
|
||||||
|
impl FromStr for AccessControlRequestHeaders {
|
||||||
|
type Err = ();
|
||||||
|
|
||||||
|
/// Will never fail
|
||||||
|
fn from_str(headers: &str) -> Result<Self, Self::Err> {
|
||||||
|
if headers.trim().is_empty() {
|
||||||
|
return Ok(AccessControlRequestHeaders(HashSet::new()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let set: HeaderFieldNamesSet = headers
|
||||||
|
.split(',')
|
||||||
|
.map(|header| UniCase(header.trim().to_string()))
|
||||||
|
.collect();
|
||||||
|
Ok(AccessControlRequestHeaders(set))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
|
||||||
|
type Error = ::Error;
|
||||||
|
|
||||||
|
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, ::Error> {
|
||||||
|
match request.headers().get_one("Access-Control-Request-Headers") {
|
||||||
|
Some(request_headers) => {
|
||||||
|
match Self::from_str(request_headers) {
|
||||||
|
Ok(request_headers) => Outcome::Success(request_headers),
|
||||||
|
Err(()) => {
|
||||||
|
unreachable!("`AccessControlRequestHeaders::from_str` should never fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => Outcome::Forward(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[allow(unmounted_route)]
|
||||||
|
mod tests {
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use hyper;
|
||||||
|
use rocket;
|
||||||
|
use rocket::local::Client;
|
||||||
|
use rocket::http::Method;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
/// Make a client with no routes for unit testing
|
||||||
|
fn make_client() -> Client {
|
||||||
|
let rocket = rocket::ignite();
|
||||||
|
Client::new(rocket).expect("valid rocket instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following tests check that CORS Request headers are parsed correctly
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn origin_header_conversion() {
|
||||||
|
let url = "https://foo.bar.xyz";
|
||||||
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
|
let expected = not_err!(Url::from_str(url));
|
||||||
|
assert_eq!(parsed, expected);
|
||||||
|
|
||||||
|
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
|
||||||
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
|
let expected = not_err!(Url::from_str(url));
|
||||||
|
assert_eq!(parsed, expected);
|
||||||
|
|
||||||
|
let url = "invalid_url";
|
||||||
|
let _ = is_err!(Origin::from_str(url));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn origin_header_parsing() {
|
||||||
|
let client = make_client();
|
||||||
|
let mut request = client.get("/");
|
||||||
|
|
||||||
|
let origin = hyper::header::Origin::new("https", "www.example.com", None);
|
||||||
|
request.add_header(origin);
|
||||||
|
|
||||||
|
let outcome: request::Outcome<Origin, ::Error> = FromRequest::from_request(request.inner());
|
||||||
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
|
assert_eq!("https://www.example.com/", parsed_header.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_method_conversion() {
|
||||||
|
let method = "POST";
|
||||||
|
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
|
||||||
|
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post));
|
||||||
|
|
||||||
|
let method = "options";
|
||||||
|
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
|
||||||
|
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options));
|
||||||
|
|
||||||
|
let method = "INVALID";
|
||||||
|
let _ = is_err!(AccessControlRequestMethod::from_str(method));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_method_parsing() {
|
||||||
|
let client = make_client();
|
||||||
|
let mut request = client.get("/");
|
||||||
|
let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get);
|
||||||
|
request.add_header(method);
|
||||||
|
let outcome: request::Outcome<AccessControlRequestMethod, ::Error> =
|
||||||
|
FromRequest::from_request(request.inner());
|
||||||
|
|
||||||
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
|
let AccessControlRequestMethod(parsed_method) = parsed_header;
|
||||||
|
assert_eq!("GET", parsed_method.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_headers_conversion() {
|
||||||
|
let headers = ["foo", "bar", "baz"];
|
||||||
|
let parsed_headers = not_err!(AccessControlRequestHeaders::from_str(&headers.join(", ")));
|
||||||
|
let expected_headers: HeaderFieldNamesSet =
|
||||||
|
headers.iter().map(|s| s.to_string().into()).collect();
|
||||||
|
let AccessControlRequestHeaders(actual_headers) = parsed_headers;
|
||||||
|
assert_eq!(actual_headers, expected_headers);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_headers_parsing() {
|
||||||
|
let client = make_client();
|
||||||
|
let mut request = client.get("/");
|
||||||
|
let headers = hyper::header::AccessControlRequestHeaders(vec![
|
||||||
|
FromStr::from_str("accept-language").unwrap(),
|
||||||
|
FromStr::from_str("date").unwrap(),
|
||||||
|
]);
|
||||||
|
request.add_header(headers);
|
||||||
|
let outcome: request::Outcome<AccessControlRequestHeaders, ::Error> =
|
||||||
|
FromRequest::from_request(request.inner());
|
||||||
|
|
||||||
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
|
let AccessControlRequestHeaders(parsed_headers) = parsed_header;
|
||||||
|
let mut parsed_headers: Vec<String> =
|
||||||
|
parsed_headers.iter().map(|s| s.to_string()).collect();
|
||||||
|
parsed_headers.sort();
|
||||||
|
assert_eq!(
|
||||||
|
vec!["accept-language".to_string(), "date".to_string()],
|
||||||
|
parsed_headers
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
243
src/lib.rs
243
src/lib.rs
|
@ -107,6 +107,15 @@ extern crate url_serde;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
extern crate hyper;
|
extern crate hyper;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[macro_use]
|
||||||
|
mod test_macros;
|
||||||
|
|
||||||
|
pub mod headers;
|
||||||
|
|
||||||
|
// Public exports
|
||||||
|
pub use headers::Url;
|
||||||
|
|
||||||
use std::collections::{HashSet, HashMap};
|
use std::collections::{HashSet, HashMap};
|
||||||
use std::error;
|
use std::error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
@ -116,13 +125,11 @@ use std::str::FromStr;
|
||||||
|
|
||||||
use rocket::{Outcome, State};
|
use rocket::{Outcome, State};
|
||||||
use rocket::http::{Method, Status};
|
use rocket::http::{Method, Status};
|
||||||
use rocket::request::{self, Request, FromRequest};
|
use rocket::request::{Request, FromRequest};
|
||||||
use rocket::response;
|
use rocket::response;
|
||||||
use unicase::UniCase;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
use headers::{HeaderFieldName, HeaderFieldNamesSet, Origin, AccessControlRequestHeaders,
|
||||||
#[macro_use]
|
AccessControlRequestMethod};
|
||||||
mod test_macros;
|
|
||||||
|
|
||||||
/// Errors during operations
|
/// Errors during operations
|
||||||
///
|
///
|
||||||
|
@ -212,129 +219,6 @@ impl<'r> response::Responder<'r> for Error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A wrapped `url::Url` to allow for deserialization
|
|
||||||
#[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)]
|
|
||||||
pub struct Url(
|
|
||||||
#[serde(with = "url_serde")]
|
|
||||||
url::Url
|
|
||||||
);
|
|
||||||
|
|
||||||
impl fmt::Display for Url {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
self.0.fmt(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Deref for Url {
|
|
||||||
type Target = url::Url;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FromStr for Url {
|
|
||||||
type Err = url::ParseError;
|
|
||||||
|
|
||||||
fn from_str(input: &str) -> Result<Self, Self::Err> {
|
|
||||||
let url = url::Url::from_str(input)?;
|
|
||||||
Ok(Url(url))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Url {
|
|
||||||
type Error = Error;
|
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Error> {
|
|
||||||
match request.headers().get_one("Origin") {
|
|
||||||
Some(origin) => {
|
|
||||||
match Self::from_str(origin) {
|
|
||||||
Ok(origin) => Outcome::Success(origin),
|
|
||||||
Err(e) => Outcome::Failure((Status::BadRequest, Error::BadOrigin(e))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => Outcome::Forward(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The `Origin` request header used in CORS
|
|
||||||
///
|
|
||||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
|
||||||
/// to ensure that Origins are passed in correctly.
|
|
||||||
pub type Origin = Url;
|
|
||||||
|
|
||||||
/// The `Access-Control-Request-Method` request header
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct AccessControlRequestMethod(pub Method);
|
|
||||||
|
|
||||||
impl FromStr for AccessControlRequestMethod {
|
|
||||||
type Err = rocket::Error;
|
|
||||||
|
|
||||||
fn from_str(method: &str) -> Result<Self, Self::Err> {
|
|
||||||
Ok(AccessControlRequestMethod(Method::from_str(method)?))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
|
|
||||||
type Error = Error;
|
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Error> {
|
|
||||||
match request.headers().get_one("Access-Control-Request-Method") {
|
|
||||||
Some(request_method) => {
|
|
||||||
match Self::from_str(request_method) {
|
|
||||||
Ok(request_method) => Outcome::Success(request_method),
|
|
||||||
Err(e) => Outcome::Failure((Status::BadRequest, Error::BadRequestMethod(e))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => Outcome::Forward(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type HeaderFieldName = UniCase<String>;
|
|
||||||
type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
|
|
||||||
|
|
||||||
/// The `Access-Control-Request-Headers` request header
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
|
|
||||||
|
|
||||||
/// Will never fail
|
|
||||||
impl FromStr for AccessControlRequestHeaders {
|
|
||||||
type Err = ();
|
|
||||||
|
|
||||||
/// Will never fail
|
|
||||||
fn from_str(headers: &str) -> Result<Self, Self::Err> {
|
|
||||||
if headers.trim().is_empty() {
|
|
||||||
return Ok(AccessControlRequestHeaders(HashSet::new()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let set: HeaderFieldNamesSet = headers
|
|
||||||
.split(',')
|
|
||||||
.map(|header| UniCase(header.trim().to_string()))
|
|
||||||
.collect();
|
|
||||||
Ok(AccessControlRequestHeaders(set))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
|
|
||||||
type Error = Error;
|
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Error> {
|
|
||||||
match request.headers().get_one("Access-Control-Request-Headers") {
|
|
||||||
Some(request_headers) => {
|
|
||||||
match Self::from_str(request_headers) {
|
|
||||||
Ok(request_headers) => Outcome::Success(request_headers),
|
|
||||||
Err(()) => {
|
|
||||||
unreachable!("`AccessControlRequestHeaders::from_str` should never fail")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => Outcome::Forward(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
|
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
|
||||||
///
|
///
|
||||||
/// `Default` is implemented for this enum and is `All`.
|
/// `Default` is implemented for this enum and is `All`.
|
||||||
|
@ -1021,112 +905,9 @@ impl Response {
|
||||||
#[allow(unmounted_route)]
|
#[allow(unmounted_route)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
use hyper;
|
|
||||||
use rocket;
|
|
||||||
use rocket::local::Client;
|
|
||||||
use rocket::http::Method;
|
use rocket::http::Method;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
/// Make a client with no routes for unit testing
|
|
||||||
fn make_client() -> Client {
|
|
||||||
let rocket = rocket::ignite();
|
|
||||||
Client::new(rocket).expect("valid rocket instance")
|
|
||||||
}
|
|
||||||
|
|
||||||
// The following tests check that CORS Request headers are parsed correctly
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn origin_header_conversion() {
|
|
||||||
let url = "https://foo.bar.xyz";
|
|
||||||
let parsed = not_err!(Origin::from_str(url));
|
|
||||||
let expected = not_err!(Url::from_str(url));
|
|
||||||
assert_eq!(parsed, expected);
|
|
||||||
|
|
||||||
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
|
|
||||||
let parsed = not_err!(Origin::from_str(url));
|
|
||||||
let expected = not_err!(Url::from_str(url));
|
|
||||||
assert_eq!(parsed, expected);
|
|
||||||
|
|
||||||
let url = "invalid_url";
|
|
||||||
let _ = is_err!(Origin::from_str(url));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn origin_header_parsing() {
|
|
||||||
let client = make_client();
|
|
||||||
let mut request = client.get("/");
|
|
||||||
|
|
||||||
let origin = hyper::header::Origin::new("https", "www.example.com", None);
|
|
||||||
request.add_header(origin);
|
|
||||||
|
|
||||||
let outcome: request::Outcome<Origin, Error> = FromRequest::from_request(request.inner());
|
|
||||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
|
||||||
assert_eq!("https://www.example.com/", parsed_header.as_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn request_method_conversion() {
|
|
||||||
let method = "POST";
|
|
||||||
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
|
|
||||||
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post));
|
|
||||||
|
|
||||||
let method = "options";
|
|
||||||
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
|
|
||||||
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options));
|
|
||||||
|
|
||||||
let method = "INVALID";
|
|
||||||
let _ = is_err!(AccessControlRequestMethod::from_str(method));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn request_method_parsing() {
|
|
||||||
let client = make_client();
|
|
||||||
let mut request = client.get("/");
|
|
||||||
let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get);
|
|
||||||
request.add_header(method);
|
|
||||||
let outcome: request::Outcome<AccessControlRequestMethod, Error> =
|
|
||||||
FromRequest::from_request(request.inner());
|
|
||||||
|
|
||||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
|
||||||
let AccessControlRequestMethod(parsed_method) = parsed_header;
|
|
||||||
assert_eq!("GET", parsed_method.as_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn request_headers_conversion() {
|
|
||||||
let headers = ["foo", "bar", "baz"];
|
|
||||||
let parsed_headers = not_err!(AccessControlRequestHeaders::from_str(&headers.join(", ")));
|
|
||||||
let expected_headers: HeaderFieldNamesSet =
|
|
||||||
headers.iter().map(|s| s.to_string().into()).collect();
|
|
||||||
let AccessControlRequestHeaders(actual_headers) = parsed_headers;
|
|
||||||
assert_eq!(actual_headers, expected_headers);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn request_headers_parsing() {
|
|
||||||
let client = make_client();
|
|
||||||
let mut request = client.get("/");
|
|
||||||
let headers = hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("accept-language").unwrap(),
|
|
||||||
FromStr::from_str("date").unwrap(),
|
|
||||||
]);
|
|
||||||
request.add_header(headers);
|
|
||||||
let outcome: request::Outcome<AccessControlRequestHeaders, Error> =
|
|
||||||
FromRequest::from_request(request.inner());
|
|
||||||
|
|
||||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
|
||||||
let AccessControlRequestHeaders(parsed_headers) = parsed_header;
|
|
||||||
let mut parsed_headers: Vec<String> =
|
|
||||||
parsed_headers.iter().map(|s| s.to_string()).collect();
|
|
||||||
parsed_headers.sort();
|
|
||||||
assert_eq!(
|
|
||||||
vec!["accept-language".to_string(), "date".to_string()],
|
|
||||||
parsed_headers
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The following tests check `Response`'s validation
|
// The following tests check `Response`'s validation
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
@ -10,7 +10,7 @@ use std::str::FromStr;
|
||||||
|
|
||||||
use rocket::local::Client;
|
use rocket::local::Client;
|
||||||
use rocket::http::{Header, Status};
|
use rocket::http::{Header, Status};
|
||||||
use rocket_cors::*;
|
use rocket_cors::headers::*;
|
||||||
|
|
||||||
#[get("/request_headers")]
|
#[get("/request_headers")]
|
||||||
fn request_headers(
|
fn request_headers(
|
||||||
|
|
Loading…
Reference in New Issue