diff --git a/src/lib.rs b/src/lib.rs index 544adae..f863c8c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -501,7 +501,8 @@ //! ``` //! //! ## Reference -//! - [CORS Specification](https://fetch.spec.whatwg.org/#cors-protocol) +//! - [Fetch CORS Specification](https://fetch.spec.whatwg.org/#cors-protocol) +//! - [Supplanted W3C CORS Specification](https://www.w3.org/TR/cors/) //! - [Resource Advice](https://w3c.github.io/webappsec-cors-for-developers/#resources) #![allow(legacy_directory_ownership, missing_copy_implementations, missing_debug_implementations, @@ -1280,9 +1281,11 @@ impl Response { Some(ref origin) => origin, }; + // Origin should be ASCII serialized + // c.f. https://html.spec.whatwg.org/multipage/origin.html#ascii-serialisation-of-an-origin let origin = match *origin { AllOrSome::All => "*".to_string(), - AllOrSome::Some(ref origin) => origin.origin().unicode_serialization(), + AllOrSome::Some(ref origin) => origin.origin().ascii_serialization(), }; let _ = response.set_raw_header("Access-Control-Allow-Origin", origin); @@ -1636,7 +1639,8 @@ fn request_headers(request: &Request) -> Result Result<(), Error> { options.validate()?; @@ -1775,7 +1781,8 @@ fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> /// Build the response for an actual request /// /// This implementation references the -/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests). +/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests) +/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch) fn actual_request_response(options: &Cors, origin: &Origin) -> Response { let response = Response::new(); @@ -1944,7 +1951,7 @@ mod tests { } #[test] - fn response_allows_origin() { + fn validate_origin_allows_origin() { let url = "https://www.example.com"; let origin = Origin::from_str(url).unwrap(); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); @@ -1955,7 +1962,7 @@ mod tests { #[test] #[should_panic(expected = "OriginNotAllowed")] - fn response_rejects_invalid_origin() { + fn validate_origin_rejects_invalid_origin() { let url = "https://www.acme.com"; let origin = Origin::from_str(url).unwrap(); let (allowed_origins, failed_origins) = AllowedOrigins::some(&["https://www.example.com"]); @@ -1964,6 +1971,65 @@ mod tests { validate_origin(&origin, &allowed_origins).unwrap(); } + #[test] + fn response_sets_allow_origin_without_vary_correctly() { + let response = Response::new(); + let response = response.origin( + &FromStr::from_str("https://www.example.com").unwrap(), + false, + ); + + // Build response and check built response header + let expected_header = vec!["https://www.example.com"]; + let response = response.response(response::Response::new()); + let actual_header: Vec<_> = response.headers().get("Access-Control-Allow-Origin").collect(); + assert_eq!(expected_header, actual_header); + + assert!(response.headers().get("Vary").next().is_none()); + } + + #[test] + fn response_sets_allow_origin_with_vary_correctly() { + let response = Response::new(); + let response = response.origin( + &FromStr::from_str("https://www.example.com").unwrap(), + true, + ); + + // Build response and check built response header + let expected_header = vec!["https://www.example.com"]; + let response = response.response(response::Response::new()); + let actual_header: Vec<_> = response.headers().get("Access-Control-Allow-Origin").collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_sets_any_origin_correctly() { + let response = Response::new(); + let response = response.any(); + + // Build response and check built response header + let expected_header = vec!["*"]; + let response = response.response(response::Response::new()); + let actual_header: Vec<_> = response.headers().get("Access-Control-Allow-Origin").collect(); + assert_eq!(expected_header, actual_header); + } + + #[test] + fn response_sets_allow_origin_with_ascii_serialization() { + let response = Response::new(); + let response = response.origin( + &FromStr::from_str("https://аpple.com").unwrap(), + false, + ); + + // Build response and check built response header + let expected_header = vec!["https://xn--pple-43d.com"]; + let response = response.response(response::Response::new()); + let actual_header: Vec<_> = response.headers().get("Access-Control-Allow-Origin").collect(); + assert_eq!(expected_header, actual_header); + } + #[test] fn response_sets_exposed_headers_correctly() { let headers = vec!["Bar", "Baz", "Foo"];