Use async version from rocket's master branch (#81)
* Use hyper re-export from rocket_http This way, the hyper version corresponding to the current rocket version is used for the tests. * Use async version from rocket's master branch * switch rocket version to master branch (use release version once async is available) * adapt code to incorporate changes from rocket and hyper * Make Clippy happy again * Make crate compile on Rust stable Rocket meanwhile works on Rust stable, so there is no reason to be limited to nightly. * Fix GitHub CI build * Use stable branch of Rust instead of broken minimum required nightly version. * Disable fail-fast to reveal all problems at once. * Remove deletion of rust-toolchain file as the file is no longer required/existing. Co-authored-by: Maximilian Köstler <maximilian@koestler.hamburg>
This commit is contained in:
parent
172f423887
commit
fae7ccf9ce
|
@ -11,8 +11,8 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
rust:
|
rust:
|
||||||
|
- stable
|
||||||
- nightly
|
- nightly
|
||||||
- nightly-2019-05-21 # MSRV
|
|
||||||
os:
|
os:
|
||||||
- ubuntu-latest
|
- ubuntu-latest
|
||||||
- windows-latest
|
- windows-latest
|
||||||
|
@ -22,6 +22,8 @@ jobs:
|
||||||
- "--all-features"
|
- "--all-features"
|
||||||
- "--no-default-features"
|
- "--no-default-features"
|
||||||
|
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v1
|
||||||
|
@ -35,9 +37,6 @@ jobs:
|
||||||
override: true
|
override: true
|
||||||
components: rustfmt, clippy
|
components: rustfmt, clippy
|
||||||
|
|
||||||
- name: Remove Rust Toolchain file
|
|
||||||
run: rm rust-toolchain
|
|
||||||
|
|
||||||
- uses: actions-rs/cargo@v1
|
- uses: actions-rs/cargo@v1
|
||||||
name: Clippy Lint
|
name: Clippy Lint
|
||||||
with:
|
with:
|
||||||
|
|
|
@ -22,7 +22,7 @@ serialization = ["serde", "serde_derive", "unicase_serde"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
regex = "1.1"
|
regex = "1.1"
|
||||||
rocket = { version = "0.4.2", default-features = false }
|
rocket = { git="https://github.com/SergioBenitez/Rocket.git", default-features = false }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
unicase = "2.0"
|
unicase = "2.0"
|
||||||
url = "2.1.0"
|
url = "2.1.0"
|
||||||
|
@ -33,7 +33,6 @@ serde_derive = { version = "1.0", optional = true }
|
||||||
unicase_serde = { version = "0.1.0", optional = true }
|
unicase_serde = { version = "0.1.0", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
hyper = "0.10"
|
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
serde_test = "1.0"
|
serde_test = "1.0"
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,16 @@
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
use std::error::Error;
|
||||||
use rocket;
|
|
||||||
use rocket_cors;
|
|
||||||
|
|
||||||
use rocket::http::Method;
|
use rocket::http::Method;
|
||||||
use rocket::{get, routes};
|
use rocket::{get, routes};
|
||||||
use rocket_cors::{AllowedHeaders, AllowedOrigins, Error};
|
use rocket_cors::{AllowedHeaders, AllowedOrigins};
|
||||||
|
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn cors<'a>() -> &'a str {
|
fn cors<'a>() -> &'a str {
|
||||||
"Hello CORS"
|
"Hello CORS"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), Error> {
|
#[rocket::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||||
|
|
||||||
// You can also deserialize this
|
// You can also deserialize this
|
||||||
|
@ -27,7 +26,8 @@ fn main() -> Result<(), Error> {
|
||||||
rocket::ignite()
|
rocket::ignite()
|
||||||
.mount("/", routes![cors])
|
.mount("/", routes![cors])
|
||||||
.attach(cors)
|
.attach(cors)
|
||||||
.launch();
|
.launch()
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,17 +1,14 @@
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
use std::error::Error;
|
||||||
use rocket;
|
|
||||||
use rocket_cors;
|
|
||||||
|
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
|
|
||||||
use rocket::http::Method;
|
use rocket::http::Method;
|
||||||
use rocket::Response;
|
use rocket::Response;
|
||||||
use rocket::{get, options, routes};
|
use rocket::{get, options, routes};
|
||||||
use rocket_cors::{AllowedHeaders, AllowedOrigins, Error, Guard, Responder};
|
use rocket_cors::{AllowedHeaders, AllowedOrigins, Guard, Responder};
|
||||||
|
|
||||||
/// Using a `Responder` -- the usual way you would use this
|
/// Using a `Responder` -- the usual way you would use this
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn responder(cors: Guard<'_>) -> Responder<'_, &str> {
|
fn responder(cors: Guard<'_>) -> Responder<'_, '_, &str> {
|
||||||
cors.responder("Hello CORS!")
|
cors.responder("Hello CORS!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,23 +16,25 @@ fn responder(cors: Guard<'_>) -> Responder<'_, &str> {
|
||||||
#[get("/response")]
|
#[get("/response")]
|
||||||
fn response(cors: Guard<'_>) -> Response<'_> {
|
fn response(cors: Guard<'_>) -> Response<'_> {
|
||||||
let mut response = Response::new();
|
let mut response = Response::new();
|
||||||
response.set_sized_body(Cursor::new("Hello CORS!"));
|
let body = "Hello CORS!";
|
||||||
|
response.set_sized_body(body.len(), Cursor::new(body));
|
||||||
cors.response(response)
|
cors.response(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Manually mount an OPTIONS route for your own handling
|
/// Manually mount an OPTIONS route for your own handling
|
||||||
#[options("/manual")]
|
#[options("/manual")]
|
||||||
fn manual_options(cors: Guard<'_>) -> Responder<'_, &str> {
|
fn manual_options(cors: Guard<'_>) -> Responder<'_, '_, &str> {
|
||||||
cors.responder("Manual OPTIONS preflight handling")
|
cors.responder("Manual OPTIONS preflight handling")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Manually mount an OPTIONS route for your own handling
|
/// Manually mount an OPTIONS route for your own handling
|
||||||
#[get("/manual")]
|
#[get("/manual")]
|
||||||
fn manual(cors: Guard<'_>) -> Responder<'_, &str> {
|
fn manual(cors: Guard<'_>) -> Responder<'_, '_, &str> {
|
||||||
cors.responder("Manual OPTIONS preflight handling")
|
cors.responder("Manual OPTIONS preflight handling")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), Error> {
|
#[rocket::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||||
|
|
||||||
// You can also deserialize this
|
// You can also deserialize this
|
||||||
|
@ -55,7 +54,8 @@ fn main() -> Result<(), Error> {
|
||||||
// You can also manually mount an OPTIONS route that will be used instead
|
// You can also manually mount an OPTIONS route that will be used instead
|
||||||
.mount("/", routes![manual, manual_options])
|
.mount("/", routes![manual, manual_options])
|
||||||
.manage(cors)
|
.manage(cors)
|
||||||
.launch();
|
.launch()
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
//! This example is to demonstrate the JSON serialization and deserialization of the Cors settings
|
//! This example is to demonstrate the JSON serialization and deserialization of the Cors settings
|
||||||
//!
|
//!
|
||||||
//! Note: This requires the `serialization` feature which is enabled by default.
|
//! Note: This requires the `serialization` feature which is enabled by default.
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
|
||||||
|
|
||||||
use rocket_cors as cors;
|
use rocket_cors as cors;
|
||||||
use serde_json;
|
|
||||||
|
|
||||||
use crate::cors::{AllowedHeaders, AllowedOrigins, CorsOptions};
|
use crate::cors::{AllowedHeaders, AllowedOrigins, CorsOptions};
|
||||||
use rocket::http::Method;
|
use rocket::http::Method;
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
|
||||||
use rocket;
|
|
||||||
use rocket_cors;
|
|
||||||
|
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
|
|
||||||
|
use rocket::error::Error;
|
||||||
use rocket::http::Method;
|
use rocket::http::Method;
|
||||||
use rocket::response::Responder;
|
use rocket::response::Responder;
|
||||||
use rocket::{get, options, routes, Response, State};
|
use rocket::{get, options, routes, Response, State};
|
||||||
|
@ -17,7 +14,7 @@ use rocket_cors::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions};
|
||||||
/// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime
|
/// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime
|
||||||
/// `'r` and so does `Responder`!
|
/// `'r` and so does `Responder`!
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn borrowed(options: State<'_, Cors>) -> impl Responder<'_> {
|
fn borrowed(options: State<'_, Cors>) -> impl Responder<'_, '_> {
|
||||||
options
|
options
|
||||||
.inner()
|
.inner()
|
||||||
.respond_borrowed(|guard| guard.responder("Hello CORS"))
|
.respond_borrowed(|guard| guard.responder("Hello CORS"))
|
||||||
|
@ -27,9 +24,10 @@ fn borrowed(options: State<'_, Cors>) -> impl Responder<'_> {
|
||||||
/// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime
|
/// Note that the `'r` lifetime annotation is not requred here because `State` borrows with lifetime
|
||||||
/// `'r` and so does `Responder`!
|
/// `'r` and so does `Responder`!
|
||||||
#[get("/response")]
|
#[get("/response")]
|
||||||
fn response(options: State<'_, Cors>) -> impl Responder<'_> {
|
fn response(options: State<'_, Cors>) -> impl Responder<'_, '_> {
|
||||||
let mut response = Response::new();
|
let mut response = Response::new();
|
||||||
response.set_sized_body(Cursor::new("Hello CORS!"));
|
let body = "Hello CORS!";
|
||||||
|
response.set_sized_body(body.len(), Cursor::new(body));
|
||||||
|
|
||||||
options
|
options
|
||||||
.inner()
|
.inner()
|
||||||
|
@ -43,7 +41,7 @@ fn response(options: State<'_, Cors>) -> impl Responder<'_> {
|
||||||
/// when the settings you want to use for a route is not the same as the rest of the application
|
/// when the settings you want to use for a route is not the same as the rest of the application
|
||||||
/// (which you might have put in Rocket's state).
|
/// (which you might have put in Rocket's state).
|
||||||
#[get("/owned")]
|
#[get("/owned")]
|
||||||
fn owned<'r>() -> impl Responder<'r> {
|
fn owned<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
|
||||||
let options = cors_options().to_cors()?;
|
let options = cors_options().to_cors()?;
|
||||||
options.respond_owned(|guard| guard.responder("Hello CORS"))
|
options.respond_owned(|guard| guard.responder("Hello CORS"))
|
||||||
}
|
}
|
||||||
|
@ -53,7 +51,7 @@ fn owned<'r>() -> impl Responder<'r> {
|
||||||
/// These routes can just return the unit type `()`
|
/// These routes can just return the unit type `()`
|
||||||
/// Note that the `'r` lifetime is needed because the compiler cannot elide anything.
|
/// Note that the `'r` lifetime is needed because the compiler cannot elide anything.
|
||||||
#[options("/owned")]
|
#[options("/owned")]
|
||||||
fn owned_options<'r>() -> impl Responder<'r> {
|
fn owned_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
|
||||||
let options = cors_options().to_cors()?;
|
let options = cors_options().to_cors()?;
|
||||||
options.respond_owned(|guard| guard.responder(()))
|
options.respond_owned(|guard| guard.responder(()))
|
||||||
}
|
}
|
||||||
|
@ -71,10 +69,12 @@ fn cors_options() -> CorsOptions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
#[rocket::main]
|
||||||
|
async fn main() -> Result<(), Error> {
|
||||||
rocket::ignite()
|
rocket::ignite()
|
||||||
.mount("/", routes![borrowed, response, owned, owned_options,])
|
.mount("/", routes![borrowed, response, owned, owned_options,])
|
||||||
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
|
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
|
||||||
.manage(cors_options().to_cors().expect("To not fail"))
|
.manage(cors_options().to_cors().expect("To not fail"))
|
||||||
.launch();
|
.launch()
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,10 +3,7 @@
|
||||||
//! In this example, you typically have an application wide `Cors` struct except for one specific
|
//! In this example, you typically have an application wide `Cors` struct except for one specific
|
||||||
//! `ping` route that you want to allow all Origins to access.
|
//! `ping` route that you want to allow all Origins to access.
|
||||||
|
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
use rocket::error::Error;
|
||||||
use rocket;
|
|
||||||
use rocket_cors;
|
|
||||||
|
|
||||||
use rocket::http::Method;
|
use rocket::http::Method;
|
||||||
use rocket::response::Responder;
|
use rocket::response::Responder;
|
||||||
use rocket::{get, options, routes};
|
use rocket::{get, options, routes};
|
||||||
|
@ -14,13 +11,13 @@ use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard};
|
||||||
|
|
||||||
/// The "usual" app route
|
/// The "usual" app route
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, &str> {
|
fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, '_, &str> {
|
||||||
cors.responder("Hello CORS!")
|
cors.responder("Hello CORS!")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The special "ping" route
|
/// The special "ping" route
|
||||||
#[get("/ping")]
|
#[get("/ping")]
|
||||||
fn ping<'r>() -> impl Responder<'r> {
|
fn ping<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
|
||||||
let cors = cors_options_all().to_cors()?;
|
let cors = cors_options_all().to_cors()?;
|
||||||
cors.respond_owned(|guard| guard.responder("Pong!"))
|
cors.respond_owned(|guard| guard.responder("Pong!"))
|
||||||
}
|
}
|
||||||
|
@ -29,7 +26,7 @@ fn ping<'r>() -> impl Responder<'r> {
|
||||||
/// that is not in Rocket's managed state.
|
/// that is not in Rocket's managed state.
|
||||||
/// These routes can just return the unit type `()`
|
/// These routes can just return the unit type `()`
|
||||||
#[options("/ping")]
|
#[options("/ping")]
|
||||||
fn ping_options<'r>() -> impl Responder<'r> {
|
fn ping_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
|
||||||
let cors = cors_options_all().to_cors()?;
|
let cors = cors_options_all().to_cors()?;
|
||||||
cors.respond_owned(|guard| guard.responder(()))
|
cors.respond_owned(|guard| guard.responder(()))
|
||||||
}
|
}
|
||||||
|
@ -57,10 +54,12 @@ fn cors_options_all() -> CorsOptions {
|
||||||
Default::default()
|
Default::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
#[rocket::main]
|
||||||
|
async fn main() -> Result<(), Error> {
|
||||||
rocket::ignite()
|
rocket::ignite()
|
||||||
.mount("/", routes![app, ping, ping_options,])
|
.mount("/", routes![app, ping, ping_options,])
|
||||||
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
|
.mount("/", rocket_cors::catch_all_options_routes()) // mount the catch all routes
|
||||||
.manage(cors_options().to_cors().expect("To not fail"))
|
.manage(cors_options().to_cors().expect("To not fail"))
|
||||||
.launch();
|
.launch()
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
nightly
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
use ::log::{error, info};
|
use ::log::{error, info};
|
||||||
use rocket::http::{self, uri::Origin, Status};
|
use rocket::http::{self, uri::Origin, Status};
|
||||||
use rocket::{self, error_, info_, log_, Outcome, Request};
|
use rocket::{self, error_, info_, log_, outcome::Outcome, Request};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
actual_request_response, origin, preflight_response, request_headers, validate, Cors, Error,
|
actual_request_response, origin, preflight_response, request_headers, validate, Cors, Error,
|
||||||
|
@ -14,8 +14,14 @@ enum CorsValidation {
|
||||||
Failure,
|
Failure,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Route for Fairing error handling
|
/// Create a `Handler` for Fairing error handling
|
||||||
pub(crate) fn fairing_error_route<'r>(
|
#[derive(Clone)]
|
||||||
|
struct FairingErrorRoute {}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl rocket::handler::Handler for FairingErrorRoute {
|
||||||
|
async fn handle<'r, 's: 'r>(
|
||||||
|
&'s self,
|
||||||
request: &'r Request<'_>,
|
request: &'r Request<'_>,
|
||||||
_: rocket::Data,
|
_: rocket::Data,
|
||||||
) -> rocket::handler::Outcome<'r> {
|
) -> rocket::handler::Outcome<'r> {
|
||||||
|
@ -29,10 +35,11 @@ pub(crate) fn fairing_error_route<'r>(
|
||||||
let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError);
|
let status = Status::from_code(status).unwrap_or_else(|| Status::InternalServerError);
|
||||||
Outcome::Failure(status)
|
Outcome::Failure(status)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a new `Route` for Fairing handling
|
/// Create a new `Route` for Fairing handling
|
||||||
fn fairing_route(rank: isize) -> rocket::Route {
|
fn fairing_route(rank: isize) -> rocket::Route {
|
||||||
rocket::Route::ranked(rank, http::Method::Get, "/<status>", fairing_error_route)
|
rocket::Route::ranked(rank, http::Method::Get, "/<status>", FairingErrorRoute {})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Modifies a `Request` to route to Fairing error handler
|
/// Modifies a `Request` to route to Fairing error handler
|
||||||
|
@ -90,6 +97,7 @@ fn on_response_wrapper(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
impl rocket::fairing::Fairing for Cors {
|
impl rocket::fairing::Fairing for Cors {
|
||||||
fn info(&self) -> rocket::fairing::Info {
|
fn info(&self) -> rocket::fairing::Info {
|
||||||
rocket::fairing::Info {
|
rocket::fairing::Info {
|
||||||
|
@ -100,14 +108,14 @@ impl rocket::fairing::Fairing for Cors {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_attach(&self, rocket: rocket::Rocket) -> Result<rocket::Rocket, rocket::Rocket> {
|
async fn on_attach(&self, rocket: rocket::Rocket) -> Result<rocket::Rocket, rocket::Rocket> {
|
||||||
Ok(rocket.mount(
|
Ok(rocket.mount(
|
||||||
&self.fairing_route_base,
|
&self.fairing_route_base,
|
||||||
vec![fairing_route(self.fairing_route_rank)],
|
vec![fairing_route(self.fairing_route_rank)],
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_request(&self, request: &mut Request<'_>, _: &rocket::Data) {
|
async fn on_request(&self, request: &mut Request<'_>, _: &rocket::Data) {
|
||||||
let result = match validate(self, request) {
|
let result = match validate(self, request) {
|
||||||
Ok(_) => CorsValidation::Success,
|
Ok(_) => CorsValidation::Success,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
@ -121,7 +129,7 @@ impl rocket::fairing::Fairing for Cors {
|
||||||
let _ = request.local_cache(|| result);
|
let _ = request.local_cache(|| result);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_response(&self, request: &Request<'_>, response: &mut rocket::Response<'_>) {
|
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut rocket::Response<'r>) {
|
||||||
if let Err(err) = on_response_wrapper(self, request, response) {
|
if let Err(err) = on_response_wrapper(self, request, response) {
|
||||||
error_!("Fairings on_response error: {}\nMost likely a bug", err);
|
error_!("Fairings on_response error: {}\nMost likely a bug", err);
|
||||||
response.set_status(Status::InternalServerError);
|
response.set_status(Status::InternalServerError);
|
||||||
|
@ -133,7 +141,7 @@ impl rocket::fairing::Fairing for Cors {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use rocket::http::{Method, Status};
|
use rocket::http::{Method, Status};
|
||||||
use rocket::local::Client;
|
use rocket::local::blocking::Client;
|
||||||
use rocket::Rocket;
|
use rocket::Rocket;
|
||||||
|
|
||||||
use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions};
|
use crate::{AllowedHeaders, AllowedOrigins, Cors, CorsOptions};
|
||||||
|
@ -161,7 +169,8 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fairing_error_route_returns_passed_in_status() {
|
#[allow(non_snake_case)]
|
||||||
|
fn FairingErrorRoute_returns_passed_in_status() {
|
||||||
let client = Client::new(rocket(make_cors_options())).expect("to not fail");
|
let client = Client::new(rocket(make_cors_options())).expect("to not fail");
|
||||||
let request = client.get(format!("{}/403", CORS_ROOT));
|
let request = client.get(format!("{}/403", CORS_ROOT));
|
||||||
let response = request.dispatch();
|
let response = request.dispatch();
|
||||||
|
@ -169,19 +178,22 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fairing_error_route_returns_500_for_unknown_status() {
|
#[allow(non_snake_case)]
|
||||||
|
fn FairingErrorRoute_returns_500_for_unknown_status() {
|
||||||
let client = Client::new(rocket(make_cors_options())).expect("to not fail");
|
let client = Client::new(rocket(make_cors_options())).expect("to not fail");
|
||||||
let request = client.get(format!("{}/999", CORS_ROOT));
|
let request = client.get(format!("{}/999", CORS_ROOT));
|
||||||
let response = request.dispatch();
|
let response = request.dispatch();
|
||||||
assert_eq!(Status::InternalServerError, response.status());
|
assert_eq!(Status::InternalServerError, response.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[rocket::async_test]
|
||||||
fn error_route_is_mounted_on_attach() {
|
async fn error_route_is_mounted_on_attach() {
|
||||||
let rocket = rocket(make_cors_options());
|
let mut rocket = rocket(make_cors_options());
|
||||||
|
|
||||||
let expected_uri = format!("{}/<status>", CORS_ROOT);
|
let expected_uri = format!("{}/<status>", CORS_ROOT);
|
||||||
let error_route = rocket
|
let error_route = rocket
|
||||||
|
.inspect()
|
||||||
|
.await
|
||||||
.routes()
|
.routes()
|
||||||
.find(|r| r.method == Method::Get && r.uri.to_string() == expected_uri);
|
.find(|r| r.method == Method::Get && r.uri.to_string() == expected_uri);
|
||||||
assert!(error_route.is_some());
|
assert!(error_route.is_some());
|
||||||
|
|
151
src/headers.rs
151
src/headers.rs
|
@ -7,14 +7,11 @@ use std::str::FromStr;
|
||||||
|
|
||||||
use rocket::http::Status;
|
use rocket::http::Status;
|
||||||
use rocket::request::{self, FromRequest};
|
use rocket::request::{self, FromRequest};
|
||||||
use rocket::{self, Outcome};
|
use rocket::{self, outcome::Outcome};
|
||||||
#[cfg(feature = "serialization")]
|
#[cfg(feature = "serialization")]
|
||||||
use serde_derive::{Deserialize, Serialize};
|
use serde_derive::{Deserialize, Serialize};
|
||||||
use unicase::UniCase;
|
use unicase::UniCase;
|
||||||
|
|
||||||
#[cfg(feature = "serialization")]
|
|
||||||
use unicase_serde;
|
|
||||||
|
|
||||||
/// A case insensitive header name
|
/// A case insensitive header name
|
||||||
#[derive(Eq, PartialEq, Clone, Debug, Hash)]
|
#[derive(Eq, PartialEq, Clone, Debug, Hash)]
|
||||||
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
|
||||||
|
@ -91,6 +88,24 @@ impl Origin {
|
||||||
Origin::Opaque(_) => false,
|
Origin::Opaque(_) => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Derives an instance of `Self` from the incoming request metadata.
|
||||||
|
///
|
||||||
|
/// If the derivation is successful, an outcome of `Success` is returned. If
|
||||||
|
/// the derivation fails in an unrecoverable fashion, `Failure` is returned.
|
||||||
|
/// `Forward` is returned to indicate that the request should be forwarded
|
||||||
|
/// to other matching routes, if any.
|
||||||
|
pub fn from_request_sync(
|
||||||
|
request: &'_ rocket::Request<'_>,
|
||||||
|
) -> request::Outcome<Self, crate::Error> {
|
||||||
|
match request.headers().get_one("Origin") {
|
||||||
|
Some(origin) => match Self::from_str(origin) {
|
||||||
|
Ok(origin) => Outcome::Success(origin),
|
||||||
|
Err(e) => Outcome::Failure((Status::BadRequest, e)),
|
||||||
|
},
|
||||||
|
None => Outcome::Forward(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FromStr for Origin {
|
impl FromStr for Origin {
|
||||||
|
@ -118,19 +133,17 @@ impl fmt::Display for Origin {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Origin {
|
impl<'a, 'r> FromRequest<'a, 'r> for Origin {
|
||||||
type Error = crate::Error;
|
type Error = crate::Error;
|
||||||
|
|
||||||
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
|
async fn from_request(
|
||||||
match request.headers().get_one("Origin") {
|
request: &'a rocket::Request<'r>,
|
||||||
Some(origin) => match Self::from_str(origin) {
|
) -> request::Outcome<Self, crate::Error> {
|
||||||
Ok(origin) => Outcome::Success(origin),
|
Origin::from_request_sync(request)
|
||||||
Err(e) => Outcome::Failure((Status::BadRequest, e)),
|
|
||||||
},
|
|
||||||
None => Outcome::Forward(()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The `Access-Control-Request-Method` request header
|
/// The `Access-Control-Request-Method` request header
|
||||||
///
|
///
|
||||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||||
|
@ -138,18 +151,16 @@ impl<'a, 'r> FromRequest<'a, 'r> for Origin {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct AccessControlRequestMethod(pub crate::Method);
|
pub struct AccessControlRequestMethod(pub crate::Method);
|
||||||
|
|
||||||
impl FromStr for AccessControlRequestMethod {
|
impl AccessControlRequestMethod {
|
||||||
type Err = ();
|
/// Derives an instance of `Self` from the incoming request metadata.
|
||||||
|
///
|
||||||
fn from_str(method: &str) -> Result<Self, Self::Err> {
|
/// If the derivation is successful, an outcome of `Success` is returned. If
|
||||||
Ok(AccessControlRequestMethod(crate::Method::from_str(method)?))
|
/// the derivation fails in an unrecoverable fashion, `Failure` is returned.
|
||||||
}
|
/// `Forward` is returned to indicate that the request should be forwarded
|
||||||
}
|
/// to other matching routes, if any.
|
||||||
|
pub fn from_request_sync(
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
|
request: &'_ rocket::Request<'_>,
|
||||||
type Error = crate::Error;
|
) -> request::Outcome<Self, crate::Error> {
|
||||||
|
|
||||||
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
|
|
||||||
match request.headers().get_one("Access-Control-Request-Method") {
|
match request.headers().get_one("Access-Control-Request-Method") {
|
||||||
Some(request_method) => match Self::from_str(request_method) {
|
Some(request_method) => match Self::from_str(request_method) {
|
||||||
Ok(request_method) => Outcome::Success(request_method),
|
Ok(request_method) => Outcome::Success(request_method),
|
||||||
|
@ -160,6 +171,25 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FromStr for AccessControlRequestMethod {
|
||||||
|
type Err = ();
|
||||||
|
|
||||||
|
fn from_str(method: &str) -> Result<Self, Self::Err> {
|
||||||
|
Ok(AccessControlRequestMethod(crate::Method::from_str(method)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
|
||||||
|
type Error = crate::Error;
|
||||||
|
|
||||||
|
async fn from_request(
|
||||||
|
request: &'a rocket::Request<'r>,
|
||||||
|
) -> request::Outcome<Self, crate::Error> {
|
||||||
|
AccessControlRequestMethod::from_request_sync(request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// The `Access-Control-Request-Headers` request header
|
/// The `Access-Control-Request-Headers` request header
|
||||||
///
|
///
|
||||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||||
|
@ -167,6 +197,28 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
|
||||||
#[derive(Eq, PartialEq, Debug)]
|
#[derive(Eq, PartialEq, Debug)]
|
||||||
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
|
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
|
||||||
|
|
||||||
|
impl AccessControlRequestHeaders {
|
||||||
|
/// Derives an instance of `Self` from the incoming request metadata.
|
||||||
|
///
|
||||||
|
/// If the derivation is successful, an outcome of `Success` is returned. If
|
||||||
|
/// the derivation fails in an unrecoverable fashion, `Failure` is returned.
|
||||||
|
/// `Forward` is returned to indicate that the request should be forwarded
|
||||||
|
/// to other matching routes, if any.
|
||||||
|
pub fn from_request_sync(
|
||||||
|
request: &'_ rocket::Request<'_>,
|
||||||
|
) -> request::Outcome<Self, crate::Error> {
|
||||||
|
match request.headers().get_one("Access-Control-Request-Headers") {
|
||||||
|
Some(request_headers) => match Self::from_str(request_headers) {
|
||||||
|
Ok(request_headers) => Outcome::Success(request_headers),
|
||||||
|
Err(()) => {
|
||||||
|
unreachable!("`AccessControlRequestHeaders::from_str` should never fail")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => Outcome::Forward(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Will never fail
|
/// Will never fail
|
||||||
impl FromStr for AccessControlRequestHeaders {
|
impl FromStr for AccessControlRequestHeaders {
|
||||||
type Err = ();
|
type Err = ();
|
||||||
|
@ -185,19 +237,14 @@ impl FromStr for AccessControlRequestHeaders {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
|
impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
|
||||||
type Error = crate::Error;
|
type Error = crate::Error;
|
||||||
|
|
||||||
fn from_request(request: &'a rocket::Request<'r>) -> request::Outcome<Self, crate::Error> {
|
async fn from_request(
|
||||||
match request.headers().get_one("Access-Control-Request-Headers") {
|
request: &'a rocket::Request<'r>,
|
||||||
Some(request_headers) => match Self::from_str(request_headers) {
|
) -> request::Outcome<Self, crate::Error> {
|
||||||
Ok(request_headers) => Outcome::Success(request_headers),
|
AccessControlRequestHeaders::from_request_sync(request)
|
||||||
Err(()) => {
|
|
||||||
unreachable!("`AccessControlRequestHeaders::from_str` should never fail")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
None => Outcome::Forward(()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,9 +252,15 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
use hyper;
|
use rocket::http::hyper;
|
||||||
use rocket;
|
use rocket::http::Header;
|
||||||
use rocket::local::Client;
|
use rocket::local::blocking::Client;
|
||||||
|
|
||||||
|
static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN;
|
||||||
|
static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_METHOD;
|
||||||
|
static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
@ -277,11 +330,10 @@ mod tests {
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
let mut request = client.get("/");
|
let mut request = client.get("/");
|
||||||
|
|
||||||
let origin = hyper::header::Origin::new("https", "www.example.com", None);
|
let origin = Header::new(ORIGIN.as_str(), "https://www.example.com");
|
||||||
request.add_header(origin);
|
request.add_header(origin);
|
||||||
|
|
||||||
let outcome: request::Outcome<Origin, crate::Error> =
|
let outcome = Origin::from_request_sync(request.inner());
|
||||||
FromRequest::from_request(request.inner());
|
|
||||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
"https://www.example.com",
|
"https://www.example.com",
|
||||||
|
@ -313,10 +365,12 @@ mod tests {
|
||||||
fn request_method_parsing() {
|
fn request_method_parsing() {
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
let mut request = client.get("/");
|
let mut request = client.get("/");
|
||||||
let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get);
|
let method = Header::new(
|
||||||
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
|
hyper::Method::GET.as_str(),
|
||||||
|
);
|
||||||
request.add_header(method);
|
request.add_header(method);
|
||||||
let outcome: request::Outcome<AccessControlRequestMethod, crate::Error> =
|
let outcome = AccessControlRequestMethod::from_request_sync(request.inner());
|
||||||
FromRequest::from_request(request.inner());
|
|
||||||
|
|
||||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
let AccessControlRequestMethod(parsed_method) = parsed_header;
|
let AccessControlRequestMethod(parsed_method) = parsed_header;
|
||||||
|
@ -337,13 +391,12 @@ mod tests {
|
||||||
fn request_headers_parsing() {
|
fn request_headers_parsing() {
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
let mut request = client.get("/");
|
let mut request = client.get("/");
|
||||||
let headers = hyper::header::AccessControlRequestHeaders(vec![
|
let headers = Header::new(
|
||||||
FromStr::from_str("accept-language").unwrap(),
|
ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
|
||||||
FromStr::from_str("date").unwrap(),
|
"accept-language, date",
|
||||||
]);
|
);
|
||||||
request.add_header(headers);
|
request.add_header(headers);
|
||||||
let outcome: request::Outcome<AccessControlRequestHeaders, crate::Error> =
|
let outcome = AccessControlRequestHeaders::from_request_sync(request.inner());
|
||||||
FromRequest::from_request(request.inner());
|
|
||||||
|
|
||||||
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
|
||||||
let AccessControlRequestHeaders(parsed_headers) = parsed_header;
|
let AccessControlRequestHeaders(parsed_headers) = parsed_header;
|
||||||
|
|
264
src/lib.rs
264
src/lib.rs
|
@ -261,7 +261,7 @@ See the [example](https://github.com/lawliet89/rocket_cors/blob/master/examples/
|
||||||
missing_debug_implementations,
|
missing_debug_implementations,
|
||||||
unknown_lints,
|
unknown_lints,
|
||||||
unsafe_code,
|
unsafe_code,
|
||||||
intra_doc_link_resolution_failure
|
broken_intra_doc_links
|
||||||
)]
|
)]
|
||||||
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
|
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
|
||||||
|
|
||||||
|
@ -285,7 +285,7 @@ use regex::RegexSet;
|
||||||
use rocket::http::{self, Status};
|
use rocket::http::{self, Status};
|
||||||
use rocket::request::{FromRequest, Request};
|
use rocket::request::{FromRequest, Request};
|
||||||
use rocket::response;
|
use rocket::response;
|
||||||
use rocket::{debug_, error_, info_, log_, Outcome, State};
|
use rocket::{debug_, error_, info_, log_, outcome::Outcome, State};
|
||||||
#[cfg(feature = "serialization")]
|
#[cfg(feature = "serialization")]
|
||||||
use serde_derive::{Deserialize, Serialize};
|
use serde_derive::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
@ -417,8 +417,8 @@ impl error::Error for Error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r> response::Responder<'r> for Error {
|
impl<'r, 'o: 'r> response::Responder<'r, 'o> for Error {
|
||||||
fn respond_to(self, _: &Request<'_>) -> Result<response::Response<'r>, Status> {
|
fn respond_to(self, _: &Request<'_>) -> Result<response::Response<'o>, Status> {
|
||||||
error_!("CORS Error: {}", self);
|
error_!("CORS Error: {}", self);
|
||||||
Err(self.status())
|
Err(self.status())
|
||||||
}
|
}
|
||||||
|
@ -1256,10 +1256,13 @@ impl Cors {
|
||||||
/// passed in to include the CORS headers.
|
/// passed in to include the CORS headers.
|
||||||
///
|
///
|
||||||
/// See the documentation at the [crate root](index.html) for usage information.
|
/// See the documentation at the [crate root](index.html) for usage information.
|
||||||
pub fn respond_owned<'r, F, R>(self, handler: F) -> Result<ManualResponder<'r, F, R>, Error>
|
pub fn respond_owned<'r, 'o: 'r, F, R>(
|
||||||
|
self,
|
||||||
|
handler: F,
|
||||||
|
) -> Result<ManualResponder<'r, F, R>, Error>
|
||||||
where
|
where
|
||||||
F: FnOnce(Guard<'r>) -> R + 'r,
|
F: FnOnce(Guard<'r>) -> R + 'r,
|
||||||
R: response::Responder<'r>,
|
R: response::Responder<'r, 'o>,
|
||||||
{
|
{
|
||||||
Ok(ManualResponder::new(Cow::Owned(self), handler))
|
Ok(ManualResponder::new(Cow::Owned(self), handler))
|
||||||
}
|
}
|
||||||
|
@ -1276,13 +1279,13 @@ impl Cors {
|
||||||
/// passed in to include the CORS headers.
|
/// passed in to include the CORS headers.
|
||||||
///
|
///
|
||||||
/// See the documentation at the [crate root](index.html) for usage information.
|
/// See the documentation at the [crate root](index.html) for usage information.
|
||||||
pub fn respond_borrowed<'r, F, R>(
|
pub fn respond_borrowed<'r, 'o: 'r, F, R>(
|
||||||
&'r self,
|
&'r self,
|
||||||
handler: F,
|
handler: F,
|
||||||
) -> Result<ManualResponder<'r, F, R>, Error>
|
) -> Result<ManualResponder<'r, F, R>, Error>
|
||||||
where
|
where
|
||||||
F: FnOnce(Guard<'r>) -> R + 'r,
|
F: FnOnce(Guard<'r>) -> R + 'r,
|
||||||
R: response::Responder<'r>,
|
R: response::Responder<'r, 'o>,
|
||||||
{
|
{
|
||||||
Ok(ManualResponder::new(Cow::Borrowed(self), handler))
|
Ok(ManualResponder::new(Cow::Borrowed(self), handler))
|
||||||
}
|
}
|
||||||
|
@ -1375,7 +1378,10 @@ impl Response {
|
||||||
|
|
||||||
/// Consumes the `Response` and return a `Responder` that wraps a
|
/// Consumes the `Response` and return a `Responder` that wraps a
|
||||||
/// provided `rocket:response::Responder` with CORS headers
|
/// provided `rocket:response::Responder` with CORS headers
|
||||||
pub fn responder<'r, R: response::Responder<'r>>(self, responder: R) -> Responder<'r, R> {
|
pub fn responder<'r, 'o: 'r, R: response::Responder<'r, 'o>>(
|
||||||
|
self,
|
||||||
|
responder: R,
|
||||||
|
) -> Responder<'r, 'o, R> {
|
||||||
Responder::new(responder, self)
|
Responder::new(responder, self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1486,7 +1492,7 @@ pub struct Guard<'r> {
|
||||||
marker: PhantomData<&'r Response>,
|
marker: PhantomData<&'r Response>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r> Guard<'r> {
|
impl<'r, 'o: 'r> Guard<'r> {
|
||||||
fn new(response: Response) -> Self {
|
fn new(response: Response) -> Self {
|
||||||
Self {
|
Self {
|
||||||
response,
|
response,
|
||||||
|
@ -1496,7 +1502,7 @@ impl<'r> Guard<'r> {
|
||||||
|
|
||||||
/// Consumes the Guard and return a `Responder` that wraps a
|
/// Consumes the Guard and return a `Responder` that wraps a
|
||||||
/// provided `rocket:response::Responder` with CORS headers
|
/// provided `rocket:response::Responder` with CORS headers
|
||||||
pub fn responder<R: response::Responder<'r>>(self, responder: R) -> Responder<'r, R> {
|
pub fn responder<R: response::Responder<'r, 'o>>(self, responder: R) -> Responder<'r, 'o, R> {
|
||||||
self.response.responder(responder)
|
self.response.responder(responder)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1509,11 +1515,12 @@ impl<'r> Guard<'r> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> {
|
impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> {
|
||||||
type Error = Error;
|
type Error = Error;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> rocket::request::Outcome<Self, Self::Error> {
|
async fn from_request(request: &'a Request<'r>) -> rocket::request::Outcome<Self, Self::Error> {
|
||||||
let options = match request.guard::<State<'_, Cors>>() {
|
let options = match request.guard::<State<'_, Cors>>().await {
|
||||||
Outcome::Success(options) => options,
|
Outcome::Success(options) => options,
|
||||||
_ => {
|
_ => {
|
||||||
let error = Error::MissingCorsInRocketState;
|
let error = Error::MissingCorsInRocketState;
|
||||||
|
@ -1545,13 +1552,13 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard<'r> {
|
||||||
///
|
///
|
||||||
/// See the documentation at the [crate root](index.html) for usage information.
|
/// See the documentation at the [crate root](index.html) for usage information.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Responder<'r, R> {
|
pub struct Responder<'r, 'o, R> {
|
||||||
responder: R,
|
responder: R,
|
||||||
cors_response: Response,
|
cors_response: Response,
|
||||||
marker: PhantomData<dyn response::Responder<'r>>,
|
marker: PhantomData<dyn response::Responder<'r, 'o>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r, R: response::Responder<'r>> Responder<'r, R> {
|
impl<'r, 'o: 'r, R: response::Responder<'r, 'o>> Responder<'r, 'o, R> {
|
||||||
fn new(responder: R, cors_response: Response) -> Self {
|
fn new(responder: R, cors_response: Response) -> Self {
|
||||||
Self {
|
Self {
|
||||||
responder,
|
responder,
|
||||||
|
@ -1561,15 +1568,17 @@ impl<'r, R: response::Responder<'r>> Responder<'r, R> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Respond to a request
|
/// Respond to a request
|
||||||
fn respond(self, request: &Request<'_>) -> response::Result<'r> {
|
fn respond(self, request: &'r Request<'_>) -> response::Result<'o> {
|
||||||
let mut response = self.responder.respond_to(request)?; // handle status errors?
|
let mut response = self.responder.respond_to(request)?; // handle status errors?
|
||||||
self.cors_response.merge(&mut response);
|
self.cors_response.merge(&mut response);
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R> {
|
impl<'r, 'o: 'r, R: response::Responder<'r, 'o>> response::Responder<'r, 'o>
|
||||||
fn respond_to(self, request: &Request<'_>) -> response::Result<'r> {
|
for Responder<'r, 'o, R>
|
||||||
|
{
|
||||||
|
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> {
|
||||||
self.respond(request)
|
self.respond(request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1583,10 +1592,10 @@ pub struct ManualResponder<'r, F, R> {
|
||||||
marker: PhantomData<R>,
|
marker: PhantomData<R>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r, F, R> ManualResponder<'r, F, R>
|
impl<'r, 'o: 'r, F, R> ManualResponder<'r, F, R>
|
||||||
where
|
where
|
||||||
F: FnOnce(Guard<'r>) -> R + 'r,
|
F: FnOnce(Guard<'r>) -> R + 'r,
|
||||||
R: response::Responder<'r>,
|
R: response::Responder<'r, 'o>,
|
||||||
{
|
{
|
||||||
/// Create a new manual responder by passing in either a borrowed or owned `Cors` option.
|
/// Create a new manual responder by passing in either a borrowed or owned `Cors` option.
|
||||||
///
|
///
|
||||||
|
@ -1607,12 +1616,12 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r, F, R> response::Responder<'r> for ManualResponder<'r, F, R>
|
impl<'r, 'o: 'r, F, R> response::Responder<'r, 'o> for ManualResponder<'r, F, R>
|
||||||
where
|
where
|
||||||
F: FnOnce(Guard<'r>) -> R + 'r,
|
F: FnOnce(Guard<'r>) -> R + 'r,
|
||||||
R: response::Responder<'r>,
|
R: response::Responder<'r, 'o>,
|
||||||
{
|
{
|
||||||
fn respond_to(self, request: &Request<'_>) -> response::Result<'r> {
|
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> {
|
||||||
let guard = match self.build_guard(request) {
|
let guard = match self.build_guard(request) {
|
||||||
Ok(guard) => guard,
|
Ok(guard) => guard,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
@ -1759,7 +1768,7 @@ fn validate_allowed_headers(
|
||||||
|
|
||||||
/// Gets the `Origin` request header from the request
|
/// Gets the `Origin` request header from the request
|
||||||
fn origin(request: &Request<'_>) -> Result<Option<Origin>, Error> {
|
fn origin(request: &Request<'_>) -> Result<Option<Origin>, Error> {
|
||||||
match Origin::from_request(request) {
|
match Origin::from_request_sync(request) {
|
||||||
Outcome::Forward(()) => Ok(None),
|
Outcome::Forward(()) => Ok(None),
|
||||||
Outcome::Success(origin) => Ok(Some(origin)),
|
Outcome::Success(origin) => Ok(Some(origin)),
|
||||||
Outcome::Failure((_, err)) => Err(err),
|
Outcome::Failure((_, err)) => Err(err),
|
||||||
|
@ -1768,7 +1777,7 @@ fn origin(request: &Request<'_>) -> Result<Option<Origin>, Error> {
|
||||||
|
|
||||||
/// Gets the `Access-Control-Request-Method` request header from the request
|
/// Gets the `Access-Control-Request-Method` request header from the request
|
||||||
fn request_method(request: &Request<'_>) -> Result<Option<AccessControlRequestMethod>, Error> {
|
fn request_method(request: &Request<'_>) -> Result<Option<AccessControlRequestMethod>, Error> {
|
||||||
match AccessControlRequestMethod::from_request(request) {
|
match AccessControlRequestMethod::from_request_sync(request) {
|
||||||
Outcome::Forward(()) => Ok(None),
|
Outcome::Forward(()) => Ok(None),
|
||||||
Outcome::Success(method) => Ok(Some(method)),
|
Outcome::Success(method) => Ok(Some(method)),
|
||||||
Outcome::Failure((_, err)) => Err(err),
|
Outcome::Failure((_, err)) => Err(err),
|
||||||
|
@ -1777,7 +1786,7 @@ fn request_method(request: &Request<'_>) -> Result<Option<AccessControlRequestMe
|
||||||
|
|
||||||
/// Gets the `Access-Control-Request-Headers` request header from the request
|
/// Gets the `Access-Control-Request-Headers` request header from the request
|
||||||
fn request_headers(request: &Request<'_>) -> Result<Option<AccessControlRequestHeaders>, Error> {
|
fn request_headers(request: &Request<'_>) -> Result<Option<AccessControlRequestHeaders>, Error> {
|
||||||
match AccessControlRequestHeaders::from_request(request) {
|
match AccessControlRequestHeaders::from_request_sync(request) {
|
||||||
Outcome::Forward(()) => Ok(None),
|
Outcome::Forward(()) => Ok(None),
|
||||||
Outcome::Success(geaders) => Ok(Some(geaders)),
|
Outcome::Success(geaders) => Ok(Some(geaders)),
|
||||||
Outcome::Failure((_, err)) => Err(err),
|
Outcome::Failure((_, err)) => Err(err),
|
||||||
|
@ -1984,23 +1993,29 @@ pub fn catch_all_options_routes() -> Vec<rocket::Route> {
|
||||||
isize::max_value(),
|
isize::max_value(),
|
||||||
http::Method::Options,
|
http::Method::Options,
|
||||||
"/",
|
"/",
|
||||||
catch_all_options_route_handler,
|
CatchAllOptionsRouteHandler {},
|
||||||
),
|
),
|
||||||
rocket::Route::ranked(
|
rocket::Route::ranked(
|
||||||
isize::max_value(),
|
isize::max_value(),
|
||||||
http::Method::Options,
|
http::Method::Options,
|
||||||
"/<catch_all_options_route..>",
|
"/<catch_all_options_route..>",
|
||||||
catch_all_options_route_handler,
|
CatchAllOptionsRouteHandler {},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handler for the "catch all options route"
|
/// Handler for the "catch all options route"
|
||||||
fn catch_all_options_route_handler<'r>(
|
#[derive(Clone)]
|
||||||
|
struct CatchAllOptionsRouteHandler {}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl rocket::handler::Handler for CatchAllOptionsRouteHandler {
|
||||||
|
async fn handle<'r, 's: 'r>(
|
||||||
|
&'s self,
|
||||||
request: &'r Request<'_>,
|
request: &'r Request<'_>,
|
||||||
_: rocket::Data,
|
_: rocket::Data,
|
||||||
) -> rocket::handler::Outcome<'r> {
|
) -> rocket::handler::Outcome<'r> {
|
||||||
let guard: Guard<'_> = match request.guard() {
|
let guard: Guard<'_> = match request.guard().await {
|
||||||
Outcome::Success(guard) => guard,
|
Outcome::Success(guard) => guard,
|
||||||
Outcome::Failure((status, _)) => return rocket::handler::Outcome::failure(status),
|
Outcome::Failure((status, _)) => return rocket::handler::Outcome::failure(status),
|
||||||
Outcome::Forward(()) => unreachable!("Should not be reachable"),
|
Outcome::Forward(()) => unreachable!("Should not be reachable"),
|
||||||
|
@ -2013,19 +2028,25 @@ fn catch_all_options_route_handler<'r>(
|
||||||
|
|
||||||
rocket::handler::Outcome::from(request, guard.responder(()))
|
rocket::handler::Outcome::from(request, guard.responder(()))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use rocket::http::hyper;
|
||||||
use rocket::http::Header;
|
use rocket::http::Header;
|
||||||
use rocket::local::Client;
|
use rocket::local::blocking::Client;
|
||||||
#[cfg(feature = "serialization")]
|
|
||||||
use serde_json;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::http::Method;
|
use crate::http::Method;
|
||||||
|
|
||||||
|
static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN;
|
||||||
|
static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_METHOD;
|
||||||
|
static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS;
|
||||||
|
|
||||||
fn to_parsed_origin<S: AsRef<str>>(origin: S) -> Result<Origin, Error> {
|
fn to_parsed_origin<S: AsRef<str>>(origin: S) -> Result<Origin, Error> {
|
||||||
Origin::from_str(origin.as_ref())
|
Origin::from_str(origin.as_ref())
|
||||||
}
|
}
|
||||||
|
@ -2083,10 +2104,20 @@ mod tests {
|
||||||
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
let allowed_origins = AllowedOrigins::some_exact(&["https://www.acme.com"]);
|
||||||
let cors_options_from_builder = CorsOptions::default()
|
let cors_options_from_builder = CorsOptions::default()
|
||||||
.allowed_origins(allowed_origins)
|
.allowed_origins(allowed_origins)
|
||||||
.allowed_methods(vec![http::Method::Get].into_iter().map(From::from).collect())
|
.allowed_methods(
|
||||||
|
vec![http::Method::Get]
|
||||||
|
.into_iter()
|
||||||
|
.map(From::from)
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
.allowed_headers(AllowedHeaders::some(&[&"Authorization", "Accept"]))
|
.allowed_headers(AllowedHeaders::some(&[&"Authorization", "Accept"]))
|
||||||
.allow_credentials(true)
|
.allow_credentials(true)
|
||||||
.expose_headers(["Content-Type", "X-Custom"].iter().map(|s| (*s).to_string()).collect());
|
.expose_headers(
|
||||||
|
["Content-Type", "X-Custom"]
|
||||||
|
.iter()
|
||||||
|
.map(|s| (*s).to_string())
|
||||||
|
.collect(),
|
||||||
|
);
|
||||||
assert_eq!(cors_options_from_builder, make_cors_options());
|
assert_eq!(cors_options_from_builder, make_cors_options());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2507,11 +2538,12 @@ mod tests {
|
||||||
fn response_build_removes_existing_cors_headers_and_keeps_others() {
|
fn response_build_removes_existing_cors_headers_and_keeps_others() {
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
|
|
||||||
|
let body = "Brewing the best coffee!";
|
||||||
let original = response::Response::build()
|
let original = response::Response::build()
|
||||||
.status(Status::ImATeapot)
|
.status(Status::ImATeapot)
|
||||||
.raw_header("X-Teapot-Make", "Rocket")
|
.raw_header("X-Teapot-Make", "Rocket")
|
||||||
.raw_header("Access-Control-Max-Age", "42")
|
.raw_header("Access-Control-Max-Age", "42")
|
||||||
.sized_body(Cursor::new("Brewing the best coffee!"))
|
.sized_body(body.len(), Cursor::new(body))
|
||||||
.finalize();
|
.finalize();
|
||||||
|
|
||||||
let response = Response::new();
|
let response = Response::new();
|
||||||
|
@ -2572,16 +2604,12 @@ mod tests {
|
||||||
let cors = make_cors_options().to_cors().expect("To not fail");
|
let cors = make_cors_options().to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
|
|
||||||
let request = client
|
let request = client
|
||||||
.options("/")
|
.options("/")
|
||||||
|
@ -2607,16 +2635,12 @@ mod tests {
|
||||||
let cors = options.to_cors().expect("To not fail");
|
let cors = options.to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
|
|
||||||
let request = client
|
let request = client
|
||||||
.options("/")
|
.options("/")
|
||||||
|
@ -2639,16 +2663,12 @@ mod tests {
|
||||||
let cors = make_cors_options().to_cors().expect("To not fail");
|
let cors = make_cors_options().to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
|
|
||||||
let request = client
|
let request = client
|
||||||
.options("/")
|
.options("/")
|
||||||
|
@ -2665,13 +2685,8 @@ mod tests {
|
||||||
let cors = make_cors_options().to_cors().expect("To not fail");
|
let cors = make_cors_options().to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
let request_headers =
|
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
|
|
||||||
let request = client
|
let request = client
|
||||||
.options("/")
|
.options("/")
|
||||||
|
@ -2687,16 +2702,12 @@ mod tests {
|
||||||
let cors = make_cors_options().to_cors().expect("To not fail");
|
let cors = make_cors_options().to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Post,
|
hyper::Method::POST.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
|
|
||||||
let request = client
|
let request = client
|
||||||
.options("/")
|
.options("/")
|
||||||
|
@ -2713,16 +2724,15 @@ mod tests {
|
||||||
let cors = make_cors_options().to_cors().expect("To not fail");
|
let cors = make_cors_options().to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers = hyper::header::AccessControlRequestHeaders(vec![
|
let request_headers = Header::new(
|
||||||
FromStr::from_str("Authorization").unwrap(),
|
ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
|
||||||
FromStr::from_str("X-NOT-ALLOWED").unwrap(),
|
"Authorization, X-NOT-ALLOWED",
|
||||||
]);
|
);
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
|
|
||||||
let request = client
|
let request = client
|
||||||
.options("/")
|
.options("/")
|
||||||
|
@ -2738,8 +2748,7 @@ mod tests {
|
||||||
let cors = make_cors_options().to_cors().expect("To not fail");
|
let cors = make_cors_options().to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let request = client.get("/").header(origin_header);
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
let result = validate(&cors, request.inner()).expect("to not fail");
|
let result = validate(&cors, request.inner()).expect("to not fail");
|
||||||
|
@ -2757,8 +2766,7 @@ mod tests {
|
||||||
let cors = options.to_cors().expect("To not fail");
|
let cors = options.to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
|
|
||||||
let request = client.get("/").header(origin_header);
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
let result = validate(&cors, request.inner()).expect("to not fail");
|
let result = validate(&cors, request.inner()).expect("to not fail");
|
||||||
|
@ -2775,8 +2783,7 @@ mod tests {
|
||||||
let cors = make_cors_options().to_cors().expect("To not fail");
|
let cors = make_cors_options().to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
|
|
||||||
let request = client.get("/").header(origin_header);
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
let _ = validate(&cors, request.inner()).unwrap();
|
let _ = validate(&cors, request.inner()).unwrap();
|
||||||
|
@ -2799,16 +2806,12 @@ mod tests {
|
||||||
let cors = options.to_cors().expect("To not fail");
|
let cors = options.to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
|
|
||||||
let request = client
|
let request = client
|
||||||
.options("/")
|
.options("/")
|
||||||
|
@ -2839,16 +2842,12 @@ mod tests {
|
||||||
|
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
|
|
||||||
let request = client
|
let request = client
|
||||||
.options("/")
|
.options("/")
|
||||||
|
@ -2879,16 +2878,12 @@ mod tests {
|
||||||
|
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
|
|
||||||
let request = client
|
let request = client
|
||||||
.options("/")
|
.options("/")
|
||||||
|
@ -2914,8 +2909,7 @@ mod tests {
|
||||||
let cors = options.to_cors().expect("To not fail");
|
let cors = options.to_cors().expect("To not fail");
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let request = client.get("/").header(origin_header);
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||||
|
@ -2937,8 +2931,7 @@ mod tests {
|
||||||
|
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let request = client.get("/").header(origin_header);
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||||
|
@ -2960,8 +2953,7 @@ mod tests {
|
||||||
|
|
||||||
let client = make_client();
|
let client = make_client();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let request = client.get("/").header(origin_header);
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
let response = validate_and_build(&cors, request.inner()).expect("to not fail");
|
||||||
|
|
151
tests/fairing.rs
151
tests/fairing.rs
|
@ -1,23 +1,24 @@
|
||||||
//! This crate tests using `rocket_cors` using Fairings
|
//! This crate tests using `rocket_cors` using Fairings
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
use rocket::http::hyper;
|
||||||
use hyper;
|
|
||||||
|
|
||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
use rocket::http::Method;
|
use rocket::http::Method;
|
||||||
use rocket::http::{Header, Status};
|
use rocket::http::{Header, Status};
|
||||||
use rocket::local::Client;
|
use rocket::local::blocking::Client;
|
||||||
use rocket::response::Body;
|
|
||||||
use rocket::{get, routes};
|
use rocket::{get, routes};
|
||||||
use rocket_cors::*;
|
use rocket_cors::*;
|
||||||
|
|
||||||
|
static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN;
|
||||||
|
static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_METHOD;
|
||||||
|
static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS;
|
||||||
|
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn cors<'a>() -> &'a str {
|
fn cors<'a>() -> &'a str {
|
||||||
"Hello CORS"
|
"Hello CORS"
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/panic")]
|
#[get("/panic")]
|
||||||
fn panicking_route() {
|
fn panicking_route<'a>() -> &'a str {
|
||||||
panic!("This route will panic");
|
panic!("This route will panic");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,16 +47,12 @@ fn smoke_test() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
// `Options` pre-flight checks
|
// `Options` pre-flight checks
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -66,37 +63,31 @@ fn smoke_test() {
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
|
|
||||||
// "Actual" request
|
// "Actual" request
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.acme.com", origin_header);
|
assert_eq!("https://www.acme.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cors_options_check() {
|
fn cors_options_check() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -117,21 +108,19 @@ fn cors_options_check() {
|
||||||
fn cors_get_check() {
|
fn cors_get_check() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.acme.com", origin_header);
|
assert_eq!("https://www.acme.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
|
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
|
||||||
|
@ -142,9 +131,9 @@ fn cors_get_no_origin() {
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(authorization);
|
let req = client.get("/").header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
let body_str = response.into_string();
|
||||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,16 +141,12 @@ fn cors_get_no_origin() {
|
||||||
fn cors_options_bad_origin() {
|
fn cors_options_bad_origin() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -177,14 +162,11 @@ fn cors_options_bad_origin() {
|
||||||
fn cors_options_missing_origin() {
|
fn cors_options_missing_origin() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
let method_header = Header::new(
|
||||||
hyper::method::Method::Get,
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
));
|
hyper::Method::GET.as_str(),
|
||||||
let request_headers =
|
);
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(method_header)
|
.header(method_header)
|
||||||
|
@ -203,16 +185,12 @@ fn cors_options_missing_origin() {
|
||||||
fn cors_options_bad_request_method() {
|
fn cors_options_bad_request_method() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Post,
|
hyper::Method::POST.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -231,14 +209,12 @@ fn cors_options_bad_request_method() {
|
||||||
fn cors_options_bad_request_header() {
|
fn cors_options_bad_request_header() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Foobar");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -257,8 +233,7 @@ fn cors_options_bad_request_header() {
|
||||||
fn cors_get_bad_origin() {
|
fn cors_get_bad_origin() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
|
@ -277,16 +252,12 @@ fn cors_get_bad_origin() {
|
||||||
fn routes_failing_checks_are_not_executed() {
|
fn routes_failing_checks_are_not_executed() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/panic")
|
.options("/panic")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
|
193
tests/guard.rs
193
tests/guard.rs
|
@ -1,36 +1,38 @@
|
||||||
//! This crate tests using `rocket_cors` using the per-route handling with request guard
|
//! This crate tests using `rocket_cors` using the per-route handling with request guard
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
|
||||||
use hyper;
|
|
||||||
use rocket_cors as cors;
|
use rocket_cors as cors;
|
||||||
|
|
||||||
use std::str::FromStr;
|
use rocket::http::hyper;
|
||||||
|
|
||||||
use rocket::http::Method;
|
use rocket::http::Method;
|
||||||
use rocket::http::{Header, Status};
|
use rocket::http::{Header, Status};
|
||||||
use rocket::local::Client;
|
use rocket::local::blocking::Client;
|
||||||
use rocket::response::Body;
|
|
||||||
use rocket::{get, options, routes};
|
use rocket::{get, options, routes};
|
||||||
use rocket::{Response, State};
|
use rocket::{Response, State};
|
||||||
|
|
||||||
|
static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN;
|
||||||
|
static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_METHOD;
|
||||||
|
static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS;
|
||||||
|
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn cors(cors: cors::Guard<'_>) -> cors::Responder<'_, &str> {
|
fn cors(cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> {
|
||||||
cors.responder("Hello CORS")
|
cors.responder("Hello CORS")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/panic")]
|
#[get("/panic")]
|
||||||
fn panicking_route(_cors: cors::Guard<'_>) {
|
fn panicking_route(_cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> {
|
||||||
panic!("This route will panic");
|
panic!("This route will panic");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Manually specify our own OPTIONS route
|
/// Manually specify our own OPTIONS route
|
||||||
#[options("/manual")]
|
#[options("/manual")]
|
||||||
fn cors_manual_options(cors: cors::Guard<'_>) -> cors::Responder<'_, &str> {
|
fn cors_manual_options(cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> {
|
||||||
cors.responder("Manual CORS Preflight")
|
cors.responder("Manual CORS Preflight")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Manually specify our own OPTIONS route
|
/// Manually specify our own OPTIONS route
|
||||||
#[get("/manual")]
|
#[get("/manual")]
|
||||||
fn cors_manual(cors: cors::Guard<'_>) -> cors::Responder<'_, &str> {
|
fn cors_manual(cors: cors::Guard<'_>) -> cors::Responder<'_, '_, &str> {
|
||||||
cors.responder("Hello CORS")
|
cors.responder("Hello CORS")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,20 +44,23 @@ fn response(cors: cors::Guard<'_>) -> Response<'_> {
|
||||||
|
|
||||||
/// `Responder` with String
|
/// `Responder` with String
|
||||||
#[get("/responder/string")]
|
#[get("/responder/string")]
|
||||||
fn responder_string(cors: cors::Guard<'_>) -> cors::Responder<'_, String> {
|
fn responder_string(cors: cors::Guard<'_>) -> cors::Responder<'_, 'static, String> {
|
||||||
cors.responder("Hello CORS".to_string())
|
cors.responder("Hello CORS".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `Responder` with 'static ()
|
/// `Responder` with 'static ()
|
||||||
#[get("/responder/unit")]
|
#[get("/responder/unit")]
|
||||||
fn responder_unit(cors: cors::Guard<'_>) -> cors::Responder<'_, ()> {
|
fn responder_unit(cors: cors::Guard<'_>) -> cors::Responder<'_, 'static, ()> {
|
||||||
cors.responder(())
|
cors.responder(())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SomeState;
|
struct SomeState;
|
||||||
/// Borrow `SomeState` from Rocket
|
/// Borrow `SomeState` from Rocket
|
||||||
#[get("/state")]
|
#[get("/state")]
|
||||||
fn state<'r>(cors: cors::Guard<'r>, _state: State<'r, SomeState>) -> cors::Responder<'r, &'r str> {
|
fn state<'r, 'o: 'r>(
|
||||||
|
cors: cors::Guard<'r>,
|
||||||
|
_state: State<'r, SomeState>,
|
||||||
|
) -> cors::Responder<'r, 'o, &'r str> {
|
||||||
cors.responder("hmm")
|
cors.responder("hmm")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,16 +97,12 @@ fn smoke_test() {
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
// `Options` pre-flight checks
|
// `Options` pre-flight checks
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -112,21 +113,19 @@ fn smoke_test() {
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
|
|
||||||
// "Actual" request
|
// "Actual" request
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.acme.com", origin_header);
|
assert_eq!("https://www.acme.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check the "catch all" OPTIONS route works for `/`
|
/// Check the "catch all" OPTIONS route works for `/`
|
||||||
|
@ -135,16 +134,12 @@ fn cors_options_catch_all_check() {
|
||||||
let rocket = make_rocket();
|
let rocket = make_rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -167,16 +162,12 @@ fn cors_options_catch_all_check_other_routes() {
|
||||||
let rocket = make_rocket();
|
let rocket = make_rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/response/unit")
|
.options("/response/unit")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -198,21 +189,19 @@ fn cors_get_check() {
|
||||||
let rocket = make_rocket();
|
let rocket = make_rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.acme.com", origin_header);
|
assert_eq!("https://www.acme.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
|
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
|
||||||
|
@ -224,14 +213,14 @@ fn cors_get_no_origin() {
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(authorization);
|
let req = client.get("/").header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
|
||||||
assert!(response
|
assert!(response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.is_none());
|
.is_none());
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -239,16 +228,12 @@ fn cors_options_bad_origin() {
|
||||||
let rocket = make_rocket();
|
let rocket = make_rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -268,14 +253,11 @@ fn cors_options_missing_origin() {
|
||||||
let rocket = make_rocket();
|
let rocket = make_rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
let method_header = Header::new(
|
||||||
hyper::method::Method::Get,
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
));
|
hyper::Method::GET.as_str(),
|
||||||
let request_headers =
|
);
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(method_header)
|
.header(method_header)
|
||||||
|
@ -294,16 +276,12 @@ fn cors_options_bad_request_method() {
|
||||||
let rocket = make_rocket();
|
let rocket = make_rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Post,
|
hyper::Method::POST.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -323,14 +301,12 @@ fn cors_options_bad_request_header() {
|
||||||
let rocket = make_rocket();
|
let rocket = make_rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Foobar");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -350,8 +326,7 @@ fn cors_get_bad_origin() {
|
||||||
let rocket = make_rocket();
|
let rocket = make_rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
|
@ -371,8 +346,7 @@ fn routes_failing_checks_are_not_executed() {
|
||||||
let rocket = make_rocket();
|
let rocket = make_rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
|
@ -391,30 +365,25 @@ fn overridden_options_routes_are_used() {
|
||||||
let rocket = make_rocket();
|
let rocket = make_rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/manual")
|
.options("/manual")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
.header(method_header)
|
.header(method_header)
|
||||||
.header(request_headers);
|
.header(request_headers);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
assert_eq!(body_str, Some("Manual CORS Preflight".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.acme.com", origin_header);
|
assert_eq!("https://www.acme.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Manual CORS Preflight".to_string()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
//! This crate tests that all the request headers are parsed correctly in the round trip
|
//! This crate tests that all the request headers are parsed correctly in the round trip
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
|
||||||
use hyper;
|
|
||||||
|
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
|
use rocket::http::hyper;
|
||||||
use rocket::http::Header;
|
use rocket::http::Header;
|
||||||
use rocket::local::Client;
|
use rocket::local::blocking::Client;
|
||||||
use rocket::response::Body;
|
|
||||||
use rocket::{get, routes};
|
use rocket::{get, routes};
|
||||||
use rocket_cors::headers::*;
|
use rocket_cors::headers::*;
|
||||||
|
|
||||||
|
static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN;
|
||||||
|
static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_METHOD;
|
||||||
|
static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS;
|
||||||
|
|
||||||
#[get("/request_headers")]
|
#[get("/request_headers")]
|
||||||
fn request_headers(
|
fn request_headers(
|
||||||
origin: Origin,
|
origin: Origin,
|
||||||
|
@ -33,30 +35,27 @@ fn request_headers_round_trip_smoke_test() {
|
||||||
let rocket = rocket::ignite().mount("/", routes![request_headers]);
|
let rocket = rocket::ignite().mount("/", routes![request_headers]);
|
||||||
let client = Client::new(rocket).expect("A valid Rocket client");
|
let client = Client::new(rocket).expect("A valid Rocket client");
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://foo.bar.xyz");
|
||||||
Header::from(hyper::header::Origin::from_str("https://foo.bar.xyz").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers = hyper::header::AccessControlRequestHeaders(vec![
|
let request_headers = Header::new(
|
||||||
FromStr::from_str("accept-language").unwrap(),
|
ACCESS_CONTROL_REQUEST_HEADERS.as_str(),
|
||||||
FromStr::from_str("X-Ping").unwrap(),
|
"accept-language, X-Ping",
|
||||||
]);
|
);
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.get("/request_headers")
|
.get("/request_headers")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
.header(method_header)
|
.header(method_header)
|
||||||
.header(request_headers);
|
.header(request_headers);
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
|
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response
|
let body_str = response.into_string();
|
||||||
.body()
|
|
||||||
.and_then(Body::into_string)
|
|
||||||
.expect("Non-empty body");
|
|
||||||
let expected_body = r#"https://foo.bar.xyz
|
let expected_body = r#"https://foo.bar.xyz
|
||||||
GET
|
GET
|
||||||
X-Ping, accept-language"#;
|
X-Ping, accept-language"#
|
||||||
assert_eq!(expected_body, body_str);
|
.to_string();
|
||||||
|
assert_eq!(body_str, Some(expected_body));
|
||||||
}
|
}
|
||||||
|
|
200
tests/manual.rs
200
tests/manual.rs
|
@ -1,28 +1,29 @@
|
||||||
//! This crate tests using `rocket_cors` using manual mode
|
//! This crate tests using `rocket_cors` using manual mode
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
use rocket::http::hyper;
|
||||||
use hyper;
|
|
||||||
|
|
||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
use rocket::http::Method;
|
use rocket::http::Method;
|
||||||
use rocket::http::{Header, Status};
|
use rocket::http::{Header, Status};
|
||||||
use rocket::local::Client;
|
use rocket::local::blocking::Client;
|
||||||
use rocket::response::Body;
|
|
||||||
use rocket::response::Responder;
|
use rocket::response::Responder;
|
||||||
use rocket::State;
|
use rocket::State;
|
||||||
use rocket::{get, options, routes};
|
use rocket::{get, options, routes};
|
||||||
use rocket_cors::*;
|
use rocket_cors::*;
|
||||||
|
|
||||||
|
static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN;
|
||||||
|
static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_METHOD;
|
||||||
|
static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS;
|
||||||
|
|
||||||
/// Using a borrowed `Cors`
|
/// Using a borrowed `Cors`
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn cors(options: State<'_, Cors>) -> impl Responder<'_> {
|
fn cors(options: State<'_, Cors>) -> impl Responder<'_, '_> {
|
||||||
options
|
options
|
||||||
.inner()
|
.inner()
|
||||||
.respond_borrowed(|guard| guard.responder("Hello CORS"))
|
.respond_borrowed(|guard| guard.responder("Hello CORS"))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/panic")]
|
#[get("/panic")]
|
||||||
fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> {
|
fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_, '_> {
|
||||||
options.inner().respond_borrowed(|_| {
|
options.inner().respond_borrowed(|_| {
|
||||||
panic!("This route will panic");
|
panic!("This route will panic");
|
||||||
})
|
})
|
||||||
|
@ -30,7 +31,7 @@ fn panicking_route(options: State<'_, Cors>) -> impl Responder<'_> {
|
||||||
|
|
||||||
/// Respond with an owned option instead
|
/// Respond with an owned option instead
|
||||||
#[options("/owned")]
|
#[options("/owned")]
|
||||||
fn owned_options<'r>() -> impl Responder<'r> {
|
fn owned_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
|
||||||
let borrow = make_different_cors_options().to_cors()?;
|
let borrow = make_different_cors_options().to_cors()?;
|
||||||
|
|
||||||
borrow.respond_owned(|guard| guard.responder("Manual CORS Preflight"))
|
borrow.respond_owned(|guard| guard.responder("Manual CORS Preflight"))
|
||||||
|
@ -38,7 +39,7 @@ fn owned_options<'r>() -> impl Responder<'r> {
|
||||||
|
|
||||||
/// Respond with an owned option instead
|
/// Respond with an owned option instead
|
||||||
#[get("/owned")]
|
#[get("/owned")]
|
||||||
fn owned<'r>() -> impl Responder<'r> {
|
fn owned<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
|
||||||
let borrow = make_different_cors_options().to_cors()?;
|
let borrow = make_different_cors_options().to_cors()?;
|
||||||
|
|
||||||
borrow.respond_owned(|guard| guard.responder("Hello CORS Owned"))
|
borrow.respond_owned(|guard| guard.responder("Hello CORS Owned"))
|
||||||
|
@ -48,7 +49,8 @@ fn owned<'r>() -> impl Responder<'r> {
|
||||||
|
|
||||||
/// `Responder` with String
|
/// `Responder` with String
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn responder_string(options: State<'_, Cors>) -> impl Responder<'_> {
|
#[allow(dead_code)]
|
||||||
|
fn responder_string(options: State<'_, Cors>) -> impl Responder<'_, '_> {
|
||||||
options
|
options
|
||||||
.inner()
|
.inner()
|
||||||
.respond_borrowed(|guard| guard.responder("Hello CORS".to_string()))
|
.respond_borrowed(|guard| guard.responder("Hello CORS".to_string()))
|
||||||
|
@ -57,7 +59,11 @@ fn responder_string(options: State<'_, Cors>) -> impl Responder<'_> {
|
||||||
struct TestState;
|
struct TestState;
|
||||||
/// Borrow something else from Rocket with lifetime `'r`
|
/// Borrow something else from Rocket with lifetime `'r`
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn borrow<'r>(options: State<'r, Cors>, test_state: State<'r, TestState>) -> impl Responder<'r> {
|
#[allow(dead_code)]
|
||||||
|
fn borrow<'r, 'o: 'r>(
|
||||||
|
options: State<'r, Cors>,
|
||||||
|
test_state: State<'r, TestState>,
|
||||||
|
) -> impl Responder<'r, 'o> {
|
||||||
let borrow = test_state.inner();
|
let borrow = test_state.inner();
|
||||||
options.inner().respond_borrowed(move |guard| {
|
options.inner().respond_borrowed(move |guard| {
|
||||||
let _ = borrow;
|
let _ = borrow;
|
||||||
|
@ -102,16 +108,12 @@ fn smoke_test() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
// `Options` pre-flight checks
|
// `Options` pre-flight checks
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -122,37 +124,31 @@ fn smoke_test() {
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
|
|
||||||
// "Actual" request
|
// "Actual" request
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.acme.com", origin_header);
|
assert_eq!("https://www.acme.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cors_options_borrowed_check() {
|
fn cors_options_borrowed_check() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -173,21 +169,19 @@ fn cors_options_borrowed_check() {
|
||||||
fn cors_get_borrowed_check() {
|
fn cors_get_borrowed_check() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.acme.com", origin_header);
|
assert_eq!("https://www.acme.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
|
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
|
||||||
|
@ -198,9 +192,9 @@ fn cors_get_no_origin() {
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(authorization);
|
let req = client.get("/").header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
let body_str = response.into_string();
|
||||||
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
assert_eq!(body_str, Some("Hello CORS".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,16 +202,12 @@ fn cors_get_no_origin() {
|
||||||
fn cors_options_bad_origin() {
|
fn cors_options_bad_origin() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -232,14 +222,11 @@ fn cors_options_bad_origin() {
|
||||||
fn cors_options_missing_origin() {
|
fn cors_options_missing_origin() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
let method_header = Header::new(
|
||||||
hyper::method::Method::Get,
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
));
|
hyper::Method::GET.as_str(),
|
||||||
let request_headers =
|
);
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(method_header)
|
.header(method_header)
|
||||||
|
@ -257,16 +244,12 @@ fn cors_options_missing_origin() {
|
||||||
fn cors_options_bad_request_method() {
|
fn cors_options_bad_request_method() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Post,
|
hyper::Method::POST.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -285,14 +268,12 @@ fn cors_options_bad_request_method() {
|
||||||
fn cors_options_bad_request_header() {
|
fn cors_options_bad_request_header() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Foobar");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -311,8 +292,7 @@ fn cors_options_bad_request_header() {
|
||||||
fn cors_get_bad_origin() {
|
fn cors_get_bad_origin() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
|
@ -331,16 +311,12 @@ fn cors_get_bad_origin() {
|
||||||
fn routes_failing_checks_are_not_executed() {
|
fn routes_failing_checks_are_not_executed() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/panic")
|
.options("/panic")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -361,32 +337,28 @@ fn cors_options_owned_check() {
|
||||||
let rocket = rocket();
|
let rocket = rocket();
|
||||||
let client = Client::new(rocket).unwrap();
|
let client = Client::new(rocket).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/owned")
|
.options("/owned")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
.header(method_header)
|
.header(method_header)
|
||||||
.header(request_headers);
|
.header(request_headers);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
assert_eq!(body_str, Some("Manual CORS Preflight".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.example.com", origin_header);
|
assert_eq!("https://www.example.com", origin_header);
|
||||||
|
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Manual CORS Preflight".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Owned manual response works
|
/// Owned manual response works
|
||||||
|
@ -394,22 +366,20 @@ fn cors_options_owned_check() {
|
||||||
fn cors_get_owned_check() {
|
fn cors_get_owned_check() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client
|
let req = client
|
||||||
.get("/owned")
|
.get("/owned")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
.header(authorization);
|
.header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Hello CORS Owned".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.example.com", origin_header);
|
assert_eq!("https://www.example.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Hello CORS Owned".to_string()));
|
||||||
}
|
}
|
||||||
|
|
160
tests/mix.rs
160
tests/mix.rs
|
@ -2,29 +2,29 @@
|
||||||
//!
|
//!
|
||||||
//! In this example, you typically have an application wide `Cors` struct except for one specific
|
//! In this example, you typically have an application wide `Cors` struct except for one specific
|
||||||
//! `ping` route that you want to allow all Origins to access.
|
//! `ping` route that you want to allow all Origins to access.
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
use rocket::http::hyper;
|
||||||
use hyper;
|
|
||||||
use rocket_cors;
|
|
||||||
|
|
||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
use rocket::http::{Header, Method, Status};
|
use rocket::http::{Header, Method, Status};
|
||||||
use rocket::local::Client;
|
use rocket::local::blocking::Client;
|
||||||
use rocket::response::Body;
|
|
||||||
use rocket::response::Responder;
|
use rocket::response::Responder;
|
||||||
use rocket::{get, options, routes};
|
use rocket::{get, options, routes};
|
||||||
|
|
||||||
use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard};
|
use rocket_cors::{AllowedHeaders, AllowedOrigins, CorsOptions, Guard};
|
||||||
|
|
||||||
|
static ORIGIN: hyper::HeaderName = hyper::header::ORIGIN;
|
||||||
|
static ACCESS_CONTROL_REQUEST_METHOD: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_METHOD;
|
||||||
|
static ACCESS_CONTROL_REQUEST_HEADERS: hyper::HeaderName =
|
||||||
|
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS;
|
||||||
|
|
||||||
/// The "usual" app route
|
/// The "usual" app route
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, &str> {
|
fn app(cors: Guard<'_>) -> rocket_cors::Responder<'_, '_, &str> {
|
||||||
cors.responder("Hello CORS!")
|
cors.responder("Hello CORS!")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The special "ping" route
|
/// The special "ping" route
|
||||||
#[get("/ping")]
|
#[get("/ping")]
|
||||||
fn ping<'r>() -> impl Responder<'r> {
|
fn ping<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
|
||||||
let cors = cors_options_all().to_cors()?;
|
let cors = cors_options_all().to_cors()?;
|
||||||
cors.respond_owned(|guard| guard.responder("Pong!"))
|
cors.respond_owned(|guard| guard.responder("Pong!"))
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ fn ping<'r>() -> impl Responder<'r> {
|
||||||
/// that is not in Rocket's managed state.
|
/// that is not in Rocket's managed state.
|
||||||
/// These routes can just return the unit type `()`
|
/// These routes can just return the unit type `()`
|
||||||
#[options("/ping")]
|
#[options("/ping")]
|
||||||
fn ping_options<'r>() -> impl Responder<'r> {
|
fn ping_options<'r, 'o: 'r>() -> impl Responder<'r, 'o> {
|
||||||
let cors = cors_options_all().to_cors()?;
|
let cors = cors_options_all().to_cors()?;
|
||||||
cors.respond_owned(|guard| guard.responder(()))
|
cors.respond_owned(|guard| guard.responder(()))
|
||||||
}
|
}
|
||||||
|
@ -73,16 +73,12 @@ fn smoke_test() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
// `Options` pre-flight checks
|
// `Options` pre-flight checks
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -93,37 +89,31 @@ fn smoke_test() {
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
|
|
||||||
// "Actual" request
|
// "Actual" request
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Hello CORS!".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.acme.com", origin_header);
|
assert_eq!("https://www.acme.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Hello CORS!".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cors_options_check() {
|
fn cors_options_check() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -144,21 +134,19 @@ fn cors_options_check() {
|
||||||
fn cors_get_check() {
|
fn cors_get_check() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Hello CORS!".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.acme.com", origin_header);
|
assert_eq!("https://www.acme.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Hello CORS!".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
|
/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl)
|
||||||
|
@ -169,9 +157,9 @@ fn cors_get_no_origin() {
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(authorization);
|
let req = client.get("/").header(authorization);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
let body_str = response.into_string();
|
||||||
assert_eq!(body_str, Some("Hello CORS!".to_string()));
|
assert_eq!(body_str, Some("Hello CORS!".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,16 +167,12 @@ fn cors_get_no_origin() {
|
||||||
fn cors_options_bad_origin() {
|
fn cors_options_bad_origin() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -203,14 +187,11 @@ fn cors_options_bad_origin() {
|
||||||
fn cors_options_missing_origin() {
|
fn cors_options_missing_origin() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
let method_header = Header::new(
|
||||||
hyper::method::Method::Get,
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
));
|
hyper::Method::GET.as_str(),
|
||||||
let request_headers =
|
);
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(method_header)
|
.header(method_header)
|
||||||
|
@ -228,16 +209,12 @@ fn cors_options_missing_origin() {
|
||||||
fn cors_options_bad_request_method() {
|
fn cors_options_bad_request_method() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Post,
|
hyper::Method::POST.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Authorization");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![
|
|
||||||
FromStr::from_str("Authorization").unwrap()
|
|
||||||
]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -256,14 +233,12 @@ fn cors_options_bad_request_method() {
|
||||||
fn cors_options_bad_request_header() {
|
fn cors_options_bad_request_header() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.acme.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
let request_headers =
|
let request_headers = Header::new(ACCESS_CONTROL_REQUEST_HEADERS.as_str(), "Foobar");
|
||||||
hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]);
|
|
||||||
let request_headers = Header::from(request_headers);
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/")
|
.options("/")
|
||||||
.header(origin_header)
|
.header(origin_header)
|
||||||
|
@ -282,8 +257,7 @@ fn cors_options_bad_request_header() {
|
||||||
fn cors_get_bad_origin() {
|
fn cors_get_bad_origin() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.bad-origin.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap());
|
|
||||||
let authorization = Header::new("Authorization", "let me in");
|
let authorization = Header::new("Authorization", "let me in");
|
||||||
let req = client.get("/").header(origin_header).header(authorization);
|
let req = client.get("/").header(origin_header).header(authorization);
|
||||||
|
|
||||||
|
@ -300,11 +274,11 @@ fn cors_get_bad_origin() {
|
||||||
fn cors_options_ping_check() {
|
fn cors_options_ping_check() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
|
let method_header = Header::new(
|
||||||
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
ACCESS_CONTROL_REQUEST_METHOD.as_str(),
|
||||||
hyper::method::Method::Get,
|
hyper::Method::GET.as_str(),
|
||||||
));
|
);
|
||||||
|
|
||||||
let req = client
|
let req = client
|
||||||
.options("/ping")
|
.options("/ping")
|
||||||
|
@ -326,19 +300,17 @@ fn cors_options_ping_check() {
|
||||||
fn cors_get_ping_check() {
|
fn cors_get_ping_check() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
|
|
||||||
let origin_header =
|
let origin_header = Header::new(ORIGIN.as_str(), "https://www.example.com");
|
||||||
Header::from(hyper::header::Origin::from_str("https://www.example.com").unwrap());
|
|
||||||
|
|
||||||
let req = client.get("/ping").header(origin_header);
|
let req = client.get("/ping").header(origin_header);
|
||||||
|
|
||||||
let mut response = req.dispatch();
|
let response = req.dispatch();
|
||||||
assert!(response.status().class().is_success());
|
assert!(response.status().class().is_success());
|
||||||
let body_str = response.body().and_then(Body::into_string);
|
|
||||||
assert_eq!(body_str, Some("Pong!".to_string()));
|
|
||||||
|
|
||||||
let origin_header = response
|
let origin_header = response
|
||||||
.headers()
|
.headers()
|
||||||
.get_one("Access-Control-Allow-Origin")
|
.get_one("Access-Control-Allow-Origin")
|
||||||
.expect("to exist");
|
.expect("to exist");
|
||||||
assert_eq!("https://www.example.com", origin_header);
|
assert_eq!("https://www.example.com", origin_header);
|
||||||
|
let body_str = response.into_string();
|
||||||
|
assert_eq!(body_str, Some("Pong!".to_string()));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue