From e59911286711ff5e2416ed915e8d407565de5968 Mon Sep 17 00:00:00 2001 From: Yong Wen Chua Date: Thu, 13 Jul 2017 15:37:15 +0800 Subject: [PATCH] First commit --- .gitignore | 3 + Cargo.toml | 33 ++ LICENSE | 201 +++++++++ build.rs | 86 ++++ rust-toolchain | 1 + src/lib.rs | 1055 ++++++++++++++++++++++++++++++++++++++++++++ src/test_macros.rs | 33 ++ 7 files changed, 1412 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 LICENSE create mode 100644 build.rs create mode 100644 rust-toolchain create mode 100644 src/lib.rs create mode 100644 src/test_macros.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4308d82 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +target/ +**/*.rs.bk +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..1ff5259 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "rocket_cors" +version = "0.1.0" +license = "Apache-2.0" +authors = ["Yong Wen Chua "] +build = "build.rs" +description = "Cross-origin resource sharing (CORS) for Rocket.rs applications" +homepage = "https://github.com/lawliet89/rocket_cors" +repository = "https://github.com/lawliet89/rocket_cors" +documentation = "https://docs.rs/rocket_cors/" +keywords = ["rocket", "cors"] +categories = ["web-programming"] + +[badges] +travis-ci = { repository = "lawliet89/rocket_cors" } + +[dependencies] +log = "0.3" +rocket = { git = "https://github.com/SergioBenitez/Rocket", rev = "aa51fe0" } +serde = "1.0" +serde_derive = "1.0" +unicase="1.4" +url = "1.5.1" +url_serde = "0.2.0" + +[build-dependencies] +ansi_term = "0.9" +version_check = "0.1" + +[dev-dependencies] +hyper = "0.10" +rocket_codegen = { git = "https://github.com/SergioBenitez/Rocket", rev = "aa51fe0" } +serde_json = "1.0" diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8dada3e --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..729aec6 --- /dev/null +++ b/build.rs @@ -0,0 +1,86 @@ +//! This tiny build script ensures that the crate is not compiled with an +//! incompatible version of rust. +//! This scipt was stolen from `rocket_codegen`. + +extern crate ansi_term; +extern crate version_check; + +use ansi_term::Color::{Red, Yellow, Blue, White}; +use version_check::{is_nightly, is_min_version, is_min_date}; + +// Specifies the minimum nightly version that is targetted +// Note that sometimes the `rustc` date might be older than the nightly version, +// usually one day older +const MIN_DATE: &'static str = "2017-07-12"; +const MIN_VERSION: &'static str = "1.20.0-nightly"; + +// Convenience macro for writing to stderr. +macro_rules! printerr { + ($($arg:tt)*) => ({ + use std::io::prelude::*; + write!(&mut ::std::io::stderr(), "{}\n", format_args!($($arg)*)) + .expect("Failed to write to stderr.") + }) +} + +fn main() { + let ok_nightly = is_nightly(); + let ok_version = is_min_version(MIN_VERSION); + let ok_date = is_min_date(MIN_DATE); + + let print_version_err = |version: &str, date: &str| { + printerr!( + "{} {}. {} {}.", + White.paint("Installed version is:"), + Yellow.paint(format!("{} ({})", version, date)), + White.paint("Minimum required:"), + Yellow.paint(format!("{} ({})", MIN_VERSION, MIN_DATE)) + ); + }; + + match (ok_nightly, ok_version, ok_date) { + (Some(is_nightly), Some((ok_version, version)), Some((ok_date, date))) => { + if !is_nightly { + printerr!( + "{} {}", + Red.bold().paint("Error:"), + White.paint("rowdy requires a nightly version of Rust.") + ); + print_version_err(&*version, &*date); + printerr!( + "{}{}{}", + Blue.paint("See the README ("), + White.paint("https://github.com/lawliet89/rowdy"), + Blue.paint(") for more information.") + ); + panic!("Aborting compilation due to incompatible compiler.") + } + + if !ok_version || !ok_date { + printerr!( + "{} {}", + Red.bold().paint("Error:"), + White.paint("rowdy requires a more recent version of rustc.") + ); + printerr!( + "{}{}{}", + Blue.paint("Use `"), + White.paint("rustup update"), + Blue.paint("` or your preferred method to update Rust.") + ); + print_version_err(&*version, &*date); + panic!("Aborting compilation due to incompatible compiler.") + } + } + _ => { + println!( + "cargo:warning={}", + "rowdy was unable to check rustc compatibility." + ); + println!( + "cargo:warning={}", + "Build may fail due to incompatible rustc version." + ); + } + } +} diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 0000000..8dd0f7b --- /dev/null +++ b/rust-toolchain @@ -0,0 +1 @@ +nightly-2017-07-13 diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..ffa812a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,1055 @@ +//! Cross-origin resource sharing (CORS) for Rocket.rs applications +//! +//! Rocket (as of v0.2) does not have middleware support. Support for it is (supposedly) +//! on the way. In the mean time, we adopt an +//! [example implementation](https://github.com/SergioBenitez/Rocket/pull/141) to nest +//! `Responders` to acheive the same effect in the short run. +//! +//! # Examples +//! ``` +//! #![feature(plugin, custom_derive)] +//! #![plugin(rocket_codegen)] +//! extern crate hyper; +//! extern crate rocket; +//! extern crate rocket_cors; +//! +//! use std::str::FromStr; +//! +//! use rocket::State; +//! use rocket::http::Method::*; +//! use rocket::http::{Header, Status}; +//! use rocket::local::Client; +//! use rocket_cors::*; +//! +//! #[options("/")] +//! fn cors_options(origin: Option, +//! method: AccessControlRequestMethod, +//! headers: AccessControlRequestHeaders, +//! options: State) +//! -> Result, Error> { +//! options.preflight(origin, &method, Some(&headers)) +//! } +//! +//! #[get("/")] +//! fn cors(origin: Option, options: State) +//! -> Result, Error> +//! { +//! options.respond("Hello CORS", origin) +//! } +//! +//! # fn main() { +//! let (allowed_origins, failed_origins) = +//! AllowedOrigins::new_from_str_list(&["https://www.acme.com"]); +//! assert!(failed_origins.is_empty()); +//! let cors_options = rocket_cors::Options { +//! allowed_origins: allowed_origins, +//! allowed_methods: [Get].iter().cloned().collect(), +//! allowed_headers: ["Authorization"].iter().map(|s| s.to_string().into()).collect(), +//! allow_credentials: true, +//! ..Default::default() +//! }; +//! let rocket = rocket::ignite().mount("/", routes![cors, cors_options]).manage(cors_options); +//! let client = Client::new(rocket).unwrap(); +//! +//! // `Options` pre-flight checks +//! let origin_header = +//! Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); +//! let method_header = +//! Header::from(hyper::header::AccessControlRequestMethod(hyper::method::Method::Get)); +//! let request_headers = +//! hyper::header::AccessControlRequestHeaders( +//! vec![FromStr::from_str("Authorization").unwrap()]); +//! let request_headers = Header::from(request_headers); +//! let req = +//! client.options("/").header(origin_header).header(method_header).header(request_headers); +//! +//! let response = req.dispatch(); +//! assert_eq!(response.status(), Status::Ok); +//! +//! // "Actual" request +//! let origin_header = +//! Header::from(hyper::header::Origin::from_str("https://www.acme.com").unwrap()); +//! let authorization = Header::new("Authorization", "let me in"); +//! let req = client.get("/").header(origin_header).header(authorization); +//! +//! let mut response = req.dispatch(); +//! assert_eq!(response.status(), Status::Ok); +//! let body_str = response.body().and_then(|body| body.into_string()); +//! assert_eq!(body_str, Some("Hello CORS".to_string())); +//! # } +//! ``` + + +#![allow( + legacy_directory_ownership, + missing_copy_implementations, + missing_debug_implementations, + unknown_lints, + unsafe_code, +)] +#![deny( + const_err, + dead_code, + deprecated, + exceeding_bitshifts, + fat_ptr_transmutes, + improper_ctypes, + missing_docs, + mutable_transmutes, + no_mangle_const_items, + non_camel_case_types, + non_shorthand_field_patterns, + non_upper_case_globals, + overflowing_literals, + path_statements, + plugin_as_library, + private_no_mangle_fns, + private_no_mangle_statics, + stable_features, + trivial_casts, + trivial_numeric_casts, + unconditional_recursion, + unknown_crate_types, + unreachable_code, + unused_allocation, + unused_assignments, + unused_attributes, + unused_comparisons, + unused_extern_crates, + unused_features, + unused_imports, + unused_import_braces, + unused_qualifications, + unused_must_use, + unused_mut, + unused_parens, + unused_results, + unused_unsafe, + unused_variables, + variant_size_differences, + warnings, + while_true, +)] + +#![cfg_attr(test, feature(plugin, custom_derive))] +#![cfg_attr(test, plugin(rocket_codegen))] +#![doc(test(attr(allow(unused_variables), deny(warnings))))] + +#[macro_use] +extern crate log; +#[macro_use] +extern crate rocket; +#[macro_use] +extern crate serde_derive; +extern crate unicase; +extern crate url; +extern crate url_serde; + +#[cfg(test)] +extern crate hyper; + +use std::collections::{HashSet, HashMap}; +use std::error; +use std::fmt; +use std::ops::Deref; +use std::str::FromStr; + +use rocket::request::{self, Request, FromRequest}; +use rocket::response::{self, Responder}; +use rocket::http::{Method, Status}; +use rocket::Outcome; +use unicase::UniCase; + +#[cfg(test)] +#[macro_use] +mod test_macros; + +/// CORS related error +#[derive(Debug)] +pub enum Error { + /// The HTTP request header `Origin` is required but was not provided + MissingOrigin, + /// The HTTP request header `Origin` could not be parsed correctly. + BadOrigin(url::ParseError), + /// The request header `Access-Control-Request-Method` is required but is missing + MissingRequestMethod, + /// The request header `Access-Control-Request-Method` has an invalid value + BadRequestMethod(rocket::Error), + /// The request header `Access-Control-Request-Headers` is required but is missing. + MissingRequestHeaders, + /// Origin is not allowed to make this request + OriginNotAllowed, + /// Requested method is not allowed + MethodNotAllowed, + /// One or more headers requested are not allowed + HeadersNotAllowed, +} + +impl error::Error for Error { + fn description(&self) -> &str { + match *self { + Error::MissingOrigin => "The request header `Origin` is required but is missing", + Error::BadOrigin(_) => "The request header `Origin` contains an invalid URL", + Error::MissingRequestMethod => { + "The request header `Access-Control-Request-Method` \ + is required but is missing" + } + Error::BadRequestMethod(_) => { + "The request header `Access-Control-Request-Method` has an invalid value" + } + Error::MissingRequestHeaders => { + "The request header `Access-Control-Request-Headers` \ + is required but is missing" + } + Error::OriginNotAllowed => "Origin is not allowed to request", + Error::MethodNotAllowed => "Method is not allowed", + Error::HeadersNotAllowed => "Headers are not allowed", + } + } + + fn cause(&self) -> Option<&error::Error> { + match *self { + Error::BadOrigin(ref e) => Some(e), + _ => Some(self), + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::BadOrigin(ref e) => fmt::Display::fmt(e, f), + Error::BadRequestMethod(ref e) => fmt::Debug::fmt(e, f), + _ => write!(f, "{}", error::Error::description(self)), + } + } +} + +impl<'r> Responder<'r> for Error { + fn respond_to(self, _: &Request) -> Result, Status> { + error_!("CORS Error: {:?}", self); + Err(match self { + Error::MissingOrigin | Error::OriginNotAllowed | Error::MethodNotAllowed | + Error::HeadersNotAllowed => Status::Forbidden, + _ => Status::BadRequest, + }) + } +} + +/// A wrapped `url::Url` to allow for deserialization +#[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)] +pub struct Url( + #[serde(with = "url_serde")] + url::Url +); + +impl fmt::Display for Url { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Deref for Url { + type Target = url::Url; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl FromStr for Url { + type Err = url::ParseError; + + fn from_str(input: &str) -> Result { + let url = url::Url::from_str(input)?; + Ok(Url(url)) + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for Url { + type Error = Error; + + fn from_request(request: &'a Request<'r>) -> request::Outcome { + match request.headers().get_one("Origin") { + Some(origin) => { + match Self::from_str(origin) { + Ok(origin) => Outcome::Success(origin), + Err(e) => Outcome::Failure((Status::BadRequest, Error::BadOrigin(e))), + } + } + None => Outcome::Forward(()), + } + } +} + +/// The `Origin` request header used in CORS +pub type Origin = Url; + +/// The `Access-Control-Request-Method` request header +#[derive(Debug)] +pub struct AccessControlRequestMethod(pub Method); + +impl FromStr for AccessControlRequestMethod { + type Err = rocket::Error; + + fn from_str(method: &str) -> Result { + Ok(AccessControlRequestMethod(Method::from_str(method)?)) + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestMethod { + type Error = Error; + + fn from_request(request: &'a Request<'r>) -> request::Outcome { + match request.headers().get_one("Access-Control-Request-Method") { + Some(request_method) => { + match Self::from_str(request_method) { + Ok(request_method) => Outcome::Success(request_method), + Err(e) => Outcome::Failure((Status::BadRequest, Error::BadRequestMethod(e))), + } + } + None => Outcome::Failure((Status::BadRequest, Error::MissingRequestMethod)), + } + } +} + +type HeaderFieldNamesSet = HashSet>; + +/// The `Access-Control-Request-Headers` request header +#[derive(Debug)] +pub struct AccessControlRequestHeaders(pub HeaderFieldNamesSet); + +/// Will never fail +impl FromStr for AccessControlRequestHeaders { + type Err = (); + + /// Will never fail + fn from_str(headers: &str) -> Result { + if headers.trim().is_empty() { + return Ok(AccessControlRequestHeaders(HashSet::new())); + } + + let set: HeaderFieldNamesSet = headers + .split(',') + .map(|header| UniCase(header.trim().to_string())) + .collect(); + Ok(AccessControlRequestHeaders(set)) + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for AccessControlRequestHeaders { + type Error = Error; + + fn from_request(request: &'a Request<'r>) -> request::Outcome { + match request.headers().get_one("Access-Control-Request-Headers") { + Some(request_headers) => { + match Self::from_str(request_headers) { + Ok(request_headers) => Outcome::Success(request_headers), + Err(()) => { + unreachable!("`AccessControlRequestHeaders::from_str` should never fail") + } + } + } + None => Outcome::Failure((Status::BadRequest, Error::MissingRequestHeaders)), + } + } +} + +/// Origins that are allowed to issue CORS request. This is needed for browser +/// access to the authentication server, but tools like `curl` +/// do not obey nor enforce the CORS convention. +/// +/// This enum (de)serialized as an [untagged](https://serde.rs/enum-representations.html) +/// enum variant. +/// +/// # Examples +/// ## Allow all origins +/// ```json +/// { "allowed_origins": null } +/// ``` +/// ``` +/// extern crate rocket_cors; +/// #[macro_use] +/// extern crate serde_derive; +/// extern crate serde_json; +/// +/// use rocket_cors::*; +/// +/// # fn main() { +/// #[derive(Serialize, Deserialize)] +/// struct Test { +/// allowed_origins: AllowedOrigins +/// } +/// +/// let json = r#"{ "allowed_origins": null }"#; +/// let deserialized: Test = serde_json::from_str(json).unwrap(); +/// # } +/// ``` +/// ## Allow specific origins +/// +/// ```json +/// { "allowed_origins": ["http://127.0.0.1:8000/","https://foobar.com/"] } +/// ``` +/// +/// ``` +/// extern crate rocket_cors; +/// #[macro_use] +/// extern crate serde_derive; +/// extern crate serde_json; +/// +/// use rocket_cors::*; +/// +/// # fn main() { +/// #[derive(Serialize, Deserialize)] +/// struct Test { +/// allowed_origins: AllowedOrigins +/// } +/// +/// let json = r#"{ "allowed_origins": ["http://127.0.0.1:8000/","https://foobar.com/"] }"#; +/// let deserialized: Test = serde_json::from_str(json).unwrap(); +/// # } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum AllowedOrigins { + /// All origins are allowed. Equivalent to the "*" value. + All, + /// Only origins listed are allowed. + Some(HashSet), +} + +impl Default for AllowedOrigins { + fn default() -> Self { + AllowedOrigins::All + } +} + +impl AllowedOrigins { + /// New `AllowedOrigins` from a list of URL strings. + /// Returns a tuple where the first element is the struct `AllowedOrigins`, + /// and the second element + /// is a map of strings which failed to parse into URLs and their associated parse errors. + pub fn new_from_str_list(urls: &[&str]) -> (Self, HashMap) { + let (ok_set, error_map): (Vec<_>, Vec<_>) = urls.iter() + .map(|s| (s.to_string(), Url::from_str(s))) + .partition(|&(_, ref r)| r.is_ok()); + + let error_map = error_map + .into_iter() + .map(|(s, r)| (s.to_string(), r.unwrap_err())) + .collect(); + + let ok_set = ok_set.into_iter().map(|(_, r)| r.unwrap()).collect(); + + (AllowedOrigins::Some(ok_set), error_map) + } +} + +/// Options to aid in the building of a CORS response during pre-flight or after. +/// See module level documentation for usage examples. +#[derive(Clone, Debug, Default)] +pub struct Options { + /// Origins that are allowed to make requests. + /// Will be verified against the `Origin` request header. + pub allowed_origins: AllowedOrigins, + /// Methods that the clients are allowed to request in. + /// Will be verified against the `Access-Control-Request-Method` request header + /// during pre-flight only. + pub allowed_methods: HashSet, + /// Headers that the clients are allowed to request in. + /// Will be verified against the `Access-Control-Request-Headers` request header + /// during pre-flight only. + pub allowed_headers: HeaderFieldNamesSet, + /// The `Access-Control-Allow-Credentials` response header + pub allow_credentials: bool, + /// The `Access-Control-Expose-Headers` responde header + pub expose_headers: HashSet, + /// The `Access-Control-Max-Age` response header + pub max_age: Option, +} + +impl Options { + /// Construct a preflight response based on the options. Will return an `Err` + /// if any of the preflight checks + /// fail. + pub fn preflight( + &self, + origin: Option, + method: &AccessControlRequestMethod, + headers: Option<&AccessControlRequestHeaders>, + ) -> Result, Error> { + + + match origin { + None => Err(Error::MissingOrigin), + Some(origin) => { + let response = Response::<()>::allowed_origin((), &origin, &self.allowed_origins)? + .allowed_methods(method, self.allowed_methods.clone())?; + + match headers { + Some(headers) => { + self.append(response.allowed_headers(headers, &self.allowed_headers)) + } + None => Ok(response), + } + } + } + } + + /// Respond to a request based on the settings. + /// If the `Origin` is not provided, then this request was not made by a browser and there is no + /// CORS enforcement. + pub fn respond<'r, R: Responder<'r>>( + &self, + responder: R, + origin: Option, + ) -> Result, Error> { + match origin { + None => Ok(Response::::any(responder)), + Some(origin) => { + self.append(Response::::allowed_origin( + responder, + &origin, + &self.allowed_origins, + )) + } + } + } + + fn append<'r, R: Responder<'r>>( + &self, + response: Result, Error>, + ) -> Result, Error> { + Ok( + response? + .credentials(self.allow_credentials) + .exposed_headers( + self.expose_headers + .iter() + .map(|s| &**s) + .collect::>() + .as_slice(), + ) + .max_age(self.max_age), + ) + } +} + +/// A CORS Response which wraps another struct which implements `Responder`. You will typically +/// use [`Options`] instead to verify and build the response instead of this directly. +/// See module level documentation for usage examples. +pub struct Response { + responder: R, + allow_origin: String, + allow_methods: HashSet, + allow_headers: HeaderFieldNamesSet, + allow_credentials: bool, + expose_headers: HeaderFieldNamesSet, + max_age: Option, +} + +impl<'r, R: Responder<'r>> Response { + /// Consumes the responder and origin and returns basic CORS + fn origin(responder: R, origin: &str) -> Self { + Self { + allow_origin: origin.to_string(), + allow_headers: HashSet::new(), + allow_methods: HashSet::new(), + responder: responder, + allow_credentials: false, + expose_headers: HashSet::new(), + max_age: None, + } + } + /// Consumes the responder and based on the provided list of allowed origins, + /// check if the requested origin is allowed. + /// Useful for pre-flight and during requests + pub fn allowed_origin( + responder: R, + origin: &Origin, + allowed_origins: &AllowedOrigins, + ) -> Result { + match *allowed_origins { + AllowedOrigins::All => Ok(Self::any(responder)), + AllowedOrigins::Some(ref allowed_origins) => { + let origin = origin.origin().unicode_serialization(); + + let allowed_origins: HashSet<_> = allowed_origins + .iter() + .map(|o| o.origin().unicode_serialization()) + .collect(); + let _ = allowed_origins.get(&origin).ok_or_else( + || Error::OriginNotAllowed, + )?; + Ok(Self::origin(responder, &origin)) + } + } + } + + /// Consumes responder and returns CORS with any origin + pub fn any(responder: R) -> Self { + Self::origin(responder, "*") + } + + /// Consumes the CORS, set allow_credentials to + /// new value and returns changed CORS + pub fn credentials(mut self, value: bool) -> Self { + self.allow_credentials = value; + self + } + + /// Consumes the CORS, set expose_headers to + /// passed headers and returns changed CORS + pub fn exposed_headers(mut self, headers: &[&str]) -> Self { + self.expose_headers = headers.into_iter().map(|s| s.to_string().into()).collect(); + self + } + + /// Consumes the CORS, set max_age to + /// passed value and returns changed CORS + pub fn max_age(mut self, value: Option) -> Self { + self.max_age = value; + self + } + + /// Consumes the CORS, set allow_methods to + /// passed methods and returns changed CORS + fn methods(mut self, methods: HashSet) -> Self { + self.allow_methods = methods; + self + } + + /// Consumes the CORS, check if requested method is allowed. + /// Useful for pre-flight checks + pub fn allowed_methods( + self, + method: &AccessControlRequestMethod, + allowed_methods: HashSet, + ) -> Result { + let &AccessControlRequestMethod(ref request_method) = method; + if !allowed_methods.iter().any(|m| m == request_method) { + Err(Error::MethodNotAllowed)? + } + Ok(self.methods(allowed_methods)) + } + + /// Consumes the CORS, set allow_headers to + /// passed headers and returns changed CORS + fn headers(mut self, headers: &[&str]) -> Self { + self.allow_headers = headers.into_iter().map(|s| s.to_string().into()).collect(); + self + } + + /// Consumes the CORS, check if requested headersa are allowed. + /// Useful for pre-flight checks + pub fn allowed_headers( + self, + headers: &AccessControlRequestHeaders, + allowed_headers: &HeaderFieldNamesSet, + ) -> Result { + let &AccessControlRequestHeaders(ref headers) = headers; + if !headers.is_empty() && !headers.is_subset(allowed_headers) { + Err(Error::HeadersNotAllowed)? + } + Ok( + self.headers( + allowed_headers + .iter() + .map(|s| &**s.deref()) + .collect::>() + .as_slice(), + ), + ) + } +} + +impl<'r, R: Responder<'r>> Responder<'r> for Response { + fn respond_to(self, request: &Request) -> response::Result<'r> { + let mut response = response::Response::build_from(self.responder.respond_to(request)?) + .raw_header("Access-Control-Allow-Origin", self.allow_origin) + .finalize(); + + if self.allow_credentials { + response.set_raw_header("Access-Control-Allow-Credentials", "true"); + } else { + response.set_raw_header("Access-Control-Allow-Credentials", "false"); + } + + if !self.expose_headers.is_empty() { + let headers: Vec = self.expose_headers + .into_iter() + .map(|s| s.deref().to_string()) + .collect(); + let headers = headers.join(", "); + + response.set_raw_header("Access-Control-Expose-Headers", headers); + } + + if !self.allow_headers.is_empty() { + let headers: Vec = self.allow_headers + .into_iter() + .map(|s| s.deref().to_string()) + .collect(); + let headers = headers.join(", "); + + response.set_raw_header("Access-Control-Allow-Headers", headers); + } + + + if !self.allow_methods.is_empty() { + let methods: Vec<_> = self.allow_methods.into_iter().map(|m| m.as_str()).collect(); + let methods = methods.join(", "); + + response.set_raw_header("Access-Control-Allow-Methods", methods); + } + + if self.max_age.is_some() { + let max_age = self.max_age.unwrap(); + response.set_raw_header("Access-Control-Max-Age", max_age.to_string()); + } + + Ok(response) + } +} + +#[cfg(test)] +#[allow(unmounted_route)] +mod tests { + use std::str::FromStr; + + use hyper; + use rocket; + use rocket::local::Client; + use rocket::http::Method; + use rocket::http::{Header, Status}; + use rocket::State; + + use super::*; + + #[test] + fn origin_header_conversion() { + let url = "https://foo.bar.xyz"; + let _ = not_err!(Origin::from_str(url)); + + let url = "https://foo.bar.xyz/path/somewhere"; // this should never really be used + let _ = not_err!(Origin::from_str(url)); + + let url = "invalid_url"; + let _ = is_err!(Origin::from_str(url)); + } + + #[test] + fn request_method_conversion() { + let method = "POST"; + let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); + assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post)); + + let method = "options"; + let parsed_method = not_err!(AccessControlRequestMethod::from_str(method)); + assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options)); + + let method = "INVALID"; + let _ = is_err!(AccessControlRequestMethod::from_str(method)); + } + + #[test] + fn request_headers_conversion() { + let headers = ["foo", "bar", "baz"]; + let parsed_headers = not_err!(AccessControlRequestHeaders::from_str(&headers.join(", "))); + let expected_headers: HeaderFieldNamesSet = + headers.iter().map(|s| s.to_string().into()).collect(); + let AccessControlRequestHeaders(actual_headers) = parsed_headers; + assert_eq!(actual_headers, expected_headers); + } + + #[get("/request_headers")] + #[allow(needless_pass_by_value)] + fn request_headers( + origin: Origin, + method: AccessControlRequestMethod, + headers: AccessControlRequestHeaders, + ) -> String { + let AccessControlRequestMethod(method) = method; + let AccessControlRequestHeaders(headers) = headers; + let mut headers = headers + .iter() + .map(|s| s.deref().to_string()) + .collect::>(); + headers.sort(); + format!("{}\n{}\n{}", origin, method, headers.join(", ")) + } + + /// Tests that all the headers are parsed correcly in a HTTP request + #[test] + fn request_headers_round_trip_smoke_test() { + let rocket = rocket::ignite().mount("/", routes![request_headers]); + let client = not_err!(Client::new(rocket)); + + let origin_header = Header::from(not_err!( + hyper::header::Origin::from_str("https://foo.bar.xyz") + )); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders(vec![ + FromStr::from_str("accept-language").unwrap(), + FromStr::from_str("X-Ping").unwrap(), + ]); + let request_headers = Header::from(request_headers); + let req = client + .get("/request_headers") + .header(origin_header) + .header(method_header) + .header(request_headers); + let mut response = req.dispatch(); + + assert_eq!(Status::Ok, response.status()); + let body_str = not_none!(response.body().and_then(|body| body.into_string())); + let expected_body = r#"https://foo.bar.xyz/ +GET +X-Ping, accept-language"#; + assert_eq!(expected_body, body_str); + } + + #[get("/any")] + #[cfg_attr(feature = "clippy_lints", allow(needless_pass_by_value))] + fn any() -> Response<&'static str> { + Response::any("Hello, world!") + } + + #[test] + fn response_any_origin_smoke_test() { + let rocket = rocket::ignite().mount("/", routes![any]); + let client = not_err!(Client::new(rocket)); + + let req = client.get("/any"); + let mut response = req.dispatch(); + + assert_eq!(Status::Ok, response.status()); + let body_str = response.body().and_then(|body| body.into_string()); + let values: Vec<_> = response + .headers() + .get("Access-Control-Allow-Origin") + .collect(); + assert_eq!(values, vec!["*"]); + assert_eq!(body_str, Some("Hello, world!".to_string())); + } + + #[options("/")] + #[allow(needless_pass_by_value)] + fn cors_options( + origin: Option, + method: AccessControlRequestMethod, + headers: AccessControlRequestHeaders, + options: State, + ) -> Result, Error> { + options.preflight(origin, &method, Some(&headers)) + } + + #[get("/")] + #[allow(needless_pass_by_value)] + fn cors( + origin: Option, + options: State, + ) -> Result, Error> { + options.respond("Hello CORS", origin) + } + + fn make_cors_options() -> Options { + let (allowed_origins, failed_origins) = + AllowedOrigins::new_from_str_list(&["https://www.acme.com"]); + assert!(failed_origins.is_empty()); + + Options { + allowed_origins: allowed_origins, + allowed_methods: [Method::Get].iter().cloned().collect(), + allowed_headers: ["Authorization"] + .iter() + .map(|s| s.to_string().into()) + .collect(), + allow_credentials: true, + ..Default::default() + } + } + + #[test] + fn cors_options_check() { + let rocket = rocket::ignite() + .mount("/", routes![cors, cors_options]) + .manage(make_cors_options()); + let client = not_err!(Client::new(rocket)); + + let origin_header = Header::from(not_err!( + hyper::header::Origin::from_str("https://www.acme.com") + )); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Ok); + } + + #[test] + fn cors_get_check() { + let rocket = rocket::ignite() + .mount("/", routes![cors, cors_options]) + .manage(make_cors_options()); + let client = not_err!(Client::new(rocket)); + + let origin_header = Header::from(not_err!( + hyper::header::Origin::from_str("https://www.acme.com") + )); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let mut response = req.dispatch(); + assert_eq!(response.status(), Status::Ok); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); + } + + /// This test is to check that non CORS compliant requests to GET should still work. (i.e. curl) + #[test] + fn cors_get_no_origin() { + let rocket = rocket::ignite() + .mount("/", routes![cors, cors_options]) + .manage(make_cors_options()); + let client = not_err!(Client::new(rocket)); + + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(authorization); + + let mut response = req.dispatch(); + assert_eq!(response.status(), Status::Ok); + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello CORS".to_string())); + } + + #[test] + fn cors_options_bad_origin() { + let rocket = rocket::ignite() + .mount("/", routes![cors, cors_options]) + .manage(make_cors_options()); + let client = not_err!(Client::new(rocket)); + + let origin_header = Header::from(not_err!(hyper::header::Origin::from_str( + "https://www.bad-origin.com", + ))); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + } + + #[test] + fn cors_options_missing_origin() { + let rocket = rocket::ignite() + .mount("/", routes![cors, cors_options]) + .manage(make_cors_options()); + let client = not_err!(Client::new(rocket)); + + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client.options("/").header(method_header).header( + request_headers, + ); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + } + + #[test] + fn cors_options_bad_request_method() { + let rocket = rocket::ignite() + .mount("/", routes![cors, cors_options]) + .manage(make_cors_options()); + let client = not_err!(Client::new(rocket)); + + let origin_header = Header::from(not_err!( + hyper::header::Origin::from_str("https://www.acme.com") + )); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Post, + )); + let request_headers = hyper::header::AccessControlRequestHeaders( + vec![FromStr::from_str("Authorization").unwrap()], + ); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + } + + #[test] + fn cors_options_bad_request_header() { + let rocket = rocket::ignite() + .mount("/", routes![cors, cors_options]) + .manage(make_cors_options()); + let client = not_err!(Client::new(rocket)); + + let origin_header = Header::from(not_err!( + hyper::header::Origin::from_str("https://www.acme.com") + )); + let method_header = Header::from(hyper::header::AccessControlRequestMethod( + hyper::method::Method::Get, + )); + let request_headers = + hyper::header::AccessControlRequestHeaders(vec![FromStr::from_str("Foobar").unwrap()]); + let request_headers = Header::from(request_headers); + let req = client + .options("/") + .header(origin_header) + .header(method_header) + .header(request_headers); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + } + + #[test] + fn cors_get_bad_origin() { + let rocket = rocket::ignite() + .mount("/", routes![cors, cors_options]) + .manage(make_cors_options()); + let client = not_err!(Client::new(rocket)); + + let origin_header = Header::from(not_err!(hyper::header::Origin::from_str( + "https://www.bad-origin.com", + ))); + let authorization = Header::new("Authorization", "let me in"); + let req = client.get("/").header(origin_header).header(authorization); + + let response = req.dispatch(); + assert_eq!(response.status(), Status::Forbidden); + } +} diff --git a/src/test_macros.rs b/src/test_macros.rs new file mode 100644 index 0000000..ff19c9f --- /dev/null +++ b/src/test_macros.rs @@ -0,0 +1,33 @@ +macro_rules! not_err { + ($e:expr) => (match $e { + Ok(e) => e, + Err(e) => panic!("{} failed with {:?}", stringify!($e), e), + }) +} + +macro_rules! is_err { + ($e:expr) => (match $e { + Ok(e) => panic!("{} did not return with an error, but with {:?}", stringify!($e), e), + Err(e) => e, + }) +} + +macro_rules! not_none { + ($e:expr) => (match $e { + Some(e) => e, + None => panic!("{} failed with None", stringify!($e)), + }) +} + +macro_rules! assert_matches { + ($e: expr, $p: pat) => (assert_matches!($e, $p, ())); + ($e: expr, $p: pat, $f: expr) => (match $e { + $p => $f, + e => { + panic!( + "{}: Expected pattern {} \ndoes not match {:?}", + stringify!($e), stringify!($p), e + ) + } + }) +}