From abda99d71fdef8e5f772fbaacfdbc0d9ebdd70e8 Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Sun, 16 Jul 2017 15:39:12 +0800 Subject: [PATCH] Errors are not handled properly --- src/lib.rs | 41 ++++++++++++++++++++++++++++++++++++++--- tests/fairings.rs | 12 ++++++------ tests/headers.rs | 2 +- tests/routes.rs | 12 ++++++------ 4 files changed, 51 insertions(+), 16 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d07d7ee..fb29fb4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -433,7 +433,42 @@ impl fairing::Fairing for Cors { } } - fn on_response(&self, _: &Request, _: &mut rocket::Response) {} + fn on_response(&self, request: &Request, response: &mut rocket::Response) { + use rocket::response::Responder; + + // Build and merge CORS response + match build_cors_response(self, request, response) { + Err(err) => { + // CORS error -- overwrite the original response + let error_response = match err.respond_to(request) { + Err(err) => { + unreachable!( + "Should not happen! The Error responder does not Err: {:?}", + err + ) + } + Ok(error_response) => error_response, + }; + response.merge(error_response) + } + Ok(()) => { + // If this was an OPTIONS request and no route can be found, we should turn this + // into a HTTP 204 with no content body. + // This allows the user to not have to specify an OPTIONS route for everything. + // + // TODO: Is there anyway we can make this smarter? Only modify status codes for + // requests where an actual route exist? + if has_allow_origin(&response) && request.method() == Method::Options && + request.route().is_none() + { + response.set_status(Status::NoContent); + let _ = response.take_body(); + } + } + }; + + + } } /// A CORS [Responder](https://rocket.rs/guide/responses/#responder) @@ -753,7 +788,7 @@ pub fn respond<'a, 'r: 'a, R: response::Responder<'r>>( fn build_cors_response( options: &Cors, request: &Request, - mut response: &mut rocket::Response, + response: &mut rocket::Response, ) -> Result<(), Error> { // Existing CORS response? if has_allow_origin(response) { @@ -781,7 +816,7 @@ fn build_cors_response( _ => actual_request(options, origin), }?; - cors_response.merge(&mut response); + cors_response.merge(response); Ok(()) } diff --git a/tests/fairings.rs b/tests/fairings.rs index 255e9c6..4fdfd68 100644 --- a/tests/fairings.rs +++ b/tests/fairings.rs @@ -64,7 +64,7 @@ fn smoke_test() { .header(request_headers); let response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); // "Actual" request let origin_header = Header::from( @@ -74,7 +74,7 @@ fn smoke_test() { let req = client.get("/").header(origin_header).header(authorization); let mut response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); @@ -101,7 +101,7 @@ fn cors_options_check() { .header(request_headers); let response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); } #[test] @@ -116,7 +116,7 @@ fn cors_get_check() { let mut response = req.dispatch(); println!("{:?}", response); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); } @@ -130,7 +130,7 @@ fn cors_get_no_origin() { let req = client.get("/").header(authorization); let mut response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); } @@ -175,7 +175,7 @@ fn cors_options_missing_origin() { ); let response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); } #[test] diff --git a/tests/headers.rs b/tests/headers.rs index 9be98da..ae5212a 100644 --- a/tests/headers.rs +++ b/tests/headers.rs @@ -52,7 +52,7 @@ fn request_headers_round_trip_smoke_test() { .header(request_headers); let mut response = req.dispatch(); - assert_eq!(Status::Ok, response.status()); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()).expect( "Non-empty body", ); diff --git a/tests/routes.rs b/tests/routes.rs index 57c4bea..997382d 100644 --- a/tests/routes.rs +++ b/tests/routes.rs @@ -81,7 +81,7 @@ fn smoke_test() { .header(request_headers); let response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); // "Actual" request let origin_header = Header::from( @@ -91,7 +91,7 @@ fn smoke_test() { let req = client.get("/").header(origin_header).header(authorization); let mut response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); @@ -121,7 +121,7 @@ fn cors_options_check() { .header(request_headers); let response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); } #[test] @@ -139,7 +139,7 @@ fn cors_get_check() { let mut response = req.dispatch(); println!("{:?}", response); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); } @@ -156,7 +156,7 @@ fn cors_get_no_origin() { let req = client.get("/").header(authorization); let mut response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); let body_str = response.body().and_then(|body| body.into_string()); assert_eq!(body_str, Some("Hello CORS".to_string())); } @@ -207,7 +207,7 @@ fn cors_options_missing_origin() { ); let response = req.dispatch(); - assert_eq!(response.status(), Status::Ok); + assert!(response.status().class().is_success()); } #[test]