Separate out `Origin`

This commit is contained in:
Yong Wen Chua 2018-12-19 11:26:27 +08:00
parent d7e5153e27
commit 2b5dfede54
No known key found for this signature in database
GPG Key ID: EDC57EEC439CF10B
3 changed files with 32 additions and 46 deletions

View File

@ -53,7 +53,7 @@ fn on_response_wrapper(
// Not a CORS request // Not a CORS request
return Ok(()); return Ok(());
} }
Some(origin) => crate::to_origin(origin)?, Some(origin) => origin,
}; };
let result = request.local_cache(|| unreachable!("This should not be executed so late")); let result = request.local_cache(|| unreachable!("This should not be executed so late"));

View File

@ -64,34 +64,27 @@ pub type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that `Origin` is passed in correctly. /// to ensure that `Origin` is passed in correctly.
#[derive(Eq, PartialEq, Clone, Hash, Debug)] #[derive(Eq, PartialEq, Clone, Hash, Debug)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] pub struct Origin(pub url::Origin);
pub struct Origin(pub String);
impl FromStr for Origin { impl FromStr for Origin {
type Err = !; type Err = crate::Error;
fn from_str(input: &str) -> Result<Self, Self::Err> { fn from_str(input: &str) -> Result<Self, Self::Err> {
Ok(Origin(input.to_string())) Ok(Origin(crate::to_origin(input)?))
} }
} }
impl Deref for Origin { impl Deref for Origin {
type Target = str; type Target = url::Origin;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.0 &self.0
} }
} }
impl AsRef<str> for Origin {
fn as_ref(&self) -> &str {
self
}
}
impl fmt::Display for Origin { impl fmt::Display for Origin {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.as_ref().fmt(f) write!(f, "{}", self.ascii_serialization())
} }
} }
@ -100,10 +93,10 @@ impl<'a, 'r> FromRequest<'a, 'r> for Origin {
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> { fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
match request.headers().get_one("Origin") { match request.headers().get_one("Origin") {
Some(origin) => { Some(origin) => match Self::from_str(origin) {
let Ok(origin) = Self::from_str(origin); Ok(origin) => Outcome::Success(origin),
Outcome::Success(origin) Err(e) => Outcome::Failure((Status::BadRequest, e)),
} },
None => Outcome::Forward(()), None => Outcome::Forward(()),
} }
} }
@ -199,17 +192,17 @@ mod tests {
#[test] #[test]
fn origin_header_conversion() { fn origin_header_conversion() {
let url = "https://foo.bar.xyz"; let url = "https://foo.bar.xyz";
let Ok(parsed) = Origin::from_str(url); let parsed = not_err!(Origin::from_str(url));
assert_eq!(parsed.as_ref(), url); assert_eq!(parsed.ascii_serialization(), url);
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used // this should never really be sent by a compliant user agent
let Ok(parsed) = Origin::from_str(url); let url = "https://foo.bar.xyz/path/somewhere";
assert_eq!(parsed.as_ref(), url); let parsed = not_err!(Origin::from_str(url));
let expected = "https://foo.bar.xyz";
assert_eq!(parsed.ascii_serialization(), expected);
// Validation is not done now
let url = "invalid_url"; let url = "invalid_url";
let Ok(parsed) = Origin::from_str(url); let _ = is_err!(Origin::from_str(url));
assert_eq!(parsed.as_ref(), url);
} }
#[test] #[test]
@ -223,7 +216,7 @@ mod tests {
let outcome: request::Outcome<Origin, crate::Error> = let outcome: request::Outcome<Origin, crate::Error> =
FromRequest::from_request(request.inner()); FromRequest::from_request(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
assert_eq!("https://www.example.com", parsed_header.as_ref()); assert_eq!("https://www.example.com", parsed_header.ascii_serialization());
} }
#[test] #[test]

View File

@ -267,8 +267,6 @@ See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/
intra_doc_link_resolution_failure intra_doc_link_resolution_failure
)] )]
#![doc(test(attr(allow(unused_variables), deny(warnings))))] #![doc(test(attr(allow(unused_variables), deny(warnings))))]
#![feature(never_type)]
#![feature(exhaustive_patterns)]
#[cfg(test)] #[cfg(test)]
#[macro_use] #[macro_use]
@ -533,21 +531,14 @@ mod method_serde {
/// let all_origins = AllowedOrigins::all(); /// let all_origins = AllowedOrigins::all();
/// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]); /// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]);
/// ``` /// ```
pub type AllowedOrigins = AllOrSome<HashSet<Origin>>; pub type AllowedOrigins = AllOrSome<HashSet<String>>;
impl AllowedOrigins { impl AllowedOrigins {
/// Allows some origins /// Allows some origins
/// ///
/// Validation is not performed at this stage, but at a later stage. /// Validation is not performed at this stage, but at a later stage.
pub fn some(urls: &[&str]) -> Self { pub fn some(urls: &[&str]) -> Self {
AllOrSome::Some( AllOrSome::Some(urls.iter().map(|s| s.to_string()).collect())
urls.iter()
.map(|s| {
let Ok(s) = FromStr::from_str(s);
s
})
.collect(),
)
} }
/// Allows all origins /// Allows all origins
@ -1308,7 +1299,7 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
// Not a CORS request // Not a CORS request
return Ok(ValidationResult::None); return Ok(ValidationResult::None);
} }
Some(origin) => to_origin(origin)?, Some(origin) => origin,
}; };
// Check if the request verb is an OPTION or something else // Check if the request verb is an OPTION or something else
@ -1317,11 +1308,16 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
let method = request_method(request)?; let method = request_method(request)?;
let headers = request_headers(request)?; let headers = request_headers(request)?;
preflight_validate(options, &origin, &method, &headers)?; preflight_validate(options, &origin, &method, &headers)?;
Ok(ValidationResult::Preflight { origin, headers }) Ok(ValidationResult::Preflight {
origin: origin.deref().clone(),
headers,
})
} }
_ => { _ => {
actual_request_validate(options, &origin)?; actual_request_validate(options, &origin)?;
Ok(ValidationResult::Request { origin }) Ok(ValidationResult::Request {
origin: origin.deref().clone(),
})
} }
} }
} }
@ -1705,8 +1701,7 @@ mod tests {
#[test] #[test]
fn validate_origin_allows_all_origins() { fn validate_origin_allows_all_origins() {
let url = "https://www.example.com"; let url = "https://www.example.com";
let Ok(origin) = Origin::from_str(url); let origin = not_err!(to_origin(&url));
let origin = not_err!(to_origin(&origin));
let allowed_origins = AllOrSome::All; let allowed_origins = AllOrSome::All;
not_err!(validate_origin(&origin, &allowed_origins)); not_err!(validate_origin(&origin, &allowed_origins));
@ -1715,8 +1710,7 @@ mod tests {
#[test] #[test]
fn validate_origin_allows_origin() { fn validate_origin_allows_origin() {
let url = "https://www.example.com"; let url = "https://www.example.com";
let Ok(origin) = Origin::from_str(url); let origin = not_err!(to_origin(&url));
let origin = not_err!(to_origin(&origin));
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
"https://www.example.com" "https://www.example.com"
]))); ])));
@ -1728,8 +1722,7 @@ mod tests {
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn validate_origin_rejects_invalid_origin() { fn validate_origin_rejects_invalid_origin() {
let url = "https://www.acme.com"; let url = "https://www.acme.com";
let Ok(origin) = Origin::from_str(url); let origin = not_err!(to_origin(&url));
let origin = not_err!(to_origin(&origin));
let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[ let allowed_origins = not_err!(parse_origins(&AllowedOrigins::some(&[
"https://www.example.com" "https://www.example.com"
]))); ])));