diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a38ec65..5057750 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -11,8 +11,8 @@ jobs: strategy: matrix: rust: + - stable - nightly - - nightly-2019-05-21 # MSRV os: - ubuntu-latest - windows-latest @@ -22,6 +22,8 @@ jobs: - "--all-features" - "--no-default-features" + fail-fast: false + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v1 @@ -35,9 +37,6 @@ jobs: override: true components: rustfmt, clippy - - name: Remove Rust Toolchain file - run: rm rust-toolchain - - uses: actions-rs/cargo@v1 name: Clippy Lint with: diff --git a/Cargo.toml b/Cargo.toml index 1d112e9..5d8693c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ serialization = ["serde", "serde_derive", "unicase_serde"] [dependencies] regex = "1.1" -rocket = { version = "0.4.2", default-features = false } +rocket = { git="https://github.com/SergioBenitez/Rocket.git", default-features = false } log = "0.4" unicase = "2.0" url = "2.1.0" @@ -33,7 +33,6 @@ serde_derive = { version = "1.0", optional = true } unicase_serde = { version = "0.1.0", optional = true } [dev-dependencies] -hyper = "0.10" serde_json = "1.0" serde_test = "1.0" diff --git a/examples/fairing.rs b/examples/fairing.rs index 814aa3c..9e33b4b 100644 --- a/examples/fairing.rs +++ b/examples/fairing.rs @@ -1,17 +1,16 @@ -#![feature(proc_macro_hygiene, decl_macro)] -use rocket; -use rocket_cors; +use std::error::Error; use rocket::http::Method; use rocket::{get, routes}; -use rocket_cors::{AllowedHeaders, AllowedOrigins, Error}; +use rocket_cors::{AllowedHeaders, AllowedOrigins}; #[get("/")] fn cors<'a>() -> &'a str { "Hello CORS" } -fn main() -> Result<(), Error> { +#[rocket::main] +async fn main() -> Result<(), Box> { let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this @@ -27,7 +26,8 @@ fn main() -> Result<(), Error> { rocket::ignite() .mount("/", routes![cors]) .attach(cors) - .launch(); + .launch() + .await?; Ok(()) } diff --git a/examples/guard.rs b/examples/guard.rs index b8faa39..28f409a 100644 --- a/examples/guard.rs +++ b/examples/guard.rs @@ -1,17 +1,14 @@ -#![feature(proc_macro_hygiene, decl_macro)] -use rocket; -use rocket_cors; - +use std::error::Error; use std::io::Cursor; use rocket::http::Method; use rocket::Response; use rocket::{get, options, routes}; -use rocket_cors::{AllowedHeaders, AllowedOrigins, Error, Guard, Responder}; +use rocket_cors::{AllowedHeaders, AllowedOrigins, Guard, Responder}; /// Using a `Responder` -- the usual way you would use this #[get("/")] -fn responder(cors: Guard<'_>) -> Responder<'_, &str> { +fn responder(cors: Guard<'_>) -> Responder<'_, '_, &str> { cors.responder("Hello CORS!") } @@ -19,23 +16,25 @@ fn responder(cors: Guard<'_>) -> Responder<'_, &str> { #[get("/response")] fn response(cors: Guard<'_>) -> Response<'_> { let mut response = Response::new(); - response.set_sized_body(Cursor::new("Hello CORS!")); + let body = "Hello CORS!"; + response.set_sized_body(body.len(), Cursor::new(body)); cors.response(response) } /// Manually mount an OPTIONS route for your own handling #[options("/manual")] -fn manual_options(cors: Guard<'_>) -> Responder<'_, &str> { +fn manual_options(cors: Guard<'_>) -> Responder<'_, '_, &str> { cors.responder("Manual OPTIONS preflight handling") } /// Manually mount an OPTIONS route for your own handling #[get("/manual")] -fn manual(cors: Guard<'_>) -> Responder<'_, &str> { +fn manual(cors: Guard<'_>) -> Responder<'_, '_, &str> { cors.responder("Manual OPTIONS preflight handling") } -fn main() -> Result<(), Error> { +#[rocket::main] +async fn main() -> Result<(), Box> { let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); // You can also deserialize this @@ -55,7 +54,8 @@ fn main() -> Result<(), Error> { // You can also manually mount an OPTIONS route that will be used instead .mount("/", routes![manual, manual_options]) .manage(cors) - .launch(); + .launch() + .await?; Ok(()) } diff --git a/examples/json.rs b/examples/json.rs index 2835f6a..1e49cf1 100644 --- a/examples/json.rs +++ b/examples/json.rs @@ -1,10 +1,7 @@ //! This example is to demonstrate the JSON serialization and deserialization of the Cors settings //! //! Note: This requires the `serialization` feature which is enabled by default. -#![feature(proc_macro_hygiene, decl_macro)] - use rocket_cors as cors; -use serde_json; use crate::cors::{AllowedHeaders, AllowedOrigins, CorsOptions}; use rocket::http::Method; diff --git a/examples/manual.rs b/examples/manual.rs index db83c25..9749664 100644 --- a/examples/manual.rs +++ b/examples/manual.rs @@ -1,9 +1,6 @@ -#![feature(proc_macro_hygiene, decl_macro)] -use rocket; -use rocket_cors; - use std::io::Cursor; +use rocket::error::Error; use rocket::http::Method; use rocket::response::Responder; use rocket::{get, options, routes, Response, State}; @@ -17,7 +14,7 @@ use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions}; /// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime /// `'r` and so does `Responder`! #[get("/")] -fn borrowed(options: State<'_, Cors>) -> impl Responder<'_> { +fn borrowed(options: State<'_, Cors>) -> impl Responder<'_, '_> { options .inner() .respond_borrowed(|guard| guard.responder("Hello CORS")) @@ -27,9 +24,10 @@ fn borrowed(options: State<'_, Cors>) -> impl Responder<'_> { /// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime /// `'r` and so does `Responder`! #[get("/response")] -fn response(options: State<'_, Cors>) -> impl Responder<'_> { +fn response(options: State<'_, Cors>) -> impl Responder<'_, '_> { let mut response = Response::new(); - response.set_sized_body(Cursor::new("Hello CORS!")); + let body = "Hello CORS!"; + response.set_sized_body(body.len(), Cursor::new(body)); options .inner() @@ -43,7 +41,7 @@ fn response(options: State<'_, Cors>) -> impl Responder<'_> { /// when the settings you want to use for a route is not the same as the rest of the application /// (which you might have put in Rocket's state). #[get("/owned")] -fn owned<'r>() -> impl Responder<'r> { +fn owned<'r, 'o: 'r>() -> impl Responder<'r, 'o> { let options = cors_options().to_cors()?; options.respond_owned(|guard| guard.responder("Hello CORS")) } @@ -53,7 +51,7 @@ fn owned<'r>() -> impl Responder<'r> { /// These routes can just return the unit type `()` /// Note that the `'r` lifetime is needed because the compiler cannot elide anything. #[options("/owned")] -fn owned_options<'r>() -> impl Responder<'r> { +fn owned_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> { let options = cors_options().to_cors()?; options.respond_owned(|guard| guard.responder(())) } @@ -71,10 +69,12 @@ fn cors_options() -> CorsOptions { } } -fn main() { +#[rocket::main] +async fn main() -> Result<(), Error> { rocket::ignite() .mount("/", routes![borrowed, response, owned, owned_options,]) .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes .manage(cors_options().to_cors().expect("To not fail")) - .launch(); + .launch() + .await } diff --git a/examples/mix.rs b/examples/mix.rs index 3f8d350..515bc94 100644 --- a/examples/mix.rs +++ b/examples/mix.rs @@ -3,10 +3,7 @@ //! In this example, you typically have an application wide `Cors` struct except for one specific //! `ping` route that you want to allow all Origins to access. -#![feature(proc_macro_hygiene, decl_macro)] -use rocket; -use rocket_cors; - +use rocket::error::Error; use rocket::http::Method; use rocket::response::Responder; use rocket::{get, options, routes}; @@ -14,13 +11,13 @@ use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard}; /// The "usual" app route #[get("/")] -fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, &str> { +fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, '_, &str> { cors.responder("Hello CORS!") } /// The special "ping" route #[get("/ping")] -fn ping<'r>() -> impl Responder<'r> { +fn ping<'r, 'o: 'r>() -> impl Responder<'r, 'o> { let cors = cors_options_all().to_cors()?; cors.respond_owned(|guard| guard.responder("Pong!")) } @@ -29,7 +26,7 @@ fn ping<'r>() -> impl Responder<'r> { /// that is not in Rocket's managed state. /// These routes can just return the unit type `()` #[options("/ping")] -fn ping_options<'r>() -> impl Responder<'r> { +fn ping_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> { let cors = cors_options_all().to_cors()?; cors.respond_owned(|guard| guard.responder(())) } @@ -57,10 +54,12 @@ fn cors_options_all() -> CorsOptions { Default::default() } -fn main() { +#[rocket::main] +async fn main() -> Result<(), Error> { rocket::ignite() .mount("/", routes![app, ping, ping_options,]) .mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes .manage(cors_options().to_cors().expect("To not fail")) - .launch(); + .launch() + .await } diff --git a/rust-toolchain b/rust-toolchain deleted file mode 100644 index bf867e0..0000000 --- a/rust-toolchain +++ /dev/null @@ -1 +0,0 @@ -nightly diff --git a/src/fairing.rs b/src/fairing.rs index 34afaf3..fac4f7a 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -2,7 +2,7 @@ use ::log::{error, info}; use rocket::http::{self, uri::Origin, Status}; -use rocket::{self, error_, info_, log_, Outcome, Request}; +use rocket::{self, error_, info_, log_, outcome::Outcome, Request}; use crate::{ actual_request_response, origin, preflight_response, request_headers, validate, Cors, Error, @@ -14,25 +14,32 @@ enum CorsValidation { Failure, } -/// Route for Fairing error handling -pub(crate) fn fairing_error_route<'r>( - request: &'r Request<'_>, - _: rocket::Data, -) -> rocket::handler::Outcome<'r> { - let status = request - .get_param::(0) - .unwrap_or(Ok(0)) - .unwrap_or_else(|e| { - error_!("Fairing Error Handling Route error: {:?}", e); - 500 - }); - let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError); - Outcome::Failure(status) +/// Create a `Handler` for Fairing error handling +#[derive(Clone)] +struct FairingErrorRoute {} + +#[rocket::async_trait] +impl rocket::handler::Handler for FairingErrorRoute { + async fn handle<'r, 's: 'r>( + &'s self, + request: &'r Request<'_>, + _: rocket::Data, + ) -> rocket::handler::Outcome<'r> { + let status = request + .get_param::(0) + .unwrap_or(Ok(0)) + .unwrap_or_else(|e| { + error_!("Fairing Error Handling Route error: {:?}", e); + 500 + }); + let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError); + Outcome::Failure(status) + } } /// Create a new `Route` for Fairing handling fn fairing_route(rank: isize) -> rocket::Route { - rocket::Route::ranked(rank, http::Method::Get, "/", fairing_error_route) + rocket::Route::ranked(rank, http::Method::Get, "/", FairingErrorRoute {}) } /// Modifies a `Request` to route to Fairing error handler @@ -90,6 +97,7 @@ fn on_response_wrapper( Ok(()) } +#[rocket::async_trait] impl rocket::fairing::Fairing for Cors { fn info(&self) -> rocket::fairing::Info { rocket::fairing::Info { @@ -100,14 +108,14 @@ impl rocket::fairing::Fairing for Cors { } } - fn on_attach(&self, rocket: rocket::Rocket) -> Result { + async fn on_attach(&self, rocket: rocket::Rocket) -> Result { Ok(rocket.mount( &self.fairing_route_base, vec![fairing_route(self.fairing_route_rank)], )) } - fn on_request(&self, request: &mut Request<'_>, _: &rocket::Data) { + async fn on_request(&self, request: &mut Request<'_>, _: &rocket::Data) { let result = match validate(self, request) { Ok(_) => CorsValidation::Success, Err(err) => { @@ -121,7 +129,7 @@ impl rocket::fairing::Fairing for Cors { let _ = request.local_cache(|| result); } - fn on_response(&self, request: &Request<'_>, response: &mut rocket::Response<'_>) { + async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut rocket::Response<'r>) { if let Err(err) = on_response_wrapper(self, request, response) { error_!("Fairings on_response error: {}\nMost likely a bug", err); response.set_status(Status::InternalServerError); @@ -133,7 +141,7 @@ impl rocket::fairing::Fairing for Cors { #[cfg(test)] mod tests { use rocket::http::{Method, Status}; - use rocket::local::Client; + use rocket::local::blocking::Client; use rocket::Rocket; use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions}; @@ -161,7 +169,8 @@ mod tests { } #[test] - fn fairing_error_route_returns_passed_in_status() { + #[allow(non_snake_case)] + fn FairingErrorRoute_returns_passed_in_status() { let client = Client::new(rocket(make_cors_options())).expect("to not fail"); let request = client.get(format!("{}/403", CORS_ROOT)); let response = request.dispatch(); @@ -169,19 +178,22 @@ mod tests { } #[test] - fn fairing_error_route_returns_500_for_unknown_status() { + #[allow(non_snake_case)] + fn FairingErrorRoute_returns_500_for_unknown_status() { let client = Client::new(rocket(make_cors_options())).expect("to not fail"); let request = client.get(format!("{}/999", CORS_ROOT)); let response = request.dispatch(); assert_eq!(Status::InternalServerError, response.status()); } - #[test] - fn error_route_is_mounted_on_attach() { - let rocket = rocket(make_cors_options()); + #[rocket::async_test] + async fn error_route_is_mounted_on_attach() { + let mut rocket = rocket(make_cors_options()); let expected_uri = format!("{}/", CORS_ROOT); let error_route = rocket + .inspect() + .await .routes() .find(|r| r.method == Method::Get && r.uri.to_string() == expected_uri); assert!(error_route.is_some()); diff --git a/src/headers.rs b/src/headers.rs index ac80e41..926ad51 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -7,14 +7,11 @@ use std::str::FromStr; use rocket::http::Status; use rocket::request::{self, FromRequest}; -use rocket::{self, Outcome}; +use rocket::{self, outcome::Outcome}; #[cfg(feature = "serialization")] use serde_derive::{Deserialize, Serialize}; use unicase::UniCase; -#[cfg(feature = "serialization")] -use unicase_serde; - /// A case insensitive header name #[derive(Eq, PartialEq, Clone, Debug, Hash)] #[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))] @@ -91,6 +88,24 @@ impl Origin { Origin::Opaque(_) => false, } } + + /// Derives an instance of `Self` from the incoming request metadata. + /// + /// If the derivation is successful, an outcome of `Success` is returned. If + /// the derivation fails in an unrecoverable fashion, `Failure` is returned. + /// `Forward` is returned to indicate that the request should be forwarded + /// to other matching routes, if any. + pub fn from_request_sync( + request: &'_ rocket::Request<'_>, + ) -> request::Outcome { + match request.headers().get_one("Origin") { + Some(origin) => match Self::from_str(origin) { + Ok(origin) => Outcome::Success(origin), + Err(e) => Outcome::Failure((Status::BadRequest, e)), + }, + None => Outcome::Forward(()), + } + } } impl FromStr for Origin { @@ -118,19 +133,17 @@ impl fmt::Display for Origin { } } +#[rocket::async_trait] impl<'a, 'r> FromRequest<'a, 'r> for Origin { type Error = crate::Error; - fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome { - match request.headers().get_one("Origin") { - Some(origin) => match Self::from_str(origin) { - Ok(origin) => Outcome::Success(origin), - Err(e) => Outcome::Failure((Status::BadRequest, e)), - }, - None => Outcome::Forward(()), - } + async fn from_request( + request: &'a rocket::Request<'r>, + ) -> request::Outcome { + Origin::from_request_sync(request) } } + /// The `Access-Control-Request-Method` request header /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) @@ -138,18 +151,16 @@ impl<'a, 'r> FromRequest<'a, 'r> for Origin { #[derive(Debug)] pub struct AccessControlRequestMethod(pub crate::Method); -impl FromStr for AccessControlRequestMethod { - type Err = (); - - fn from_str(method: &str) -> Result { - Ok(AccessControlRequestMethod(crate::Method::from_str(method)?)) - } -} - -impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod { - type Error = crate::Error; - - fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome { +impl AccessControlRequestMethod { + /// Derives an instance of `Self` from the incoming request metadata. + /// + /// If the derivation is successful, an outcome of `Success` is returned. If + /// the derivation fails in an unrecoverable fashion, `Failure` is returned. + /// `Forward` is returned to indicate that the request should be forwarded + /// to other matching routes, if any. + pub fn from_request_sync( + request: &'_ rocket::Request<'_>, + ) -> request::Outcome { match request.headers().get_one("Access-Control-Request-Method") { Some(request_method) => match Self::from_str(request_method) { Ok(request_method) => Outcome::Success(request_method), @@ -160,6 +171,25 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod { } } +impl FromStr for AccessControlRequestMethod { + type Err = (); + + fn from_str(method: &str) -> Result { + Ok(AccessControlRequestMethod(crate::Method::from_str(method)?)) + } +} + +#[rocket::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod { + type Error = crate::Error; + + async fn from_request( + request: &'a rocket::Request<'r>, + ) -> request::Outcome { + AccessControlRequestMethod::from_request_sync(request) + } +} + /// The `Access-Control-Request-Headers` request header /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) @@ -167,6 +197,28 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod { #[derive(Eq, PartialEq, Debug)] pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet); +impl AccessControlRequestHeaders { + /// Derives an instance of `Self` from the incoming request metadata. + /// + /// If the derivation is successful, an outcome of `Success` is returned. If + /// the derivation fails in an unrecoverable fashion, `Failure` is returned. + /// `Forward` is returned to indicate that the request should be forwarded + /// to other matching routes, if any. + pub fn from_request_sync( + request: &'_ rocket::Request<'_>, + ) -> request::Outcome { + match request.headers().get_one("Access-Control-Request-Headers") { + Some(request_headers) => match Self::from_str(request_headers) { + Ok(request_headers) => Outcome::Success(request_headers), + Err(()) => { + unreachable!("`AccessControlRequestHeaders::from_str` should never fail") + } + }, + None => Outcome::Forward(()), + } + } +} + /// Will never fail impl FromStr for AccessControlRequestHeaders { type Err = (); @@ -185,19 +237,14 @@ impl FromStr for AccessControlRequestHeaders { } } +#[rocket::async_trait] impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders { type Error = crate::Error; - fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome { - match request.headers().get_one("Access-Control-Request-Headers") { - Some(request_headers) => match Self::from_str(request_headers) { - Ok(request_headers) => Outcome::Success(request_headers), - Err(()) => { - unreachable!("`AccessControlRequestHeaders::from_str` should never fail") - } - }, - None => Outcome::Forward(()), - } + async fn from_request( + request: &'a rocket::Request<'r>, + ) -> request::Outcome { + AccessControlRequestHeaders::from_request_sync(request) } } @@ -205,9 +252,15 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders { mod tests { use std::str::FromStr; - use hyper; - use rocket; - use rocket::local::Client; + use rocket::http::hyper; + use rocket::http::Header; + use rocket::local::blocking::Client; + + static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN; + static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_METHOD; + static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_HEADERS; use super::*; @@ -277,11 +330,10 @@ mod tests { let client = make_client(); let mut request = client.get("/"); - let origin = hyper::header::Origin::new("https", "www.example.com", None); + let origin = Header::new(ORIGIN.as_str(), "https://www.example.com"); request.add_header(origin); - let outcome: request::Outcome = - FromRequest::from_request(request.inner()); + let outcome = Origin::from_request_sync(request.inner()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); assert_eq!( "https://www.example.com", @@ -313,10 +365,12 @@ mod tests { fn request_method_parsing() { let client = make_client(); let mut request = client.get("/"); - let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get); + let method = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); request.add_header(method); - let outcome: request::Outcome = - FromRequest::from_request(request.inner()); + let outcome = AccessControlRequestMethod::from_request_sync(request.inner()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); let AccessControlRequestMethod(parsed_method) = parsed_header; @@ -337,13 +391,12 @@ mod tests { fn request_headers_parsing() { let client = make_client(); let mut request = client.get("/"); - let headers = hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("accept-language").unwrap(), - FromStr::from_str("date").unwrap(), - ]); + let headers = Header::new( + ACCESS_CONTROL_REQUEST_HEADERS.as_str(), + "accept-language, date", + ); request.add_header(headers); - let outcome: request::Outcome = - FromRequest::from_request(request.inner()); + let outcome = AccessControlRequestHeaders::from_request_sync(request.inner()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); let AccessControlRequestHeaders(parsed_headers) = parsed_header; diff --git a/src/lib.rs b/src/lib.rs index 57e9413..7731c53 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -261,7 +261,7 @@ See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/ missing_debug_implementations, unknown_lints, unsafe_code, - intra_doc_link_resolution_failure + broken_intra_doc_links )] #![doc(test(attr(allow(unused_variables), deny(warnings))))] @@ -285,7 +285,7 @@ use regex::RegexSet; use rocket::http::{self, Status}; use rocket::request::{FromRequest, Request}; use rocket::response; -use rocket::{debug_, error_, info_, log_, Outcome, State}; +use rocket::{debug_, error_, info_, log_, outcome::Outcome, State}; #[cfg(feature = "serialization")] use serde_derive::{Deserialize, Serialize}; @@ -417,8 +417,8 @@ impl error::Error for Error { } } -impl<'r> response::Responder<'r> for Error { - fn respond_to(self, _: &Request<'_>) -> Result, Status> { +impl<'r, 'o: 'r> response::Responder<'r, 'o> for Error { + fn respond_to(self, _: &Request<'_>) -> Result, Status> { error_!("CORS Error: {}", self); Err(self.status()) } @@ -1201,7 +1201,7 @@ impl CorsOptions { } /// Sets the rank of the fairing route - pub fn fairing_route_rank(mut self, fairing_route_rank: isize) -> Self { + pub fn fairing_route_rank(mut self, fairing_route_rank: isize) -> Self { self.fairing_route_rank = fairing_route_rank; self } @@ -1256,10 +1256,13 @@ impl Cors { /// passed in to include the CORS headers. /// /// See the documentation at the [crate root](index.html) for usage information. - pub fn respond_owned<'r, F, R>(self, handler: F) -> Result, Error> + pub fn respond_owned<'r, 'o: 'r, F, R>( + self, + handler: F, + ) -> Result, Error> where F: FnOnce(Guard<'r>) -> R + 'r, - R: response::Responder<'r>, + R: response::Responder<'r, 'o>, { Ok(ManualResponder::new(Cow::Owned(self), handler)) } @@ -1276,13 +1279,13 @@ impl Cors { /// passed in to include the CORS headers. /// /// See the documentation at the [crate root](index.html) for usage information. - pub fn respond_borrowed<'r, F, R>( + pub fn respond_borrowed<'r, 'o: 'r, F, R>( &'r self, handler: F, ) -> Result, Error> where F: FnOnce(Guard<'r>) -> R + 'r, - R: response::Responder<'r>, + R: response::Responder<'r, 'o>, { Ok(ManualResponder::new(Cow::Borrowed(self), handler)) } @@ -1375,7 +1378,10 @@ impl Response { /// Consumes the `Response` and return a `Responder` that wraps a /// provided `rocket:response::Responder` with CORS headers - pub fn responder<'r, R: response::Responder<'r>>(self, responder: R) -> Responder<'r, R> { + pub fn responder<'r, 'o: 'r, R: response::Responder<'r, 'o>>( + self, + responder: R, + ) -> Responder<'r, 'o, R> { Responder::new(responder, self) } @@ -1486,7 +1492,7 @@ pub struct Guard<'r> { marker: PhantomData<&'r Response>, } -impl<'r> Guard<'r> { +impl<'r, 'o: 'r> Guard<'r> { fn new(response: Response) -> Self { Self { response, @@ -1496,7 +1502,7 @@ impl<'r> Guard<'r> { /// Consumes the Guard and return a `Responder` that wraps a /// provided `rocket:response::Responder` with CORS headers - pub fn responder>(self, responder: R) -> Responder<'r, R> { + pub fn responder>(self, responder: R) -> Responder<'r, 'o, R> { self.response.responder(responder) } @@ -1509,11 +1515,12 @@ impl<'r> Guard<'r> { } } +#[rocket::async_trait] impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> { type Error = Error; - fn from_request(request: &'a Request<'r>) -> rocket::request::Outcome { - let options = match request.guard::>() { + async fn from_request(request: &'a Request<'r>) -> rocket::request::Outcome { + let options = match request.guard::>().await { Outcome::Success(options) => options, _ => { let error = Error::MissingCorsInRocketState; @@ -1545,13 +1552,13 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> { /// /// See the documentation at the [crate root](index.html) for usage information. #[derive(Debug)] -pub struct Responder<'r, R> { +pub struct Responder<'r, 'o, R> { responder: R, cors_response: Response, - marker: PhantomData>, + marker: PhantomData>, } -impl<'r, R: response::Responder<'r>> Responder<'r, R> { +impl<'r, 'o: 'r, R: response::Responder<'r, 'o>> Responder<'r, 'o, R> { fn new(responder: R, cors_response: Response) -> Self { Self { responder, @@ -1561,15 +1568,17 @@ impl<'r, R: response::Responder<'r>> Responder<'r, R> { } /// Respond to a request - fn respond(self, request: &Request<'_>) -> response::Result<'r> { + fn respond(self, request: &'r Request<'_>) -> response::Result<'o> { let mut response = self.responder.respond_to(request)?; // handle status errors? self.cors_response.merge(&mut response); Ok(response) } } -impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R> { - fn respond_to(self, request: &Request<'_>) -> response::Result<'r> { +impl<'r, 'o: 'r, R: response::Responder<'r, 'o>> response::Responder<'r, 'o> + for Responder<'r, 'o, R> +{ + fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> { self.respond(request) } } @@ -1583,10 +1592,10 @@ pub struct ManualResponder<'r, F, R> { marker: PhantomData, } -impl<'r, F, R> ManualResponder<'r, F, R> +impl<'r, 'o: 'r, F, R> ManualResponder<'r, F, R> where F: FnOnce(Guard<'r>) -> R + 'r, - R: response::Responder<'r>, + R: response::Responder<'r, 'o>, { /// Create a new manual responder by passing in either a borrowed or owned `Cors` option. /// @@ -1607,12 +1616,12 @@ where } } -impl<'r, F, R> response::Responder<'r> for ManualResponder<'r, F, R> +impl<'r, 'o: 'r, F, R> response::Responder<'r, 'o> for ManualResponder<'r, F, R> where F: FnOnce(Guard<'r>) -> R + 'r, - R: response::Responder<'r>, + R: response::Responder<'r, 'o>, { - fn respond_to(self, request: &Request<'_>) -> response::Result<'r> { + fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> { let guard = match self.build_guard(request) { Ok(guard) => guard, Err(err) => { @@ -1759,7 +1768,7 @@ fn validate_allowed_headers( /// Gets the `Origin` request header from the request fn origin(request: &Request<'_>) -> Result, Error> { - match Origin::from_request(request) { + match Origin::from_request_sync(request) { Outcome::Forward(()) => Ok(None), Outcome::Success(origin) => Ok(Some(origin)), Outcome::Failure((_, err)) => Err(err), @@ -1768,7 +1777,7 @@ fn origin(request: &Request<'_>) -> Result, Error> { /// Gets the `Access-Control-Request-Method` request header from the request fn request_method(request: &Request<'_>) -> Result, Error> { - match AccessControlRequestMethod::from_request(request) { + match AccessControlRequestMethod::from_request_sync(request) { Outcome::Forward(()) => Ok(None), Outcome::Success(method) => Ok(Some(method)), Outcome::Failure((_, err)) => Err(err), @@ -1777,7 +1786,7 @@ fn request_method(request: &Request<'_>) -> Result) -> Result, Error> { - match AccessControlRequestHeaders::from_request(request) { + match AccessControlRequestHeaders::from_request_sync(request) { Outcome::Forward(()) => Ok(None), Outcome::Success(geaders) => Ok(Some(geaders)), Outcome::Failure((_, err)) => Err(err), @@ -1984,48 +1993,60 @@ pub fn catch_all_options_routes() -> Vec { isize::max_value(), http::Method::Options, "/", - catch_all_options_route_handler, + CatchAllOptionsRouteHandler {}, ), rocket::Route::ranked( isize::max_value(), http::Method::Options, "/", - catch_all_options_route_handler, + CatchAllOptionsRouteHandler {}, ), ] } /// Handler for the "catch all options route" -fn catch_all_options_route_handler<'r>( - request: &'r Request<'_>, - _: rocket::Data, -) -> rocket::handler::Outcome<'r> { - let guard: Guard<'_> = match request.guard() { - Outcome::Success(guard) => guard, - Outcome::Failure((status, _)) => return rocket::handler::Outcome::failure(status), - Outcome::Forward(()) => unreachable!("Should not be reachable"), - }; +#[derive(Clone)] +struct CatchAllOptionsRouteHandler {} - info_!( - "\"Catch all\" handling of CORS `OPTIONS` preflight for request {}", - request - ); +#[rocket::async_trait] +impl rocket::handler::Handler for CatchAllOptionsRouteHandler { + async fn handle<'r, 's: 'r>( + &'s self, + request: &'r Request<'_>, + _: rocket::Data, + ) -> rocket::handler::Outcome<'r> { + let guard: Guard<'_> = match request.guard().await { + Outcome::Success(guard) => guard, + Outcome::Failure((status, _)) => return rocket::handler::Outcome::failure(status), + Outcome::Forward(()) => unreachable!("Should not be reachable"), + }; - rocket::handler::Outcome::from(request, guard.responder(())) + info_!( + "\"Catch all\" handling of CORS `OPTIONS` preflight for request {}", + request + ); + + rocket::handler::Outcome::from(request, guard.responder(())) + } } #[cfg(test)] mod tests { use std::str::FromStr; + use rocket::http::hyper; use rocket::http::Header; - use rocket::local::Client; - #[cfg(feature = "serialization")] - use serde_json; + use rocket::local::blocking::Client; use super::*; use crate::http::Method; + static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN; + static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_METHOD; + static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_HEADERS; + fn to_parsed_origin>(origin: S) -> Result { Origin::from_str(origin.as_ref()) } @@ -2083,10 +2104,20 @@ mod tests { let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]); let cors_options_from_builder = CorsOptions::default() .allowed_origins(allowed_origins) - .allowed_methods(vec![http::Method::Get].into_iter().map(From::from).collect()) + .allowed_methods( + vec![http::Method::Get] + .into_iter() + .map(From::from) + .collect(), + ) .allowed_headers(AllowedHeaders::some(&[&"Authorization", "Accept"])) .allow_credentials(true) - .expose_headers(["Content-Type", "X-Custom"].iter().map(|s| (*s).to_string()).collect()); + .expose_headers( + ["Content-Type", "X-Custom"] + .iter() + .map(|s| (*s).to_string()) + .collect(), + ); assert_eq!(cors_options_from_builder, make_cors_options()); } @@ -2507,11 +2538,12 @@ mod tests { fn response_build_removes_existing_cors_headers_and_keeps_others() { use std::io::Cursor; + let body = "Brewing the best coffee!"; let original = response::Response::build() .status(Status::ImATeapot) .raw_header("X-Teapot-Make", "Rocket") .raw_header("Access-Control-Max-Age", "42") - .sized_body(Cursor::new("Brewing the best coffee!")) + .sized_body(body.len(), Cursor::new(body)) .finalize(); let response = Response::new(); @@ -2572,16 +2604,12 @@ mod tests { let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let request = client .options("/") @@ -2607,16 +2635,12 @@ mod tests { let cors = options.to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let request = client .options("/") @@ -2639,16 +2663,12 @@ mod tests { let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let request = client .options("/") @@ -2665,13 +2685,8 @@ mod tests { let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let request = client .options("/") @@ -2687,16 +2702,12 @@ mod tests { let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Post, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::POST.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let request = client .options("/") @@ -2713,16 +2724,15 @@ mod tests { let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap(), - FromStr::from_str("X-NOT-ALLOWED").unwrap(), - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new( + ACCESS_CONTROL_REQUEST_HEADERS.as_str(), + "Authorization, X-NOT-ALLOWED", + ); let request = client .options("/") @@ -2738,8 +2748,7 @@ mod tests { let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let request = client.get("/").header(origin_header); let result = validate(&cors, request.inner()).expect("to not fail"); @@ -2757,8 +2766,7 @@ mod tests { let cors = options.to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com"); let request = client.get("/").header(origin_header); let result = validate(&cors, request.inner()).expect("to not fail"); @@ -2775,8 +2783,7 @@ mod tests { let cors = make_cors_options().to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com"); let request = client.get("/").header(origin_header); let _ = validate(&cors, request.inner()).unwrap(); @@ -2799,16 +2806,12 @@ mod tests { let cors = options.to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let request = client .options("/") @@ -2839,16 +2842,12 @@ mod tests { let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let request = client .options("/") @@ -2879,16 +2878,12 @@ mod tests { let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let request = client .options("/") @@ -2914,8 +2909,7 @@ mod tests { let cors = options.to_cors().expect("To not fail"); let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let request = client.get("/").header(origin_header); let response = validate_and_build(&cors, request.inner()).expect("to not fail"); @@ -2937,8 +2931,7 @@ mod tests { let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let request = client.get("/").header(origin_header); let response = validate_and_build(&cors, request.inner()).expect("to not fail"); @@ -2960,8 +2953,7 @@ mod tests { let client = make_client(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let request = client.get("/").header(origin_header); let response = validate_and_build(&cors, request.inner()).expect("to not fail"); diff --git a/tests/fairing.rs b/tests/fairing.rs index 055e015..c791374 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -1,23 +1,24 @@ //! This crate tests using `rocket_cors` using Fairings -#![feature(proc_macro_hygiene, decl_macro)] -use hyper; - -use std::str::FromStr; - +use rocket::http::hyper; use rocket::http::Method; use rocket::http::{Header, Status}; -use rocket::local::Client; -use rocket::response::Body; +use rocket::local::blocking::Client; use rocket::{get, routes}; use rocket_cors::*; +static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN; +static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_METHOD; +static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_HEADERS; + #[get("/")] fn cors<'a>() -> &'a str { "Hello CORS" } #[get("/panic")] -fn panicking_route() { +fn panicking_route<'a>() -> &'a str { panic!("This route will panic"); } @@ -46,16 +47,12 @@ fn smoke_test() { let client = Client::new(rocket()).unwrap(); // `Options` pre-flight checks - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -66,37 +63,31 @@ fn smoke_test() { assert!(response.status().class().is_success()); // "Actual" request - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Hello CORS".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.acme.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Hello CORS".to_string())); } #[test] fn cors_options_check() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -117,21 +108,19 @@ fn cors_options_check() { fn cors_get_check() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Hello CORS".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.acme.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Hello CORS".to_string())); } /// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) @@ -142,9 +131,9 @@ fn cors_get_no_origin() { let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); + let body_str = response.into_string(); assert_eq!(body_str, Some("Hello CORS".to_string())); } @@ -152,16 +141,12 @@ fn cors_get_no_origin() { fn cors_options_bad_origin() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -177,14 +162,11 @@ fn cors_options_bad_origin() { fn cors_options_missing_origin() { let client = Client::new(rocket()).unwrap(); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(method_header) @@ -203,16 +185,12 @@ fn cors_options_missing_origin() { fn cors_options_bad_request_method() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Post, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::POST.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -231,14 +209,12 @@ fn cors_options_bad_request_method() { fn cors_options_bad_request_header() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Foobar"); let req = client .options("/") .header(origin_header) @@ -257,8 +233,7 @@ fn cors_options_bad_request_header() { fn cors_get_bad_origin() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); @@ -277,16 +252,12 @@ fn cors_get_bad_origin() { fn routes_failing_checks_are_not_executed() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/panic") .header(origin_header) diff --git a/tests/guard.rs b/tests/guard.rs index d929616..3d0c9a9 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -1,36 +1,38 @@ //! This crate tests using `rocket_cors` using the per-route handling with request guard -#![feature(proc_macro_hygiene, decl_macro)] -use hyper; use rocket_cors as cors; -use std::str::FromStr; - +use rocket::http::hyper; use rocket::http::Method; use rocket::http::{Header, Status}; -use rocket::local::Client; -use rocket::response::Body; +use rocket::local::blocking::Client; use rocket::{get, options, routes}; use rocket::{Response, State}; +static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN; +static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_METHOD; +static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_HEADERS; + #[get("/")] -fn cors(cors: cors::Guard<'_>) -> cors::Responder<'_, &str> { +fn cors(cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> { cors.responder("Hello CORS") } #[get("/panic")] -fn panicking_route(_cors: cors::Guard<'_>) { +fn panicking_route(_cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> { panic!("This route will panic"); } /// Manually specify our own OPTIONS route #[options("/manual")] -fn cors_manual_options(cors: cors::Guard<'_>) -> cors::Responder<'_, &str> { +fn cors_manual_options(cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> { cors.responder("Manual CORS Preflight") } /// Manually specify our own OPTIONS route #[get("/manual")] -fn cors_manual(cors: cors::Guard<'_>) -> cors::Responder<'_, &str> { +fn cors_manual(cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> { cors.responder("Hello CORS") } @@ -42,20 +44,23 @@ fn response(cors: cors::Guard<'_>) -> Response<'_> { /// `Responder` with String #[get("/responder/string")] -fn responder_string(cors: cors::Guard<'_>) -> cors::Responder<'_, String> { +fn responder_string(cors: cors::Guard<'_>) -> cors::Responder<'_, 'static, String> { cors.responder("Hello CORS".to_string()) } /// `Responder` with 'static () #[get("/responder/unit")] -fn responder_unit(cors: cors::Guard<'_>) -> cors::Responder<'_, ()> { +fn responder_unit(cors: cors::Guard<'_>) -> cors::Responder<'_, 'static, ()> { cors.responder(()) } struct SomeState; /// Borrow `SomeState` from Rocket #[get("/state")] -fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Responder<'r, &'r str> { +fn state<'r, 'o: 'r>( + cors: cors::Guard<'r>, + _state: State<'r, SomeState>, +) -> cors::Responder<'r, 'o, &'r str> { cors.responder("hmm") } @@ -92,16 +97,12 @@ fn smoke_test() { let client = Client::new(rocket).unwrap(); // `Options` pre-flight checks - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -112,21 +113,19 @@ fn smoke_test() { assert!(response.status().class().is_success()); // "Actual" request - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Hello CORS".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.acme.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Hello CORS".to_string())); } /// Check the "catch all" OPTIONS route works for `/` @@ -135,16 +134,12 @@ fn cors_options_catch_all_check() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -167,16 +162,12 @@ fn cors_options_catch_all_check_other_routes() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/response/unit") .header(origin_header) @@ -198,21 +189,19 @@ fn cors_get_check() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Hello CORS".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.acme.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Hello CORS".to_string())); } /// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) @@ -224,14 +213,14 @@ fn cors_get_no_origin() { let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Hello CORS".to_string())); assert!(response .headers() .get_one("Access-Control-Allow-Origin") .is_none()); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Hello CORS".to_string())); } #[test] @@ -239,16 +228,12 @@ fn cors_options_bad_origin() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -268,14 +253,11 @@ fn cors_options_missing_origin() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(method_header) @@ -294,16 +276,12 @@ fn cors_options_bad_request_method() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Post, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::POST.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -323,14 +301,12 @@ fn cors_options_bad_request_header() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Foobar"); let req = client .options("/") .header(origin_header) @@ -350,8 +326,7 @@ fn cors_get_bad_origin() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); @@ -371,8 +346,7 @@ fn routes_failing_checks_are_not_executed() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); @@ -391,30 +365,25 @@ fn overridden_options_routes_are_used() { let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/manual") .header(origin_header) .header(method_header) .header(request_headers); - let mut response = req.dispatch(); - let body_str = response.body().and_then(Body::into_string); + let response = req.dispatch(); assert!(response.status().class().is_success()); - assert_eq!(body_str, Some("Manual CORS Preflight".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.acme.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Manual CORS Preflight".to_string())); } diff --git a/tests/headers.rs b/tests/headers.rs index 603c25c..7becb6c 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -1,16 +1,18 @@ //! This crate tests that all the request headers are parsed correctly in the round trip -#![feature(proc_macro_hygiene, decl_macro)] -use hyper; - use std::ops::Deref; -use std::str::FromStr; +use rocket::http::hyper; use rocket::http::Header; -use rocket::local::Client; -use rocket::response::Body; +use rocket::local::blocking::Client; use rocket::{get, routes}; use rocket_cors::headers::*; +static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN; +static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_METHOD; +static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_HEADERS; + #[get("/request_headers")] fn request_headers( origin: Origin, @@ -33,30 +35,27 @@ fn request_headers_round_trip_smoke_test() { let rocket = rocket::ignite().mount("/", routes![request_headers]); let client = Client::new(rocket).expect("A valid Rocket client"); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://foo.bar.xyz").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("accept-language").unwrap(), - FromStr::from_str("X-Ping").unwrap(), - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://foo.bar.xyz"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new( + ACCESS_CONTROL_REQUEST_HEADERS.as_str(), + "accept-language, X-Ping", + ); let req = client .get("/request_headers") .header(origin_header) .header(method_header) .header(request_headers); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response - .body() - .and_then(Body::into_string) - .expect("Non-empty body"); + let body_str = response.into_string(); let expected_body = r#"https://foo.bar.xyz GET -X-Ping, accept-language"#; - assert_eq!(expected_body, body_str); +X-Ping, accept-language"# + .to_string(); + assert_eq!(body_str, Some(expected_body)); } diff --git a/tests/manual.rs b/tests/manual.rs index b7f4147..f01c915 100644 --- a/tests/manual.rs +++ b/tests/manual.rs @@ -1,28 +1,29 @@ //! This crate tests using `rocket_cors` using manual mode -#![feature(proc_macro_hygiene, decl_macro)] -use hyper; - -use std::str::FromStr; - +use rocket::http::hyper; use rocket::http::Method; use rocket::http::{Header, Status}; -use rocket::local::Client; -use rocket::response::Body; +use rocket::local::blocking::Client; use rocket::response::Responder; use rocket::State; use rocket::{get, options, routes}; use rocket_cors::*; +static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN; +static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_METHOD; +static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_HEADERS; + /// Using a borrowed `Cors` #[get("/")] -fn cors(options: State<'_, Cors>) -> impl Responder<'_> { +fn cors(options: State<'_, Cors>) -> impl Responder<'_, '_> { options .inner() .respond_borrowed(|guard| guard.responder("Hello CORS")) } #[get("/panic")] -fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> { +fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_, '_> { options.inner().respond_borrowed(|_| { panic!("This route will panic"); }) @@ -30,7 +31,7 @@ fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> { /// Respond with an owned option instead #[options("/owned")] -fn owned_options<'r>() -> impl Responder<'r> { +fn owned_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> { let borrow = make_different_cors_options().to_cors()?; borrow.respond_owned(|guard| guard.responder("Manual CORS Preflight")) @@ -38,7 +39,7 @@ fn owned_options<'r>() -> impl Responder<'r> { /// Respond with an owned option instead #[get("/owned")] -fn owned<'r>() -> impl Responder<'r> { +fn owned<'r, 'o: 'r>() -> impl Responder<'r, 'o> { let borrow = make_different_cors_options().to_cors()?; borrow.respond_owned(|guard| guard.responder("Hello CORS Owned")) @@ -48,7 +49,8 @@ fn owned<'r>() -> impl Responder<'r> { /// `Responder` with String #[get("/")] -fn responder_string(options: State<'_, Cors>) -> impl Responder<'_> { +#[allow(dead_code)] +fn responder_string(options: State<'_, Cors>) -> impl Responder<'_, '_> { options .inner() .respond_borrowed(|guard| guard.responder("Hello CORS".to_string())) @@ -57,7 +59,11 @@ fn responder_string(options: State<'_, Cors>) -> impl Responder<'_> { struct TestState; /// Borrow something else from Rocket with lifetime `'r` #[get("/")] -fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> impl Responder<'r> { +#[allow(dead_code)] +fn borrow<'r, 'o: 'r>( + options: State<'r, Cors>, + test_state: State<'r, TestState>, +) -> impl Responder<'r, 'o> { let borrow = test_state.inner(); options.inner().respond_borrowed(move |guard| { let _ = borrow; @@ -102,16 +108,12 @@ fn smoke_test() { let client = Client::new(rocket()).unwrap(); // `Options` pre-flight checks - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -122,37 +124,31 @@ fn smoke_test() { assert!(response.status().class().is_success()); // "Actual" request - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Hello CORS".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.acme.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Hello CORS".to_string())); } #[test] fn cors_options_borrowed_check() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -173,21 +169,19 @@ fn cors_options_borrowed_check() { fn cors_get_borrowed_check() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Hello CORS".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.acme.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Hello CORS".to_string())); } /// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) @@ -198,9 +192,9 @@ fn cors_get_no_origin() { let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); + let body_str = response.into_string(); assert_eq!(body_str, Some("Hello CORS".to_string())); } @@ -208,16 +202,12 @@ fn cors_get_no_origin() { fn cors_options_bad_origin() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -232,14 +222,11 @@ fn cors_options_bad_origin() { fn cors_options_missing_origin() { let client = Client::new(rocket()).unwrap(); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(method_header) @@ -257,16 +244,12 @@ fn cors_options_missing_origin() { fn cors_options_bad_request_method() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Post, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::POST.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -285,14 +268,12 @@ fn cors_options_bad_request_method() { fn cors_options_bad_request_header() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Foobar"); let req = client .options("/") .header(origin_header) @@ -311,8 +292,7 @@ fn cors_options_bad_request_header() { fn cors_get_bad_origin() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); @@ -331,16 +311,12 @@ fn cors_get_bad_origin() { fn routes_failing_checks_are_not_executed() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/panic") .header(origin_header) @@ -361,32 +337,28 @@ fn cors_options_owned_check() { let rocket = rocket(); let client = Client::new(rocket).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/owned") .header(origin_header) .header(method_header) .header(request_headers); - let mut response = req.dispatch(); - let body_str = response.body().and_then(Body::into_string); + let response = req.dispatch(); assert!(response.status().class().is_success()); - assert_eq!(body_str, Some("Manual CORS Preflight".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.example.com", origin_header); + + let body_str = response.into_string(); + assert_eq!(body_str, Some("Manual CORS Preflight".to_string())); } /// Owned manual response works @@ -394,22 +366,20 @@ fn cors_options_owned_check() { fn cors_get_owned_check() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com"); let authorization = Header::new("Authorization", "let me in"); let req = client .get("/owned") .header(origin_header) .header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Hello CORS Owned".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.example.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Hello CORS Owned".to_string())); } diff --git a/tests/mix.rs b/tests/mix.rs index 08ffbf0..d0bc6e0 100644 --- a/tests/mix.rs +++ b/tests/mix.rs @@ -2,29 +2,29 @@ //! //! In this example, you typically have an application wide `Cors` struct except for one specific //! `ping` route that you want to allow all Origins to access. -#![feature(proc_macro_hygiene, decl_macro)] -use hyper; -use rocket_cors; - -use std::str::FromStr; - +use rocket::http::hyper; use rocket::http::{Header, Method, Status}; -use rocket::local::Client; -use rocket::response::Body; +use rocket::local::blocking::Client; use rocket::response::Responder; use rocket::{get, options, routes}; use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard}; +static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN; +static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_METHOD; +static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName = + hyper::header::ACCESS_CONTROL_REQUEST_HEADERS; + /// The "usual" app route #[get("/")] -fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, &str> { +fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, '_, &str> { cors.responder("Hello CORS!") } /// The special "ping" route #[get("/ping")] -fn ping<'r>() -> impl Responder<'r> { +fn ping<'r, 'o: 'r>() -> impl Responder<'r, 'o> { let cors = cors_options_all().to_cors()?; cors.respond_owned(|guard| guard.responder("Pong!")) } @@ -33,7 +33,7 @@ fn ping<'r>() -> impl Responder<'r> { /// that is not in Rocket's managed state. /// These routes can just return the unit type `()` #[options("/ping")] -fn ping_options<'r>() -> impl Responder<'r> { +fn ping_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> { let cors = cors_options_all().to_cors()?; cors.respond_owned(|guard| guard.responder(())) } @@ -73,16 +73,12 @@ fn smoke_test() { let client = Client::new(rocket()).unwrap(); // `Options` pre-flight checks - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -93,37 +89,31 @@ fn smoke_test() { assert!(response.status().class().is_success()); // "Actual" request - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Hello CORS!".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.acme.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Hello CORS!".to_string())); } #[test] fn cors_options_check() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -144,21 +134,19 @@ fn cors_options_check() { fn cors_get_check() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Hello CORS!".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.acme.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Hello CORS!".to_string())); } /// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) @@ -169,9 +157,9 @@ fn cors_get_no_origin() { let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(authorization); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); + let body_str = response.into_string(); assert_eq!(body_str, Some("Hello CORS!".to_string())); } @@ -179,16 +167,12 @@ fn cors_get_no_origin() { fn cors_options_bad_origin() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -203,14 +187,11 @@ fn cors_options_bad_origin() { fn cors_options_missing_origin() { let client = Client::new(rocket()).unwrap(); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(method_header) @@ -228,16 +209,12 @@ fn cors_options_missing_origin() { fn cors_options_bad_request_method() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Post, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![ - FromStr::from_str("Authorization").unwrap() - ]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::POST.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization"); let req = client .options("/") .header(origin_header) @@ -256,14 +233,12 @@ fn cors_options_bad_request_method() { fn cors_options_bad_request_header() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = - hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]); - let request_headers = Header::from(request_headers); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); + let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Foobar"); let req = client .options("/") .header(origin_header) @@ -282,8 +257,7 @@ fn cors_options_bad_request_header() { fn cors_get_bad_origin() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com"); let authorization = Header::new("Authorization", "let me in"); let req = client.get("/").header(origin_header).header(authorization); @@ -300,11 +274,11 @@ fn cors_get_bad_origin() { fn cors_options_ping_check() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com"); + let method_header = Header::new( + ACCESS_CONTROL_REQUEST_METHOD.as_str(), + hyper::Method::GET.as_str(), + ); let req = client .options("/ping") @@ -326,19 +300,17 @@ fn cors_options_ping_check() { fn cors_get_ping_check() { let client = Client::new(rocket()).unwrap(); - let origin_header = - Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap()); + let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com"); let req = client.get("/ping").header(origin_header); - let mut response = req.dispatch(); + let response = req.dispatch(); assert!(response.status().class().is_success()); - let body_str = response.body().and_then(Body::into_string); - assert_eq!(body_str, Some("Pong!".to_string())); - let origin_header = response .headers() .get_one("Access-Control-Allow-Origin") .expect("to exist"); assert_eq!("https://www.example.com", origin_header); + let body_str = response.into_string(); + assert_eq!(body_str, Some("Pong!".to_string())); }