Add remaining missing unit tests (#12)
This commit is contained in:
parent
6389f6d1c6
commit
35e9665628
|
@ -88,7 +88,6 @@ impl rocket::fairing::Fairing for Cors {
|
||||||
|
|
||||||
fn on_request(&self, request: &mut Request, _: &rocket::Data) {
|
fn on_request(&self, request: &mut Request, _: &rocket::Data) {
|
||||||
// Build and merge CORS response
|
// Build and merge CORS response
|
||||||
// Type annotation is for sanity check
|
|
||||||
let cors_response = validate(self, request);
|
let cors_response = validate(self, request);
|
||||||
if let Err(ref err) = cors_response {
|
if let Err(ref err) = cors_response {
|
||||||
error_!("CORS Error: {}", err);
|
error_!("CORS Error: {}", err);
|
||||||
|
@ -105,3 +104,78 @@ impl rocket::fairing::Fairing for Cors {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use rocket::Rocket;
|
||||||
|
use rocket::http::{Method, Status};
|
||||||
|
use rocket::local::Client;
|
||||||
|
|
||||||
|
use {Cors, AllOrSome};
|
||||||
|
|
||||||
|
const CORS_ROOT: &'static str = "/my_cors";
|
||||||
|
|
||||||
|
fn make_cors_options() -> Cors {
|
||||||
|
let (allowed_origins, failed_origins) =
|
||||||
|
AllOrSome::new_from_str_list(&["https://www.acme.com"]);
|
||||||
|
assert!(failed_origins.is_empty());
|
||||||
|
|
||||||
|
Cors {
|
||||||
|
allowed_origins: allowed_origins,
|
||||||
|
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
||||||
|
allowed_headers: AllOrSome::Some(
|
||||||
|
["Authorization"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| s.to_string().into())
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
allow_credentials: true,
|
||||||
|
fairing_route_base: CORS_ROOT.to_string(),
|
||||||
|
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rocket(fairing: Cors) -> Rocket {
|
||||||
|
Rocket::ignite().attach(fairing)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fairing_error_route_returns_passed_in_status() {
|
||||||
|
let client = Client::new(rocket(make_cors_options())).expect("to not fail");
|
||||||
|
let request = client.get(format!("{}/403", CORS_ROOT));
|
||||||
|
let response = request.dispatch();
|
||||||
|
assert_eq!(Status::Forbidden, response.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fairing_error_route_returns_500_for_unknown_status() {
|
||||||
|
let client = Client::new(rocket(make_cors_options())).expect("to not fail");
|
||||||
|
let request = client.get(format!("{}/999", CORS_ROOT));
|
||||||
|
let response = request.dispatch();
|
||||||
|
assert_eq!(Status::InternalServerError, response.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_route_is_mounted_on_attach() {
|
||||||
|
let rocket = rocket(make_cors_options());
|
||||||
|
|
||||||
|
let expected_uri = format!("{}/<status>", CORS_ROOT);
|
||||||
|
let error_route = rocket.routes().find(|r| {
|
||||||
|
r.method == Method::Get && r.uri.as_str() == expected_uri
|
||||||
|
});
|
||||||
|
assert!(error_route.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "launch fairing failure")]
|
||||||
|
fn options_are_validated_on_attach() {
|
||||||
|
let mut options = make_cors_options();
|
||||||
|
options.allowed_origins = AllOrSome::All;
|
||||||
|
options.send_wildcard = true;
|
||||||
|
|
||||||
|
let _ = rocket(options).launch();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rest of the things can only be tested in integration tests
|
||||||
|
}
|
||||||
|
|
|
@ -144,7 +144,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
|
||||||
///
|
///
|
||||||
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
|
||||||
/// to ensure that the header is passed in correctly.
|
/// to ensure that the header is passed in correctly.
|
||||||
#[derive(Debug)]
|
#[derive(Eq, PartialEq, Debug)]
|
||||||
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
|
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
|
||||||
|
|
||||||
/// Will never fail
|
/// Will never fail
|
||||||
|
@ -184,7 +184,6 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[allow(unmounted_route)]
|
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
|
500
src/lib.rs
500
src/lib.rs
|
@ -500,14 +500,17 @@ impl Cors {
|
||||||
"/cors".to_string()
|
"/cors".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a CORS `Guard` to an incoming request.
|
/// Validate a request and then return a CORS Response
|
||||||
///
|
///
|
||||||
/// You will usually not have to use this function but simply place a route argument for the
|
/// You will usually not have to use this function but simply place a r
|
||||||
/// `Guard` type. This is useful if you want an even more ad-hoc based approach to respond to
|
/// equest guard in the route argument for the `Guard` type.
|
||||||
|
///
|
||||||
|
/// This is useful if you want an even more ad-hoc based approach to respond to
|
||||||
/// CORS by using a `Cors` that is not in Rocket's managed state.
|
/// CORS by using a `Cors` that is not in Rocket's managed state.
|
||||||
pub fn guard<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result<Guard<'r>, Error> {
|
#[doc(hidden)] // Need to figure out a way to do this
|
||||||
|
pub fn validate_request<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result<Response, Error> {
|
||||||
let response = validate_and_build(self, request)?;
|
let response = validate_and_build(self, request)?;
|
||||||
Ok(Guard::new(response))
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validates if any of the settings are disallowed or incorrect
|
/// Validates if any of the settings are disallowed or incorrect
|
||||||
|
@ -531,8 +534,11 @@ impl Cors {
|
||||||
/// - `Access-Control-Allow-Methods`
|
/// - `Access-Control-Allow-Methods`
|
||||||
/// - `Access-Control-Allow-Headers`
|
/// - `Access-Control-Allow-Headers`
|
||||||
/// - `Vary`
|
/// - `Vary`
|
||||||
|
///
|
||||||
|
/// You can get this struct by using `Cors::validate_request` in an ad-hoc manner.
|
||||||
|
#[doc(hidden)]
|
||||||
#[derive(Eq, PartialEq, Debug)]
|
#[derive(Eq, PartialEq, Debug)]
|
||||||
struct Response {
|
pub struct Response {
|
||||||
allow_origin: Option<AllOrSome<String>>,
|
allow_origin: Option<AllOrSome<String>>,
|
||||||
allow_methods: HashSet<Method>,
|
allow_methods: HashSet<Method>,
|
||||||
allow_headers: HeaderFieldNamesSet,
|
allow_headers: HeaderFieldNamesSet,
|
||||||
|
@ -705,8 +711,8 @@ impl Response {
|
||||||
/// A [request guard](https://rocket.rs/guide/requests/#request-guards) to check CORS headers
|
/// A [request guard](https://rocket.rs/guide/requests/#request-guards) to check CORS headers
|
||||||
/// before a route is run. Will not execute the route if checks fail
|
/// before a route is run. Will not execute the route if checks fail
|
||||||
///
|
///
|
||||||
// In essence, this is just a wrapper around `Response` with a `'r` borrowed lifetime so users
|
/// In essence, this is just a wrapper around `Response` with a `'r` borrowed lifetime so users
|
||||||
// don't have to keep specifying the lifetimes in their routes
|
/// don't have to keep specifying the lifetimes in their routes
|
||||||
pub struct Guard<'r> {
|
pub struct Guard<'r> {
|
||||||
response: Response,
|
response: Response,
|
||||||
marker: PhantomData<&'r Response>,
|
marker: PhantomData<&'r Response>,
|
||||||
|
@ -799,6 +805,7 @@ impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R
|
||||||
/// Result of CORS validation.
|
/// Result of CORS validation.
|
||||||
///
|
///
|
||||||
/// The variants hold enough information to build a response to the validation result
|
/// The variants hold enough information to build a response to the validation result
|
||||||
|
#[derive(Debug, Eq, PartialEq)]
|
||||||
enum ValidationResult {
|
enum ValidationResult {
|
||||||
/// Not a CORS request
|
/// Not a CORS request
|
||||||
None,
|
None,
|
||||||
|
@ -1122,12 +1129,15 @@ fn actual_request_response(options: &Cors, origin: Origin) -> Response {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[allow(unmounted_route)]
|
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use rocket::local::Client;
|
||||||
|
use rocket::http::Header;
|
||||||
use serde_json;
|
use serde_json;
|
||||||
use http::Method;
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use http::Method;
|
||||||
|
|
||||||
fn make_cors_options() -> Cors {
|
fn make_cors_options() -> Cors {
|
||||||
let (allowed_origins, failed_origins) =
|
let (allowed_origins, failed_origins) =
|
||||||
|
@ -1141,16 +1151,34 @@ mod tests {
|
||||||
.map(From::from)
|
.map(From::from)
|
||||||
.collect(),
|
.collect(),
|
||||||
allowed_headers: AllOrSome::Some(
|
allowed_headers: AllOrSome::Some(
|
||||||
["Authorization"]
|
["Authorization", "Accept"]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|s| s.to_string().into())
|
.map(|s| s.to_string().into())
|
||||||
.collect(),
|
.collect(),
|
||||||
),
|
),
|
||||||
allow_credentials: true,
|
allow_credentials: true,
|
||||||
|
expose_headers: ["Content-Type", "X-Custom"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| s.to_string().into())
|
||||||
|
.collect(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_invalid_options() -> Cors {
|
||||||
|
let mut cors = make_cors_options();
|
||||||
|
cors.allow_credentials = true;
|
||||||
|
cors.allowed_origins = AllOrSome::All;
|
||||||
|
cors.send_wildcard = true;
|
||||||
|
cors
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Make a client with no routes for unit testing
|
||||||
|
fn make_client() -> Client {
|
||||||
|
let rocket = rocket::ignite();
|
||||||
|
Client::new(rocket).expect("valid rocket instance")
|
||||||
|
}
|
||||||
|
|
||||||
// CORS options test
|
// CORS options test
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -1161,10 +1189,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
|
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
|
||||||
fn cors_validates_illegal_allow_credentials() {
|
fn cors_validates_illegal_allow_credentials() {
|
||||||
let mut cors = make_cors_options();
|
let cors = make_invalid_options();
|
||||||
cors.allow_credentials = true;
|
|
||||||
cors.allowed_origins = AllOrSome::All;
|
|
||||||
cors.send_wildcard = true;
|
|
||||||
|
|
||||||
cors.validate().unwrap();
|
cors.validate().unwrap();
|
||||||
}
|
}
|
||||||
|
@ -1416,8 +1441,447 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Preflight tests
|
#[test]
|
||||||
// TODO: Actual requests tests
|
fn preflight_validated_correctly() {
|
||||||
|
let options = make_cors_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
// Origin all (wildcard + echoed with Vary). Origin Echo
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
||||||
|
hyper::method::Method::Get,
|
||||||
|
));
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(
|
||||||
|
vec![FromStr::from_str("Authorization").unwrap()],
|
||||||
|
);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
|
||||||
|
let request = client
|
||||||
|
.options("/")
|
||||||
|
.header(origin_header)
|
||||||
|
.header(method_header)
|
||||||
|
.header(request_headers);
|
||||||
|
|
||||||
|
let result = validate(&options, request.inner()).expect("to not fail");
|
||||||
|
let expected_result = ValidationResult::Preflight {
|
||||||
|
origin: FromStr::from_str("https://www.acme.com").unwrap(),
|
||||||
|
// Checks that only a subset of allowed headers are returned
|
||||||
|
// -- i.e. whatever is requested for
|
||||||
|
headers: Some(FromStr::from_str("Authorization").unwrap()),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(expected_result, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
|
||||||
|
fn preflight_validation_errors_on_invalid_options() {
|
||||||
|
let options = make_invalid_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
||||||
|
hyper::method::Method::Get,
|
||||||
|
));
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(
|
||||||
|
vec![FromStr::from_str("Authorization").unwrap()],
|
||||||
|
);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
|
||||||
|
let request = client
|
||||||
|
.options("/")
|
||||||
|
.header(origin_header)
|
||||||
|
.header(method_header)
|
||||||
|
.header(request_headers);
|
||||||
|
|
||||||
|
let _ = validate(&options, request.inner()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preflight_validation_allows_all_origin() {
|
||||||
|
let mut options = make_cors_options();
|
||||||
|
options.allowed_origins = AllOrSome::All;
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.example.com").unwrap(),
|
||||||
|
);
|
||||||
|
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
||||||
|
hyper::method::Method::Get,
|
||||||
|
));
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(
|
||||||
|
vec![FromStr::from_str("Authorization").unwrap()],
|
||||||
|
);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
|
||||||
|
let request = client
|
||||||
|
.options("/")
|
||||||
|
.header(origin_header)
|
||||||
|
.header(method_header)
|
||||||
|
.header(request_headers);
|
||||||
|
|
||||||
|
let result = validate(&options, request.inner()).expect("to not fail");
|
||||||
|
let expected_result = ValidationResult::Preflight {
|
||||||
|
origin: FromStr::from_str("https://www.example.com").unwrap(),
|
||||||
|
headers: Some(FromStr::from_str("Authorization").unwrap()),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(expected_result, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "OriginNotAllowed")]
|
||||||
|
fn preflight_validation_errors_on_invalid_origin() {
|
||||||
|
let options = make_cors_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.example.com").unwrap(),
|
||||||
|
);
|
||||||
|
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
||||||
|
hyper::method::Method::Get,
|
||||||
|
));
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(
|
||||||
|
vec![FromStr::from_str("Authorization").unwrap()],
|
||||||
|
);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
|
||||||
|
let request = client
|
||||||
|
.options("/")
|
||||||
|
.header(origin_header)
|
||||||
|
.header(method_header)
|
||||||
|
.header(request_headers);
|
||||||
|
|
||||||
|
let _ = validate(&options, request.inner()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "MissingRequestMethod")]
|
||||||
|
fn preflight_validation_errors_on_missing_request_method() {
|
||||||
|
let options = make_cors_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(
|
||||||
|
vec![FromStr::from_str("Authorization").unwrap()],
|
||||||
|
);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
|
||||||
|
let request = client.options("/").header(origin_header).header(
|
||||||
|
request_headers,
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = validate(&options, request.inner()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "MethodNotAllowed")]
|
||||||
|
fn preflight_validation_errors_on_disallowed_method() {
|
||||||
|
let options = make_cors_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
||||||
|
hyper::method::Method::Post,
|
||||||
|
));
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(
|
||||||
|
vec![FromStr::from_str("Authorization").unwrap()],
|
||||||
|
);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
|
||||||
|
let request = client
|
||||||
|
.options("/")
|
||||||
|
.header(origin_header)
|
||||||
|
.header(method_header)
|
||||||
|
.header(request_headers);
|
||||||
|
|
||||||
|
let _ = validate(&options, request.inner()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "HeadersNotAllowed")]
|
||||||
|
fn preflight_validation_errors_on_disallowed_headers() {
|
||||||
|
let options = make_cors_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
||||||
|
hyper::method::Method::Get,
|
||||||
|
));
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(vec![
|
||||||
|
FromStr::from_str("Authorization").unwrap(),
|
||||||
|
FromStr::from_str("X-NOT-ALLOWED").unwrap(),
|
||||||
|
]);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
|
||||||
|
let request = client
|
||||||
|
.options("/")
|
||||||
|
.header(origin_header)
|
||||||
|
.header(method_header)
|
||||||
|
.header(request_headers);
|
||||||
|
|
||||||
|
let _ = validate(&options, request.inner()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn actual_request_validated_correctly() {
|
||||||
|
let options = make_cors_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
|
let result = validate(&options, request.inner()).expect("to not fail");
|
||||||
|
let expected_result = ValidationResult::Request {
|
||||||
|
origin: FromStr::from_str("https://www.acme.com").unwrap(),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(expected_result, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
|
||||||
|
fn actual_request_validation_errors_on_invalid_options() {
|
||||||
|
let options = make_invalid_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
|
let _ = validate(&options, request.inner()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn actual_request_validation_allows_all_origin() {
|
||||||
|
let mut options = make_cors_options();
|
||||||
|
options.allowed_origins = AllOrSome::All;
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.example.com").unwrap(),
|
||||||
|
);
|
||||||
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
|
let result = validate(&options, request.inner()).expect("to not fail");
|
||||||
|
let expected_result = ValidationResult::Request {
|
||||||
|
origin: FromStr::from_str("https://www.example.com").unwrap(),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(expected_result, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "OriginNotAllowed")]
|
||||||
|
fn actual_request_validation_errors_on_incorrect_origin() {
|
||||||
|
let options = make_cors_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.example.com").unwrap(),
|
||||||
|
);
|
||||||
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
|
let _ = validate(&options, request.inner()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn non_cors_request_return_empty_response() {
|
||||||
|
let options = make_cors_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let request = client.options("/");
|
||||||
|
let response = validate_and_build(&options, request.inner()).expect("to not fail");
|
||||||
|
let expected_response = Response::new();
|
||||||
|
assert_eq!(expected_response, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preflight_validated_and_built_correctly() {
|
||||||
|
let options = make_cors_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
||||||
|
hyper::method::Method::Get,
|
||||||
|
));
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(
|
||||||
|
vec![FromStr::from_str("Authorization").unwrap()],
|
||||||
|
);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
|
||||||
|
let request = client
|
||||||
|
.options("/")
|
||||||
|
.header(origin_header)
|
||||||
|
.header(method_header)
|
||||||
|
.header(request_headers);
|
||||||
|
|
||||||
|
let response = validate_and_build(&options, request.inner()).expect("to not fail");
|
||||||
|
|
||||||
|
let expected_response = Response::new()
|
||||||
|
.origin("https://www.acme.com/", false)
|
||||||
|
.headers(&["Authorization"])
|
||||||
|
.methods(&options.allowed_methods)
|
||||||
|
.credentials(options.allow_credentials)
|
||||||
|
.max_age(options.max_age);
|
||||||
|
|
||||||
|
assert_eq!(expected_response, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tests that when All origins are allowed and send_wildcard disabled, the vary header is set
|
||||||
|
/// in the response and the requested origin is echoed
|
||||||
|
#[test]
|
||||||
|
fn preflight_all_origins_with_vary() {
|
||||||
|
let mut options = make_cors_options();
|
||||||
|
options.allowed_origins = AllOrSome::All;
|
||||||
|
options.send_wildcard = false;
|
||||||
|
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
||||||
|
hyper::method::Method::Get,
|
||||||
|
));
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(
|
||||||
|
vec![FromStr::from_str("Authorization").unwrap()],
|
||||||
|
);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
|
||||||
|
let request = client
|
||||||
|
.options("/")
|
||||||
|
.header(origin_header)
|
||||||
|
.header(method_header)
|
||||||
|
.header(request_headers);
|
||||||
|
|
||||||
|
let response = validate_and_build(&options, request.inner()).expect("to not fail");
|
||||||
|
|
||||||
|
let expected_response = Response::new()
|
||||||
|
.origin("https://www.acme.com/", true)
|
||||||
|
.headers(&["Authorization"])
|
||||||
|
.methods(&options.allowed_methods)
|
||||||
|
.credentials(options.allow_credentials)
|
||||||
|
.max_age(options.max_age);
|
||||||
|
|
||||||
|
assert_eq!(expected_response, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tests that when All origins are allowed and send_wildcard enabled, the origin is set to "*"
|
||||||
|
#[test]
|
||||||
|
fn preflight_all_origins_with_wildcard() {
|
||||||
|
let mut options = make_cors_options();
|
||||||
|
options.allowed_origins = AllOrSome::All;
|
||||||
|
options.send_wildcard = true;
|
||||||
|
options.allow_credentials = false;
|
||||||
|
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
|
||||||
|
hyper::method::Method::Get,
|
||||||
|
));
|
||||||
|
let request_headers = hyper::header::AccessControlRequestHeaders(
|
||||||
|
vec![FromStr::from_str("Authorization").unwrap()],
|
||||||
|
);
|
||||||
|
let request_headers = Header::from(request_headers);
|
||||||
|
|
||||||
|
let request = client
|
||||||
|
.options("/")
|
||||||
|
.header(origin_header)
|
||||||
|
.header(method_header)
|
||||||
|
.header(request_headers);
|
||||||
|
|
||||||
|
let response = validate_and_build(&options, request.inner()).expect("to not fail");
|
||||||
|
|
||||||
|
let expected_response = Response::new()
|
||||||
|
.any()
|
||||||
|
.headers(&["Authorization"])
|
||||||
|
.methods(&options.allowed_methods)
|
||||||
|
.credentials(options.allow_credentials)
|
||||||
|
.max_age(options.max_age);
|
||||||
|
|
||||||
|
assert_eq!(expected_response, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn actual_request_validated_and_built_correctly() {
|
||||||
|
let options = make_cors_options();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
|
let response = validate_and_build(&options, request.inner()).expect("to not fail");
|
||||||
|
let expected_response = Response::new()
|
||||||
|
.origin("https://www.acme.com/", false)
|
||||||
|
.credentials(options.allow_credentials)
|
||||||
|
.exposed_headers(&["Content-Type", "X-Custom"]);
|
||||||
|
|
||||||
|
assert_eq!(expected_response, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn actual_request_all_origins_with_vary() {
|
||||||
|
let mut options = make_cors_options();
|
||||||
|
options.allowed_origins = AllOrSome::All;
|
||||||
|
options.send_wildcard = false;
|
||||||
|
options.allow_credentials = false;
|
||||||
|
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
|
let response = validate_and_build(&options, request.inner()).expect("to not fail");
|
||||||
|
let expected_response = Response::new()
|
||||||
|
.origin("https://www.acme.com/", true)
|
||||||
|
.credentials(options.allow_credentials)
|
||||||
|
.exposed_headers(&["Content-Type", "X-Custom"]);
|
||||||
|
|
||||||
|
assert_eq!(expected_response, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn actual_request_all_origins_with_wildcard() {
|
||||||
|
let mut options = make_cors_options();
|
||||||
|
options.allowed_origins = AllOrSome::All;
|
||||||
|
options.send_wildcard = true;
|
||||||
|
options.allow_credentials = false;
|
||||||
|
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let origin_header = Header::from(
|
||||||
|
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
|
||||||
|
);
|
||||||
|
let request = client.get("/").header(origin_header);
|
||||||
|
|
||||||
|
let response = validate_and_build(&options, request.inner()).expect("to not fail");
|
||||||
|
let expected_response = Response::new()
|
||||||
|
.any()
|
||||||
|
.credentials(options.allow_credentials)
|
||||||
|
.exposed_headers(&["Content-Type", "X-Custom"]);
|
||||||
|
|
||||||
|
assert_eq!(expected_response, response);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ fn make_cors_options() -> Cors {
|
||||||
allowed_origins: allowed_origins,
|
allowed_origins: allowed_origins,
|
||||||
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
||||||
allowed_headers: AllOrSome::Some(
|
allowed_headers: AllOrSome::Some(
|
||||||
["Authorization"]
|
["Authorization", "Accept"]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|s| s.to_string().into())
|
.map(|s| s.to_string().into())
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//! This crate tests using rocket_cors using the "classic" ad-hoc per-route handling
|
//! This crate tests using rocket_cors using the per-route handling with request guard
|
||||||
|
|
||||||
#![feature(plugin, custom_derive)]
|
#![feature(plugin, custom_derive)]
|
||||||
#![plugin(rocket_codegen)]
|
#![plugin(rocket_codegen)]
|
||||||
|
@ -68,7 +68,7 @@ fn make_cors_options() -> cors::Cors {
|
||||||
allowed_origins: allowed_origins,
|
allowed_origins: allowed_origins,
|
||||||
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
|
||||||
allowed_headers: cors::AllOrSome::Some(
|
allowed_headers: cors::AllOrSome::Some(
|
||||||
["Authorization"]
|
["Authorization", "Accept"]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|s| s.to_string().into())
|
.map(|s| s.to_string().into())
|
||||||
.collect(),
|
.collect(),
|
Loading…
Reference in New Issue