diff --git a/src/headers.rs b/src/headers.rs index a191640..6e819e9 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -216,7 +216,10 @@ mod tests { let outcome: request::Outcome = FromRequest::from_request(request.inner()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); - assert_eq!("https://www.example.com", parsed_header.ascii_serialization()); + assert_eq!( + "https://www.example.com", + parsed_header.ascii_serialization() + ); } #[test] diff --git a/src/lib.rs b/src/lib.rs index ea20549..0378016 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1673,6 +1673,39 @@ mod tests { Client::new(rocket).expect("valid rocket instance") } + // `to_origin` tests + + #[test] + fn origin_is_parsed_properly() { + let url = "https://foo.bar.xyz"; + let parsed = not_err!(Origin::from_str(url)); + assert_eq!(parsed.ascii_serialization(), url); + } + + #[test] + fn origin_parsing_strips_paths() { + // this should never really be sent by a compliant user agent + let url = "https://foo.bar.xyz/path/somewhere"; + let parsed = not_err!(Origin::from_str(url)); + let expected = "https://foo.bar.xyz"; + assert_eq!(parsed.ascii_serialization(), expected); + } + + #[test] + #[should_panic(expected = "BadOrigin")] + fn origin_parsing_disallows_invalid_origins() { + let url = "invalid_url"; + let _ = Origin::from_str(url).unwrap(); + } + + #[test] + fn origin_parses_opaque_origins() { + let url = "blob://foobar"; + let parsed = not_err!(Origin::from_str(url)); + + assert!(!parsed.is_tuple()); + } + // CORS options test #[test] @@ -1718,6 +1751,24 @@ mod tests { not_err!(validate_origin(&origin, &allowed_origins)); } + #[test] + fn validate_origin_handles_punycode_properly() { + // Test a variety of scenarios where the Origin and settings are in punycode, or not + let cases = vec![ + ("https://аpple.com", "https://аpple.com"), + ("https://аpple.com", "https://xn--pple-43d.com"), + ("https://xn--pple-43d.com", "https://аpple.com"), + ("https://xn--pple-43d.com", "https://xn--pple-43d.com"), + ]; + + for (url, allowed_origin) in cases { + let origin = not_err!(to_origin(&url)); + let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[allowed_origin]))); + + not_err!(validate_origin(&origin, &allowed_origins)); + } + } + #[test] #[should_panic(expected = "OriginNotAllowed")] fn validate_origin_rejects_invalid_origin() {