Use async version from rocket's master branch

* switch rocket version to master branch
  (use release version once async is available)
* adapt code to incorporate changes from rocket and hyper

Co-authored-by: Maximilian Köstler <maximilian@koestler.hamburg>
This commit is contained in:
Henning Holm 2020-08-11 10:36:40 +02:00
parent 7e13b63313
commit 2046f3c7c0
13 changed files with 690 additions and 655 deletions

View File

@ -22,7 +22,7 @@ serialization = ["serde", "serde_derive", "unicase_serde"]
[dependencies]
regex = "1.1"
rocket = { version = "0.4.2", default-features = false }
rocket = { git="https://github.com/SergioBenitez/Rocket.git", default-features = false }
log = "0.4"
unicase = "2.0"
url = "2.1.0"

View File

@ -2,16 +2,19 @@
use rocket;
use rocket_cors;
use std::error::Error;
use rocket::http::Method;
use rocket::{get, routes};
use rocket_cors::{AllowedHeaders, AllowedOrigins, Error};
use rocket_cors::{AllowedHeaders, AllowedOrigins};
#[get("/")]
fn cors<'a>() -> &'a str {
"Hello CORS"
}
fn main() -> Result<(), Error> {
#[rocket::main]
async fn main() -> Result<(), Box<dyn Error>> {
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this
@ -27,7 +30,8 @@ fn main() -> Result<(), Error> {
rocket::ignite()
.mount("/", routes![cors])
.attach(cors)
.launch();
.launch()
.await?;
Ok(())
}

View File

@ -2,16 +2,17 @@
use rocket;
use rocket_cors;
use std::error::Error;
use std::io::Cursor;
use rocket::http::Method;
use rocket::Response;
use rocket::{get, options, routes};
use rocket_cors::{AllowedHeaders, AllowedOrigins, Error, Guard, Responder};
use rocket_cors::{AllowedHeaders, AllowedOrigins, Guard, Responder};
/// Using a `Responder` -- the usual way you would use this
#[get("/")]
fn responder(cors: Guard<'_>) -> Responder<'_, &str> {
fn responder(cors: Guard<'_>) -> Responder<'_, '_, &str> {
cors.responder("Hello CORS!")
}
@ -19,23 +20,25 @@ fn responder(cors: Guard<'_>) -> Responder<'_, &str> {
#[get("/response")]
fn response(cors: Guard<'_>) -> Response<'_> {
let mut response = Response::new();
response.set_sized_body(Cursor::new("Hello CORS!"));
let body = "Hello CORS!";
response.set_sized_body(body.len(), Cursor::new(body));
cors.response(response)
}
/// Manually mount an OPTIONS route for your own handling
#[options("/manual")]
fn manual_options(cors: Guard<'_>) -> Responder<'_, &str> {
fn manual_options(cors: Guard<'_>) -> Responder<'_, '_, &str> {
cors.responder("Manual OPTIONS preflight handling")
}
/// Manually mount an OPTIONS route for your own handling
#[get("/manual")]
fn manual(cors: Guard<'_>) -> Responder<'_, &str> {
fn manual(cors: Guard<'_>) -> Responder<'_, '_, &str> {
cors.responder("Manual OPTIONS preflight handling")
}
fn main() -> Result<(), Error> {
#[rocket::main]
async fn main() -> Result<(), Box<dyn Error>> {
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
// You can also deserialize this
@ -55,7 +58,8 @@ fn main() -> Result<(), Error> {
// You can also manually mount an OPTIONS route that will be used instead
.mount("/", routes![manual, manual_options])
.manage(cors)
.launch();
.launch()
.await?;
Ok(())
}

View File

@ -4,6 +4,7 @@ use rocket_cors;
use std::io::Cursor;
use rocket::error::Error;
use rocket::http::Method;
use rocket::response::Responder;
use rocket::{get, options, routes, Response, State};
@ -17,7 +18,7 @@ use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions};
/// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime
/// `'r` and so does `Responder`!
#[get("/")]
fn borrowed(options: State<'_, Cors>) -> impl Responder<'_> {
fn borrowed(options: State<'_, Cors>) -> impl Responder<'_, '_> {
options
.inner()
.respond_borrowed(|guard| guard.responder("Hello CORS"))
@ -27,9 +28,10 @@ fn borrowed(options: State<'_, Cors>) -> impl Responder<'_> {
/// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime
/// `'r` and so does `Responder`!
#[get("/response")]
fn response(options: State<'_, Cors>) -> impl Responder<'_> {
fn response(options: State<'_, Cors>) -> impl Responder<'_, '_> {
let mut response = Response::new();
response.set_sized_body(Cursor::new("Hello CORS!"));
let body = "Hello CORS!";
response.set_sized_body(body.len(), Cursor::new(body));
options
.inner()
@ -43,7 +45,7 @@ fn response(options: State<'_, Cors>) -> impl Responder<'_> {
/// when the settings you want to use for a route is not the same as the rest of the application
/// (which you might have put in Rocket's state).
#[get("/owned")]
fn owned<'r>() -> impl Responder<'r> {
fn owned<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
let options = cors_options().to_cors()?;
options.respond_owned(|guard| guard.responder("Hello CORS"))
}
@ -53,7 +55,7 @@ fn owned<'r>() -> impl Responder<'r> {
/// These routes can just return the unit type `()`
/// Note that the `'r` lifetime is needed because the compiler cannot elide anything.
#[options("/owned")]
fn owned_options<'r>() -> impl Responder<'r> {
fn owned_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
let options = cors_options().to_cors()?;
options.respond_owned(|guard| guard.responder(()))
}
@ -71,10 +73,12 @@ fn cors_options() -> CorsOptions {
}
}
fn main() {
#[rocket::main]
async fn main() -> Result<(), Error> {
rocket::ignite()
.mount("/", routes![borrowed, response, owned, owned_options,])
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
.manage(cors_options().to_cors().expect("To not fail"))
.launch();
.launch()
.await
}

View File

@ -7,6 +7,7 @@
use rocket;
use rocket_cors;
use rocket::error::Error;
use rocket::http::Method;
use rocket::response::Responder;
use rocket::{get, options, routes};
@ -14,13 +15,13 @@ use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard};
/// The "usual" app route
#[get("/")]
fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, &str> {
fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, '_, &str> {
cors.responder("Hello CORS!")
}
/// The special "ping" route
#[get("/ping")]
fn ping<'r>() -> impl Responder<'r> {
fn ping<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
let cors = cors_options_all().to_cors()?;
cors.respond_owned(|guard| guard.responder("Pong!"))
}
@ -29,7 +30,7 @@ fn ping<'r>() -> impl Responder<'r> {
/// that is not in Rocket's managed state.
/// These routes can just return the unit type `()`
#[options("/ping")]
fn ping_options<'r>() -> impl Responder<'r> {
fn ping_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
let cors = cors_options_all().to_cors()?;
cors.respond_owned(|guard| guard.responder(()))
}
@ -57,10 +58,12 @@ fn cors_options_all() -> CorsOptions {
Default::default()
}
fn main() {
#[rocket::main]
async fn main() -> Result<(), Error> {
rocket::ignite()
.mount("/", routes![app, ping, ping_options,])
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
.manage(cors_options().to_cors().expect("To not fail"))
.launch();
.launch()
.await
}

View File

@ -2,7 +2,7 @@
use ::log::{error, info};
use rocket::http::{self, uri::Origin, Status};
use rocket::{self, error_, info_, log_, Outcome, Request};
use rocket::{self, error_, info_, log_, outcome::Outcome, Request};
use crate::{
actual_request_response, origin, preflight_response, request_headers, validate, Cors, Error,
@ -14,25 +14,32 @@ enum CorsValidation {
Failure,
}
/// Route for Fairing error handling
pub(crate) fn fairing_error_route<'r>(
request: &'r Request<'_>,
_: rocket::Data,
) -> rocket::handler::Outcome<'r> {
let status = request
.get_param::<u16>(0)
.unwrap_or(Ok(0))
.unwrap_or_else(|e| {
error_!("Fairing Error Handling Route error: {:?}", e);
500
});
let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError);
Outcome::Failure(status)
/// Create a `Handler` for Fairing error handling
#[derive(Clone)]
struct FairingErrorRoute {}
#[rocket::async_trait]
impl rocket::handler::Handler for FairingErrorRoute {
async fn handle<'r, 's: 'r>(
&'s self,
request: &'r Request<'_>,
_: rocket::Data,
) -> rocket::handler::Outcome<'r> {
let status = request
.get_param::<u16>(0)
.unwrap_or(Ok(0))
.unwrap_or_else(|e| {
error_!("Fairing Error Handling Route error: {:?}", e);
500
});
let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError);
Outcome::Failure(status)
}
}
/// Create a new `Route` for Fairing handling
fn fairing_route(rank: isize) -> rocket::Route {
rocket::Route::ranked(rank, http::Method::Get, "/<status>", fairing_error_route)
rocket::Route::ranked(rank, http::Method::Get, "/<status>", FairingErrorRoute {})
}
/// Modifies a `Request` to route to Fairing error handler
@ -90,6 +97,7 @@ fn on_response_wrapper(
Ok(())
}
#[rocket::async_trait]
impl rocket::fairing::Fairing for Cors {
fn info(&self) -> rocket::fairing::Info {
rocket::fairing::Info {
@ -100,14 +108,14 @@ impl rocket::fairing::Fairing for Cors {
}
}
fn on_attach(&self, rocket: rocket::Rocket) -> Result<rocket::Rocket, rocket::Rocket> {
async fn on_attach(&self, rocket: rocket::Rocket) -> Result<rocket::Rocket, rocket::Rocket> {
Ok(rocket.mount(
&self.fairing_route_base,
vec![fairing_route(self.fairing_route_rank)],
))
}
fn on_request(&self, request: &mut Request<'_>, _: &rocket::Data) {
async fn on_request(&self, request: &mut Request<'_>, _: &rocket::Data) {
let result = match validate(self, request) {
Ok(_) => CorsValidation::Success,
Err(err) => {
@ -121,7 +129,7 @@ impl rocket::fairing::Fairing for Cors {
let _ = request.local_cache(|| result);
}
fn on_response(&self, request: &Request<'_>, response: &mut rocket::Response<'_>) {
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut rocket::Response<'r>) {
if let Err(err) = on_response_wrapper(self, request, response) {
error_!("Fairings on_response error: {}\nMost likely a bug", err);
response.set_status(Status::InternalServerError);
@ -133,7 +141,7 @@ impl rocket::fairing::Fairing for Cors {
#[cfg(test)]
mod tests {
use rocket::http::{Method, Status};
use rocket::local::Client;
use rocket::local::blocking::Client;
use rocket::Rocket;
use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions};
@ -161,7 +169,8 @@ mod tests {
}
#[test]
fn fairing_error_route_returns_passed_in_status() {
#[allow(non_snake_case)]
fn FairingErrorRoute_returns_passed_in_status() {
let client = Client::new(rocket(make_cors_options())).expect("to not fail");
let request = client.get(format!("{}/403", CORS_ROOT));
let response = request.dispatch();
@ -169,19 +178,22 @@ mod tests {
}
#[test]
fn fairing_error_route_returns_500_for_unknown_status() {
#[allow(non_snake_case)]
fn FairingErrorRoute_returns_500_for_unknown_status() {
let client = Client::new(rocket(make_cors_options())).expect("to not fail");
let request = client.get(format!("{}/999", CORS_ROOT));
let response = request.dispatch();
assert_eq!(Status::InternalServerError, response.status());
}
#[test]
fn error_route_is_mounted_on_attach() {
let rocket = rocket(make_cors_options());
#[rocket::async_test]
async fn error_route_is_mounted_on_attach() {
let mut rocket = rocket(make_cors_options());
let expected_uri = format!("{}/<status>", CORS_ROOT);
let error_route = rocket
.inspect()
.await
.routes()
.find(|r| r.method == Method::Get && r.uri.to_string() == expected_uri);
assert!(error_route.is_some());

View File

@ -7,7 +7,7 @@ use std::str::FromStr;
use rocket::http::Status;
use rocket::request::{self, FromRequest};
use rocket::{self, Outcome};
use rocket::{self, outcome::Outcome};
#[cfg(feature = "serialization")]
use serde_derive::{Deserialize, Serialize};
use unicase::UniCase;
@ -91,6 +91,24 @@ impl Origin {
Origin::Opaque(_) => false,
}
}
/// Derives an instance of `Self` from the incoming request metadata.
///
/// If the derivation is successful, an outcome of `Success` is returned. If
/// the derivation fails in an unrecoverable fashion, `Failure` is returned.
/// `Forward` is returned to indicate that the request should be forwarded
/// to other matching routes, if any.
pub fn from_request_sync(
request: &'_ rocket::Request<'_>,
) -> request::Outcome<Self, crate::Error> {
match request.headers().get_one("Origin") {
Some(origin) => match Self::from_str(origin) {
Ok(origin) => Outcome::Success(origin),
Err(e) => Outcome::Failure((Status::BadRequest, e)),
},
None => Outcome::Forward(()),
}
}
}
impl FromStr for Origin {
@ -118,19 +136,17 @@ impl fmt::Display for Origin {
}
}
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Origin {
type Error = crate::Error;
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
match request.headers().get_one("Origin") {
Some(origin) => match Self::from_str(origin) {
Ok(origin) => Outcome::Success(origin),
Err(e) => Outcome::Failure((Status::BadRequest, e)),
},
None => Outcome::Forward(()),
}
async fn from_request(
request: &'a rocket::Request<'r>,
) -> request::Outcome<Self, crate::Error> {
Origin::from_request_sync(request)
}
}
/// The `Access-Control-Request-Method` request header
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
@ -138,18 +154,16 @@ impl<'a, 'r> FromRequest<'a, 'r> for Origin {
#[derive(Debug)]
pub struct AccessControlRequestMethod(pub crate::Method);
impl FromStr for AccessControlRequestMethod {
type Err = ();
fn from_str(method: &str) -> Result<Self, Self::Err> {
Ok(AccessControlRequestMethod(crate::Method::from_str(method)?))
}
}
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
type Error = crate::Error;
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
impl AccessControlRequestMethod {
/// Derives an instance of `Self` from the incoming request metadata.
///
/// If the derivation is successful, an outcome of `Success` is returned. If
/// the derivation fails in an unrecoverable fashion, `Failure` is returned.
/// `Forward` is returned to indicate that the request should be forwarded
/// to other matching routes, if any.
pub fn from_request_sync(
request: &'_ rocket::Request<'_>,
) -> request::Outcome<Self, crate::Error> {
match request.headers().get_one("Access-Control-Request-Method") {
Some(request_method) => match Self::from_str(request_method) {
Ok(request_method) => Outcome::Success(request_method),
@ -160,6 +174,25 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
}
}
impl FromStr for AccessControlRequestMethod {
type Err = ();
fn from_str(method: &str) -> Result<Self, Self::Err> {
Ok(AccessControlRequestMethod(crate::Method::from_str(method)?))
}
}
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
type Error = crate::Error;
async fn from_request(
request: &'a rocket::Request<'r>,
) -> request::Outcome<Self, crate::Error> {
AccessControlRequestMethod::from_request_sync(request)
}
}
/// The `Access-Control-Request-Headers` request header
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
@ -167,6 +200,28 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
#[derive(Eq, PartialEq, Debug)]
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
impl AccessControlRequestHeaders {
/// Derives an instance of `Self` from the incoming request metadata.
///
/// If the derivation is successful, an outcome of `Success` is returned. If
/// the derivation fails in an unrecoverable fashion, `Failure` is returned.
/// `Forward` is returned to indicate that the request should be forwarded
/// to other matching routes, if any.
pub fn from_request_sync(
request: &'_ rocket::Request<'_>,
) -> request::Outcome<Self, crate::Error> {
match request.headers().get_one("Access-Control-Request-Headers") {
Some(request_headers) => match Self::from_str(request_headers) {
Ok(request_headers) => Outcome::Success(request_headers),
Err(()) => {
unreachable!("`AccessControlRequestHeaders::from_str` should never fail")
}
},
None => Outcome::Forward(()),
}
}
}
/// Will never fail
impl FromStr for AccessControlRequestHeaders {
type Err = ();
@ -185,19 +240,14 @@ impl FromStr for AccessControlRequestHeaders {
}
}
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
type Error = crate::Error;
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
match request.headers().get_one("Access-Control-Request-Headers") {
Some(request_headers) => match Self::from_str(request_headers) {
Ok(request_headers) => Outcome::Success(request_headers),
Err(()) => {
unreachable!("`AccessControlRequestHeaders::from_str` should never fail")
}
},
None => Outcome::Forward(()),
}
async fn from_request(
request: &'a rocket::Request<'r>,
) -> request::Outcome<Self, crate::Error> {
AccessControlRequestHeaders::from_request_sync(request)
}
}
@ -207,7 +257,8 @@ mod tests {
use rocket;
use rocket::http::hyper;
use rocket::local::Client;
use rocket::http::Header;
use rocket::local::blocking::Client;
use super::*;
@ -277,11 +328,10 @@ mod tests {
let client = make_client();
let mut request = client.get("/");
let origin = hyper::header::Origin::new("https", "www.example.com", None);
let origin = Header::new(hyper::header::ORIGIN.as_str(), "https://www.example.com");
request.add_header(origin);
let outcome: request::Outcome<Origin, crate::Error> =
FromRequest::from_request(request.inner());
let outcome = Origin::from_request_sync(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
assert_eq!(
"https://www.example.com",
@ -313,10 +363,12 @@ mod tests {
fn request_method_parsing() {
let client = make_client();
let mut request = client.get("/");
let method = hyper::header::AccessControlRequestMethod(hyper::Method::Get);
let method = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
request.add_header(method);
let outcome: request::Outcome<AccessControlRequestMethod, crate::Error> =
FromRequest::from_request(request.inner());
let outcome = AccessControlRequestMethod::from_request_sync(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
let AccessControlRequestMethod(parsed_method) = parsed_header;
@ -337,13 +389,12 @@ mod tests {
fn request_headers_parsing() {
let client = make_client();
let mut request = client.get("/");
let headers = hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("accept-language").unwrap(),
FromStr::from_str("date").unwrap(),
]);
let headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"accept-language, date",
);
request.add_header(headers);
let outcome: request::Outcome<AccessControlRequestHeaders, crate::Error> =
FromRequest::from_request(request.inner());
let outcome = AccessControlRequestHeaders::from_request_sync(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
let AccessControlRequestHeaders(parsed_headers) = parsed_header;

View File

@ -285,7 +285,7 @@ use regex::RegexSet;
use rocket::http::{self, Status};
use rocket::request::{FromRequest, Request};
use rocket::response;
use rocket::{debug_, error_, info_, log_, Outcome, State};
use rocket::{debug_, error_, info_, log_, outcome::Outcome, State};
#[cfg(feature = "serialization")]
use serde_derive::{Deserialize, Serialize};
@ -417,8 +417,8 @@ impl error::Error for Error {
}
}
impl<'r> response::Responder<'r> for Error {
fn respond_to(self, _: &Request<'_>) -> Result<response::Response<'r>, Status> {
impl<'r, 'o: 'r> response::Responder<'r, 'o> for Error {
fn respond_to(self, _: &Request<'_>) -> Result<response::Response<'o>, Status> {
error_!("CORS Error: {}", self);
Err(self.status())
}
@ -1201,7 +1201,7 @@ impl CorsOptions {
}
/// Sets the rank of the fairing route
pub fn fairing_route_rank(mut self, fairing_route_rank: isize) -> Self {
pub fn fairing_route_rank(mut self, fairing_route_rank: isize) -> Self {
self.fairing_route_rank = fairing_route_rank;
self
}
@ -1256,10 +1256,13 @@ impl Cors {
/// passed in to include the CORS headers.
///
/// See the documentation at the [crate root](index.html) for usage information.
pub fn respond_owned<'r, F, R>(self, handler: F) -> Result<ManualResponder<'r, F, R>, Error>
pub fn respond_owned<'r, 'o: 'r, F, R>(
self,
handler: F,
) -> Result<ManualResponder<'r, F, R>, Error>
where
F: FnOnce(Guard<'r>) -> R + 'r,
R: response::Responder<'r>,
R: response::Responder<'r, 'o>,
{
Ok(ManualResponder::new(Cow::Owned(self), handler))
}
@ -1276,13 +1279,13 @@ impl Cors {
/// passed in to include the CORS headers.
///
/// See the documentation at the [crate root](index.html) for usage information.
pub fn respond_borrowed<'r, F, R>(
pub fn respond_borrowed<'r, 'o: 'r, F, R>(
&'r self,
handler: F,
) -> Result<ManualResponder<'r, F, R>, Error>
where
F: FnOnce(Guard<'r>) -> R + 'r,
R: response::Responder<'r>,
R: response::Responder<'r, 'o>,
{
Ok(ManualResponder::new(Cow::Borrowed(self), handler))
}
@ -1375,7 +1378,10 @@ impl Response {
/// Consumes the `Response` and return a `Responder` that wraps a
/// provided `rocket:response::Responder` with CORS headers
pub fn responder<'r, R: response::Responder<'r>>(self, responder: R) -> Responder<'r, R> {
pub fn responder<'r, 'o: 'r, R: response::Responder<'r, 'o>>(
self,
responder: R,
) -> Responder<'r, 'o, R> {
Responder::new(responder, self)
}
@ -1486,7 +1492,7 @@ pub struct Guard<'r> {
marker: PhantomData<&'r Response>,
}
impl<'r> Guard<'r> {
impl<'r, 'o: 'r> Guard<'r> {
fn new(response: Response) -> Self {
Self {
response,
@ -1496,7 +1502,7 @@ impl<'r> Guard<'r> {
/// Consumes the Guard and return a `Responder` that wraps a
/// provided `rocket:response::Responder` with CORS headers
pub fn responder<R: response::Responder<'r>>(self, responder: R) -> Responder<'r, R> {
pub fn responder<R: response::Responder<'r, 'o>>(self, responder: R) -> Responder<'r, 'o, R> {
self.response.responder(responder)
}
@ -1509,11 +1515,12 @@ impl<'r> Guard<'r> {
}
}
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> {
type Error = Error;
fn from_request(request: &'a Request<'r>) -> rocket::request::Outcome<Self, Self::Error> {
let options = match request.guard::<State<'_, Cors>>() {
async fn from_request(request: &'a Request<'r>) -> rocket::request::Outcome<Self, Self::Error> {
let options = match request.guard::<State<'_, Cors>>().await {
Outcome::Success(options) => options,
_ => {
let error = Error::MissingCorsInRocketState;
@ -1545,13 +1552,13 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> {
///
/// See the documentation at the [crate root](index.html) for usage information.
#[derive(Debug)]
pub struct Responder<'r, R> {
pub struct Responder<'r, 'o, R> {
responder: R,
cors_response: Response,
marker: PhantomData<dyn response::Responder<'r>>,
marker: PhantomData<dyn response::Responder<'r, 'o>>,
}
impl<'r, R: response::Responder<'r>> Responder<'r, R> {
impl<'r, 'o: 'r, R: response::Responder<'r, 'o>> Responder<'r, 'o, R> {
fn new(responder: R, cors_response: Response) -> Self {
Self {
responder,
@ -1561,15 +1568,17 @@ impl<'r, R: response::Responder<'r>> Responder<'r, R> {
}
/// Respond to a request
fn respond(self, request: &Request<'_>) -> response::Result<'r> {
fn respond(self, request: &'r Request<'_>) -> response::Result<'o> {
let mut response = self.responder.respond_to(request)?; // handle status errors?
self.cors_response.merge(&mut response);
Ok(response)
}
}
impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R> {
fn respond_to(self, request: &Request<'_>) -> response::Result<'r> {
impl<'r, 'o: 'r, R: response::Responder<'r, 'o>> response::Responder<'r, 'o>
for Responder<'r, 'o, R>
{
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> {
self.respond(request)
}
}
@ -1583,10 +1592,10 @@ pub struct ManualResponder<'r, F, R> {
marker: PhantomData<R>,
}
impl<'r, F, R> ManualResponder<'r, F, R>
impl<'r, 'o: 'r, F, R> ManualResponder<'r, F, R>
where
F: FnOnce(Guard<'r>) -> R + 'r,
R: response::Responder<'r>,
R: response::Responder<'r, 'o>,
{
/// Create a new manual responder by passing in either a borrowed or owned `Cors` option.
///
@ -1607,12 +1616,12 @@ where
}
}
impl<'r, F, R> response::Responder<'r> for ManualResponder<'r, F, R>
impl<'r, 'o: 'r, F, R> response::Responder<'r, 'o> for ManualResponder<'r, F, R>
where
F: FnOnce(Guard<'r>) -> R + 'r,
R: response::Responder<'r>,
R: response::Responder<'r, 'o>,
{
fn respond_to(self, request: &Request<'_>) -> response::Result<'r> {
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> {
let guard = match self.build_guard(request) {
Ok(guard) => guard,
Err(err) => {
@ -1759,7 +1768,7 @@ fn validate_allowed_headers(
/// Gets the `Origin` request header from the request
fn origin(request: &Request<'_>) -> Result<Option<Origin>, Error> {
match Origin::from_request(request) {
match Origin::from_request_sync(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(origin) => Ok(Some(origin)),
Outcome::Failure((_, err)) => Err(err),
@ -1768,7 +1777,7 @@ fn origin(request: &Request<'_>) -> Result<Option<Origin>, Error> {
/// Gets the `Access-Control-Request-Method` request header from the request
fn request_method(request: &Request<'_>) -> Result<Option<AccessControlRequestMethod>, Error> {
match AccessControlRequestMethod::from_request(request) {
match AccessControlRequestMethod::from_request_sync(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(method) => Ok(Some(method)),
Outcome::Failure((_, err)) => Err(err),
@ -1777,7 +1786,7 @@ fn request_method(request: &Request<'_>) -> Result<Option<AccessControlRequestMe
/// Gets the `Access-Control-Request-Headers` request header from the request
fn request_headers(request: &Request<'_>) -> Result<Option<AccessControlRequestHeaders>, Error> {
match AccessControlRequestHeaders::from_request(request) {
match AccessControlRequestHeaders::from_request_sync(request) {
Outcome::Forward(()) => Ok(None),
Outcome::Success(geaders) => Ok(Some(geaders)),
Outcome::Failure((_, err)) => Err(err),
@ -1984,34 +1993,41 @@ pub fn catch_all_options_routes() -> Vec<rocket::Route> {
isize::max_value(),
http::Method::Options,
"/",
catch_all_options_route_handler,
CatchAllOptionsRouteHandler {},
),
rocket::Route::ranked(
isize::max_value(),
http::Method::Options,
"/<catch_all_options_route..>",
catch_all_options_route_handler,
CatchAllOptionsRouteHandler {},
),
]
}
/// Handler for the "catch all options route"
fn catch_all_options_route_handler<'r>(
request: &'r Request<'_>,
_: rocket::Data,
) -> rocket::handler::Outcome<'r> {
let guard: Guard<'_> = match request.guard() {
Outcome::Success(guard) => guard,
Outcome::Failure((status, _)) => return rocket::handler::Outcome::failure(status),
Outcome::Forward(()) => unreachable!("Should not be reachable"),
};
#[derive(Clone)]
struct CatchAllOptionsRouteHandler {}
info_!(
"\"Catch all\" handling of CORS `OPTIONS` preflight for request {}",
request
);
#[rocket::async_trait]
impl rocket::handler::Handler for CatchAllOptionsRouteHandler {
async fn handle<'r, 's: 'r>(
&'s self,
request: &'r Request<'_>,
_: rocket::Data,
) -> rocket::handler::Outcome<'r> {
let guard: Guard<'_> = match request.guard().await {
Outcome::Success(guard) => guard,
Outcome::Failure((status, _)) => return rocket::handler::Outcome::failure(status),
Outcome::Forward(()) => unreachable!("Should not be reachable"),
};
rocket::handler::Outcome::from(request, guard.responder(()))
info_!(
"\"Catch all\" handling of CORS `OPTIONS` preflight for request {}",
request
);
rocket::handler::Outcome::from(request, guard.responder(()))
}
}
#[cfg(test)]
@ -2020,7 +2036,7 @@ mod tests {
use rocket::http::hyper;
use rocket::http::Header;
use rocket::local::Client;
use rocket::local::blocking::Client;
#[cfg(feature = "serialization")]
use serde_json;
@ -2084,10 +2100,20 @@ mod tests {
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
let cors_options_from_builder = CorsOptions::default()
.allowed_origins(allowed_origins)
.allowed_methods(vec![http::Method::Get].into_iter().map(From::from).collect())
.allowed_methods(
vec![http::Method::Get]
.into_iter()
.map(From::from)
.collect(),
)
.allowed_headers(AllowedHeaders::some(&[&"Authorization", "Accept"]))
.allow_credentials(true)
.expose_headers(["Content-Type", "X-Custom"].iter().map(|s| (*s).to_string()).collect());
.expose_headers(
["Content-Type", "X-Custom"]
.iter()
.map(|s| (*s).to_string())
.collect(),
);
assert_eq!(cors_options_from_builder, make_cors_options());
}
@ -2508,11 +2534,12 @@ mod tests {
fn response_build_removes_existing_cors_headers_and_keeps_others() {
use std::io::Cursor;
let body = "Brewing the best coffee!";
let original = response::Response::build()
.status(Status::ImATeapot)
.raw_header("X-Teapot-Make", "Rocket")
.raw_header("Access-Control-Max-Age", "42")
.sized_body(Cursor::new("Brewing the best coffee!"))
.sized_body(body.len(), Cursor::new(body))
.finalize();
let response = Response::new();
@ -2573,16 +2600,15 @@ mod tests {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let request = client
.options("/")
@ -2608,16 +2634,15 @@ mod tests {
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.example.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let request = client
.options("/")
@ -2640,16 +2665,15 @@ mod tests {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.example.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let request = client
.options("/")
@ -2666,13 +2690,11 @@ mod tests {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let request = client
.options("/")
@ -2688,16 +2710,15 @@ mod tests {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Post,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::POST.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let request = client
.options("/")
@ -2714,16 +2735,15 @@ mod tests {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap(),
FromStr::from_str("X-NOT-ALLOWED").unwrap(),
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization, X-NOT-ALLOWED",
);
let request = client
.options("/")
@ -2739,8 +2759,7 @@ mod tests {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let request = client.get("/").header(origin_header);
let result = validate(&cors, request.inner()).expect("to not fail");
@ -2758,8 +2777,7 @@ mod tests {
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.example.com");
let request = client.get("/").header(origin_header);
let result = validate(&cors, request.inner()).expect("to not fail");
@ -2776,8 +2794,7 @@ mod tests {
let cors = make_cors_options().to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.example.com");
let request = client.get("/").header(origin_header);
let _ = validate(&cors, request.inner()).unwrap();
@ -2800,16 +2817,15 @@ mod tests {
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let request = client
.options("/")
@ -2840,16 +2856,15 @@ mod tests {
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let request = client
.options("/")
@ -2880,16 +2895,15 @@ mod tests {
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let request = client
.options("/")
@ -2915,8 +2929,7 @@ mod tests {
let cors = options.to_cors().expect("To not fail");
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let request = client.get("/").header(origin_header);
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
@ -2938,8 +2951,7 @@ mod tests {
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let request = client.get("/").header(origin_header);
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
@ -2961,8 +2973,7 @@ mod tests {
let client = make_client();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let request = client.get("/").header(origin_header);
let response = validate_and_build(&cors, request.inner()).expect("to not fail");

View File

@ -1,12 +1,9 @@
//! This crate tests using `rocket_cors` using Fairings
#![feature(proc_macro_hygiene, decl_macro)]
use std::str::FromStr;
use rocket::http::hyper;
use rocket::http::Method;
use rocket::http::{Header, Status};
use rocket::local::Client;
use rocket::response::Body;
use rocket::local::blocking::Client;
use rocket::{get, routes};
use rocket_cors::*;
@ -45,16 +42,15 @@ fn smoke_test() {
let client = Client::new(rocket()).unwrap();
// `Options` pre-flight checks
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -65,37 +61,34 @@ fn smoke_test() {
assert!(response.status().class().is_success());
// "Actual" request
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS".to_string()));
}
#[test]
fn cors_options_check() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -116,21 +109,19 @@ fn cors_options_check() {
fn cors_get_check() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS".to_string()));
}
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
@ -141,9 +132,9 @@ fn cors_get_no_origin() {
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS".to_string()));
}
@ -151,16 +142,15 @@ fn cors_get_no_origin() {
fn cors_options_bad_origin() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -176,14 +166,14 @@ fn cors_options_bad_origin() {
fn cors_options_missing_origin() {
let client = Client::new(rocket()).unwrap();
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(method_header)
@ -202,16 +192,15 @@ fn cors_options_missing_origin() {
fn cors_options_bad_request_method() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Post,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::POST.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -230,14 +219,15 @@ fn cors_options_bad_request_method() {
fn cors_options_bad_request_header() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Foobar",
);
let req = client
.options("/")
.header(origin_header)
@ -256,8 +246,7 @@ fn cors_options_bad_request_header() {
fn cors_get_bad_origin() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
@ -276,16 +265,15 @@ fn cors_get_bad_origin() {
fn routes_failing_checks_are_not_executed() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/panic")
.header(origin_header)

View File

@ -2,18 +2,15 @@
#![feature(proc_macro_hygiene, decl_macro)]
use rocket_cors as cors;
use std::str::FromStr;
use rocket::http::hyper;
use rocket::http::Method;
use rocket::http::{Header, Status};
use rocket::local::Client;
use rocket::response::Body;
use rocket::local::blocking::Client;
use rocket::{get, options, routes};
use rocket::{Response, State};
#[get("/")]
fn cors(cors: cors::Guard<'_>) -> cors::Responder<'_, &str> {
fn cors(cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> {
cors.responder("Hello CORS")
}
@ -24,13 +21,13 @@ fn panicking_route(_cors: cors::Guard<'_>) {
/// Manually specify our own OPTIONS route
#[options("/manual")]
fn cors_manual_options(cors: cors::Guard<'_>) -> cors::Responder<'_, &str> {
fn cors_manual_options(cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> {
cors.responder("Manual CORS Preflight")
}
/// Manually specify our own OPTIONS route
#[get("/manual")]
fn cors_manual(cors: cors::Guard<'_>) -> cors::Responder<'_, &str> {
fn cors_manual(cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> {
cors.responder("Hello CORS")
}
@ -42,20 +39,23 @@ fn response(cors: cors::Guard<'_>) -> Response<'_> {
/// `Responder` with String
#[get("/responder/string")]
fn responder_string(cors: cors::Guard<'_>) -> cors::Responder<'_, String> {
fn responder_string(cors: cors::Guard<'_>) -> cors::Responder<'_, 'static, String> {
cors.responder("Hello CORS".to_string())
}
/// `Responder` with 'static ()
#[get("/responder/unit")]
fn responder_unit(cors: cors::Guard<'_>) -> cors::Responder<'_, ()> {
fn responder_unit(cors: cors::Guard<'_>) -> cors::Responder<'_, 'static, ()> {
cors.responder(())
}
struct SomeState;
/// Borrow `SomeState` from Rocket
#[get("/state")]
fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Responder<'r, &'r str> {
fn state<'r, 'o: 'r>(
cors: cors::Guard<'r>,
_state: State<'r, SomeState>,
) -> cors::Responder<'r, 'o, &'r str> {
cors.responder("hmm")
}
@ -92,16 +92,15 @@ fn smoke_test() {
let client = Client::new(rocket).unwrap();
// `Options` pre-flight checks
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -112,21 +111,19 @@ fn smoke_test() {
assert!(response.status().class().is_success());
// "Actual" request
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS".to_string()));
}
/// Check the "catch all" OPTIONS route works for `/`
@ -135,16 +132,15 @@ fn cors_options_catch_all_check() {
let rocket = make_rocket();
let client = Client::new(rocket).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -167,16 +163,15 @@ fn cors_options_catch_all_check_other_routes() {
let rocket = make_rocket();
let client = Client::new(rocket).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/response/unit")
.header(origin_header)
@ -198,21 +193,19 @@ fn cors_get_check() {
let rocket = make_rocket();
let client = Client::new(rocket).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS".to_string()));
}
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
@ -224,14 +217,14 @@ fn cors_get_no_origin() {
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string()));
assert!(response
.headers()
.get_one("Access-Control-Allow-Origin")
.is_none());
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS".to_string()));
}
#[test]
@ -239,16 +232,15 @@ fn cors_options_bad_origin() {
let rocket = make_rocket();
let client = Client::new(rocket).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -268,14 +260,14 @@ fn cors_options_missing_origin() {
let rocket = make_rocket();
let client = Client::new(rocket).unwrap();
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(method_header)
@ -294,16 +286,15 @@ fn cors_options_bad_request_method() {
let rocket = make_rocket();
let client = Client::new(rocket).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Post,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::POST.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -323,14 +314,15 @@ fn cors_options_bad_request_header() {
let rocket = make_rocket();
let client = Client::new(rocket).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Foobar",
);
let req = client
.options("/")
.header(origin_header)
@ -350,8 +342,7 @@ fn cors_get_bad_origin() {
let rocket = make_rocket();
let client = Client::new(rocket).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
@ -371,8 +362,7 @@ fn routes_failing_checks_are_not_executed() {
let rocket = make_rocket();
let client = Client::new(rocket).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
@ -391,30 +381,28 @@ fn overridden_options_routes_are_used() {
let rocket = make_rocket();
let client = Client::new(rocket).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/manual")
.header(origin_header)
.header(method_header)
.header(request_headers);
let mut response = req.dispatch();
let body_str = response.body().and_then(Body::into_string);
let response = req.dispatch();
assert!(response.status().class().is_success());
assert_eq!(body_str, Some("Manual CORS Preflight".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Manual CORS Preflight".to_string()));
}

View File

@ -1,12 +1,10 @@
//! This crate tests that all the request headers are parsed correctly in the round trip
#![feature(proc_macro_hygiene, decl_macro)]
use std::ops::Deref;
use std::str::FromStr;
use rocket::http::hyper;
use rocket::http::Header;
use rocket::local::Client;
use rocket::response::Body;
use rocket::local::blocking::Client;
use rocket::{get, routes};
use rocket_cors::headers::*;
@ -32,30 +30,27 @@ fn request_headers_round_trip_smoke_test() {
let rocket = rocket::ignite().mount("/", routes![request_headers]);
let client = Client::new(rocket).expect("A valid Rocket client");
let origin_header =
Header::from(hyper::header::Origin::from_str("https://foo.bar.xyz").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("accept-language").unwrap(),
FromStr::from_str("X-Ping").unwrap(),
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://foo.bar.xyz");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"accept-language, X-Ping",
);
let req = client
.get("/request_headers")
.header(origin_header)
.header(method_header)
.header(request_headers);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response
.body()
.and_then(Body::into_string)
.expect("Non-empty body");
let body_str = response.into_string();
let expected_body = r#"https://foo.bar.xyz
GET
X-Ping, accept-language"#;
assert_eq!(expected_body, body_str);
X-Ping, accept-language"#
.to_string();
assert_eq!(body_str, Some(expected_body));
}

View File

@ -1,12 +1,9 @@
//! This crate tests using `rocket_cors` using manual mode
#![feature(proc_macro_hygiene, decl_macro)]
use std::str::FromStr;
use rocket::http::hyper;
use rocket::http::Method;
use rocket::http::{Header, Status};
use rocket::local::Client;
use rocket::response::Body;
use rocket::local::blocking::Client;
use rocket::response::Responder;
use rocket::State;
use rocket::{get, options, routes};
@ -14,14 +11,14 @@ use rocket_cors::*;
/// Using a borrowed `Cors`
#[get("/")]
fn cors(options: State<'_, Cors>) -> impl Responder<'_> {
fn cors(options: State<'_, Cors>) -> impl Responder<'_, '_> {
options
.inner()
.respond_borrowed(|guard| guard.responder("Hello CORS"))
}
#[get("/panic")]
fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> {
fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_, '_> {
options.inner().respond_borrowed(|_| {
panic!("This route will panic");
})
@ -29,7 +26,7 @@ fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> {
/// Respond with an owned option instead
#[options("/owned")]
fn owned_options<'r>() -> impl Responder<'r> {
fn owned_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
let borrow = make_different_cors_options().to_cors()?;
borrow.respond_owned(|guard| guard.responder("Manual CORS Preflight"))
@ -37,7 +34,7 @@ fn owned_options<'r>() -> impl Responder<'r> {
/// Respond with an owned option instead
#[get("/owned")]
fn owned<'r>() -> impl Responder<'r> {
fn owned<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
let borrow = make_different_cors_options().to_cors()?;
borrow.respond_owned(|guard| guard.responder("Hello CORS Owned"))
@ -47,7 +44,7 @@ fn owned<'r>() -> impl Responder<'r> {
/// `Responder` with String
#[get("/")]
fn responder_string(options: State<'_, Cors>) -> impl Responder<'_> {
fn responder_string(options: State<'_, Cors>) -> impl Responder<'_, '_> {
options
.inner()
.respond_borrowed(|guard| guard.responder("Hello CORS".to_string()))
@ -56,7 +53,10 @@ fn responder_string(options: State<'_, Cors>) -> impl Responder<'_> {
struct TestState;
/// Borrow something else from Rocket with lifetime `'r`
#[get("/")]
fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> impl Responder<'r> {
fn borrow<'r, 'o: 'r>(
options: State<'r, Cors>,
test_state: State<'r, TestState>,
) -> impl Responder<'r, 'o> {
let borrow = test_state.inner();
options.inner().respond_borrowed(move |guard| {
let _ = borrow;
@ -101,16 +101,15 @@ fn smoke_test() {
let client = Client::new(rocket()).unwrap();
// `Options` pre-flight checks
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -121,37 +120,34 @@ fn smoke_test() {
assert!(response.status().class().is_success());
// "Actual" request
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS".to_string()));
}
#[test]
fn cors_options_borrowed_check() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -172,21 +168,19 @@ fn cors_options_borrowed_check() {
fn cors_get_borrowed_check() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS".to_string()));
}
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
@ -197,9 +191,9 @@ fn cors_get_no_origin() {
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS".to_string()));
}
@ -207,16 +201,15 @@ fn cors_get_no_origin() {
fn cors_options_bad_origin() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -231,14 +224,14 @@ fn cors_options_bad_origin() {
fn cors_options_missing_origin() {
let client = Client::new(rocket()).unwrap();
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(method_header)
@ -256,16 +249,15 @@ fn cors_options_missing_origin() {
fn cors_options_bad_request_method() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Post,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::POST.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -284,14 +276,15 @@ fn cors_options_bad_request_method() {
fn cors_options_bad_request_header() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Foobar",
);
let req = client
.options("/")
.header(origin_header)
@ -310,8 +303,7 @@ fn cors_options_bad_request_header() {
fn cors_get_bad_origin() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
@ -330,16 +322,15 @@ fn cors_get_bad_origin() {
fn routes_failing_checks_are_not_executed() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/panic")
.header(origin_header)
@ -360,32 +351,31 @@ fn cors_options_owned_check() {
let rocket = rocket();
let client = Client::new(rocket).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.example.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/owned")
.header(origin_header)
.header(method_header)
.header(request_headers);
let mut response = req.dispatch();
let body_str = response.body().and_then(Body::into_string);
let response = req.dispatch();
assert!(response.status().class().is_success());
assert_eq!(body_str, Some("Manual CORS Preflight".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.example.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Manual CORS Preflight".to_string()));
}
/// Owned manual response works
@ -393,22 +383,20 @@ fn cors_options_owned_check() {
fn cors_get_owned_check() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.example.com");
let authorization = Header::new("Authorization", "let me in");
let req = client
.get("/owned")
.header(origin_header)
.header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS Owned".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.example.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS Owned".to_string()));
}

View File

@ -5,12 +5,9 @@
#![feature(proc_macro_hygiene, decl_macro)]
use rocket_cors;
use std::str::FromStr;
use rocket::http::hyper;
use rocket::http::{Header, Method, Status};
use rocket::local::Client;
use rocket::response::Body;
use rocket::local::blocking::Client;
use rocket::response::Responder;
use rocket::{get, options, routes};
@ -18,13 +15,13 @@ use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard};
/// The "usual" app route
#[get("/")]
fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, &str> {
fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, '_, &str> {
cors.responder("Hello CORS!")
}
/// The special "ping" route
#[get("/ping")]
fn ping<'r>() -> impl Responder<'r> {
fn ping<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
let cors = cors_options_all().to_cors()?;
cors.respond_owned(|guard| guard.responder("Pong!"))
}
@ -33,7 +30,7 @@ fn ping<'r>() -> impl Responder<'r> {
/// that is not in Rocket's managed state.
/// These routes can just return the unit type `()`
#[options("/ping")]
fn ping_options<'r>() -> impl Responder<'r> {
fn ping_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
let cors = cors_options_all().to_cors()?;
cors.respond_owned(|guard| guard.responder(()))
}
@ -73,16 +70,15 @@ fn smoke_test() {
let client = Client::new(rocket()).unwrap();
// `Options` pre-flight checks
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -93,37 +89,34 @@ fn smoke_test() {
assert!(response.status().class().is_success());
// "Actual" request
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS!".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS!".to_string()));
}
#[test]
fn cors_options_check() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -144,21 +137,19 @@ fn cors_options_check() {
fn cors_get_check() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Hello CORS!".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.acme.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS!".to_string()));
}
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
@ -169,9 +160,9 @@ fn cors_get_no_origin() {
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(authorization);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
let body_str = response.into_string();
assert_eq!(body_str, Some("Hello CORS!".to_string()));
}
@ -179,16 +170,15 @@ fn cors_get_no_origin() {
fn cors_options_bad_origin() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -203,14 +193,14 @@ fn cors_options_bad_origin() {
fn cors_options_missing_origin() {
let client = Client::new(rocket()).unwrap();
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(method_header)
@ -228,16 +218,15 @@ fn cors_options_missing_origin() {
fn cors_options_bad_request_method() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Post,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap()
]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::POST.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Authorization",
);
let req = client
.options("/")
.header(origin_header)
@ -256,14 +245,15 @@ fn cors_options_bad_request_method() {
fn cors_options_bad_request_header() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let request_headers =
hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]);
let request_headers = Header::from(request_headers);
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.acme.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let request_headers = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
"Foobar",
);
let req = client
.options("/")
.header(origin_header)
@ -282,8 +272,7 @@ fn cors_options_bad_request_header() {
fn cors_get_bad_origin() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.bad-origin.com");
let authorization = Header::new("Authorization", "let me in");
let req = client.get("/").header(origin_header).header(authorization);
@ -300,11 +289,11 @@ fn cors_get_bad_origin() {
fn cors_options_ping_check() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::Method::Get,
));
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.example.com");
let method_header = Header::new(
hyper::header::ACCESS_CONTROL_REQUEST_METHOD.as_str(),
hyper::Method::GET.as_str(),
);
let req = client
.options("/ping")
@ -326,19 +315,17 @@ fn cors_options_ping_check() {
fn cors_get_ping_check() {
let client = Client::new(rocket()).unwrap();
let origin_header =
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
let origin_header = Header::new(hyper::header::ORIGIN.as_str(), "https://www.example.com");
let req = client.get("/ping").header(origin_header);
let mut response = req.dispatch();
let response = req.dispatch();
assert!(response.status().class().is_success());
let body_str = response.body().and_then(Body::into_string);
assert_eq!(body_str, Some("Pong!".to_string()));
let origin_header = response
.headers()
.get_one("Access-Control-Allow-Origin")
.expect("to exist");
assert_eq!("https://www.example.com", origin_header);
let body_str = response.into_string();
assert_eq!(body_str, Some("Pong!".to_string()));
}