//! CORS specific Request Headers use std::collections::HashSet; use std::fmt; use std::ops::Deref; use std::str::FromStr; use rocket::http::Status; use rocket::request::{self, FromRequest}; use rocket::{self, outcome::Outcome}; #[cfg(feature = "serialization")] use serde_derive::{Deserialize, Serialize}; use unicase::UniCase; /// A case insensitive header name #[derive(Eq, PartialEq, Clone, Debug, Hash)] #[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] pub struct HeaderFieldName( #[cfg_attr(feature = "serialization", serde(with = "unicase_serde::unicase"))] UniCase, ); impl Deref for HeaderFieldName { type Target = String; fn deref(&self) -> &Self::Target { self.0.deref() } } impl fmt::Display for HeaderFieldName { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } impl<'a> From<&'a str> for HeaderFieldName { fn from(s: &'a str) -> Self { HeaderFieldName(From::from(s)) } } impl<'a> From for HeaderFieldName { fn from(s: String) -> Self { HeaderFieldName(From::from(s)) } } impl FromStr for HeaderFieldName { type Err = ::Err; fn from_str(s: &str) -> Result { Ok(HeaderFieldName(FromStr::from_str(s)?)) } } /// A set of case insensitive header names pub type HeaderFieldNamesSet = HashSet; /// The `Origin` request header used in CORS /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// to ensure that `Origin` is passed in correctly. /// /// Reference: [Mozilla](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) #[derive(Eq, PartialEq, Clone, Hash, Debug)] pub enum Origin { /// A `null` Origin Null, /// A well-formed origin that was parsed by [`url::Url::origin`] Parsed(url::Origin), /// An unknown "opaque" origin that could not be parsed Opaque(String), } impl Origin { /// Perform an /// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin) /// of this origin. pub fn ascii_serialization(&self) -> String { self.to_string() } /// Returns whether the origin was parsed as non-opaque pub fn is_tuple(&self) -> bool { match self { Origin::Null => false, Origin::Parsed(ref parsed) => parsed.is_tuple(), Origin::Opaque(_) => false, } } /// Derives an instance of `Self` from the incoming request metadata. /// /// If the derivation is successful, an outcome of `Success` is returned. If /// the derivation fails in an unrecoverable fashion, `Failure` is returned. /// `Forward` is returned to indicate that the request should be forwarded /// to other matching routes, if any. pub fn from_request_sync( request: &'_ rocket::Request<'_>, ) -> request::Outcome { match request.headers().get_one("Origin") { Some(origin) => match Self::from_str(origin) { Ok(origin) => Outcome::Success(origin), Err(e) => Outcome::Failure((Status::BadRequest, e)), }, None => Outcome::Forward(()), } } } impl FromStr for Origin { type Err = crate::Error; fn from_str(input: &str) -> Result { if input.to_lowercase() == "null" { Ok(Origin::Null) } else { match crate::to_origin(input)? { url::Origin::Opaque(_) => Ok(Origin::Opaque(input.to_string())), parsed @ url::Origin::Tuple(..) => Ok(Origin::Parsed(parsed)), } } } } impl fmt::Display for Origin { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Origin::Null => write!(f, "null"), Origin::Parsed(ref parsed) => write!(f, "{}", parsed.ascii_serialization()), Origin::Opaque(ref opaque) => write!(f, "{}", opaque), } } } #[rocket::async_trait] impl<'r> FromRequest<'r> for Origin { type Error = crate::Error; async fn from_request( request: &'r rocket::Request<'_>, ) -> request::Outcome { Origin::from_request_sync(request) } } /// The `Access-Control-Request-Method` request header /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// to ensure that the header is passed in correctly. #[derive(Debug)] pub struct AccessControlRequestMethod(pub crate::Method); impl AccessControlRequestMethod { /// Derives an instance of `Self` from the incoming request metadata. /// /// If the derivation is successful, an outcome of `Success` is returned. If /// the derivation fails in an unrecoverable fashion, `Failure` is returned. /// `Forward` is returned to indicate that the request should be forwarded /// to other matching routes, if any. pub fn from_request_sync( request: &'_ rocket::Request<'_>, ) -> request::Outcome { match request.headers().get_one("Access-Control-Request-Method") { Some(request_method) => match Self::from_str(request_method) { Ok(request_method) => Outcome::Success(request_method), Err(_) => Outcome::Failure((Status::BadRequest, crate::Error::BadRequestMethod)), }, None => Outcome::Forward(()), } } } impl FromStr for AccessControlRequestMethod { type Err = (); fn from_str(method: &str) -> Result { Ok(AccessControlRequestMethod(crate::Method::from_str(method)?)) } } #[rocket::async_trait] impl<'r> FromRequest<'r> for AccessControlRequestMethod { type Error = crate::Error; async fn from_request( request: &'r rocket::Request<'_>, ) -> request::Outcome { AccessControlRequestMethod::from_request_sync(request) } } /// The `Access-Control-Request-Headers` request header /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// to ensure that the header is passed in correctly. #[derive(Eq, PartialEq, Debug)] pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet); impl AccessControlRequestHeaders { /// Derives an instance of `Self` from the incoming request metadata. /// /// If the derivation is successful, an outcome of `Success` is returned. If /// the derivation fails in an unrecoverable fashion, `Failure` is returned. /// `Forward` is returned to indicate that the request should be forwarded /// to other matching routes, if any. pub fn from_request_sync( request: &'_ rocket::Request<'_>, ) -> request::Outcome { match request.headers().get_one("Access-Control-Request-Headers") { Some(request_headers) => match Self::from_str(request_headers) { Ok(request_headers) => Outcome::Success(request_headers), Err(()) => { unreachable!("`AccessControlRequestHeaders::from_str` should never fail") } }, None => Outcome::Forward(()), } } } /// Will never fail impl FromStr for AccessControlRequestHeaders { type Err = (); /// Will never fail fn from_str(headers: &str) -> Result { if headers.trim().is_empty() { return Ok(AccessControlRequestHeaders(HashSet::new())); } let set: HeaderFieldNamesSet = headers .split(',') .map(|header| From::from(header.trim().to_string())) .collect(); Ok(AccessControlRequestHeaders(set)) } } #[rocket::async_trait] impl<'r> FromRequest<'r> for AccessControlRequestHeaders { type Error = crate::Error; async fn from_request( request: &'r rocket::Request<'_>, ) -> request::Outcome { AccessControlRequestHeaders::from_request_sync(request) } } #[cfg(test)] mod tests { use std::str::FromStr; use rocket::http::hyper; use rocket::http::Header; use rocket::local::blocking::Client; static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN; static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName = hyper::header::ACCESS_CONTROL_REQUEST_METHOD; static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName = hyper::header::ACCESS_CONTROL_REQUEST_HEADERS; use super::*; /// Make a client with no routes for unit testing fn make_client() -> Client { let rocket = rocket::build(); Client::tracked(rocket).expect("valid rocket instance") } // `Origin::from_str` tests #[test] fn origin_is_parsed_properly() { let url = "https://foo.bar.xyz"; let parsed = not_err!(Origin::from_str(url)); assert_eq!(parsed.ascii_serialization(), url); } #[test] fn origin_parsing_strips_paths() { // this should never really be sent by a compliant user agent let url = "https://foo.bar.xyz/path/somewhere"; let parsed = not_err!(Origin::from_str(url)); let expected = "https://foo.bar.xyz"; assert_eq!(parsed.ascii_serialization(), expected); } #[test] #[should_panic(expected = "BadOrigin")] fn origin_parsing_disallows_invalid_origins() { let url = "invalid_url"; let _ = Origin::from_str(url).unwrap(); } #[test] fn origin_parses_opaque_origins() { let url = "blob://foobar"; let parsed = not_err!(Origin::from_str(url)); assert!(!parsed.is_tuple()); } // The following tests check that CORS Request headers are parsed correctly #[test] fn origin_header_conversion() { let url = "https://foo.bar.xyz"; let parsed = not_err!(Origin::from_str(url)); assert_eq!(parsed.ascii_serialization(), url); let url = "https://foo.bar.xyz:1234"; let parsed = not_err!(Origin::from_str(url)); assert_eq!(parsed.ascii_serialization(), url); // this should never really be sent by a compliant user agent let url = "https://foo.bar.xyz/path/somewhere"; let parsed = not_err!(Origin::from_str(url)); let expected = "https://foo.bar.xyz"; assert_eq!(parsed.ascii_serialization(), expected); let url = "invalid_url"; let _ = is_err!(Origin::from_str(url)); } #[test] fn origin_header_parsing() { let client = make_client(); let mut request = client.get("/"); let origin = Header::new(ORIGIN.as_str(), "https://www.example.com"); request.add_header(origin); let outcome = Origin::from_request_sync(request.inner()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); assert_eq!( "https://www.example.com", parsed_header.ascii_serialization() ); } #[test] fn request_method_conversion() { let method = "POST"; let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); assert_matches!( parsed_method, AccessControlRequestMethod(crate::Method(rocket::http::Method::Post)) ); let method = "options"; let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); assert_matches!( parsed_method, AccessControlRequestMethod(crate::Method(rocket::http::Method::Options)) ); let method = "INVALID"; is_err!(AccessControlRequestMethod::from_str(method)); } #[test] fn request_method_parsing() { let client = make_client(); let mut request = client.get("/"); let method = Header::new( ACCESS_CONTROL_REQUEST_METHOD.as_str(), hyper::Method::GET.as_str(), ); request.add_header(method); let outcome = AccessControlRequestMethod::from_request_sync(request.inner()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); let AccessControlRequestMethod(parsed_method) = parsed_header; assert_eq!("GET", parsed_method.as_str()); } #[test] fn request_headers_conversion() { let headers = ["foo", "bar", "baz"]; let parsed_headers = not_err!(AccessControlRequestHeaders::from_str(&headers.join(", "))); let expected_headers: HeaderFieldNamesSet = headers.iter().map(|s| (*s).to_string().into()).collect(); let AccessControlRequestHeaders(actual_headers) = parsed_headers; assert_eq!(actual_headers, expected_headers); } #[test] fn request_headers_parsing() { let client = make_client(); let mut request = client.get("/"); let headers = Header::new( ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "accept-language, date", ); request.add_header(headers); let outcome = AccessControlRequestHeaders::from_request_sync(request.inner()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); let AccessControlRequestHeaders(parsed_headers) = parsed_header; let mut parsed_headers: Vec = parsed_headers.iter().map(ToString::to_string).collect(); parsed_headers.sort(); assert_eq!( vec!["accept-language".to_string(), "date".to_string()], parsed_headers ); } }