Ad-hoc response now use response guards
This commit is contained in:
parent
6746e835e7
commit
f45fa4df04
215
src/lib.rs
215
src/lib.rs
|
@ -161,6 +161,22 @@ pub enum Error {
|
|||
///
|
||||
/// This is a misconfiguration. Check the docuemntation for `Cors`.
|
||||
CredentialsWithWildcardOrigin,
|
||||
/// A CORS Request Guard was used, but no CORS Options was available in Rocket's state
|
||||
///
|
||||
/// This is a misconfiguration. Use `Rocket::manage` to add a CORS options to managed state.
|
||||
MissingCorsInRocketState,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
fn status(&self) -> Status {
|
||||
match *self {
|
||||
Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed |
|
||||
Error::HeadersNotAllowed => Status::Forbidden,
|
||||
Error::CredentialsWithWildcardOrigin |
|
||||
Error::MissingCorsInRocketState => Status::InternalServerError,
|
||||
_ => Status::BadRequest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for Error {
|
||||
|
@ -186,6 +202,9 @@ impl error::Error for Error {
|
|||
"Credentials are allowed, but the Origin is set to \"*\". \
|
||||
This is not allowed by W3C"
|
||||
}
|
||||
Error::MissingCorsInRocketState => {
|
||||
"A CORS Request Guard was used, but no CORS Options was available in Rocket's state"
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -211,19 +230,14 @@ impl fmt::Display for Error {
|
|||
impl<'r> response::Responder<'r> for Error {
|
||||
fn respond_to(self, _: &Request) -> Result<response::Response<'r>, Status> {
|
||||
error_!("CORS Error: {:?}", self);
|
||||
Err(match self {
|
||||
Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed |
|
||||
Error::HeadersNotAllowed => Status::Forbidden,
|
||||
Error::CredentialsWithWildcardOrigin => Status::InternalServerError,
|
||||
_ => Status::BadRequest,
|
||||
})
|
||||
Err(self.status())
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
|
||||
///
|
||||
/// `Default` is implemented for this enum and is `All`.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum AllOrSome<T> {
|
||||
/// Everything is allowed. Usually equivalent to the "*" value.
|
||||
|
@ -274,9 +288,9 @@ impl AllOrSome<HashSet<Url>> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Responder generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS
|
||||
/// Response generator and [Fairing](https://rocket.rs/guide/fairings/) for CORS
|
||||
///
|
||||
/// This struct can be used as Fairing for Rocket, or as an ad-hoc responder for any CORS requests.
|
||||
/// This struct can be as Fairing or in an ad-hoc manner to generate CORS response.
|
||||
///
|
||||
/// You create a new copy of this struct by defining the configurations in the fields below.
|
||||
/// This struct can also be deserialized by serde.
|
||||
|
@ -381,15 +395,6 @@ impl Default for Cors {
|
|||
}
|
||||
|
||||
impl Cors {
|
||||
/// Wrap any `Rocket::Response` and respond with CORS headers.
|
||||
/// This is only used for ad-hoc route CORS response
|
||||
fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
|
||||
&'a self,
|
||||
responder: R,
|
||||
) -> Responder<'a, 'r, R> {
|
||||
Responder::new(responder, self)
|
||||
}
|
||||
|
||||
fn default_allowed_methods() -> HashSet<Method> {
|
||||
vec![
|
||||
Method::Get,
|
||||
|
@ -403,6 +408,17 @@ impl Cors {
|
|||
.collect()
|
||||
}
|
||||
|
||||
/// Build a CORS `Response` to an incoming request.
|
||||
///
|
||||
/// The `Response` should be merged with an
|
||||
/// existing `Rocket::Response` or `rocket::response::Responder`.
|
||||
///
|
||||
/// This is only used for ad-hoc route CORS response
|
||||
pub fn build<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result<Response, Error> {
|
||||
build_cors_response(self, request)
|
||||
}
|
||||
|
||||
|
||||
/// Validates if any of the settings are disallowed or incorrect
|
||||
///
|
||||
/// This is run during initial Fairing attachment
|
||||
|
@ -473,61 +489,6 @@ impl fairing::Fairing for Cors {
|
|||
}
|
||||
}
|
||||
|
||||
/// A CORS [Responder](https://rocket.rs/guide/responses/#responder)
|
||||
/// which will inspect the incoming requests and respond accordingly.
|
||||
///
|
||||
/// If the wrapped `Responder` already has the `Access-Control-Allow-Origin` header set,
|
||||
/// this responder will leave the response untouched.
|
||||
/// This allows for chaining of several CORS responders.
|
||||
///
|
||||
/// Otherwise, the following headers may be set for the final Rocket `Response`, overwriting any
|
||||
/// existing headers defined:
|
||||
///
|
||||
/// - `Access-Control-Allow-Origin`
|
||||
/// - `Access-Control-Expose-Headers`
|
||||
/// - `Access-Control-Max-Age`
|
||||
/// - `Access-Control-Allow-Credentials`
|
||||
/// - `Access-Control-Allow-Methods`
|
||||
/// - `Access-Control-Allow-Headers`
|
||||
/// - `Vary`
|
||||
#[derive(Debug)]
|
||||
pub struct Responder<'a, 'r: 'a, R> {
|
||||
responder: R,
|
||||
options: &'a Cors,
|
||||
marker: PhantomData<response::Responder<'r>>,
|
||||
}
|
||||
|
||||
impl<'a, 'r: 'a, R: response::Responder<'r>> Responder<'a, 'r, R> {
|
||||
fn new(responder: R, options: &'a Cors) -> Self {
|
||||
Self {
|
||||
responder,
|
||||
options,
|
||||
marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Respond to a request
|
||||
fn respond(self, request: &Request) -> response::Result<'r> {
|
||||
let mut response = self.responder.respond_to(request)?; // handle status errors?
|
||||
|
||||
match build_cors_response(self.options, request) {
|
||||
Ok(cors_response) => {
|
||||
cors_response.merge(&mut response);
|
||||
Ok(response)
|
||||
},
|
||||
Err(e) => response::Responder::respond_to(e, request),
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'r: 'a, R: response::Responder<'r>> response::Responder<'r> for Responder<'a, 'r, R> {
|
||||
fn respond_to(self, request: &Request) -> response::Result<'r> {
|
||||
self.respond(request)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// A CORS Response which provides the following CORS headers:
|
||||
///
|
||||
/// - `Access-Control-Allow-Origin`
|
||||
|
@ -537,8 +498,8 @@ impl<'a, 'r: 'a, R: response::Responder<'r>> response::Responder<'r> for Respond
|
|||
/// - `Access-Control-Allow-Methods`
|
||||
/// - `Access-Control-Allow-Headers`
|
||||
/// - `Vary`
|
||||
#[derive(Debug)]
|
||||
struct Response {
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub struct Response {
|
||||
allow_origin: Option<AllOrSome<String>>,
|
||||
allow_methods: HashSet<Method>,
|
||||
allow_headers: HeaderFieldNamesSet,
|
||||
|
@ -549,7 +510,7 @@ struct Response {
|
|||
}
|
||||
|
||||
impl Response {
|
||||
/// Consumes the responder and return an empty `Response`
|
||||
/// Create an empty `Response`
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
allow_origin: None,
|
||||
|
@ -609,11 +570,17 @@ impl Response {
|
|||
self
|
||||
}
|
||||
|
||||
/// Builds a `rocket::Response` from this struct based off some base `rocket::Response`
|
||||
/// Consumes the `Response` and return a `Responder` that wraps a
|
||||
/// provided `rocket:response::Responder` with CORS headers
|
||||
pub fn responder<'r, R: response::Responder<'r>>(self, responder: R) -> Responder<'r, R> {
|
||||
Responder::new(responder, self)
|
||||
}
|
||||
|
||||
/// Merge a `rocket::Response` with this CORS response. This is usually used in the final step
|
||||
/// of a route to return a value for the route.
|
||||
///
|
||||
/// This will overwrite any existing CORS headers
|
||||
#[cfg(test)]
|
||||
fn build<'r>(&self, base: response::Response<'r>) -> response::Response<'r> {
|
||||
pub fn respond<'r>(&self, base: response::Response<'r>) -> response::Response<'r> {
|
||||
let mut response = response::Response::build_from(base).finalize();
|
||||
self.merge(&mut response);
|
||||
response
|
||||
|
@ -691,27 +658,79 @@ impl Response {
|
|||
response.remove_header("Vary");
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate and create a new CORS Response from a request and settings
|
||||
pub fn build_cors_response<'a, 'r>(
|
||||
options: &'a Cors,
|
||||
request: &'a Request<'r>,
|
||||
) -> Result<Self, Error> {
|
||||
build_cors_response(options, request)
|
||||
}
|
||||
}
|
||||
|
||||
/// Ad-hoc per route CORS response to requests
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for Response {
|
||||
type Error = Error;
|
||||
|
||||
fn from_request(request: &'a Request<'r>) -> rocket::request::Outcome<Self, Self::Error> {
|
||||
let options = match request.guard::<State<Cors>>() {
|
||||
Outcome::Success(options) => options,
|
||||
_ => {
|
||||
let error = Error::MissingCorsInRocketState;
|
||||
return Outcome::Failure((error.status(), error));
|
||||
}
|
||||
};
|
||||
|
||||
match Self::build_cors_response(&options, request) {
|
||||
Ok(response) => Outcome::Success(response),
|
||||
Err(error) => Outcome::Failure((error.status(), error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A [`Responder`](https://rocket.rs/guide/responses/#responder) which will simply wraps another
|
||||
/// `Responder` with CORS headers.
|
||||
///
|
||||
/// Note: If you use this, the lifetime parameter `'r` of your `rocket:::response::Responder<'r>`
|
||||
/// CANNOT be `'static`. This is because the code generated by Rocket will implicitly try to
|
||||
/// to restrain the `Request` object passed to the route to `&'static Request`, and it is not
|
||||
/// possible to have such a reference.
|
||||
/// See [this PR on Rocket](https://github.com/SergioBenitez/Rocket/pull/345).
|
||||
pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>(
|
||||
options: State<'a, Cors>,
|
||||
/// The following CORS headers will be overwritten:
|
||||
///
|
||||
/// - `Access-Control-Allow-Origin`
|
||||
/// - `Access-Control-Expose-Headers`
|
||||
/// - `Access-Control-Max-Age`
|
||||
/// - `Access-Control-Allow-Credentials`
|
||||
/// - `Access-Control-Allow-Methods`
|
||||
/// - `Access-Control-Allow-Headers`
|
||||
/// - `Vary`
|
||||
#[derive(Debug)]
|
||||
pub struct Responder<'r, R> {
|
||||
responder: R,
|
||||
) -> Responder<'a, 'r, R> {
|
||||
options.inner().respond(responder)
|
||||
cors_response: Response,
|
||||
marker: PhantomData<response::Responder<'r>>,
|
||||
}
|
||||
|
||||
impl<'r, R: response::Responder<'r>> Responder<'r, R> {
|
||||
fn new(responder: R, cors_response: Response) -> Self {
|
||||
Self {
|
||||
responder,
|
||||
cors_response,
|
||||
marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Respond to a request
|
||||
fn respond(self, request: &Request) -> response::Result<'r> {
|
||||
let mut response = self.responder.respond_to(request)?; // handle status errors?
|
||||
self.cors_response.merge(&mut response);
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R> {
|
||||
fn respond_to(self, request: &Request) -> response::Result<'r> {
|
||||
self.respond(request)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a CORS response and merge with an existing `rocket::Response` for the request
|
||||
fn build_cors_response(
|
||||
options: &Cors,
|
||||
request: &Request,
|
||||
) -> Result<Response, Error> {
|
||||
fn build_cors_response(options: &Cors, request: &Request) -> Result<Response, Error> {
|
||||
// Existing CORS response?
|
||||
// if has_allow_origin(response) {
|
||||
// return Ok(());
|
||||
|
@ -1072,7 +1091,7 @@ mod tests {
|
|||
let response = response.exposed_headers(&headers);
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build(response::Response::new());
|
||||
let response = response.respond(response::Response::new());
|
||||
let actual_header: Vec<_> = response
|
||||
.headers()
|
||||
.get("Access-Control-Expose-Headers")
|
||||
|
@ -1096,7 +1115,7 @@ mod tests {
|
|||
|
||||
// Build response and check built response header
|
||||
let expected_header = vec!["42"];
|
||||
let response = response.build(response::Response::new());
|
||||
let response = response.respond(response::Response::new());
|
||||
let actual_header: Vec<_> = response.headers().get("Access-Control-Max-Age").collect();
|
||||
assert_eq!(expected_header, actual_header);
|
||||
}
|
||||
|
@ -1109,7 +1128,7 @@ mod tests {
|
|||
let response = response.max_age(None);
|
||||
|
||||
// Build response and check built response header
|
||||
let response = response.build(response::Response::new());
|
||||
let response = response.respond(response::Response::new());
|
||||
assert!(
|
||||
response
|
||||
.headers()
|
||||
|
@ -1198,7 +1217,7 @@ mod tests {
|
|||
#[test]
|
||||
fn response_does_not_build_if_origin_is_not_set() {
|
||||
let response = Response::new();
|
||||
let response = response.build(response::Response::new());
|
||||
let response = response.respond(response::Response::new());
|
||||
|
||||
let headers: Vec<_> = response.headers().iter().collect();
|
||||
assert_eq!(headers.len(), 0);
|
||||
|
@ -1217,7 +1236,7 @@ mod tests {
|
|||
|
||||
let response = Response::new();
|
||||
let response = response.origin("https://www.example.com", false);
|
||||
let response = response.build(original);
|
||||
let response = response.respond(original);
|
||||
// Check CORS header
|
||||
let expected_header = vec!["https://www.example.com"];
|
||||
let actual_header: Vec<_> = response
|
||||
|
|
|
@ -4,34 +4,33 @@
|
|||
#![plugin(rocket_codegen)]
|
||||
extern crate hyper;
|
||||
extern crate rocket;
|
||||
extern crate rocket_cors;
|
||||
extern crate rocket_cors as cors;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use rocket::State;
|
||||
use rocket::http::Method;
|
||||
use rocket::http::{Header, Status};
|
||||
use rocket::local::Client;
|
||||
use rocket_cors::*;
|
||||
|
||||
#[options("/")]
|
||||
fn cors_options(options: State<rocket_cors::Cors>) -> Responder<&str> {
|
||||
rocket_cors::respond(options, "")
|
||||
fn cors_options<'a>(cors: cors::Response) -> cors::Responder<'a, &'a str> {
|
||||
cors.responder("")
|
||||
}
|
||||
|
||||
#[get("/")]
|
||||
fn cors(options: State<rocket_cors::Cors>) -> Responder<&str> {
|
||||
rocket_cors::respond(options, "Hello CORS")
|
||||
fn cors<'a>(cors: cors::Response) -> cors::Responder<'a, &'a str> {
|
||||
cors.responder("Hello CORS")
|
||||
}
|
||||
|
||||
fn make_cors_options() -> Cors {
|
||||
let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
fn make_cors_options() -> cors::Cors {
|
||||
let (allowed_origins, failed_origins) =
|
||||
cors::AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
|
||||
Cors {
|
||||
cors::Cors {
|
||||
allowed_origins: allowed_origins,
|
||||
allowed_methods: [Method::Get].iter().cloned().collect(),
|
||||
allowed_headers: AllOrSome::Some(
|
||||
allowed_headers: cors::AllOrSome::Some(
|
||||
["Authorization"]
|
||||
.into_iter()
|
||||
.map(|s| s.to_string().into())
|
||||
|
@ -44,12 +43,13 @@ fn make_cors_options() -> Cors {
|
|||
|
||||
#[test]
|
||||
fn smoke_test() {
|
||||
let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
let (allowed_origins, failed_origins) =
|
||||
cors::AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||
assert!(failed_origins.is_empty());
|
||||
let cors_options = rocket_cors::Cors {
|
||||
let cors_options = cors::Cors {
|
||||
allowed_origins: allowed_origins,
|
||||
allowed_methods: [Method::Get].iter().cloned().collect(),
|
||||
allowed_headers: AllOrSome::Some(
|
||||
allowed_headers: cors::AllOrSome::Some(
|
||||
["Authorization"]
|
||||
.iter()
|
||||
.map(|s| s.to_string().into())
|
Loading…
Reference in New Issue