Refactor Origin
This commit is contained in:
parent
f9bffe77d6
commit
bc16568e8b
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "rocket_cors"
|
||||
version = "0.4.0"
|
||||
version = "0.5.0"
|
||||
license = "MIT/Apache-2.0"
|
||||
authors = ["Yong Wen Chua <me@yongwen.xyz>"]
|
||||
description = "Cross-origin resource sharing (CORS) for Rocket.rs applications"
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# rocket_cors
|
||||
|
||||
[![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors)
|
||||
[![Dependency Status](https://dependencyci.com/github/lawliet89/rocket_cors/badge)](https://dependencyci.com/github/lawliet89/rocket_cors)
|
||||
[![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors)
|
||||
[![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors)
|
||||
|
||||
|
@ -31,7 +30,7 @@ work, but they are subject to the minimum that Rocket sets.
|
|||
Add the following to Cargo.toml:
|
||||
|
||||
```toml
|
||||
rocket_cors = "0.4.0"
|
||||
rocket_cors = "0.5.0"
|
||||
```
|
||||
|
||||
To use the latest `master` branch, for example:
|
||||
|
|
|
@ -63,6 +63,7 @@ fn on_response_wrapper(
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
let origin = origin.to_string();
|
||||
let cors_response = if request.method() == http::Method::Options {
|
||||
let headers = request_headers(request)?;
|
||||
preflight_response(options, &origin, headers.as_ref())
|
||||
|
|
|
@ -63,28 +63,34 @@ pub type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
|
|||
///
|
||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||
/// to ensure that `Origin` is passed in correctly.
|
||||
///
|
||||
/// Reference: [Mozilla](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin)
|
||||
#[derive(Eq, PartialEq, Clone, Hash, Debug)]
|
||||
pub struct Origin(pub url::Origin);
|
||||
pub enum Origin {
|
||||
/// A `null` Origin
|
||||
Null,
|
||||
/// A well-formed origin that was parsed by [`url::Url::origin`]
|
||||
Parsed(url::Origin),
|
||||
}
|
||||
|
||||
impl FromStr for Origin {
|
||||
type Err = crate::Error;
|
||||
|
||||
fn from_str(input: &str) -> Result<Self, Self::Err> {
|
||||
Ok(Origin(crate::to_origin(input)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Origin {
|
||||
type Target = url::Origin;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
if input.to_lowercase() == "null" {
|
||||
Ok(Origin::Null)
|
||||
} else {
|
||||
Ok(Origin::Parsed(crate::to_origin(input)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Origin {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.ascii_serialization())
|
||||
match self {
|
||||
Origin::Null => write!(f, "null"),
|
||||
Origin::Parsed(ref parsed) => write!(f, "{}", parsed.ascii_serialization()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -195,6 +201,10 @@ mod tests {
|
|||
let parsed = not_err!(Origin::from_str(url));
|
||||
assert_eq!(parsed.ascii_serialization(), url);
|
||||
|
||||
let url = "https://foo.bar.xyz:1234";
|
||||
let parsed = not_err!(Origin::from_str(url));
|
||||
assert_eq!(parsed.ascii_serialization(), url);
|
||||
|
||||
// this should never really be sent by a compliant user agent
|
||||
let url = "https://foo.bar.xyz/path/somewhere";
|
||||
let parsed = not_err!(Origin::from_str(url));
|
||||
|
|
153
src/lib.rs
153
src/lib.rs
|
@ -1,6 +1,5 @@
|
|||
/*!
|
||||
[![Build Status](https://travis-ci.org/lawliet89/rocket_cors.svg)](https://travis-ci.org/lawliet89/rocket_cors)
|
||||
[![Dependency Status](https://dependencyci.com/github/lawliet89/rocket_cors/badge)](https://dependencyci.com/github/lawliet89/rocket_cors)
|
||||
[![Repository](https://img.shields.io/github/tag/lawliet89/rocket_cors.svg)](https://github.com/lawliet89/rocket_cors)
|
||||
[![Crates.io](https://img.shields.io/crates/v/rocket_cors.svg)](https://crates.io/crates/rocket_cors)
|
||||
|
||||
|
@ -30,7 +29,7 @@ might work, but they are subject to the minimum that Rocket sets.
|
|||
Add the following to Cargo.toml:
|
||||
|
||||
```toml
|
||||
rocket_cors = "0.4.0"
|
||||
rocket_cors = "0.5.0"
|
||||
```
|
||||
|
||||
To use the latest `master` branch, for example:
|
||||
|
@ -46,7 +45,7 @@ the [`CorsOptions`] struct that is described below. If you would like to disable
|
|||
change your `Cargo.toml` to:
|
||||
|
||||
```toml
|
||||
rocket_cors = { version = "0.4.0", default-features = false }
|
||||
rocket_cors = { version = "0.5.0", default-features = false }
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
@ -316,7 +315,7 @@ pub enum Error {
|
|||
/// The request header `Access-Control-Request-Headers` is required but is missing.
|
||||
MissingRequestHeaders,
|
||||
/// Origin is not allowed to make this request
|
||||
OriginNotAllowed(url::Origin),
|
||||
OriginNotAllowed(String),
|
||||
/// Requested method is not allowed
|
||||
MethodNotAllowed(String),
|
||||
/// One or more headers requested are not allowed
|
||||
|
@ -365,7 +364,7 @@ impl fmt::Display for Error {
|
|||
"The request header `Access-Control-Request-Headers` \
|
||||
is required but is missing")
|
||||
}
|
||||
Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin.ascii_serialization()),
|
||||
Error::OriginNotAllowed(origin) => write!(f, "Origin '{}' is not allowed to request", origin),
|
||||
Error::MethodNotAllowed(method) => write!(f, "Method '{}' is not allowed", &method),
|
||||
Error::HeadersNotAllowed => write!(f, "Headers are not allowed"),
|
||||
Error::CredentialsWithWildcardOrigin => { write!(f,
|
||||
|
@ -529,16 +528,44 @@ mod method_serde {
|
|||
/// use rocket_cors::AllowedOrigins;
|
||||
///
|
||||
/// let all_origins = AllowedOrigins::all();
|
||||
/// let some_origins = AllowedOrigins::some(&["https://www.acme.com"]);
|
||||
/// let some_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||
/// let null_origins = AllowedOrigins::some_null();
|
||||
/// ```
|
||||
pub type AllowedOrigins = AllOrSome<HashSet<String>>;
|
||||
pub type AllowedOrigins = AllOrSome<Origins>;
|
||||
|
||||
impl AllowedOrigins {
|
||||
/// Allows some origins
|
||||
/// Allows some _exact_ origins
|
||||
///
|
||||
/// Validation is not performed at this stage, but at a later stage.
|
||||
#[deprecated(since = "0.5.0", note = "use `some_exact` instead")]
|
||||
pub fn some(urls: &[&str]) -> Self {
|
||||
AllOrSome::Some(urls.iter().map(|s| s.to_string()).collect())
|
||||
Self::some_exact(urls)
|
||||
}
|
||||
|
||||
/// Allows some _exact_ origins
|
||||
///
|
||||
/// Validation is not performed at this stage, but at a later stage.
|
||||
pub fn some_exact<S: AsRef<str>>(urls: &[S]) -> Self {
|
||||
AllOrSome::Some(Origins {
|
||||
exact: Some(urls.iter().map(|s| s.as_ref().to_string()).collect()),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Allow some __regex__ origins
|
||||
pub fn some_regex<S: AsRef<str>>(regex: &[S]) -> Self {
|
||||
AllOrSome::Some(Origins {
|
||||
regex: Some(regex.iter().map(|s| s.as_ref().to_string()).collect()),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Allow some `null` origins
|
||||
pub fn some_null() -> Self {
|
||||
AllOrSome::Some(Origins {
|
||||
allow_null: true,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Allows all origins
|
||||
|
@ -547,6 +574,53 @@ impl AllowedOrigins {
|
|||
}
|
||||
}
|
||||
|
||||
/// A list of allows origins
|
||||
///
|
||||
/// An origin is defined according
|
||||
/// [syntax](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) defined here.
|
||||
///
|
||||
/// Origins can be specified as an exact match or via some other supported way according to the
|
||||
/// fields of the struct.
|
||||
///
|
||||
/// These Origins are specified as logical `ORs`. That is, if any of the origins match, the entire
|
||||
/// request is considered to be valid.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Default)]
|
||||
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "serialization", serde(default))]
|
||||
pub struct Origins {
|
||||
/// Whether null origins are accepted
|
||||
#[cfg_attr(feature = "serialization", serde(default))]
|
||||
pub allow_null: bool,
|
||||
/// Origins that must be matched exactly as below. These __must__ be valid URL strings that will
|
||||
/// be parsed and validated when creating [`Cors`].
|
||||
#[cfg_attr(feature = "serialization", serde(default))]
|
||||
pub exact: Option<HashSet<String>>,
|
||||
/// Origins that will be matched via __any__ regex in this list. These __must__ be valid Regex
|
||||
/// that will be parsed and validated when creating [`Cors`].
|
||||
#[cfg_attr(feature = "serialization", serde(default))]
|
||||
pub regex: Option<HashSet<String>>,
|
||||
}
|
||||
|
||||
/// Parsed set of configured allowed origins
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub(crate) struct ParsedAllowedOrigins {
|
||||
pub allow_null: bool,
|
||||
pub exact: HashSet<url::Origin>,
|
||||
}
|
||||
|
||||
impl ParsedAllowedOrigins {
|
||||
fn parse(origins: &Origins) -> Result<Self, Error> {
|
||||
let exact: Result<_, Error> = match &origins.exact {
|
||||
Some(exact) => exact.iter().map(|url| to_origin(url.as_str())).collect(),
|
||||
None => Ok(Default::default()),
|
||||
};
|
||||
Ok(Self {
|
||||
allow_null: origins.allow_null,
|
||||
exact: exact?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A list of allowed methods
|
||||
///
|
||||
/// The [list](https://api.rocket.rs/rocket/http/enum.Method.html)
|
||||
|
@ -833,7 +907,7 @@ impl CorsOptions {
|
|||
/// This struct can be created by using [`CorsOptions::to_cors`] or [`Cors::from_options`].
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct Cors {
|
||||
pub(crate) allowed_origins: AllOrSome<HashSet<url::Origin>>,
|
||||
pub(crate) allowed_origins: AllOrSome<ParsedAllowedOrigins>,
|
||||
pub(crate) allowed_methods: AllowedMethods,
|
||||
pub(crate) allowed_headers: AllOrSome<HashSet<HeaderFieldName>>,
|
||||
pub(crate) allow_credentials: bool,
|
||||
|
@ -921,7 +995,7 @@ impl Cors {
|
|||
/// You can get this struct by using `Cors::validate_request` in an ad-hoc manner.
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub(crate) struct Response {
|
||||
allow_origin: Option<AllOrSome<url::Origin>>,
|
||||
allow_origin: Option<AllOrSome<String>>,
|
||||
allow_methods: HashSet<Method>,
|
||||
allow_headers: HeaderFieldNamesSet,
|
||||
allow_credentials: bool,
|
||||
|
@ -945,8 +1019,8 @@ impl Response {
|
|||
}
|
||||
|
||||
/// Consumes the `Response` and return an altered response with origin and `vary_origin` set
|
||||
fn origin(mut self, origin: &url::Origin, vary_origin: bool) -> Self {
|
||||
self.allow_origin = Some(AllOrSome::Some(origin.clone()));
|
||||
fn origin(mut self, origin: &str, vary_origin: bool) -> Self {
|
||||
self.allow_origin = Some(AllOrSome::Some(origin.to_string()));
|
||||
self.vary_origin = vary_origin;
|
||||
self
|
||||
}
|
||||
|
@ -1022,7 +1096,7 @@ impl Response {
|
|||
|
||||
let origin = match *origin {
|
||||
AllOrSome::All => "*".to_string(),
|
||||
AllOrSome::Some(ref origin) => origin.ascii_serialization(),
|
||||
AllOrSome::Some(ref origin) => origin.to_string(),
|
||||
};
|
||||
|
||||
let _ = response.set_raw_header("Access-Control-Allow-Origin", origin);
|
||||
|
@ -1251,11 +1325,11 @@ enum ValidationResult {
|
|||
None,
|
||||
/// Successful preflight request
|
||||
Preflight {
|
||||
origin: url::Origin,
|
||||
origin: String,
|
||||
headers: Option<AccessControlRequestHeaders>,
|
||||
},
|
||||
/// Successful actual request
|
||||
Request { origin: url::Origin },
|
||||
Request { origin: String },
|
||||
}
|
||||
|
||||
/// Convert a str to Origin
|
||||
|
@ -1265,13 +1339,12 @@ fn to_origin<S: AsRef<str>>(origin: S) -> Result<url::Origin, Error> {
|
|||
}
|
||||
|
||||
/// Parse and process allowed origins
|
||||
fn parse_origins(origins: &AllowedOrigins) -> Result<AllOrSome<HashSet<url::Origin>>, Error> {
|
||||
fn parse_origins(origins: &AllowedOrigins) -> Result<AllOrSome<ParsedAllowedOrigins>, Error> {
|
||||
match origins {
|
||||
AllOrSome::All => Ok(AllOrSome::All),
|
||||
AllOrSome::Some(ref origins) => {
|
||||
let parsed: Result<HashSet<url::Origin>, Error> =
|
||||
origins.iter().map(to_origin).collect();
|
||||
Ok(AllOrSome::Some(parsed?))
|
||||
AllOrSome::Some(origins) => {
|
||||
let parsed = ParsedAllowedOrigins::parse(origins)?;
|
||||
Ok(AllOrSome::Some(parsed))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1309,14 +1382,14 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
|
|||
let headers = request_headers(request)?;
|
||||
preflight_validate(options, &origin, &method, &headers)?;
|
||||
Ok(ValidationResult::Preflight {
|
||||
origin: origin.deref().clone(),
|
||||
origin: origin.to_string(),
|
||||
headers,
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
actual_request_validate(options, &origin)?;
|
||||
Ok(ValidationResult::Request {
|
||||
origin: origin.deref().clone(),
|
||||
origin: origin.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1326,16 +1399,30 @@ fn validate(options: &Cors, request: &Request<'_>) -> Result<ValidationResult, E
|
|||
/// check if the requested origin is allowed.
|
||||
/// Useful for pre-flight and during requests
|
||||
fn validate_origin(
|
||||
origin: &url::Origin,
|
||||
allowed_origins: &AllOrSome<HashSet<url::Origin>>,
|
||||
origin: &Origin,
|
||||
allowed_origins: &AllOrSome<ParsedAllowedOrigins>,
|
||||
) -> Result<(), Error> {
|
||||
match *allowed_origins {
|
||||
// Always matching is acceptable since the list of origins can be unbounded.
|
||||
AllOrSome::All => Ok(()),
|
||||
AllOrSome::Some(ref allowed_origins) => allowed_origins
|
||||
.get(origin)
|
||||
.and_then(|_| Some(()))
|
||||
.ok_or_else(|| Error::OriginNotAllowed(origin.clone())),
|
||||
// AllOrSome::Some(ref allowed_origins) => allowed_origins
|
||||
// .get(origin)
|
||||
// .and_then(|_| Some(()))
|
||||
// .ok_or_else(|| Error::OriginNotAllowed(origin.clone())),
|
||||
AllOrSome::Some(ref allowed_origins) => match origin {
|
||||
Origin::Null => {
|
||||
if allowed_origins.allow_null {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::OriginNotAllowed(origin.to_string()))
|
||||
}
|
||||
}
|
||||
Origin::Parsed(ref parsed) => allowed_origins
|
||||
.exact
|
||||
.get(parsed)
|
||||
.and_then(|_| Some(()))
|
||||
.ok_or_else(|| Error::OriginNotAllowed(origin.to_string())),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1405,7 +1492,7 @@ fn request_headers(request: &Request<'_>) -> Result<Option<AccessControlRequestH
|
|||
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch)
|
||||
fn preflight_validate(
|
||||
options: &Cors,
|
||||
origin: &url::Origin,
|
||||
origin: &Origin,
|
||||
method: &Option<AccessControlRequestMethod>,
|
||||
headers: &Option<AccessControlRequestHeaders>,
|
||||
) -> Result<(), Error> {
|
||||
|
@ -1453,7 +1540,7 @@ fn preflight_validate(
|
|||
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch).
|
||||
fn preflight_response(
|
||||
options: &Cors,
|
||||
origin: &url::Origin,
|
||||
origin: &str,
|
||||
headers: Option<&AccessControlRequestHeaders>,
|
||||
) -> Response {
|
||||
let response = Response::new();
|
||||
|
@ -1524,7 +1611,7 @@ fn preflight_response(
|
|||
/// This implementation references the
|
||||
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests)
|
||||
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch).
|
||||
fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), Error> {
|
||||
fn actual_request_validate(options: &Cors, origin: &Origin) -> Result<(), Error> {
|
||||
// Note: All header parse failures are dealt with in the `FromRequest` trait implementation
|
||||
|
||||
// 2. If the value of the Origin header is not a case-sensitive match for any of the values
|
||||
|
@ -1541,7 +1628,7 @@ fn actual_request_validate(options: &Cors, origin: &url::Origin) -> Result<(), E
|
|||
/// This implementation references the
|
||||
/// [W3C recommendation](https://www.w3.org/TR/cors/#resource-requests)
|
||||
/// and [Fetch specification](https://fetch.spec.whatwg.org/#cors-preflight-fetch)
|
||||
fn actual_request_response(options: &Cors, origin: &url::Origin) -> Response {
|
||||
fn actual_request_response(options: &Cors, origin: &str) -> Response {
|
||||
let response = Response::new();
|
||||
|
||||
// 3. If the resource supports credentials add a single Access-Control-Allow-Origin header,
|
||||
|
|
Loading…
Reference in New Issue