diff --git a/Cargo.lock b/Cargo.lock index 3ec633a..6417100 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,6 +101,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "arrayref" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544" + [[package]] name = "arrayvec" version = "0.5.2" @@ -198,6 +204,17 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake2b_simd" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afa748e348ad3be8263be728124b24a24f268266f6f5d58af9d75f6a40b5c587" +dependencies = [ + "arrayref", + "arrayvec", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.9.0" @@ -313,6 +330,12 @@ version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f92cfa0fd5690b3cf8c1ef2cabbd9b7ef22fa53cf5e1f92b05103f6d5d1cf6e7" +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + [[package]] name = "core-foundation" version = "0.9.1" @@ -2025,6 +2048,18 @@ dependencies = [ "smallvec", ] +[[package]] +name = "rust-argon2" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b18820d944b33caa75a71378964ac46f58517c92b6ae5f762636247c09e78fb" +dependencies = [ + "base64", + "blake2b_simd", + "constant_time_eq", + "crossbeam-utils", +] + [[package]] name = "rustc_version" version = "0.2.3" @@ -2510,6 +2545,7 @@ dependencies = [ "phf", "rand 0.8.3", "refinery", + "rust-argon2", "serde", "sqlx", "tempfile", diff --git a/Cargo.toml b/Cargo.toml index 7d5bacc..ff312b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ tracing-subscriber = "0.2" toml = "0.5" nom = "5" rand = "0.8" +rust-argon2 = "0.8" thiserror = "1.0" itertools = "0.10" async-trait = "0.1" diff --git a/src/commands/management.rs b/src/commands/management.rs index f4396d2..aaf8dc5 100644 --- a/src/commands/management.rs +++ b/src/commands/management.rs @@ -1,6 +1,9 @@ use super::{Command, Execution, ExecutionResult}; use crate::context::Context; -use crate::logic::record_room_information; +use crate::db::Users; +use crate::error::BotError::PasswordCreationError; +use crate::logic::{hash_password, record_room_information}; +use crate::models::User; use async_trait::async_trait; use matrix_sdk::identifiers::UserId; @@ -47,6 +50,13 @@ impl Command for RegisterCommand { } async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { - Execution::success("User account registered".to_string()) + let pw_hash = hash_password(&self.0).map_err(|e| PasswordCreationError(e))?; + let user = User { + username: ctx.username.to_owned(), + password: pw_hash, + }; + + ctx.db.upsert_user(&user).await?; + Execution::success("User account registered/updated".to_string()) } } diff --git a/src/db/mod.rs b/src/db/mod.rs index c2a107e..6e916f1 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,3 +1,5 @@ +use crate::error::BotError; +use crate::models::User; use async_trait::async_trait; use errors::DataError; use std::collections::{HashMap, HashSet}; @@ -14,6 +16,19 @@ pub(crate) trait DbState { async fn set_device_id(&self, device_id: &str) -> Result<(), DataError>; } +#[async_trait] +pub(crate) trait Users { + async fn upsert_user(&self, user: &User) -> Result<(), DataError>; + + async fn get_user(&self, username: &str) -> Result, DataError>; + + async fn authenticate_user( + &self, + username: &str, + raw_password: &str, + ) -> Result, BotError>; +} + #[async_trait] pub(crate) trait Rooms { async fn should_process(&self, room_id: &str, event_id: &str) -> Result; diff --git a/src/db/sqlite/mod.rs b/src/db/sqlite/mod.rs index 68f04c2..13df2c7 100644 --- a/src/db/sqlite/mod.rs +++ b/src/db/sqlite/mod.rs @@ -7,6 +7,7 @@ use std::str::FromStr; pub mod migrator; pub mod rooms; pub mod state; +pub mod users; pub mod variables; pub struct Database { diff --git a/src/db/sqlite/users.rs b/src/db/sqlite/users.rs new file mode 100644 index 0000000..78c6793 --- /dev/null +++ b/src/db/sqlite/users.rs @@ -0,0 +1,102 @@ +use super::Database; +use crate::db::{errors::DataError, Users}; +use crate::error::BotError; +use crate::models::User; +use async_trait::async_trait; + +#[async_trait] +impl Users for Database { + async fn upsert_user(&self, user: &User) -> Result<(), DataError> { + sqlx::query( + r#"INSERT INTO accounts (user_id, password) VALUES (?, ?) + ON CONFLICT(user_id) DO UPDATE SET password = ?"#, + ) + .bind(&user.username) + .bind(&user.password) + .bind(&user.password) + .execute(&self.conn) + .await?; + + Ok(()) + } + + async fn get_user(&self, username: &str) -> Result, DataError> { + let user_row = sqlx::query!( + r#"SELECT user_id, password FROM accounts + WHERE user_id = ?"#, + username + ) + .fetch_optional(&self.conn) + .await?; + + Ok(user_row.map(|u| User { + username: u.user_id, + password: u.password, + })) + } + + async fn authenticate_user( + &self, + username: &str, + raw_password: &str, + ) -> Result, BotError> { + let user = self.get_user(username).await?; + Ok(user.filter(|u| u.verify_password(raw_password))) + } +} + +#[cfg(test)] +mod tests { + use crate::db::sqlite::Database; + use crate::db::DbState; + + async fn create_db() -> Database { + let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); + crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) + .await + .unwrap(); + + Database::new(db_path.path().to_str().unwrap()) + .await + .unwrap() + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn set_and_get_device_id() { + let db = create_db().await; + + db.set_device_id("device_id") + .await + .expect("Could not set device ID"); + + let device_id = db.get_device_id().await.expect("Could not get device ID"); + + assert!(device_id.is_some()); + assert_eq!(device_id.unwrap(), "device_id"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn no_device_id_set_returns_none() { + let db = create_db().await; + let device_id = db.get_device_id().await.expect("Could not get device ID"); + assert!(device_id.is_none()); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn can_update_device_id() { + let db = create_db().await; + + db.set_device_id("device_id") + .await + .expect("Could not set device ID"); + + db.set_device_id("device_id2") + .await + .expect("Could not set device ID"); + + let device_id = db.get_device_id().await.expect("Could not get device ID"); + + assert!(device_id.is_some()); + assert_eq!(device_id.unwrap(), "device_id2"); + } +} diff --git a/src/error.rs b/src/error.rs index 505ac3c..6ed0be3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -78,6 +78,9 @@ pub enum BotError { #[error("identifier error: {0}")] IdentifierError(#[from] matrix_sdk::identifiers::Error), + + #[error("password creation error: {0}")] + PasswordCreationError(argon2::Error), } #[derive(Error, Debug)] diff --git a/src/logic.rs b/src/logic.rs index 953e8ce..8ce0156 100644 --- a/src/logic.rs +++ b/src/logic.rs @@ -4,8 +4,10 @@ use crate::error::{BotError, DiceRollingError}; use crate::matrix; use crate::models::RoomInfo; use crate::parser::dice::{Amount, Element}; +use argon2::{self, Config, Error as ArgonError}; use futures::stream::{self, StreamExt, TryStreamExt}; use matrix_sdk::{self, identifiers::RoomId, Client}; +use rand::Rng; use std::slice; /// Record the information about a room, including users in it. @@ -86,3 +88,10 @@ pub async fn calculate_dice_amount(amounts: &[Amount], ctx: &Context<'_>) -> Res Ok(dice_amount) } + +/// Hash a password using the argon2 algorithm with a 16 byte salt. +pub(crate) fn hash_password(raw_password: &str) -> Result { + let salt = rand::thread_rng().gen::<[u8; 16]>(); + let config = Config::default(); + argon2::hash_encoded(raw_password.as_bytes(), &salt, &config) +} diff --git a/src/models.rs b/src/models.rs index 83802cb..20f7f98 100644 --- a/src/models.rs +++ b/src/models.rs @@ -6,3 +6,15 @@ pub struct RoomInfo { pub room_id: String, pub room_name: String, } + +#[derive(Eq, PartialEq, Debug)] +pub struct User { + pub username: String, + pub password: String, +} + +impl User { + pub fn verify_password(&self, raw_password: &str) -> bool { + argon2::verify_encoded(&self.password, raw_password.as_bytes()).unwrap_or(false) + } +}