From ca096ceb281ce3f4282b7fa38e0c00c843e6cc3a Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Fri, 14 Jul 2017 13:54:34 +0800 Subject: [PATCH] Extract headers integration tests --- src/lib.rs | 106 +++++++++++++++++++++++++-------------------- src/test_macros.rs | 7 --- tests/headers.rs | 63 +++++++++++++++++++++++++++ tests/routes.rs | 2 +- 4 files changed, 124 insertions(+), 54 deletions(-) create mode 100644 tests/headers.rs diff --git a/src/lib.rs b/src/lib.rs index 515430f..b2cb4f6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -261,6 +261,9 @@ impl<'a, 'r> FromRequest<'a, 'r> for Url { /// The `Origin` request header used in CORS +/// +/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards) +/// to ensure that Origins are passed in correctly. pub type Origin = Url; /// The `Access-Control-Request-Method` request header @@ -948,22 +951,45 @@ mod tests { use rocket; use rocket::local::Client; use rocket::http::Method; - use rocket::http::{Header, Status}; + use rocket::http::Status; use super::*; + /// Make a client with no routes for unit testing + fn make_client() -> Client { + let rocket = rocket::ignite(); + Client::new(rocket).expect("valid rocket instance") + } + #[test] fn origin_header_conversion() { let url = "https://foo.bar.xyz"; - let _ = not_err!(Origin::from_str(url)); + let parsed = not_err!(Origin::from_str(url)); + let expected = not_err!(Url::from_str(url)); + assert_eq!(parsed, expected); let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used - let _ = not_err!(Origin::from_str(url)); + let parsed = not_err!(Origin::from_str(url)); + let expected = not_err!(Url::from_str(url)); + assert_eq!(parsed, expected); let url = "invalid_url"; let _ = is_err!(Origin::from_str(url)); } + #[test] + fn origin_header_parsing() { + let client = make_client(); + let mut request = client.get("/"); + + let origin = hyper::header::Origin::new("https", "www.example.com", None); + request.add_header(origin); + + let outcome: request::Outcome = FromRequest::from_request(request.inner()); + let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); + assert_eq!("https://www.example.com/", parsed_header.as_str()); + } + #[test] fn request_method_conversion() { let method = "POST"; @@ -978,6 +1004,20 @@ mod tests { let _ = is_err!(AccessControlRequestMethod::from_str(method)); } + #[test] + fn request_method_parsing() { + let client = make_client(); + let mut request = client.get("/"); + let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get); + request.add_header(method); + let outcome: request::Outcome = + FromRequest::from_request(request.inner()); + + let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); + let AccessControlRequestMethod(parsed_method) = parsed_header; + assert_eq!("GET", parsed_method.as_str()); + } + #[test] fn request_headers_conversion() { let headers = ["foo", "bar", "baz"]; @@ -988,53 +1028,27 @@ mod tests { assert_eq!(actual_headers, expected_headers); } - #[get("/request_headers")] - #[allow(needless_pass_by_value)] - fn request_headers( - origin: Origin, - method: AccessControlRequestMethod, - headers: AccessControlRequestHeaders, - ) -> String { - let AccessControlRequestMethod(method) = method; - let AccessControlRequestHeaders(headers) = headers; - let mut headers = headers - .iter() - .map(|s| s.deref().to_string()) - .collect::>(); - headers.sort(); - format!("{}\n{}\n{}", origin, method, headers.join(", ")) - } - - /// Tests that all the headers are parsed correcly in a HTTP request #[test] - fn request_headers_round_trip_smoke_test() { - let rocket = rocket::ignite().mount("/", routes![request_headers]); - let client = not_err!(Client::new(rocket)); - - let origin_header = Header::from(not_err!( - hyper::header::Origin::from_str("https://foo.bar.xyz") - )); - let method_header = Header::from(hyper::header::AccessControlRequestMethod( - hyper::method::Method::Get, - )); - let request_headers = hyper::header::AccessControlRequestHeaders(vec![ + fn request_headers_parsing() { + let client = make_client(); + let mut request = client.get("/"); + let headers = hyper::header::AccessControlRequestHeaders(vec![ FromStr::from_str("accept-language").unwrap(), - FromStr::from_str("X-Ping").unwrap(), + FromStr::from_str("date").unwrap(), ]); - let request_headers = Header::from(request_headers); - let req = client - .get("/request_headers") - .header(origin_header) - .header(method_header) - .header(request_headers); - let mut response = req.dispatch(); + request.add_header(headers); + let outcome: request::Outcome = + FromRequest::from_request(request.inner()); - assert_eq!(Status::Ok, response.status()); - let body_str = not_none!(response.body().and_then(|body| body.into_string())); - let expected_body = r#"https://foo.bar.xyz/ -GET -X-Ping, accept-language"#; - assert_eq!(expected_body, body_str); + let parsed_header = assert_matches!(outcome, Outcome::Success(s), s); + let AccessControlRequestHeaders(parsed_headers) = parsed_header; + let mut parsed_headers: Vec = + parsed_headers.iter().map(|s| s.to_string()).collect(); + parsed_headers.sort(); + assert_eq!( + vec!["accept-language".to_string(), "date".to_string()], + parsed_headers + ); } #[get("/any")] diff --git a/src/test_macros.rs b/src/test_macros.rs index ff19c9f..43c3b87 100644 --- a/src/test_macros.rs +++ b/src/test_macros.rs @@ -12,13 +12,6 @@ macro_rules! is_err { }) } -macro_rules! not_none { - ($e:expr) => (match $e { - Some(e) => e, - None => panic!("{} failed with None", stringify!($e)), - }) -} - macro_rules! assert_matches { ($e: expr, $p: pat) => (assert_matches!($e, $p, ())); ($e: expr, $p: pat, $f: expr) => (match $e { diff --git a/tests/headers.rs b/tests/headers.rs new file mode 100644 index 0000000..d56f9e2 --- /dev/null +++ b/tests/headers.rs @@ -0,0 +1,63 @@ +//! This crate tests that all the request headers are parsed correctly in the round trip +#![feature(plugin, custom_derive)] +#![plugin(rocket_codegen)] +extern crate hyper; +extern crate rocket; +extern crate rocket_cors; + +use std::ops::Deref; +use std::str::FromStr; + +use rocket::local::Client; +use rocket::http::{Header, Status}; +use rocket_cors::*; + +#[get("/request_headers")] +fn request_headers( + origin: Origin, + method: AccessControlRequestMethod, + headers: AccessControlRequestHeaders, +) -> String { + let AccessControlRequestMethod(method) = method; + let AccessControlRequestHeaders(headers) = headers; + let mut headers = headers + .iter() + .map(|s| s.deref().to_string()) + .collect::>(); + headers.sort(); + format!("{}\n{}\n{}", origin, method, headers.join(", ")) +} + +/// Tests that all the headers are parsed correcly in a HTTP request +#[test] +fn request_headers_round_trip_smoke_test() { + let rocket = rocket::ignite().mount("/", routes![request_headers]); + let client = Client::new(rocket).expect("A valid Rocket client"); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://foo.bar.xyz").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders(vec![ + FromStr::from_str("accept-language").unwrap(), + FromStr::from_str("X-Ping").unwrap(), + ]); + let request_headers = Header::from(request_headers); + let req = client + .get("/request_headers") + .header(origin_header) + .header(method_header) + .header(request_headers); + let mut response = req.dispatch(); + + assert_eq!(Status::Ok, response.status()); + let body_str = response.body().and_then(|body| body.into_string()).expect( + "Non-empty body", + ); + let expected_body = r#"https://foo.bar.xyz/ +GET +X-Ping, accept-language"#; + assert_eq!(expected_body, body_str); +} diff --git a/tests/routes.rs b/tests/routes.rs index 6b2b346..1617572 100644 --- a/tests/routes.rs +++ b/tests/routes.rs @@ -1,4 +1,4 @@ -//! This crate tests using rocket_cors using the "classic"" per-route handling +//! This crate tests using rocket_cors using the "classic" per-route handling #![feature(plugin, custom_derive)] #![plugin(rocket_codegen)]