Separate out `Origin`
This commit is contained in:
parent
d7e5153e27
commit
2b5dfede54
|
@ -53,7 +53,7 @@ fn on_response_wrapper(
|
||||||
// Not a CORS request
|
// Not a CORS request
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
Some(origin) => crate::to_origin(origin)?,
|
Some(origin) => origin,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = request.local_cache(|| unreachable!("This should not be executed so late"));
|
let result = request.local_cache(|| unreachable!("This should not be executed so late"));
|
||||||
|
|
|
@ -64,34 +64,27 @@ pub type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
|
||||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||||
/// to ensure that `Origin` is passed in correctly.
|
/// to ensure that `Origin` is passed in correctly.
|
||||||
#[derive(Eq, PartialEq, Clone, Hash, Debug)]
|
#[derive(Eq, PartialEq, Clone, Hash, Debug)]
|
||||||
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
|
pub struct Origin(pub url::Origin);
|
||||||
pub struct Origin(pub String);
|
|
||||||
|
|
||||||
impl FromStr for Origin {
|
impl FromStr for Origin {
|
||||||
type Err = !;
|
type Err = crate::Error;
|
||||||
|
|
||||||
fn from_str(input: &str) -> Result<Self, Self::Err> {
|
fn from_str(input: &str) -> Result<Self, Self::Err> {
|
||||||
Ok(Origin(input.to_string()))
|
Ok(Origin(crate::to_origin(input)?))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Deref for Origin {
|
impl Deref for Origin {
|
||||||
type Target = str;
|
type Target = url::Origin;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsRef<str> for Origin {
|
|
||||||
fn as_ref(&self) -> &str {
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for Origin {
|
impl fmt::Display for Origin {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
self.as_ref().fmt(f)
|
write!(f, "{}", self.ascii_serialization())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,10 +93,10 @@ impl<'a, 'r> FromRequest<'a, 'r> for Origin {
|
||||||
|
|
||||||
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
|
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
|
||||||
match request.headers().get_one("Origin") {
|
match request.headers().get_one("Origin") {
|
||||||
Some(origin) => {
|
Some(origin) => match Self::from_str(origin) {
|
||||||
let Ok(origin) = Self::from_str(origin);
|
Ok(origin) => Outcome::Success(origin),
|
||||||
Outcome::Success(origin)
|
Err(e) => Outcome::Failure((Status::BadRequest, e)),
|
||||||
}
|
},
|
||||||
None => Outcome::Forward(()),
|
None => Outcome::Forward(()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -199,17 +192,17 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn origin_header_conversion() {
|
fn origin_header_conversion() {
|
||||||
let url = "https://foo.bar.xyz";
|
let url = "https://foo.bar.xyz";
|
||||||
let Ok(parsed) = Origin::from_str(url);
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
assert_eq!(parsed.as_ref(), url);
|
assert_eq!(parsed.ascii_serialization(), url);
|
||||||
|
|
||||||
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
|
// this should never really be sent by a compliant user agent
|
||||||
let Ok(parsed) = Origin::from_str(url);
|
let url = "https://foo.bar.xyz/path/somewhere";
|
||||||
assert_eq!(parsed.as_ref(), url);
|
let parsed = not_err!(Origin::from_str(url));
|
||||||
|
let expected = "https://foo.bar.xyz";
|
||||||
|
assert_eq!(parsed.ascii_serialization(), expected);
|
||||||
|
|
||||||
// Validation is not done now
|
|
||||||
let url = "invalid_url";
|
let url = "invalid_url";
|
||||||
let Ok(parsed) = Origin::from_str(url);
|
let _ = is_err!(Origin::from_str(url));
|
||||||
assert_eq!(parsed.as_ref(), url);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -223,7 +216,7 @@ mod tests {
|
||||||
let outcome: request::Outcome<Origin, crate::Error> =
|
let outcome: request::Outcome<Origin, crate::Error> =
|
||||||
FromRequest::from_request(request.inner());
|
FromRequest::from_request(request.inner());
|
||||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
assert_eq!("https://www.example.com", parsed_header.as_ref());
|
assert_eq!("https://www.example.com", parsed_header.ascii_serialization());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
33
src/lib.rs
33
src/lib.rs
|
@ -267,8 +267,6 @@ See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/
|
||||||
intra_doc_link_resolution_failure
|
intra_doc_link_resolution_failure
|
||||||
)]
|
)]
|
||||||
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
|
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
|
||||||
#![feature(never_type)]
|
|
||||||
#![feature(exhaustive_patterns)]
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
|
@ -533,21 +531,14 @@ mod method_serde {
|
||||||
/// let all_origins = AllowedOrigins::all();
|
/// let all_origins = AllowedOrigins::all();
|
||||||
/// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
/// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||||
/// ```
|
/// ```
|
||||||
pub type AllowedOrigins = AllOrSome<HashSet<Origin>>;
|
pub type AllowedOrigins = AllOrSome<HashSet<String>>;
|
||||||
|
|
||||||
impl AllowedOrigins {
|
impl AllowedOrigins {
|
||||||
/// Allows some origins
|
/// Allows some origins
|
||||||
///
|
///
|
||||||
/// Validation is not performed at this stage, but at a later stage.
|
/// Validation is not performed at this stage, but at a later stage.
|
||||||
pub fn some(urls: &[&str]) -> Self {
|
pub fn some(urls: &[&str]) -> Self {
|
||||||
AllOrSome::Some(
|
AllOrSome::Some(urls.iter().map(|s| s.to_string()).collect())
|
||||||
urls.iter()
|
|
||||||
.map(|s| {
|
|
||||||
let Ok(s) = FromStr::from_str(s);
|
|
||||||
s
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allows all origins
|
/// Allows all origins
|
||||||
|
@ -1308,7 +1299,7 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
|
||||||
// Not a CORS request
|
// Not a CORS request
|
||||||
return Ok(ValidationResult::None);
|
return Ok(ValidationResult::None);
|
||||||
}
|
}
|
||||||
Some(origin) => to_origin(origin)?,
|
Some(origin) => origin,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if the request verb is an OPTION or something else
|
// Check if the request verb is an OPTION or something else
|
||||||
|
@ -1317,11 +1308,16 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
|
||||||
let method = request_method(request)?;
|
let method = request_method(request)?;
|
||||||
let headers = request_headers(request)?;
|
let headers = request_headers(request)?;
|
||||||
preflight_validate(options, &origin, &method, &headers)?;
|
preflight_validate(options, &origin, &method, &headers)?;
|
||||||
Ok(ValidationResult::Preflight { origin, headers })
|
Ok(ValidationResult::Preflight {
|
||||||
|
origin: origin.deref().clone(),
|
||||||
|
headers,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
actual_request_validate(options, &origin)?;
|
actual_request_validate(options, &origin)?;
|
||||||
Ok(ValidationResult::Request { origin })
|
Ok(ValidationResult::Request {
|
||||||
|
origin: origin.deref().clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1705,8 +1701,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn validate_origin_allows_all_origins() {
|
fn validate_origin_allows_all_origins() {
|
||||||
let url = "https://www.example.com";
|
let url = "https://www.example.com";
|
||||||
let Ok(origin) = Origin::from_str(url);
|
let origin = not_err!(to_origin(&url));
|
||||||
let origin = not_err!(to_origin(&origin));
|
|
||||||
let allowed_origins = AllOrSome::All;
|
let allowed_origins = AllOrSome::All;
|
||||||
|
|
||||||
not_err!(validate_origin(&origin, &allowed_origins));
|
not_err!(validate_origin(&origin, &allowed_origins));
|
||||||
|
@ -1715,8 +1710,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn validate_origin_allows_origin() {
|
fn validate_origin_allows_origin() {
|
||||||
let url = "https://www.example.com";
|
let url = "https://www.example.com";
|
||||||
let Ok(origin) = Origin::from_str(url);
|
let origin = not_err!(to_origin(&url));
|
||||||
let origin = not_err!(to_origin(&origin));
|
|
||||||
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
|
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
|
||||||
"https://www.example.com"
|
"https://www.example.com"
|
||||||
])));
|
])));
|
||||||
|
@ -1728,8 +1722,7 @@ mod tests {
|
||||||
#[should_panic(expected = "OriginNotAllowed")]
|
#[should_panic(expected = "OriginNotAllowed")]
|
||||||
fn validate_origin_rejects_invalid_origin() {
|
fn validate_origin_rejects_invalid_origin() {
|
||||||
let url = "https://www.acme.com";
|
let url = "https://www.acme.com";
|
||||||
let Ok(origin) = Origin::from_str(url);
|
let origin = not_err!(to_origin(&url));
|
||||||
let origin = not_err!(to_origin(&origin));
|
|
||||||
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
|
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
|
||||||
"https://www.example.com"
|
"https://www.example.com"
|
||||||
])));
|
])));
|
||||||
|
|
Loading…
Reference in New Issue