diff --git a/src/fairing.rs b/src/fairing.rs index f2bb393..66703d5 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -88,7 +88,6 @@ impl rocket::fairing::Fairing for Cors { fn on_request(&self, request: &mut Request, _: &rocket::Data) { // Build and merge CORS response - // Type annotation is for sanity check let cors_response = validate(self, request); if let Err(ref err) = cors_response { error_!("CORS Error: {}", err); @@ -105,3 +104,78 @@ impl rocket::fairing::Fairing for Cors { } } } + +#[cfg(test)] +mod tests { + use rocket::Rocket; + use rocket::http::{Method, Status}; + use rocket::local::Client; + + use {Cors, AllOrSome}; + + const CORS_ROOT: &'static str = "/my_cors"; + + fn make_cors_options() -> Cors { + let (allowed_origins, failed_origins) = + AllOrSome::new_from_str_list(&["https://www.acme.com"]); + assert!(failed_origins.is_empty()); + + Cors { + allowed_origins: allowed_origins, + allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), + allowed_headers: AllOrSome::Some( + ["Authorization"] + .into_iter() + .map(|s| s.to_string().into()) + .collect(), + ), + allow_credentials: true, + fairing_route_base: CORS_ROOT.to_string(), + + ..Default::default() + } + } + + fn rocket(fairing: Cors) -> Rocket { + Rocket::ignite().attach(fairing) + } + + #[test] + fn fairing_error_route_returns_passed_in_status() { + let client = Client::new(rocket(make_cors_options())).expect("to not fail"); + let request = client.get(format!("{}/403", CORS_ROOT)); + let response = request.dispatch(); + assert_eq!(Status::Forbidden, response.status()); + } + + #[test] + fn fairing_error_route_returns_500_for_unknown_status() { + let client = Client::new(rocket(make_cors_options())).expect("to not fail"); + let request = client.get(format!("{}/999", CORS_ROOT)); + let response = request.dispatch(); + assert_eq!(Status::InternalServerError, response.status()); + } + + #[test] + fn error_route_is_mounted_on_attach() { + let rocket = rocket(make_cors_options()); + + let expected_uri = format!("{}/", CORS_ROOT); + let error_route = rocket.routes().find(|r| { + r.method == Method::Get && r.uri.as_str() == expected_uri + }); + assert!(error_route.is_some()); + } + + #[test] + #[should_panic(expected = "launch fairing failure")] + fn options_are_validated_on_attach() { + let mut options = make_cors_options(); + options.allowed_origins = AllOrSome::All; + options.send_wildcard = true; + + let _ = rocket(options).launch(); + } + + // Rest of the things can only be tested in integration tests +} diff --git a/src/headers.rs b/src/headers.rs index b85e730..71dfee7 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -144,7 +144,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod { /// /// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) /// to ensure that the header is passed in correctly. -#[derive(Debug)] +#[derive(Eq, PartialEq, Debug)] pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet); /// Will never fail @@ -184,7 +184,6 @@ impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders { } #[cfg(test)] -#[allow(unmounted_route)] mod tests { use std::str::FromStr; diff --git a/src/lib.rs b/src/lib.rs index 4b294df..6181ffd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -500,14 +500,17 @@ impl Cors { "/cors".to_string() } - /// Build a CORS `Guard` to an incoming request. + /// Validate a request and then return a CORS Response /// - /// You will usually not have to use this function but simply place a route argument for the - /// `Guard` type. This is useful if you want an even more ad-hoc based approach to respond to + /// You will usually not have to use this function but simply place a r + /// equest guard in the route argument for the `Guard` type. + /// + /// This is useful if you want an even more ad-hoc based approach to respond to /// CORS by using a `Cors` that is not in Rocket's managed state. - pub fn guard<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result, Error> { + #[doc(hidden)] // Need to figure out a way to do this + pub fn validate_request<'a, 'r>(&'a self, request: &'a Request<'r>) -> Result { let response = validate_and_build(self, request)?; - Ok(Guard::new(response)) + Ok(response) } /// Validates if any of the settings are disallowed or incorrect @@ -531,8 +534,11 @@ impl Cors { /// - `Access-Control-Allow-Methods` /// - `Access-Control-Allow-Headers` /// - `Vary` +/// +/// You can get this struct by using `Cors::validate_request` in an ad-hoc manner. +#[doc(hidden)] #[derive(Eq, PartialEq, Debug)] -struct Response { +pub struct Response { allow_origin: Option>, allow_methods: HashSet, allow_headers: HeaderFieldNamesSet, @@ -705,8 +711,8 @@ impl Response { /// A [request guard](https://rocket.rs/guide/requests/#request-guards) to check CORS headers /// before a route is run. Will not execute the route if checks fail /// -// In essence, this is just a wrapper around `Response` with a `'r` borrowed lifetime so users -// don't have to keep specifying the lifetimes in their routes +/// In essence, this is just a wrapper around `Response` with a `'r` borrowed lifetime so users +/// don't have to keep specifying the lifetimes in their routes pub struct Guard<'r> { response: Response, marker: PhantomData<&'r Response>, @@ -799,6 +805,7 @@ impl<'r, R: response::Responder<'r>> response::Responder<'r> for Responder<'r, R /// Result of CORS validation. /// /// The variants hold enough information to build a response to the validation result +#[derive(Debug, Eq, PartialEq)] enum ValidationResult { /// Not a CORS request None, @@ -1122,12 +1129,15 @@ fn actual_request_response(options: &Cors, origin: Origin) -> Response { } #[cfg(test)] -#[allow(unmounted_route)] mod tests { use std::str::FromStr; + + use rocket::local::Client; + use rocket::http::Header; use serde_json; - use http::Method; + use super::*; + use http::Method; fn make_cors_options() -> Cors { let (allowed_origins, failed_origins) = @@ -1141,16 +1151,34 @@ mod tests { .map(From::from) .collect(), allowed_headers: AllOrSome::Some( - ["Authorization"] + ["Authorization", "Accept"] .into_iter() .map(|s| s.to_string().into()) .collect(), ), allow_credentials: true, + expose_headers: ["Content-Type", "X-Custom"] + .into_iter() + .map(|s| s.to_string().into()) + .collect(), ..Default::default() } } + fn make_invalid_options() -> Cors { + let mut cors = make_cors_options(); + cors.allow_credentials = true; + cors.allowed_origins = AllOrSome::All; + cors.send_wildcard = true; + cors + } + + /// Make a client with no routes for unit testing + fn make_client() -> Client { + let rocket = rocket::ignite(); + Client::new(rocket).expect("valid rocket instance") + } + // CORS options test #[test] @@ -1161,10 +1189,7 @@ mod tests { #[test] #[should_panic(expected = "CredentialsWithWildcardOrigin")] fn cors_validates_illegal_allow_credentials() { - let mut cors = make_cors_options(); - cors.allow_credentials = true; - cors.allowed_origins = AllOrSome::All; - cors.send_wildcard = true; + let cors = make_invalid_options(); cors.validate().unwrap(); } @@ -1416,8 +1441,447 @@ mod tests { ); } - // TODO: Preflight tests - // TODO: Actual requests tests + #[test] + fn preflight_validated_correctly() { + let options = make_cors_options(); + let client = make_client(); - // Origin all (wildcard + echoed with Vary). Origin Echo + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + + let request = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let result = validate(&options, request.inner()).expect("to not fail"); + let expected_result = ValidationResult::Preflight { + origin: FromStr::from_str("https://www.acme.com").unwrap(), + // Checks that only a subset of allowed headers are returned + // -- i.e. whatever is requested for + headers: Some(FromStr::from_str("Authorization").unwrap()), + }; + + assert_eq!(expected_result, result); + } + + #[test] + #[should_panic(expected = "CredentialsWithWildcardOrigin")] + fn preflight_validation_errors_on_invalid_options() { + let options = make_invalid_options(); + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + + let request = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let _ = validate(&options, request.inner()).unwrap(); + } + + #[test] + fn preflight_validation_allows_all_origin() { + let mut options = make_cors_options(); + options.allowed_origins = AllOrSome::All; + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.example.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + + let request = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let result = validate(&options, request.inner()).expect("to not fail"); + let expected_result = ValidationResult::Preflight { + origin: FromStr::from_str("https://www.example.com").unwrap(), + headers: Some(FromStr::from_str("Authorization").unwrap()), + }; + + assert_eq!(expected_result, result); + } + + #[test] + #[should_panic(expected = "OriginNotAllowed")] + fn preflight_validation_errors_on_invalid_origin() { + let options = make_cors_options(); + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.example.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + + let request = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let _ = validate(&options, request.inner()).unwrap(); + } + + #[test] + #[should_panic(expected = "MissingRequestMethod")] + fn preflight_validation_errors_on_missing_request_method() { + let options = make_cors_options(); + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + + let request = client.options("/").header(origin_header).header( + request_headers, + ); + + let _ = validate(&options, request.inner()).unwrap(); + } + + #[test] + #[should_panic(expected = "MethodNotAllowed")] + fn preflight_validation_errors_on_disallowed_method() { + let options = make_cors_options(); + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Post, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + + let request = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let _ = validate(&options, request.inner()).unwrap(); + } + + #[test] + #[should_panic(expected = "HeadersNotAllowed")] + fn preflight_validation_errors_on_disallowed_headers() { + let options = make_cors_options(); + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders(vec![ + FromStr::from_str("Authorization").unwrap(), + FromStr::from_str("X-NOT-ALLOWED").unwrap(), + ]); + let request_headers = Header::from(request_headers); + + let request = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let _ = validate(&options, request.inner()).unwrap(); + } + + #[test] + fn actual_request_validated_correctly() { + let options = make_cors_options(); + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let request = client.get("/").header(origin_header); + + let result = validate(&options, request.inner()).expect("to not fail"); + let expected_result = ValidationResult::Request { + origin: FromStr::from_str("https://www.acme.com").unwrap(), + }; + + assert_eq!(expected_result, result); + } + + #[test] + #[should_panic(expected = "CredentialsWithWildcardOrigin")] + fn actual_request_validation_errors_on_invalid_options() { + let options = make_invalid_options(); + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let request = client.get("/").header(origin_header); + + let _ = validate(&options, request.inner()).unwrap(); + } + + #[test] + fn actual_request_validation_allows_all_origin() { + let mut options = make_cors_options(); + options.allowed_origins = AllOrSome::All; + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.example.com").unwrap(), + ); + let request = client.get("/").header(origin_header); + + let result = validate(&options, request.inner()).expect("to not fail"); + let expected_result = ValidationResult::Request { + origin: FromStr::from_str("https://www.example.com").unwrap(), + }; + + assert_eq!(expected_result, result); + } + + #[test] + #[should_panic(expected = "OriginNotAllowed")] + fn actual_request_validation_errors_on_incorrect_origin() { + let options = make_cors_options(); + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.example.com").unwrap(), + ); + let request = client.get("/").header(origin_header); + + let _ = validate(&options, request.inner()).unwrap(); + } + + #[test] + fn non_cors_request_return_empty_response() { + let options = make_cors_options(); + let client = make_client(); + + let request = client.options("/"); + let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let expected_response = Response::new(); + assert_eq!(expected_response, response); + } + + #[test] + fn preflight_validated_and_built_correctly() { + let options = make_cors_options(); + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + + let request = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = validate_and_build(&options, request.inner()).expect("to not fail"); + + let expected_response = Response::new() + .origin("https://www.acme.com/", false) + .headers(&["Authorization"]) + .methods(&options.allowed_methods) + .credentials(options.allow_credentials) + .max_age(options.max_age); + + assert_eq!(expected_response, response); + } + + /// Tests that when All origins are allowed and send_wildcard disabled, the vary header is set + /// in the response and the requested origin is echoed + #[test] + fn preflight_all_origins_with_vary() { + let mut options = make_cors_options(); + options.allowed_origins = AllOrSome::All; + options.send_wildcard = false; + + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + + let request = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = validate_and_build(&options, request.inner()).expect("to not fail"); + + let expected_response = Response::new() + .origin("https://www.acme.com/", true) + .headers(&["Authorization"]) + .methods(&options.allowed_methods) + .credentials(options.allow_credentials) + .max_age(options.max_age); + + assert_eq!(expected_response, response); + } + + /// Tests that when All origins are allowed and send_wildcard enabled, the origin is set to "*" + #[test] + fn preflight_all_origins_with_wildcard() { + let mut options = make_cors_options(); + options.allowed_origins = AllOrSome::All; + options.send_wildcard = true; + options.allow_credentials = false; + + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + + let request = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = validate_and_build(&options, request.inner()).expect("to not fail"); + + let expected_response = Response::new() + .any() + .headers(&["Authorization"]) + .methods(&options.allowed_methods) + .credentials(options.allow_credentials) + .max_age(options.max_age); + + assert_eq!(expected_response, response); + } + + #[test] + fn actual_request_validated_and_built_correctly() { + let options = make_cors_options(); + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let request = client.get("/").header(origin_header); + + let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let expected_response = Response::new() + .origin("https://www.acme.com/", false) + .credentials(options.allow_credentials) + .exposed_headers(&["Content-Type", "X-Custom"]); + + assert_eq!(expected_response, response); + } + + #[test] + fn actual_request_all_origins_with_vary() { + let mut options = make_cors_options(); + options.allowed_origins = AllOrSome::All; + options.send_wildcard = false; + options.allow_credentials = false; + + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let request = client.get("/").header(origin_header); + + let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let expected_response = Response::new() + .origin("https://www.acme.com/", true) + .credentials(options.allow_credentials) + .exposed_headers(&["Content-Type", "X-Custom"]); + + assert_eq!(expected_response, response); + } + + #[test] + fn actual_request_all_origins_with_wildcard() { + let mut options = make_cors_options(); + options.allowed_origins = AllOrSome::All; + options.send_wildcard = true; + options.allow_credentials = false; + + let client = make_client(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let request = client.get("/").header(origin_header); + + let response = validate_and_build(&options, request.inner()).expect("to not fail"); + let expected_response = Response::new() + .any() + .credentials(options.allow_credentials) + .exposed_headers(&["Content-Type", "X-Custom"]); + + assert_eq!(expected_response, response); + } } diff --git a/tests/fairing.rs b/tests/fairing.rs index 7518ee1..9d579f0 100644 --- a/tests/fairing.rs +++ b/tests/fairing.rs @@ -31,7 +31,7 @@ fn make_cors_options() -> Cors { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: AllOrSome::Some( - ["Authorization"] + ["Authorization", "Accept"] .into_iter() .map(|s| s.to_string().into()) .collect(), diff --git a/tests/ad_hoc.rs b/tests/guard.rs similarity index 98% rename from tests/ad_hoc.rs rename to tests/guard.rs index 22cdfe1..091f09f 100644 --- a/tests/ad_hoc.rs +++ b/tests/guard.rs @@ -1,4 +1,4 @@ -//! This crate tests using rocket_cors using the "classic" ad-hoc per-route handling +//! This crate tests using rocket_cors using the per-route handling with request guard #![feature(plugin, custom_derive)] #![plugin(rocket_codegen)] @@ -68,7 +68,7 @@ fn make_cors_options() -> cors::Cors { allowed_origins: allowed_origins, allowed_methods: vec![Method::Get].into_iter().map(From::from).collect(), allowed_headers: cors::AllOrSome::Some( - ["Authorization"] + ["Authorization", "Accept"] .into_iter() .map(|s| s.to_string().into()) .collect(),