Extract headers integration tests

This commit is contained in:
Yong Wen Chua 2017-07-14 13:54:34 +08:00
parent 29952e182d
commit ca096ceb28
4 changed files with 124 additions and 54 deletions

View File

@ -261,6 +261,9 @@ impl<'a, 'r> FromRequest<'a, 'r> for Url {
/// The `Origin` request header used in CORS /// The `Origin` request header used in CORS
///
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that Origins are passed in correctly.
pub type Origin = Url; pub type Origin = Url;
/// The `Access-Control-Request-Method` request header /// The `Access-Control-Request-Method` request header
@ -948,22 +951,45 @@ mod tests {
use rocket; use rocket;
use rocket::local::Client; use rocket::local::Client;
use rocket::http::Method; use rocket::http::Method;
use rocket::http::{Header, Status}; use rocket::http::Status;
use super::*; use super::*;
/// Make a client with no routes for unit testing
fn make_client() -> Client {
let rocket = rocket::ignite();
Client::new(rocket).expect("valid rocket instance")
}
#[test] #[test]
fn origin_header_conversion() { fn origin_header_conversion() {
let url = "https://foo.bar.xyz"; let url = "https://foo.bar.xyz";
let _ = not_err!(Origin::from_str(url)); let parsed = not_err!(Origin::from_str(url));
let expected = not_err!(Url::from_str(url));
assert_eq!(parsed, expected);
let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used
let _ = not_err!(Origin::from_str(url)); let parsed = not_err!(Origin::from_str(url));
let expected = not_err!(Url::from_str(url));
assert_eq!(parsed, expected);
let url = "invalid_url"; let url = "invalid_url";
let _ = is_err!(Origin::from_str(url)); let _ = is_err!(Origin::from_str(url));
} }
#[test]
fn origin_header_parsing() {
let client = make_client();
let mut request = client.get("/");
let origin = hyper::header::Origin::new("https", "www.example.com", None);
request.add_header(origin);
let outcome: request::Outcome<Origin, Error> = FromRequest::from_request(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
assert_eq!("https://www.example.com/", parsed_header.as_str());
}
#[test] #[test]
fn request_method_conversion() { fn request_method_conversion() {
let method = "POST"; let method = "POST";
@ -978,6 +1004,20 @@ mod tests {
let _ = is_err!(AccessControlRequestMethod::from_str(method)); let _ = is_err!(AccessControlRequestMethod::from_str(method));
} }
#[test]
fn request_method_parsing() {
let client = make_client();
let mut request = client.get("/");
let method = hyper::header::AccessControlRequestMethod(hyper::method::Method::Get);
request.add_header(method);
let outcome: request::Outcome<AccessControlRequestMethod, Error> =
FromRequest::from_request(request.inner());
let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
let AccessControlRequestMethod(parsed_method) = parsed_header;
assert_eq!("GET", parsed_method.as_str());
}
#[test] #[test]
fn request_headers_conversion() { fn request_headers_conversion() {
let headers = ["foo", "bar", "baz"]; let headers = ["foo", "bar", "baz"];
@ -988,53 +1028,27 @@ mod tests {
assert_eq!(actual_headers, expected_headers); assert_eq!(actual_headers, expected_headers);
} }
#[get("/request_headers")]
#[allow(needless_pass_by_value)]
fn request_headers(
origin: Origin,
method: AccessControlRequestMethod,
headers: AccessControlRequestHeaders,
) -> String {
let AccessControlRequestMethod(method) = method;
let AccessControlRequestHeaders(headers) = headers;
let mut headers = headers
.iter()
.map(|s| s.deref().to_string())
.collect::<Vec<String>>();
headers.sort();
format!("{}\n{}\n{}", origin, method, headers.join(", "))
}
/// Tests that all the headers are parsed correcly in a HTTP request
#[test] #[test]
fn request_headers_round_trip_smoke_test() { fn request_headers_parsing() {
let rocket = rocket::ignite().mount("/", routes![request_headers]); let client = make_client();
let client = not_err!(Client::new(rocket)); let mut request = client.get("/");
let headers = hyper::header::AccessControlRequestHeaders(vec![
let origin_header = Header::from(not_err!(
hyper::header::Origin::from_str("https://foo.bar.xyz")
));
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("accept-language").unwrap(), FromStr::from_str("accept-language").unwrap(),
FromStr::from_str("X-Ping").unwrap(), FromStr::from_str("date").unwrap(),
]); ]);
let request_headers = Header::from(request_headers); request.add_header(headers);
let req = client let outcome: request::Outcome<AccessControlRequestHeaders, Error> =
.get("/request_headers") FromRequest::from_request(request.inner());
.header(origin_header)
.header(method_header)
.header(request_headers);
let mut response = req.dispatch();
assert_eq!(Status::Ok, response.status()); let parsed_header = assert_matches!(outcome, Outcome::Success(s), s);
let body_str = not_none!(response.body().and_then(|body| body.into_string())); let AccessControlRequestHeaders(parsed_headers) = parsed_header;
let expected_body = r#"https://foo.bar.xyz/ let mut parsed_headers: Vec<String> =
GET parsed_headers.iter().map(|s| s.to_string()).collect();
X-Ping, accept-language"#; parsed_headers.sort();
assert_eq!(expected_body, body_str); assert_eq!(
vec!["accept-language".to_string(), "date".to_string()],
parsed_headers
);
} }
#[get("/any")] #[get("/any")]

View File

@ -12,13 +12,6 @@ macro_rules! is_err {
}) })
} }
macro_rules! not_none {
($e:expr) => (match $e {
Some(e) => e,
None => panic!("{} failed with None", stringify!($e)),
})
}
macro_rules! assert_matches { macro_rules! assert_matches {
($e: expr, $p: pat) => (assert_matches!($e, $p, ())); ($e: expr, $p: pat) => (assert_matches!($e, $p, ()));
($e: expr, $p: pat, $f: expr) => (match $e { ($e: expr, $p: pat, $f: expr) => (match $e {

63
tests/headers.rs Normal file
View File

@ -0,0 +1,63 @@
//! This crate tests that all the request headers are parsed correctly in the round trip
#![feature(plugin, custom_derive)]
#![plugin(rocket_codegen)]
extern crate hyper;
extern crate rocket;
extern crate rocket_cors;
use std::ops::Deref;
use std::str::FromStr;
use rocket::local::Client;
use rocket::http::{Header, Status};
use rocket_cors::*;
#[get("/request_headers")]
fn request_headers(
origin: Origin,
method: AccessControlRequestMethod,
headers: AccessControlRequestHeaders,
) -> String {
let AccessControlRequestMethod(method) = method;
let AccessControlRequestHeaders(headers) = headers;
let mut headers = headers
.iter()
.map(|s| s.deref().to_string())
.collect::<Vec<String>>();
headers.sort();
format!("{}\n{}\n{}", origin, method, headers.join(", "))
}
/// Tests that all the headers are parsed correcly in a HTTP request
#[test]
fn request_headers_round_trip_smoke_test() {
let rocket = rocket::ignite().mount("/", routes![request_headers]);
let client = Client::new(rocket).expect("A valid Rocket client");
let origin_header = Header::from(
hyper::header::Origin::from_str("https://foo.bar.xyz").unwrap(),
);
let method_header = Header::from(hyper::header::AccessControlRequestMethod(
hyper::method::Method::Get,
));
let request_headers = hyper::header::AccessControlRequestHeaders(vec![
FromStr::from_str("accept-language").unwrap(),
FromStr::from_str("X-Ping").unwrap(),
]);
let request_headers = Header::from(request_headers);
let req = client
.get("/request_headers")
.header(origin_header)
.header(method_header)
.header(request_headers);
let mut response = req.dispatch();
assert_eq!(Status::Ok, response.status());
let body_str = response.body().and_then(|body| body.into_string()).expect(
"Non-empty body",
);
let expected_body = r#"https://foo.bar.xyz/
GET
X-Ping, accept-language"#;
assert_eq!(expected_body, body_str);
}

View File

@ -1,4 +1,4 @@
//! This crate tests using rocket_cors using the "classic"" per-route handling //! This crate tests using rocket_cors using the "classic" per-route handling
#![feature(plugin, custom_derive)] #![feature(plugin, custom_derive)]
#![plugin(rocket_codegen)] #![plugin(rocket_codegen)]