Fix tests

This commit is contained in:
Yong Wen Chua 2019-03-12 14:00:34 +08:00
parent bc16568e8b
commit 3349c972cf
No known key found for this signature in database
GPG Key ID: A70BD30B21497EA9
12 changed files with 78 additions and 55 deletions

View File

@ -12,7 +12,7 @@ fn cors<'a>() -> &'a str {
} }
fn main() -> Result<(), Error> { fn main() -> Result<(), Error> {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this // You can also deserialize this
let cors = rocket_cors::CorsOptions { let cors = rocket_cors::CorsOptions {

View File

@ -36,7 +36,7 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> {
} }
fn main() -> Result<(), Error> { fn main() -> Result<(), Error> {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this // You can also deserialize this
let cors = rocket_cors::CorsOptions { let cors = rocket_cors::CorsOptions {

View File

@ -13,7 +13,7 @@ fn main() {
// The default demonstrates the "All" serialization of several of the settings // The default demonstrates the "All" serialization of several of the settings
let default: CorsOptions = Default::default(); let default: CorsOptions = Default::default();
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
let options = cors::CorsOptions { let options = cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,

View File

@ -59,7 +59,7 @@ fn owned_options<'r>() -> impl Responder<'r> {
} }
fn cors_options() -> CorsOptions { fn cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this // You can also deserialize this
rocket_cors::CorsOptions { rocket_cors::CorsOptions {

View File

@ -36,7 +36,7 @@ fn ping_options<'r>() -> impl Responder<'r> {
/// Returns the "application wide" Cors struct /// Returns the "application wide" Cors struct
fn cors_options() -> CorsOptions { fn cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this // You can also deserialize this
rocket_cors::CorsOptions { rocket_cors::CorsOptions {

View File

@ -138,10 +138,10 @@ mod tests {
use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions}; use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions};
const CORS_ROOT: &'static str = "/my_cors"; const CORS_ROOT: &str = "/my_cors";
fn make_cors_options() -> Cors { fn make_cors_options() -> Cors {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
CorsOptions { CorsOptions {
allowed_origins, allowed_origins,

View File

@ -73,6 +73,23 @@ pub enum Origin {
Parsed(url::Origin), Parsed(url::Origin),
} }
impl Origin {
/// Perform an
/// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin)
/// of this origin.
pub fn ascii_serialization(&self) -> String {
self.to_string()
}
/// Returns whether the origin was parsed as non-opaque
pub fn is_tuple(&self) -> bool {
match self {
Origin::Null => false,
Origin::Parsed(ref parsed) => parsed.is_tuple(),
}
}
}
impl FromStr for Origin { impl FromStr for Origin {
type Err = crate::Error; type Err = crate::Error;

View File

@ -308,6 +308,8 @@ pub enum Error {
MissingOrigin, MissingOrigin,
/// The HTTP request header `Origin` could not be parsed correctly. /// The HTTP request header `Origin` could not be parsed correctly.
BadOrigin(url::ParseError), BadOrigin(url::ParseError),
/// The configured Allowed Origin is opaque and cannot be parsed.
OpaqueAllowedOrigin(String),
/// The request header `Access-Control-Request-Method` is required but is missing /// The request header `Access-Control-Request-Method` is required but is missing
MissingRequestMethod, MissingRequestMethod,
/// The request header `Access-Control-Request-Method` has an invalid value /// The request header `Access-Control-Request-Method` has an invalid value
@ -376,7 +378,8 @@ impl fmt::Display for Error {
} }
Error::MissingInjectedHeader => write!(f, Error::MissingInjectedHeader => write!(f,
"The `on_response` handler of Fairing could not find the injected header from the \ "The `on_response` handler of Fairing could not find the injected header from the \
Request. Either some other fairing has removed it, or this is a bug.") Request. Either some other fairing has removed it, or this is a bug."),
Error::OpaqueAllowedOrigin(ref origin) => write!(f, "The configured Origin '{}' not have a parsable Origin. Use a regex instead.", origin),
} }
} }
} }
@ -597,6 +600,9 @@ pub struct Origins {
pub exact: Option<HashSet<String>>, pub exact: Option<HashSet<String>>,
/// Origins that will be matched via __any__ regex in this list. These __must__ be valid Regex /// Origins that will be matched via __any__ regex in this list. These __must__ be valid Regex
/// that will be parsed and validated when creating [`Cors`]. /// that will be parsed and validated when creating [`Cors`].
/// The regex will be matched according to the
/// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin)
/// of the incoming Origin.
#[cfg_attr(feature = "serialization", serde(default))] #[cfg_attr(feature = "serialization", serde(default))]
pub regex: Option<HashSet<String>>, pub regex: Option<HashSet<String>>,
} }
@ -610,13 +616,24 @@ pub(crate) struct ParsedAllowedOrigins {
impl ParsedAllowedOrigins { impl ParsedAllowedOrigins {
fn parse(origins: &Origins) -> Result<Self, Error> { fn parse(origins: &Origins) -> Result<Self, Error> {
let exact: Result<_, Error> = match &origins.exact { let exact: Result<HashSet<url::Origin>, Error> = match &origins.exact {
Some(exact) => exact.iter().map(|url| to_origin(url.as_str())).collect(), Some(exact) => exact.iter().map(|url| to_origin(url.as_str())).collect(),
None => Ok(Default::default()), None => Ok(Default::default()),
}; };
let exact = exact?;
// Let's check if any of them is Opaque
exact.iter().try_for_each(|url| {
if !url.is_tuple() {
Err(Error::OpaqueAllowedOrigin(url.ascii_serialization()))
} else {
Ok(())
}
})?;
Ok(Self { Ok(Self {
allow_null: origins.allow_null, allow_null: origins.allow_null,
exact: exact?, exact,
}) })
} }
} }
@ -1332,9 +1349,8 @@ enum ValidationResult {
Request { origin: String }, Request { origin: String },
} }
/// Convert a str to Origin /// Convert a str to a URL Origin
fn to_origin<S: AsRef<str>>(origin: S) -> Result<url::Origin, Error> { fn to_origin<S: AsRef<str>>(origin: S) -> Result<url::Origin, Error> {
// What to do about Opaque origins?
Ok(url::Url::parse(origin.as_ref())?.origin()) Ok(url::Url::parse(origin.as_ref())?.origin())
} }
@ -1727,8 +1743,12 @@ mod tests {
use super::*; use super::*;
use crate::http::Method; use crate::http::Method;
fn to_parsed_origin<S: AsRef<str>>(origin: S) -> Result<Origin, Error> {
Origin::from_str(origin.as_ref())
}
fn make_cors_options() -> CorsOptions { fn make_cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
CorsOptions { CorsOptions {
allowed_origins, allowed_origins,
@ -1821,7 +1841,7 @@ mod tests {
#[test] #[test]
fn validate_origin_allows_all_origins() { fn validate_origin_allows_all_origins() {
let url = "https://www.example.com"; let url = "https://www.example.com";
let origin = not_err!(to_origin(&url)); let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = AllOrSome::All; let allowed_origins = AllOrSome::All;
not_err!(validate_origin(&origin, &allowed_origins)); not_err!(validate_origin(&origin, &allowed_origins));
@ -1830,8 +1850,8 @@ mod tests {
#[test] #[test]
fn validate_origin_allows_origin() { fn validate_origin_allows_origin() {
let url = "https://www.example.com"; let url = "https://www.example.com";
let origin = not_err!(to_origin(&url)); let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[
"https://www.example.com" "https://www.example.com"
]))); ])));
@ -1849,8 +1869,10 @@ mod tests {
]; ];
for (url, allowed_origin) in cases { for (url, allowed_origin) in cases {
let origin = not_err!(to_origin(&url)); let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[allowed_origin]))); let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[
allowed_origin
])));
not_err!(validate_origin(&origin, &allowed_origins)); not_err!(validate_origin(&origin, &allowed_origins));
} }
@ -1860,8 +1882,8 @@ mod tests {
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn validate_origin_rejects_invalid_origin() { fn validate_origin_rejects_invalid_origin() {
let url = "https://www.acme.com"; let url = "https://www.acme.com";
let origin = not_err!(to_origin(&url)); let origin = not_err!(to_parsed_origin(&url));
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[
"https://www.example.com" "https://www.example.com"
]))); ])));
@ -1871,7 +1893,7 @@ mod tests {
#[test] #[test]
fn response_sets_allow_origin_without_vary_correctly() { fn response_sets_allow_origin_without_vary_correctly() {
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.origin("https://www.example.com", false);
// Build response and check built response header // Build response and check built response header
let expected_header = vec!["https://www.example.com"]; let expected_header = vec!["https://www.example.com"];
@ -1888,7 +1910,7 @@ mod tests {
#[test] #[test]
fn response_sets_allow_origin_with_vary_correctly() { fn response_sets_allow_origin_with_vary_correctly() {
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), true); let response = response.origin("https://www.example.com", true);
// Build response and check built response header // Build response and check built response header
let expected_header = vec!["https://www.example.com"]; let expected_header = vec!["https://www.example.com"];
@ -1915,27 +1937,11 @@ mod tests {
assert_eq!(expected_header, actual_header); assert_eq!(expected_header, actual_header);
} }
#[test]
fn response_sets_allow_origin_with_ascii_serialization() {
let response = Response::new();
let response = response.origin(&to_origin("https://аpple.com").unwrap(), false);
// Build response and check built response header
// This is "punycode"
let expected_header = vec!["https://xn--pple-43d.com"];
let response = response.response(response::Response::new());
let actual_header: Vec<_> = response
.headers()
.get("Access-Control-Allow-Origin")
.collect();
assert_eq!(expected_header, actual_header);
}
#[test] #[test]
fn response_sets_exposed_headers_correctly() { fn response_sets_exposed_headers_correctly() {
let headers = vec!["Bar", "Baz", "Foo"]; let headers = vec!["Bar", "Baz", "Foo"];
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.origin("https://www.example.com", false);
let response = response.exposed_headers(&headers); let response = response.exposed_headers(&headers);
// Build response and check built response header // Build response and check built response header
@ -1957,7 +1963,7 @@ mod tests {
#[test] #[test]
fn response_sets_max_age_correctly() { fn response_sets_max_age_correctly() {
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.origin("https://www.example.com", false);
let response = response.max_age(Some(42)); let response = response.max_age(Some(42));
@ -1971,7 +1977,7 @@ mod tests {
#[test] #[test]
fn response_does_not_set_max_age_when_none() { fn response_does_not_set_max_age_when_none() {
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.origin("https://www.example.com", false);
let response = response.max_age(None); let response = response.max_age(None);
@ -2084,7 +2090,7 @@ mod tests {
.finalize(); .finalize();
let response = Response::new(); let response = Response::new();
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false); let response = response.origin("https://www.example.com", false);
let response = response.response(original); let response = response.response(original);
// Check CORS header // Check CORS header
let expected_header = vec!["https://www.example.com"]; let expected_header = vec!["https://www.example.com"];
@ -2160,7 +2166,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight { let expected_result = ValidationResult::Preflight {
origin: to_origin("https://www.acme.com").unwrap(), origin: "https://www.acme.com".to_string(),
// Checks that only a subset of allowed headers are returned // Checks that only a subset of allowed headers are returned
// -- i.e. whatever is requested for // -- i.e. whatever is requested for
headers: Some(FromStr::from_str("Authorization").unwrap()), headers: Some(FromStr::from_str("Authorization").unwrap()),
@ -2195,7 +2201,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight { let expected_result = ValidationResult::Preflight {
origin: to_origin("https://www.example.com").unwrap(), origin: "https://www.example.com".to_string(),
headers: Some(FromStr::from_str("Authorization").unwrap()), headers: Some(FromStr::from_str("Authorization").unwrap()),
}; };
@ -2313,7 +2319,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request { let expected_result = ValidationResult::Request {
origin: to_origin("https://www.acme.com").unwrap(), origin: "https://www.acme.com".to_string(),
}; };
assert_eq!(expected_result, result); assert_eq!(expected_result, result);
@ -2332,7 +2338,7 @@ mod tests {
let result = validate(&cors, request.inner()).expect("to not fail"); let result = validate(&cors, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request { let expected_result = ValidationResult::Request {
origin: to_origin("https://www.example.com").unwrap(), origin: "https://www.example.com".to_string(),
}; };
assert_eq!(expected_result, result); assert_eq!(expected_result, result);
@ -2388,7 +2394,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&to_origin("https://www.acme.com").unwrap(), false) .origin("https://www.acme.com", false)
.headers(&["Authorization"]) .headers(&["Authorization"])
.methods(&options.allowed_methods) .methods(&options.allowed_methods)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
@ -2428,7 +2434,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&to_origin("https://www.acme.com").unwrap(), true) .origin("https://www.acme.com", true)
.headers(&["Authorization"]) .headers(&["Authorization"])
.methods(&options.allowed_methods) .methods(&options.allowed_methods)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
@ -2489,7 +2495,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&to_origin("https://www.acme.com").unwrap(), false) .origin("https://www.acme.com", false)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]); .exposed_headers(&["Content-Type", "X-Custom"]);
@ -2512,7 +2518,7 @@ mod tests {
let response = validate_and_build(&cors, request.inner()).expect("to not fail"); let response = validate_and_build(&cors, request.inner()).expect("to not fail");
let expected_response = Response::new() let expected_response = Response::new()
.origin(&to_origin("https://www.acme.com").unwrap(), true) .origin("https://www.acme.com", true)
.credentials(options.allow_credentials) .credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]); .exposed_headers(&["Content-Type", "X-Custom"]);

View File

@ -22,7 +22,7 @@ fn panicking_route() {
} }
fn make_cors() -> Cors { fn make_cors() -> Cors {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
CorsOptions { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,

View File

@ -60,7 +60,7 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo
} }
fn make_cors() -> cors::Cors { fn make_cors() -> cors::Cors {
let allowed_origins = cors::AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = cors::AllowedOrigins::some_exact(&["https://www.acme.com"]);
cors::CorsOptions { cors::CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,

View File

@ -66,7 +66,7 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp
} }
fn make_cors_options() -> CorsOptions { fn make_cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
CorsOptions { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,
@ -78,7 +78,7 @@ fn make_cors_options() -> CorsOptions {
} }
fn make_different_cors_options() -> CorsOptions { fn make_different_cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.example.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.example.com"]);
CorsOptions { CorsOptions {
allowed_origins: allowed_origins, allowed_origins: allowed_origins,

View File

@ -40,7 +40,7 @@ fn ping_options<'r>() -> impl Responder<'r> {
/// Returns the "application wide" Cors struct /// Returns the "application wide" Cors struct
fn cors_options() -> CorsOptions { fn cors_options() -> CorsOptions {
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]); let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this // You can also deserialize this
rocket_cors::CorsOptions { rocket_cors::CorsOptions {