diff --git a/examples/fairing.rs b/examples/fairing.rs index 393ce4f..56cdd8d 100644 --- a/examples/fairing.rs +++ b/examples/fairing.rs @@ -12,7 +12,7 @@ fn cors<'a>() -> &'a str { } fn main() -> Result<(), Error> { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this let cors = rocket_cors::CorsOptions { diff --git a/examples/guard.rs b/examples/guard.rs index 2e507e6..cd3d2d0 100644 --- a/examples/guard.rs +++ b/examples/guard.rs @@ -36,7 +36,7 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> { } fn main() -> Result<(), Error> { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this let cors = rocket_cors::CorsOptions { diff --git a/examples/json.rs b/examples/json.rs index c79000f..e77c1d9 100644 --- a/examples/json.rs +++ b/examples/json.rs @@ -13,7 +13,7 @@ fn main() { // The default demonstrates the "All" serialization of several of the settings let default: CorsOptions = Default::default(); - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); let options = cors::CorsOptions { allowed_origins: allowed_origins, diff --git a/examples/manual.rs b/examples/manual.rs index 35b4274..0768d2e 100644 --- a/examples/manual.rs +++ b/examples/manual.rs @@ -59,7 +59,7 @@ fn owned_options<'r>() -> impl Responder<'r> { } fn cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions { diff --git a/examples/mix.rs b/examples/mix.rs index 3a243e4..9bae94f 100644 --- a/examples/mix.rs +++ b/examples/mix.rs @@ -36,7 +36,7 @@ fn ping_options<'r>() -> impl Responder<'r> { /// Returns the "application wide" Cors struct fn cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions { diff --git a/src/fairing.rs b/src/fairing.rs index 5fe40b2..3d96adc 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -138,10 +138,10 @@ mod tests { use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions}; - const CORS_ROOT: &'static str = "/my_cors"; + const CORS_ROOT: &str = "/my_cors"; fn make_cors_options() -> Cors { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); CorsOptions { allowed_origins, diff --git a/src/headers.rs b/src/headers.rs index 501e008..b947fbf 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -73,6 +73,23 @@ pub enum Origin { Parsed(url::Origin), } +impl Origin { + /// Perform an + /// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin) + /// of this origin. + pub fn ascii_serialization(&self) -> String { + self.to_string() + } + + /// Returns whether the origin was parsed as non-opaque + pub fn is_tuple(&self) -> bool { + match self { + Origin::Null => false, + Origin::Parsed(ref parsed) => parsed.is_tuple(), + } + } +} + impl FromStr for Origin { type Err = crate::Error; diff --git a/src/lib.rs b/src/lib.rs index 1145096..a7e323e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -308,6 +308,8 @@ pub enum Error { MissingOrigin, /// The HTTP request header `Origin` could not be parsed correctly. BadOrigin(url::ParseError), + /// The configured Allowed Origin is opaque and cannot be parsed. + OpaqueAllowedOrigin(String), /// The request header `Access-Control-Request-Method` is required but is missing MissingRequestMethod, /// The request header `Access-Control-Request-Method` has an invalid value @@ -376,7 +378,8 @@ impl fmt::Display for Error { } Error::MissingInjectedHeader => write!(f, "The `on_response` handler of Fairing could not find the injected header from the \ - Request. Either some other fairing has removed it, or this is a bug.") + Request. Either some other fairing has removed it, or this is a bug."), + Error::OpaqueAllowedOrigin(ref origin) => write!(f, "The configured Origin '{}' not have a parsable Origin. Use a regex instead.", origin), } } } @@ -597,6 +600,9 @@ pub struct Origins { pub exact: Option>, /// Origins that will be matched via __any__ regex in this list. These __must__ be valid Regex /// that will be parsed and validated when creating [`Cors`]. + /// The regex will be matched according to the + /// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin) + /// of the incoming Origin. #[cfg_attr(feature = "serialization", serde(default))] pub regex: Option>, } @@ -610,13 +616,24 @@ pub(crate) struct ParsedAllowedOrigins { impl ParsedAllowedOrigins { fn parse(origins: &Origins) -> Result { - let exact: Result<_, Error> = match &origins.exact { + let exact: Result, Error> = match &origins.exact { Some(exact) => exact.iter().map(|url| to_origin(url.as_str())).collect(), None => Ok(Default::default()), }; + let exact = exact?; + + // Let's check if any of them is Opaque + exact.iter().try_for_each(|url| { + if !url.is_tuple() { + Err(Error::OpaqueAllowedOrigin(url.ascii_serialization())) + } else { + Ok(()) + } + })?; + Ok(Self { allow_null: origins.allow_null, - exact: exact?, + exact, }) } } @@ -1332,9 +1349,8 @@ enum ValidationResult { Request { origin: String }, } -/// Convert a str to Origin +/// Convert a str to a URL Origin fn to_origin>(origin: S) -> Result { - // What to do about Opaque origins? Ok(url::Url::parse(origin.as_ref())?.origin()) } @@ -1727,8 +1743,12 @@ mod tests { use super::*; use crate::http::Method; + fn to_parsed_origin>(origin: S) -> Result { + Origin::from_str(origin.as_ref()) + } + fn make_cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); CorsOptions { allowed_origins, @@ -1821,7 +1841,7 @@ mod tests { #[test] fn validate_origin_allows_all_origins() { let url = "https://www.example.com"; - let origin = not_err!(to_origin(&url)); + let origin = not_err!(to_parsed_origin(&url)); let allowed_origins = AllOrSome::All; not_err!(validate_origin(&origin, &allowed_origins)); @@ -1830,8 +1850,8 @@ mod tests { #[test] fn validate_origin_allows_origin() { let url = "https://www.example.com"; - let origin = not_err!(to_origin(&url)); - let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ + let origin = not_err!(to_parsed_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[ "https://www.example.com" ]))); @@ -1849,8 +1869,10 @@ mod tests { ]; for (url, allowed_origin) in cases { - let origin = not_err!(to_origin(&url)); - let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[allowed_origin]))); + let origin = not_err!(to_parsed_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[ + allowed_origin + ]))); not_err!(validate_origin(&origin, &allowed_origins)); } @@ -1860,8 +1882,8 @@ mod tests { #[should_panic(expected = "OriginNotAllowed")] fn validate_origin_rejects_invalid_origin() { let url = "https://www.acme.com"; - let origin = not_err!(to_origin(&url)); - let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ + let origin = not_err!(to_parsed_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[ "https://www.example.com" ]))); @@ -1871,7 +1893,7 @@ mod tests { #[test] fn response_sets_allow_origin_without_vary_correctly() { let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); + let response = response.origin("https://www.example.com", false); // Build response and check built response header let expected_header = vec!["https://www.example.com"]; @@ -1888,7 +1910,7 @@ mod tests { #[test] fn response_sets_allow_origin_with_vary_correctly() { let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), true); + let response = response.origin("https://www.example.com", true); // Build response and check built response header let expected_header = vec!["https://www.example.com"]; @@ -1915,27 +1937,11 @@ mod tests { assert_eq!(expected_header, actual_header); } - #[test] - fn response_sets_allow_origin_with_ascii_serialization() { - let response = Response::new(); - let response = response.origin(&to_origin("https://аpple.com").unwrap(), false); - - // Build response and check built response header - // This is "punycode" - let expected_header = vec!["https://xn--pple-43d.com"]; - let response = response.response(response::Response::new()); - let actual_header: Vec<_> = response - .headers() - .get("Access-Control-Allow-Origin") - .collect(); - assert_eq!(expected_header, actual_header); - } - #[test] fn response_sets_exposed_headers_correctly() { let headers = vec!["Bar", "Baz", "Foo"]; let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); + let response = response.origin("https://www.example.com", false); let response = response.exposed_headers(&headers); // Build response and check built response header @@ -1957,7 +1963,7 @@ mod tests { #[test] fn response_sets_max_age_correctly() { let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); + let response = response.origin("https://www.example.com", false); let response = response.max_age(Some(42)); @@ -1971,7 +1977,7 @@ mod tests { #[test] fn response_does_not_set_max_age_when_none() { let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); + let response = response.origin("https://www.example.com", false); let response = response.max_age(None); @@ -2084,7 +2090,7 @@ mod tests { .finalize(); let response = Response::new(); - let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); + let response = response.origin("https://www.example.com", false); let response = response.response(original); // Check CORS header let expected_header = vec!["https://www.example.com"]; @@ -2160,7 +2166,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Preflight { - origin: to_origin("https://www.acme.com").unwrap(), + origin: "https://www.acme.com".to_string(), // Checks that only a subset of allowed headers are returned // -- i.e. whatever is requested for headers: Some(FromStr::from_str("Authorization").unwrap()), @@ -2195,7 +2201,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Preflight { - origin: to_origin("https://www.example.com").unwrap(), + origin: "https://www.example.com".to_string(), headers: Some(FromStr::from_str("Authorization").unwrap()), }; @@ -2313,7 +2319,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Request { - origin: to_origin("https://www.acme.com").unwrap(), + origin: "https://www.acme.com".to_string(), }; assert_eq!(expected_result, result); @@ -2332,7 +2338,7 @@ mod tests { let result = validate(&cors, request.inner()).expect("to not fail"); let expected_result = ValidationResult::Request { - origin: to_origin("https://www.example.com").unwrap(), + origin: "https://www.example.com".to_string(), }; assert_eq!(expected_result, result); @@ -2388,7 +2394,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&to_origin("https://www.acme.com").unwrap(), false) + .origin("https://www.acme.com", false) .headers(&["Authorization"]) .methods(&options.allowed_methods) .credentials(options.allow_credentials) @@ -2428,7 +2434,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&to_origin("https://www.acme.com").unwrap(), true) + .origin("https://www.acme.com", true) .headers(&["Authorization"]) .methods(&options.allowed_methods) .credentials(options.allow_credentials) @@ -2489,7 +2495,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&to_origin("https://www.acme.com").unwrap(), false) + .origin("https://www.acme.com", false) .credentials(options.allow_credentials) .exposed_headers(&["Content-Type", "X-Custom"]); @@ -2512,7 +2518,7 @@ mod tests { let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let expected_response = Response::new() - .origin(&to_origin("https://www.acme.com").unwrap(), true) + .origin("https://www.acme.com", true) .credentials(options.allow_credentials) .exposed_headers(&["Content-Type", "X-Custom"]); diff --git a/tests/fairing.rs b/tests/fairing.rs index 509ec9b..943cb26 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -22,7 +22,7 @@ fn panicking_route() { } fn make_cors() -> Cors { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); CorsOptions { allowed_origins: allowed_origins, diff --git a/tests/guard.rs b/tests/guard.rs index eb30da4..c3010e4 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -60,7 +60,7 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo } fn make_cors() -> cors::Cors { - let allowed_origins = cors::AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = cors::AllowedOrigins::some_exact(&["https://www.acme.com"]); cors::CorsOptions { allowed_origins: allowed_origins, diff --git a/tests/manual.rs b/tests/manual.rs index 27cb58d..54b08ba 100644 --- a/tests/manual.rs +++ b/tests/manual.rs @@ -66,7 +66,7 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp } fn make_cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); CorsOptions { allowed_origins: allowed_origins, @@ -78,7 +78,7 @@ fn make_cors_options() -> CorsOptions { } fn make_different_cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.example.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.example.com"]); CorsOptions { allowed_origins: allowed_origins, diff --git a/tests/mix.rs b/tests/mix.rs index 87a969b..25dd65a 100644 --- a/tests/mix.rs +++ b/tests/mix.rs @@ -40,7 +40,7 @@ fn ping_options<'r>() -> impl Responder<'r> { /// Returns the "application wide" Cors struct fn cors_options() -> CorsOptions { - let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); + let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this rocket_cors::CorsOptions {