Add remaining missing unit tests (#12)

This commit is contained in:
Yong Wen Chua 2017-07-18 13:11:30 +08:00 committed by GitHub
parent 6389f6d1c6
commit 35e9665628
5 changed files with 561 additions and 24 deletions

View File

@ -88,7 +88,6 @@ impl rocket::fairing::Fairing for Cors {
fn on_request(&self, request: &mut Request, _: &rocket::Data) {
// Build and merge CORS response
// Type annotation is for sanity check
let cors_response = validate(self, request);
if let Err(ref err) = cors_response {
error_!("CORS Error: {}", err);
@ -105,3 +104,78 @@ impl rocket::fairing::Fairing for Cors {
}
}
}
#[cfg(test)]
mod tests {
use rocket::Rocket;
use rocket::http::{Method, Status};
use rocket::local::Client;
use {Cors, AllOrSome};
const CORS_ROOT: &'static str = "/my_cors";
fn make_cors_options() -> Cors {
let (allowed_origins, failed_origins) =
AllOrSome::new_from_str_list(&["https://www.acme.com"]);
assert!(failed_origins.is_empty());
Cors {
allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllOrSome::Some(
["Authorization"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true,
fairing_route_base: CORS_ROOT.to_string(),
..Default::default()
}
}
fn rocket(fairing: Cors) -> Rocket {
Rocket::ignite().attach(fairing)
}
#[test]
fn fairing_error_route_returns_passed_in_status() {
let client = Client::new(rocket(make_cors_options())).expect("to not fail");
let request = client.get(format!("{}/403", CORS_ROOT));
let response = request.dispatch();
assert_eq!(Status::Forbidden, response.status());
}
#[test]
fn fairing_error_route_returns_500_for_unknown_status() {
let client = Client::new(rocket(make_cors_options())).expect("to not fail");
let request = client.get(format!("{}/999", CORS_ROOT));
let response = request.dispatch();
assert_eq!(Status::InternalServerError, response.status());
}
#[test]
fn error_route_is_mounted_on_attach() {
let rocket = rocket(make_cors_options());
let expected_uri = format!("{}/<status>", CORS_ROOT);
let error_route = rocket.routes().find(|r| {
r.method == Method::Get && r.uri.as_str() == expected_uri
});
assert!(error_route.is_some());
}
#[test]
#[should_panic(expected = "launch fairing failure")]
fn options_are_validated_on_attach() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
options.send_wildcard = true;
let _ = rocket(options).launch();
}
// Rest of the things can only be tested in integration tests
}

View File

@ -144,7 +144,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod {
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that the header is passed in correctly.
#[derive(Debug)]
#[derive(Eq, PartialEq, Debug)]
pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet);
/// Will never fail
@ -184,7 +184,6 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders {
}
#[cfg(test)]
#[allow(unmounted_route)]
mod tests {
use std::str::FromStr;

View File

@ -500,14 +500,17 @@ impl Cors {
"/cors".to_string()
}
/// Build a CORS `Guard` to an incoming request.
/// Validate a request and then return a CORS Response
///
/// You will usually not have to use this function but simply place a route argument for the
/// `Guard` type. This is useful if you want an even more ad-hoc based approach to respond to
/// You will usually not have to use this function but simply place a r
/// equest guard in the route argument for the `Guard` type.
///
/// This is useful if you want an even more ad-hoc based approach to respond to
/// CORS by using a `Cors` that is not in Rocket's managed state.
pub fn guard<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result<Guard<'r>, Error> {
#[doc(hidden)] // Need to figure out a way to do this
pub fn validate_request<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result<Response, Error> {
let response = validate_and_build(self, request)?;
Ok(Guard::new(response))
Ok(response)
}
/// Validates if any of the settings are disallowed or incorrect
@ -531,8 +534,11 @@ impl Cors {
/// - `Access-Control-Allow-Methods`
/// - `Access-Control-Allow-Headers`
/// - `Vary`
///
/// You can get this struct by using `Cors::validate_request` in an ad-hoc manner.
#[doc(hidden)]
#[derive(Eq, PartialEq, Debug)]
struct Response {
pub struct Response {
allow_origin: Option<AllOrSome<String>>,
allow_methods: HashSet<Method>,
allow_headers: HeaderFieldNamesSet,
@ -705,8 +711,8 @@ impl Response {
/// A [request guard](https://rocket.rs/guide/requests/#request-guards) to check CORS headers
/// before a route is run. Will not execute the route if checks fail
///
// In essence, this is just a wrapper around `Response` with a `'r` borrowed lifetime so users
// don't have to keep specifying the lifetimes in their routes
/// In essence, this is just a wrapper around `Response` with a `'r` borrowed lifetime so users
/// don't have to keep specifying the lifetimes in their routes
pub struct Guard<'r> {
response: Response,
marker: PhantomData<&'r Response>,
@ -799,6 +805,7 @@ impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R
/// Result of CORS validation.
///
/// The variants hold enough information to build a response to the validation result
#[derive(Debug, Eq, PartialEq)]
enum ValidationResult {
/// Not a CORS request
None,
@ -1122,12 +1129,15 @@ fn actual_request_response(options: &Cors, origin: Origin) -> Response {
}
#[cfg(test)]
#[allow(unmounted_route)]
mod tests {
use std::str::FromStr;
use rocket::local::Client;
use rocket::http::Header;
use serde_json;
use http::Method;
use super::*;
use http::Method;
fn make_cors_options() -> Cors {
let (allowed_origins, failed_origins) =
@ -1141,16 +1151,34 @@ mod tests {
.map(From::from)
.collect(),
allowed_headers: AllOrSome::Some(
["Authorization"]
["Authorization", "Accept"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true,
expose_headers: ["Content-Type", "X-Custom"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),
..Default::default()
}
}
fn make_invalid_options() -> Cors {
let mut cors = make_cors_options();
cors.allow_credentials = true;
cors.allowed_origins = AllOrSome::All;
cors.send_wildcard = true;
cors
}
/// Make a client with no routes for unit testing
fn make_client() -> Client {
let rocket = rocket::ignite();
Client::new(rocket).expect("valid rocket instance")
}
// CORS options test
#[test]
@ -1161,10 +1189,7 @@ mod tests {
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn cors_validates_illegal_allow_credentials() {
let mut cors = make_cors_options();
cors.allow_credentials = true;
cors.allowed_origins = AllOrSome::All;
cors.send_wildcard = true;
let cors = make_invalid_options();
cors.validate().unwrap();
}
@ -1416,8 +1441,447 @@ mod tests {
);
}
// TODO: Preflight tests
// TODO: Actual requests tests
#[test]
fn preflight_validated_correctly() {
let options = make_cors_options();
let client = make_client();
// Origin all (wildcard + echoed with Vary). Origin Echo
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(
vec![FromStr::from_str("Authorization").unwrap()],
);
let request_headers = Header::from(request_headers);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let result = validate(&options, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight {
origin: FromStr::from_str("https://www.acme.com").unwrap(),
// Checks that only a subset of allowed headers are returned
// -- i.e. whatever is requested for
headers: Some(FromStr::from_str("Authorization").unwrap()),
};
assert_eq!(expected_result, result);
}
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn preflight_validation_errors_on_invalid_options() {
let options = make_invalid_options();
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(
vec![FromStr::from_str("Authorization").unwrap()],
);
let request_headers = Header::from(request_headers);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let _ = validate(&options, request.inner()).unwrap();
}
#[test]
fn preflight_validation_allows_all_origin() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.example.com").unwrap(),
);
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(
vec![FromStr::from_str("Authorization").unwrap()],
);
let request_headers = Header::from(request_headers);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let result = validate(&options, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Preflight {
origin: FromStr::from_str("https://www.example.com").unwrap(),
headers: Some(FromStr::from_str("Authorization").unwrap()),
};
assert_eq!(expected_result, result);
}
#[test]
#[should_panic(expected = "OriginNotAllowed")]
fn preflight_validation_errors_on_invalid_origin() {
let options = make_cors_options();
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.example.com").unwrap(),
);
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(
vec![FromStr::from_str("Authorization").unwrap()],
);
let request_headers = Header::from(request_headers);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let _ = validate(&options, request.inner()).unwrap();
}
#[test]
#[should_panic(expected = "MissingRequestMethod")]
fn preflight_validation_errors_on_missing_request_method() {
let options = make_cors_options();
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let request_headers = hyper::header::AccessControlRequestHeaders(
vec![FromStr::from_str("Authorization").unwrap()],
);
let request_headers = Header::from(request_headers);
let request = client.options("/").header(origin_header).header(
request_headers,
);
let _ = validate(&options, request.inner()).unwrap();
}
#[test]
#[should_panic(expected = "MethodNotAllowed")]
fn preflight_validation_errors_on_disallowed_method() {
let options = make_cors_options();
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Post,
));
let request_headers = hyper::header::AccessControlRequestHeaders(
vec![FromStr::from_str("Authorization").unwrap()],
);
let request_headers = Header::from(request_headers);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let _ = validate(&options, request.inner()).unwrap();
}
#[test]
#[should_panic(expected = "HeadersNotAllowed")]
fn preflight_validation_errors_on_disallowed_headers() {
let options = make_cors_options();
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("Authorization").unwrap(),
FromStr::from_str("X-NOT-ALLOWED").unwrap(),
]);
let request_headers = Header::from(request_headers);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let _ = validate(&options, request.inner()).unwrap();
}
#[test]
fn actual_request_validated_correctly() {
let options = make_cors_options();
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let request = client.get("/").header(origin_header);
let result = validate(&options, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request {
origin: FromStr::from_str("https://www.acme.com").unwrap(),
};
assert_eq!(expected_result, result);
}
#[test]
#[should_panic(expected = "CredentialsWithWildcardOrigin")]
fn actual_request_validation_errors_on_invalid_options() {
let options = make_invalid_options();
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let request = client.get("/").header(origin_header);
let _ = validate(&options, request.inner()).unwrap();
}
#[test]
fn actual_request_validation_allows_all_origin() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.example.com").unwrap(),
);
let request = client.get("/").header(origin_header);
let result = validate(&options, request.inner()).expect("to not fail");
let expected_result = ValidationResult::Request {
origin: FromStr::from_str("https://www.example.com").unwrap(),
};
assert_eq!(expected_result, result);
}
#[test]
#[should_panic(expected = "OriginNotAllowed")]
fn actual_request_validation_errors_on_incorrect_origin() {
let options = make_cors_options();
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.example.com").unwrap(),
);
let request = client.get("/").header(origin_header);
let _ = validate(&options, request.inner()).unwrap();
}
#[test]
fn non_cors_request_return_empty_response() {
let options = make_cors_options();
let client = make_client();
let request = client.options("/");
let response = validate_and_build(&options, request.inner()).expect("to not fail");
let expected_response = Response::new();
assert_eq!(expected_response, response);
}
#[test]
fn preflight_validated_and_built_correctly() {
let options = make_cors_options();
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(
vec![FromStr::from_str("Authorization").unwrap()],
);
let request_headers = Header::from(request_headers);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let response = validate_and_build(&options, request.inner()).expect("to not fail");
let expected_response = Response::new()
.origin("https://www.acme.com/", false)
.headers(&["Authorization"])
.methods(&options.allowed_methods)
.credentials(options.allow_credentials)
.max_age(options.max_age);
assert_eq!(expected_response, response);
}
/// Tests that when All origins are allowed and send_wildcard disabled, the vary header is set
/// in the response and the requested origin is echoed
#[test]
fn preflight_all_origins_with_vary() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
options.send_wildcard = false;
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(
vec![FromStr::from_str("Authorization").unwrap()],
);
let request_headers = Header::from(request_headers);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let response = validate_and_build(&options, request.inner()).expect("to not fail");
let expected_response = Response::new()
.origin("https://www.acme.com/", true)
.headers(&["Authorization"])
.methods(&options.allowed_methods)
.credentials(options.allow_credentials)
.max_age(options.max_age);
assert_eq!(expected_response, response);
}
/// Tests that when All origins are allowed and send_wildcard enabled, the origin is set to "*"
#[test]
fn preflight_all_origins_with_wildcard() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
options.send_wildcard = true;
options.allow_credentials = false;
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(
vec![FromStr::from_str("Authorization").unwrap()],
);
let request_headers = Header::from(request_headers);
let request = client
.options("/")
.header(origin_header)
.header(method_header)
.header(request_headers);
let response = validate_and_build(&options, request.inner()).expect("to not fail");
let expected_response = Response::new()
.any()
.headers(&["Authorization"])
.methods(&options.allowed_methods)
.credentials(options.allow_credentials)
.max_age(options.max_age);
assert_eq!(expected_response, response);
}
#[test]
fn actual_request_validated_and_built_correctly() {
let options = make_cors_options();
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let request = client.get("/").header(origin_header);
let response = validate_and_build(&options, request.inner()).expect("to not fail");
let expected_response = Response::new()
.origin("https://www.acme.com/", false)
.credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]);
assert_eq!(expected_response, response);
}
#[test]
fn actual_request_all_origins_with_vary() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
options.send_wildcard = false;
options.allow_credentials = false;
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let request = client.get("/").header(origin_header);
let response = validate_and_build(&options, request.inner()).expect("to not fail");
let expected_response = Response::new()
.origin("https://www.acme.com/", true)
.credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]);
assert_eq!(expected_response, response);
}
#[test]
fn actual_request_all_origins_with_wildcard() {
let mut options = make_cors_options();
options.allowed_origins = AllOrSome::All;
options.send_wildcard = true;
options.allow_credentials = false;
let client = make_client();
let origin_header = Header::from(
hyper::header::Origin::from_str("https://www.acme.com").unwrap(),
);
let request = client.get("/").header(origin_header);
let response = validate_and_build(&options, request.inner()).expect("to not fail");
let expected_response = Response::new()
.any()
.credentials(options.allow_credentials)
.exposed_headers(&["Content-Type", "X-Custom"]);
assert_eq!(expected_response, response);
}
}

View File

@ -31,7 +31,7 @@ fn make_cors_options() -> Cors {
allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: AllOrSome::Some(
["Authorization"]
["Authorization", "Accept"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),

View File

@ -1,4 +1,4 @@
//! This crate tests using rocket_cors using the "classic" ad-hoc per-route handling
//! This crate tests using rocket_cors using the per-route handling with request guard
#![feature(plugin, custom_derive)]
#![plugin(rocket_codegen)]
@ -68,7 +68,7 @@ fn make_cors_options() -> cors::Cors {
allowed_origins: allowed_origins,
allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(),
allowed_headers: cors::AllOrSome::Some(
["Authorization"]
["Authorization", "Accept"]
.into_iter()
.map(|s| s.to_string().into())
.collect(),