Fix tests
This commit is contained in:
parent
bc16568e8b
commit
3349c972cf
|
@ -12,7 +12,7 @@ fn cors<'a>() -> &'a str {
|
|||
}
|
||||
|
||||
fn main() -> Result<(), Error> {
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
// You can also deserialize this
|
||||
let cors = rocket_cors::CorsOptions {
|
||||
|
|
|
@ -36,7 +36,7 @@ fn manual(cors: Guard<'_>) -> Responder<'_, &str> {
|
|||
}
|
||||
|
||||
fn main() -> Result<(), Error> {
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
// You can also deserialize this
|
||||
let cors = rocket_cors::CorsOptions {
|
||||
|
|
|
@ -13,7 +13,7 @@ fn main() {
|
|||
// The default demonstrates the "All" serialization of several of the settings
|
||||
let default: CorsOptions = Default::default();
|
||||
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
let options = cors::CorsOptions {
|
||||
allowed_origins: allowed_origins,
|
||||
|
|
|
@ -59,7 +59,7 @@ fn owned_options<'r>() -> impl Responder<'r> {
|
|||
}
|
||||
|
||||
fn cors_options() -> CorsOptions {
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
// You can also deserialize this
|
||||
rocket_cors::CorsOptions {
|
||||
|
|
|
@ -36,7 +36,7 @@ fn ping_options<'r>() -> impl Responder<'r> {
|
|||
|
||||
/// Returns the "application wide" Cors struct
|
||||
fn cors_options() -> CorsOptions {
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
// You can also deserialize this
|
||||
rocket_cors::CorsOptions {
|
||||
|
|
|
@ -138,10 +138,10 @@ mod tests {
|
|||
|
||||
use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions};
|
||||
|
||||
const CORS_ROOT: &'static str = "/my_cors";
|
||||
const CORS_ROOT: &str = "/my_cors";
|
||||
|
||||
fn make_cors_options() -> Cors {
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
CorsOptions {
|
||||
allowed_origins,
|
||||
|
|
|
@ -73,6 +73,23 @@ pub enum Origin {
|
|||
Parsed(url::Origin),
|
||||
}
|
||||
|
||||
impl Origin {
|
||||
/// Perform an
|
||||
/// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin)
|
||||
/// of this origin.
|
||||
pub fn ascii_serialization(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
|
||||
/// Returns whether the origin was parsed as non-opaque
|
||||
pub fn is_tuple(&self) -> bool {
|
||||
match self {
|
||||
Origin::Null => false,
|
||||
Origin::Parsed(ref parsed) => parsed.is_tuple(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Origin {
|
||||
type Err = crate::Error;
|
||||
|
||||
|
|
92
src/lib.rs
92
src/lib.rs
|
@ -308,6 +308,8 @@ pub enum Error {
|
|||
MissingOrigin,
|
||||
/// The HTTP request header `Origin` could not be parsed correctly.
|
||||
BadOrigin(url::ParseError),
|
||||
/// The configured Allowed Origin is opaque and cannot be parsed.
|
||||
OpaqueAllowedOrigin(String),
|
||||
/// The request header `Access-Control-Request-Method` is required but is missing
|
||||
MissingRequestMethod,
|
||||
/// The request header `Access-Control-Request-Method` has an invalid value
|
||||
|
@ -376,7 +378,8 @@ impl fmt::Display for Error {
|
|||
}
|
||||
Error::MissingInjectedHeader => write!(f,
|
||||
"The `on_response` handler of Fairing could not find the injected header from the \
|
||||
Request. Either some other fairing has removed it, or this is a bug.")
|
||||
Request. Either some other fairing has removed it, or this is a bug."),
|
||||
Error::OpaqueAllowedOrigin(ref origin) => write!(f, "The configured Origin '{}' not have a parsable Origin. Use a regex instead.", origin),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -597,6 +600,9 @@ pub struct Origins {
|
|||
pub exact: Option<HashSet<String>>,
|
||||
/// Origins that will be matched via __any__ regex in this list. These __must__ be valid Regex
|
||||
/// that will be parsed and validated when creating [`Cors`].
|
||||
/// The regex will be matched according to the
|
||||
/// [ASCII serialization](https://html.spec.whatwg.org/multipage/#ascii-serialisation-of-an-origin)
|
||||
/// of the incoming Origin.
|
||||
#[cfg_attr(feature = "serialization", serde(default))]
|
||||
pub regex: Option<HashSet<String>>,
|
||||
}
|
||||
|
@ -610,13 +616,24 @@ pub(crate) struct ParsedAllowedOrigins {
|
|||
|
||||
impl ParsedAllowedOrigins {
|
||||
fn parse(origins: &Origins) -> Result<Self, Error> {
|
||||
let exact: Result<_, Error> = match &origins.exact {
|
||||
let exact: Result<HashSet<url::Origin>, Error> = match &origins.exact {
|
||||
Some(exact) => exact.iter().map(|url| to_origin(url.as_str())).collect(),
|
||||
None => Ok(Default::default()),
|
||||
};
|
||||
let exact = exact?;
|
||||
|
||||
// Let's check if any of them is Opaque
|
||||
exact.iter().try_for_each(|url| {
|
||||
if !url.is_tuple() {
|
||||
Err(Error::OpaqueAllowedOrigin(url.ascii_serialization()))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
allow_null: origins.allow_null,
|
||||
exact: exact?,
|
||||
exact,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1332,9 +1349,8 @@ enum ValidationResult {
|
|||
Request { origin: String },
|
||||
}
|
||||
|
||||
/// Convert a str to Origin
|
||||
/// Convert a str to a URL Origin
|
||||
fn to_origin<S: AsRef<str>>(origin: S) -> Result<url::Origin, Error> {
|
||||
// What to do about Opaque origins?
|
||||
Ok(url::Url::parse(origin.as_ref())?.origin())
|
||||
}
|
||||
|
||||
|
@ -1727,8 +1743,12 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::http::Method;
|
||||
|
||||
fn to_parsed_origin<S: AsRef<str>>(origin: S) -> Result<Origin, Error> {
|
||||
Origin::from_str(origin.as_ref())
|
||||
}
|
||||
|
||||
fn make_cors_options() -> CorsOptions {
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
CorsOptions {
|
||||
allowed_origins,
|
||||
|
@ -1821,7 +1841,7 @@ mod tests {
|
|||
#[test]
|
||||
fn validate_origin_allows_all_origins() {
|
||||
let url = "https://www.example.com";
|
||||
let origin = not_err!(to_origin(&url));
|
||||
let origin = not_err!(to_parsed_origin(&url));
|
||||
let allowed_origins = AllOrSome::All;
|
||||
|
||||
not_err!(validate_origin(&origin, &allowed_origins));
|
||||
|
@ -1830,8 +1850,8 @@ mod tests {
|
|||
#[test]
|
||||
fn validate_origin_allows_origin() {
|
||||
let url = "https://www.example.com";
|
||||
let origin = not_err!(to_origin(&url));
|
||||
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
|
||||
let origin = not_err!(to_parsed_origin(&url));
|
||||
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[
|
||||
"https://www.example.com"
|
||||
])));
|
||||
|
||||
|
@ -1849,8 +1869,10 @@ mod tests {
|
|||
];
|
||||
|
||||
for (url, allowed_origin) in cases {
|
||||
let origin = not_err!(to_origin(&url));
|
||||
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[allowed_origin])));
|
||||
let origin = not_err!(to_parsed_origin(&url));
|
||||
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[
|
||||
allowed_origin
|
||||
])));
|
||||
|
||||
not_err!(validate_origin(&origin, &allowed_origins));
|
||||
}
|
||||
|
@ -1860,8 +1882,8 @@ mod tests {
|
|||
#[should_panic(expected = "OriginNotAllowed")]
|
||||
fn validate_origin_rejects_invalid_origin() {
|
||||
let url = "https://www.acme.com";
|
||||
let origin = not_err!(to_origin(&url));
|
||||
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
|
||||
let origin = not_err!(to_parsed_origin(&url));
|
||||
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some_exact(&[
|
||||
"https://www.example.com"
|
||||
])));
|
||||
|
||||
|
@ -1871,7 +1893,7 @@ mod tests {
|
|||
#[test]
|
||||
fn response_sets_allow_origin_without_vary_correctly() {
|
||||
let response = Response::new();
|
||||
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["https://www.example.com"];
|
||||
|
@ -1888,7 +1910,7 @@ mod tests {
|
|||
#[test]
|
||||
fn response_sets_allow_origin_with_vary_correctly() {
|
||||
let response = Response::new();
|
||||
let response = response.origin(&to_origin("https://www.example.com").unwrap(), true);
|
||||
let response = response.origin("https://www.example.com", true);
|
||||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["https://www.example.com"];
|
||||
|
@ -1915,27 +1937,11 @@ mod tests {
|
|||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_sets_allow_origin_with_ascii_serialization() {
|
||||
let response = Response::new();
|
||||
let response = response.origin(&to_origin("https://аpple.com").unwrap(), false);
|
||||
|
||||
// Build response and check built response header
|
||||
// This is "punycode"
|
||||
let expected_header = vec!["https://xn--pple-43d.com"];
|
||||
let response = response.response(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Allow-Origin")
|
||||
.collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_sets_exposed_headers_correctly() {
|
||||
let headers = vec!["Bar", "Baz", "Foo"];
|
||||
let response = Response::new();
|
||||
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response.exposed_headers(&headers);
|
||||
|
||||
// Build response and check built response header
|
||||
|
@ -1957,7 +1963,7 @@ mod tests {
|
|||
#[test]
|
||||
fn response_sets_max_age_correctly() {
|
||||
let response = Response::new();
|
||||
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
|
||||
let response = response.max_age(Some(42));
|
||||
|
||||
|
@ -1971,7 +1977,7 @@ mod tests {
|
|||
#[test]
|
||||
fn response_does_not_set_max_age_when_none() {
|
||||
let response = Response::new();
|
||||
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
|
||||
let response = response.max_age(None);
|
||||
|
||||
|
@ -2084,7 +2090,7 @@ mod tests {
|
|||
.finalize();
|
||||
|
||||
let response = Response::new();
|
||||
let response = response.origin(&to_origin("https://www.example.com").unwrap(), false);
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response.response(original);
|
||||
// Check CORS header
|
||||
let expected_header = vec!["https://www.example.com"];
|
||||
|
@ -2160,7 +2166,7 @@ mod tests {
|
|||
|
||||
let result = validate(&cors, request.inner()).expect("to not fail");
|
||||
let expected_result = ValidationResult::Preflight {
|
||||
origin: to_origin("https://www.acme.com").unwrap(),
|
||||
origin: "https://www.acme.com".to_string(),
|
||||
// Checks that only a subset of allowed headers are returned
|
||||
// -- i.e. whatever is requested for
|
||||
headers: Some(FromStr::from_str("Authorization").unwrap()),
|
||||
|
@ -2195,7 +2201,7 @@ mod tests {
|
|||
|
||||
let result = validate(&cors, request.inner()).expect("to not fail");
|
||||
let expected_result = ValidationResult::Preflight {
|
||||
origin: to_origin("https://www.example.com").unwrap(),
|
||||
origin: "https://www.example.com".to_string(),
|
||||
headers: Some(FromStr::from_str("Authorization").unwrap()),
|
||||
};
|
||||
|
||||
|
@ -2313,7 +2319,7 @@ mod tests {
|
|||
|
||||
let result = validate(&cors, request.inner()).expect("to not fail");
|
||||
let expected_result = ValidationResult::Request {
|
||||
origin: to_origin("https://www.acme.com").unwrap(),
|
||||
origin: "https://www.acme.com".to_string(),
|
||||
};
|
||||
|
||||
assert_eq!(expected_result, result);
|
||||
|
@ -2332,7 +2338,7 @@ mod tests {
|
|||
|
||||
let result = validate(&cors, request.inner()).expect("to not fail");
|
||||
let expected_result = ValidationResult::Request {
|
||||
origin: to_origin("https://www.example.com").unwrap(),
|
||||
origin: "https://www.example.com".to_string(),
|
||||
};
|
||||
|
||||
assert_eq!(expected_result, result);
|
||||
|
@ -2388,7 +2394,7 @@ mod tests {
|
|||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||
|
||||
let expected_response = Response::new()
|
||||
.origin(&to_origin("https://www.acme.com").unwrap(), false)
|
||||
.origin("https://www.acme.com", false)
|
||||
.headers(&["Authorization"])
|
||||
.methods(&options.allowed_methods)
|
||||
.credentials(options.allow_credentials)
|
||||
|
@ -2428,7 +2434,7 @@ mod tests {
|
|||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||
|
||||
let expected_response = Response::new()
|
||||
.origin(&to_origin("https://www.acme.com").unwrap(), true)
|
||||
.origin("https://www.acme.com", true)
|
||||
.headers(&["Authorization"])
|
||||
.methods(&options.allowed_methods)
|
||||
.credentials(options.allow_credentials)
|
||||
|
@ -2489,7 +2495,7 @@ mod tests {
|
|||
|
||||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||
let expected_response = Response::new()
|
||||
.origin(&to_origin("https://www.acme.com").unwrap(), false)
|
||||
.origin("https://www.acme.com", false)
|
||||
.credentials(options.allow_credentials)
|
||||
.exposed_headers(&["Content-Type", "X-Custom"]);
|
||||
|
||||
|
@ -2512,7 +2518,7 @@ mod tests {
|
|||
|
||||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||
let expected_response = Response::new()
|
||||
.origin(&to_origin("https://www.acme.com").unwrap(), true)
|
||||
.origin("https://www.acme.com", true)
|
||||
.credentials(options.allow_credentials)
|
||||
.exposed_headers(&["Content-Type", "X-Custom"]);
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ fn panicking_route() {
|
|||
}
|
||||
|
||||
fn make_cors() -> Cors {
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
CorsOptions {
|
||||
allowed_origins: allowed_origins,
|
||||
|
|
|
@ -60,7 +60,7 @@ fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Respo
|
|||
}
|
||||
|
||||
fn make_cors() -> cors::Cors {
|
||||
let allowed_origins = cors::AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = cors::AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
cors::CorsOptions {
|
||||
allowed_origins: allowed_origins,
|
||||
|
|
|
@ -66,7 +66,7 @@ fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> imp
|
|||
}
|
||||
|
||||
fn make_cors_options() -> CorsOptions {
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
CorsOptions {
|
||||
allowed_origins: allowed_origins,
|
||||
|
@ -78,7 +78,7 @@ fn make_cors_options() -> CorsOptions {
|
|||
}
|
||||
|
||||
fn make_different_cors_options() -> CorsOptions {
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.example.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.example.com"]);
|
||||
|
||||
CorsOptions {
|
||||
allowed_origins: allowed_origins,
|
||||
|
|
|
@ -40,7 +40,7 @@ fn ping_options<'r>() -> impl Responder<'r> {
|
|||
|
||||
/// Returns the "application wide" Cors struct
|
||||
fn cors_options() -> CorsOptions {
|
||||
let allowed_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
|
||||
// You can also deserialize this
|
||||
rocket_cors::CorsOptions {
|
||||
|
|
Loading…
Reference in New Issue