From 88640afaae3a876dada951f1a1992ee2df3b5ac0 Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Sun, 16 Jul 2017 14:41:33 +0800 Subject: [PATCH] Skeleton --- src/lib.rs | 23 +++++ tests/fairings.rs | 240 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+) create mode 100644 tests/fairings.rs diff --git a/src/lib.rs b/src/lib.rs index 8ec8530..eec9df0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,6 +124,7 @@ use std::ops::Deref; use std::str::FromStr; use rocket::{Outcome, State}; +use rocket::fairing; use rocket::http::{Method, Status}; use rocket::request::{Request, FromRequest}; use rocket::response; @@ -428,6 +429,28 @@ impl Cors { } } +impl fairing::Fairing for Cors { + fn info(&self) -> fairing::Info { + fairing::Info { + name: "CORS", + kind: fairing::Kind::Attach | fairing::Kind::Response, + } + } + + fn on_attach(&self, rocket: rocket::Rocket) -> Result { + match self.validate() { + Ok(()) => Ok(rocket), + Err(e) => { + error_!("Error attaching CORS fairing: {}", e); + Err(rocket) + } + } + } + + fn on_response(&self, request: &Request, response: &mut rocket::Response) { + } +} + /// A CORS [Responder](https://rocket.rs/guide/responses/#responder) /// which will inspect the incoming requests and respond accordingly. /// diff --git a/tests/fairings.rs b/tests/fairings.rs new file mode 100644 index 0000000..549221f --- /dev/null +++ b/tests/fairings.rs @@ -0,0 +1,240 @@ +//! This crate tests using rocket_cors using Fairings + +#![feature(plugin, custom_derive)] +#![plugin(rocket_codegen)] +extern crate hyper; +extern crate rocket; +extern crate rocket_cors; + +use std::str::FromStr; + +use rocket::http::Method; +use rocket::http::{Header, Status}; +use rocket::local::Client; +use rocket_cors::*; + +#[get("/")] +fn cors<'a>() -> &'a str { + "Hello CORS" +} + +fn make_cors_options() -> Cors { + let (allowed_origins, failed_origins) = AllOrSome::new_from_str_list(&["https://www.acme.com"]); + assert!(failed_origins.is_empty()); + + Cors { + allowed_origins: allowed_origins, + allowed_methods: [Method::Get].iter().cloned().collect(), + allowed_headers: AllOrSome::Some( + ["Authorization"] + .into_iter() + .map(|s| s.to_string().into()) + .collect(), + ), + allow_credentials: true, + ..Default::default() + } +} + +fn rocket() -> rocket::Rocket { + rocket::ignite() + .mount("/", routes![cors]) + .attach(make_cors_options()) +} + +#[test] +fn smoke_test() { + let client = Client::new(rocket()).unwrap(); + + // `Options` pre-flight checks + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Ok); + + // "Actual" request + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let mut response = req.dispatch(); + assert_eq!(response.status(), Status::Ok); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); + +} + +#[test] +fn cors_options_check() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Ok); +} + +#[test] +fn cors_get_check() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let mut response = req.dispatch(); + println!("{:?}", response); + assert_eq!(response.status(), Status::Ok); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); +} + +/// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) +#[test] +fn cors_get_no_origin() { + let client = Client::new(rocket()).unwrap(); + + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(authorization); + + let mut response = req.dispatch(); + assert_eq!(response.status(), Status::Ok); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); +} + +#[test] +fn cors_options_bad_origin() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +} + +#[test] +fn cors_options_missing_origin() { + let client = Client::new(rocket()).unwrap(); + + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client.options("/").header(method_header).header( + request_headers, + ); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Ok); +} + +#[test] +fn cors_options_bad_request_method() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Post, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +} + +#[test] +fn cors_options_bad_request_header() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.acme.com").unwrap(), + ); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = + hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +} + +#[test] +fn cors_get_bad_origin() { + let client = Client::new(rocket()).unwrap(); + + let origin_header = Header::from( + hyper::header::Origin::from_str("https://www.bad-origin.com").unwrap(), + ); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); +}