diff --git a/tests/ad_hoc.rs b/tests/ad_hoc.rs index ac49944..47aeccf 100644 --- a/tests/ad_hoc.rs +++ b/tests/ad_hoc.rs @@ -23,6 +23,11 @@ fn cors(cors: cors::Guard) -> cors::Responder<&str> { cors.responder("Hello CORS") } +#[get("/panic")] +fn panicking_route(_cors: cors::Guard) { + panic!("This route will panic"); +} + // The following routes tests that the routes can be compiled with ad-hoc CORS Response/Responders /// Using a `Response` instead of a `Responder` @@ -73,26 +78,15 @@ fn make_cors_options() -> cors::Cors { } } +fn make_rocket() -> rocket::Rocket { + rocket::ignite() + .mount("/", routes![cors, cors_options, panicking_route]) + .manage(make_cors_options()) +} + #[test] fn smoke_test() { - let (allowed_origins, failed_origins) = - cors::AllOrSome::new_from_str_list(&["https://www.acme.com"]); - assert!(failed_origins.is_empty()); - let cors_options = cors::Cors { - allowed_origins: allowed_origins, - allowed_methods: [Method::Get].iter().cloned().collect(), - allowed_headers: cors::AllOrSome::Some( - ["Authorization"] - .iter() - .map(|s| s.to_string().into()) - .collect(), - ), - allow_credentials: true, - ..Default::default() - }; - let rocket = rocket::ignite() - .mount("/", routes![cors, cors_options]) - .manage(cors_options); + let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); // `Options` pre-flight checks @@ -131,9 +125,7 @@ fn smoke_test() { #[test] fn cors_options_check() { - let rocket = rocket::ignite() - .mount("/", routes![cors, cors_options]) - .manage(make_cors_options()); + let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); let origin_header = Header::from( @@ -158,9 +150,7 @@ fn cors_options_check() { #[test] fn cors_get_check() { - let rocket = rocket::ignite() - .mount("/", routes![cors, cors_options]) - .manage(make_cors_options()); + let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); let origin_header = Header::from( @@ -179,9 +169,7 @@ fn cors_get_check() { /// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) #[test] fn cors_get_no_origin() { - let rocket = rocket::ignite() - .mount("/", routes![cors, cors_options]) - .manage(make_cors_options()); + let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); let authorization = Header::new("Authorization", "let me in"); @@ -195,9 +183,7 @@ fn cors_get_no_origin() { #[test] fn cors_options_bad_origin() { - let rocket = rocket::ignite() - .mount("/", routes![cors, cors_options]) - .manage(make_cors_options()); + let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); let origin_header = Header::from( @@ -222,9 +208,7 @@ fn cors_options_bad_origin() { #[test] fn cors_options_missing_origin() { - let rocket = rocket::ignite() - .mount("/", routes![cors, cors_options]) - .manage(make_cors_options()); + let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); let method_header = Header::from(hyper::header::AccessControlRequestMethod( @@ -244,9 +228,7 @@ fn cors_options_missing_origin() { #[test] fn cors_options_bad_request_method() { - let rocket = rocket::ignite() - .mount("/", routes![cors, cors_options]) - .manage(make_cors_options()); + let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); let origin_header = Header::from( @@ -271,9 +253,7 @@ fn cors_options_bad_request_method() { #[test] fn cors_options_bad_request_header() { - let rocket = rocket::ignite() - .mount("/", routes![cors, cors_options]) - .manage(make_cors_options()); + let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); let origin_header = Header::from( @@ -297,9 +277,25 @@ fn cors_options_bad_request_header() { #[test] fn cors_get_bad_origin() { - let rocket = rocket::ignite() - .mount("/", routes![cors, cors_options]) - .manage(make_cors_options()); + let rocket = make_rocket(); + let client = Client::new(rocket).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +} + +/// This test ensures that on a failing CORS request, the route (along with its side effects) +/// should never be executed. +/// The route used will panic if executed +#[test] +fn routes_failing_checks_are_not_executed() { + let rocket = make_rocket(); let client = Client::new(rocket).unwrap(); let origin_header = Header::from( diff --git a/tests/fairings.rs b/tests/fairings.rs index 4fdfd68..641f4ef 100644 --- a/tests/fairings.rs +++ b/tests/fairings.rs @@ -18,6 +18,11 @@ fn cors<'a>() -> &'a str { "Hello CORS" } +#[get("/panic")] +fn panicking_route() { + panic!("This route will panic"); +} + fn make_cors_options() -> Cors { let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); assert!(failed_origins.is_empty()); @@ -37,7 +42,7 @@ fn make_cors_options() -> Cors { } fn rocket() -> rocket::Rocket { - rocket::ignite().mount("/", routes![cors]).attach( + rocket::ignite().mount("/", routes![cors, panicking_route]).attach( make_cors_options(), ) } @@ -238,3 +243,30 @@ fn cors_get_bad_origin() { let response = req.dispatch(); assert_eq!(response.status(), Status::Forbidden); } + +/// This test ensures that on a failing CORS request, the route (along with its side effects) +/// should never be executed. +/// The route used will panic if executed +#[test] +fn routes_failing_checks_are_not_executed() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/panic") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +}