diff --git a/src/fairing.rs b/src/fairing.rs index 79be9c3..7f497dd 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -53,7 +53,7 @@ fn on_response_wrapper( // Not a CORS request return Ok(()); } - Some(origin) => crate::to_origin(origin)?, + Some(origin) => origin, }; let result = request.local_cache(|| unreachable!("This should not be executed so late")); diff --git a/src/headers.rs b/src/headers.rs index 625c80a..a191640 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -64,34 +64,27 @@ pub type HeaderFieldNamesSet = HashSet; /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// to ensure that `Origin` is passed in correctly. #[derive(Eq, PartialEq, Clone, Hash, Debug)] -#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] -pub struct Origin(pub String); +pub struct Origin(pub url::Origin); impl FromStr for Origin { - type Err = !; + type Err = crate::Error; fn from_str(input: &str) -> Result { - Ok(Origin(input.to_string())) + Ok(Origin(crate::to_origin(input)?)) } } impl Deref for Origin { - type Target = str; + type Target = url::Origin; fn deref(&self) -> &Self::Target { &self.0 } } -impl AsRef for Origin { - fn as_ref(&self) -> &str { - self - } -} - impl fmt::Display for Origin { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.as_ref().fmt(f) + write!(f, "{}", self.ascii_serialization()) } } @@ -100,10 +93,10 @@ impl<'a, 'r> FromRequest<'a, 'r> for Origin { fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome { match request.headers().get_one("Origin") { - Some(origin) => { - let Ok(origin) = Self::from_str(origin); - Outcome::Success(origin) - } + Some(origin) => match Self::from_str(origin) { + Ok(origin) => Outcome::Success(origin), + Err(e) => Outcome::Failure((Status::BadRequest, e)), + }, None => Outcome::Forward(()), } } @@ -199,17 +192,17 @@ mod tests { #[test] fn origin_header_conversion() { let url = "https://foo.bar.xyz"; - let Ok(parsed) = Origin::from_str(url); - assert_eq!(parsed.as_ref(), url); + let parsed = not_err!(Origin::from_str(url)); + assert_eq!(parsed.ascii_serialization(), url); - let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used - let Ok(parsed) = Origin::from_str(url); - assert_eq!(parsed.as_ref(), url); + // this should never really be sent by a compliant user agent + let url = "https://foo.bar.xyz/path/somewhere"; + let parsed = not_err!(Origin::from_str(url)); + let expected = "https://foo.bar.xyz"; + assert_eq!(parsed.ascii_serialization(), expected); - // Validation is not done now let url = "invalid_url"; - let Ok(parsed) = Origin::from_str(url); - assert_eq!(parsed.as_ref(), url); + let _ = is_err!(Origin::from_str(url)); } #[test] @@ -223,7 +216,7 @@ mod tests { let outcome: request::Outcome = FromRequest::from_request(request.inner()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); - assert_eq!("https://www.example.com", parsed_header.as_ref()); + assert_eq!("https://www.example.com", parsed_header.ascii_serialization()); } #[test] diff --git a/src/lib.rs b/src/lib.rs index 081b13c..ea20549 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -267,8 +267,6 @@ See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/ intra_doc_link_resolution_failure )] #![doc(test(attr(allow(unused_variables), deny(warnings))))] -#![feature(never_type)] -#![feature(exhaustive_patterns)] #[cfg(test)] #[macro_use] @@ -533,21 +531,14 @@ mod method_serde { /// let all_origins = AllowedOrigins::all(); /// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]); /// ``` -pub type AllowedOrigins = AllOrSome>; +pub type AllowedOrigins = AllOrSome>; impl AllowedOrigins { /// Allows some origins /// /// Validation is not performed at this stage, but at a later stage. pub fn some(urls: &[&str]) -> Self { - AllOrSome::Some( - urls.iter() - .map(|s| { - let Ok(s) = FromStr::from_str(s); - s - }) - .collect(), - ) + AllOrSome::Some(urls.iter().map(|s| s.to_string()).collect()) } /// Allows all origins @@ -1308,7 +1299,7 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result to_origin(origin)?, + Some(origin) => origin, }; // Check if the request verb is an OPTION or something else @@ -1317,11 +1308,16 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result { actual_request_validate(options, &origin)?; - Ok(ValidationResult::Request { origin }) + Ok(ValidationResult::Request { + origin: origin.deref().clone(), + }) } } } @@ -1705,8 +1701,7 @@ mod tests { #[test] fn validate_origin_allows_all_origins() { let url = "https://www.example.com"; - let Ok(origin) = Origin::from_str(url); - let origin = not_err!(to_origin(&origin)); + let origin = not_err!(to_origin(&url)); let allowed_origins = AllOrSome::All; not_err!(validate_origin(&origin, &allowed_origins)); @@ -1715,8 +1710,7 @@ mod tests { #[test] fn validate_origin_allows_origin() { let url = "https://www.example.com"; - let Ok(origin) = Origin::from_str(url); - let origin = not_err!(to_origin(&origin)); + let origin = not_err!(to_origin(&url)); let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ "https://www.example.com" ]))); @@ -1728,8 +1722,7 @@ mod tests { #[should_panic(expected = "OriginNotAllowed")] fn validate_origin_rejects_invalid_origin() { let url = "https://www.acme.com"; - let Ok(origin) = Origin::from_str(url); - let origin = not_err!(to_origin(&origin)); + let origin = not_err!(to_origin(&url)); let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ "https://www.example.com" ])));