diff --git a/src/bin/dicebot-cmd.rs b/src/bin/dicebot-cmd.rs index f90f178..3f63dd8 100644 --- a/src/bin/dicebot-cmd.rs +++ b/src/bin/dicebot-cmd.rs @@ -4,7 +4,7 @@ use tenebrous_dicebot::commands::ResponseExtractor; use tenebrous_dicebot::context::{Context, RoomContext}; use tenebrous_dicebot::db::sqlite::Database; use tenebrous_dicebot::error::BotError; -use tenebrous_dicebot::models::User; +use tenebrous_dicebot::models::Account; use url::Url; #[tokio::main] @@ -27,7 +27,7 @@ async fn main() -> Result<(), BotError> { let context = Context { db: db, - user: User::default(), + account: Account::default(), matrix_client: &matrix_sdk::Client::new(homeserver) .expect("Could not create matrix client"), room: RoomContext { diff --git a/src/bot/command_execution.rs b/src/bot/command_execution.rs index f7d7df9..9005206 100644 --- a/src/bot/command_execution.rs +++ b/src/bot/command_execution.rs @@ -1,8 +1,8 @@ use crate::commands::{execute_command, ExecutionError, ExecutionResult, ResponseExtractor}; use crate::context::{Context, RoomContext}; use crate::db::sqlite::Database; -use crate::db::Users; use crate::error::BotError; +use crate::logic; use crate::matrix; use futures::stream::{self, StreamExt}; use matrix_sdk::{self, identifiers::EventId, room::Joined, Client}; @@ -78,7 +78,7 @@ async fn create_context<'a>( matrix_client: client, room: room_ctx, username: &sender, - user: db.get_or_create_user(&sender).await?, + account: logic::get_account(db, &sender).await?, message_body: &command, }) } diff --git a/src/cofd/dice.rs b/src/cofd/dice.rs index 33c0b89..976ec51 100644 --- a/src/cofd/dice.rs +++ b/src/cofd/dice.rs @@ -483,7 +483,7 @@ mod tests { .unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db, matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: dummy_room!(), @@ -524,7 +524,7 @@ mod tests { let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db, matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: dummy_room!(), @@ -562,7 +562,7 @@ mod tests { let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db.clone(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: dummy_room!(), diff --git a/src/commands/mod.rs b/src/commands/mod.rs index e390f40..52bd5c4 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -201,7 +201,7 @@ mod tests { let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db, matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: secure_room!(), @@ -223,7 +223,7 @@ mod tests { let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db, matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: secure_room!(), @@ -245,7 +245,7 @@ mod tests { let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db, matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: dummy_room!(), @@ -267,7 +267,7 @@ mod tests { let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db, matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: dummy_room!(), @@ -298,7 +298,7 @@ mod tests { let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db, matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: dummy_room!(), diff --git a/src/context.rs b/src/context.rs index 86f40d4..7cd5837 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,6 +1,6 @@ use crate::db::sqlite::Database; use crate::error::BotError; -use crate::models::User; +use crate::models::{Account, User}; use matrix_sdk::identifiers::{RoomId, UserId}; use matrix_sdk::room::Joined; use matrix_sdk::Client; @@ -15,7 +15,7 @@ pub struct Context<'a> { pub room: RoomContext<'a>, pub username: &'a str, pub message_body: &'a str, - pub user: User, + pub account: Account, } impl Context<'_> { diff --git a/src/cthulhu/dice.rs b/src/cthulhu/dice.rs index 724fcdc..73ad1fb 100644 --- a/src/cthulhu/dice.rs +++ b/src/cthulhu/dice.rs @@ -504,7 +504,7 @@ mod tests { let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db, matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: dummy_room!(), @@ -541,7 +541,7 @@ mod tests { let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db, matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: dummy_room!(), @@ -578,7 +578,7 @@ mod tests { let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { - user: crate::models::User::default(), + account: crate::models::Account::default(), db: db, matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), room: dummy_room!(), diff --git a/src/db/mod.rs b/src/db/mod.rs index 789ec45..f725a56 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -16,8 +16,6 @@ pub(crate) trait DbState { #[async_trait] pub(crate) trait Users { - async fn get_or_create_user(&self, username: &str) -> Result; - async fn upsert_user(&self, user: &User) -> Result<(), DataError>; async fn get_user(&self, username: &str) -> Result, DataError>; diff --git a/src/db/sqlite/users.rs b/src/db/sqlite/users.rs index b619c4b..ac6e819 100644 --- a/src/db/sqlite/users.rs +++ b/src/db/sqlite/users.rs @@ -76,21 +76,6 @@ impl Users for Database { Ok(user_row) } - //TODO should this logic be moved further up into logic.rs maybe? - async fn get_or_create_user(&self, username: &str) -> Result { - let maybe_user = self.get_user(username).await?; - - match maybe_user { - Some(user) => Ok(user), - None => { - info!("Creating unregistered account for {}", username); - let user = User::unregistered(&username); - self.upsert_user(&user).await?; - Ok(user) - } - } - } - async fn authenticate_user( &self, username: &str, @@ -119,48 +104,6 @@ mod tests { .unwrap() } - #[tokio::test(flavor = "multi_thread", worker_threads = 1)] - async fn get_or_create_user_no_user_exists() { - let db = create_db().await; - - let user = db - .get_or_create_user("@test:example.com") - .await - .expect("User creation didn't work."); - - assert_eq!(user.username, "@test:example.com"); - - let user_again = db - .get_user("@test:example.com") - .await - .expect("User retrieval didn't work.") - .expect("No user returned from option."); - - assert_eq!(user, user_again); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 1)] - async fn get_or_create_user_when_user_exists() { - let db = create_db().await; - - let user = User { - username: "myuser".to_string(), - password: Some("abc".to_string()), - account_status: AccountStatus::Registered, - active_room: Some("myroom".to_string()), - }; - - let insert_result = db.upsert_user(&user).await; - assert!(insert_result.is_ok()); - - let user_again = db - .get_or_create_user("myuser") - .await - .expect("User retrieval didn't work."); - - assert_eq!(user, user_again); - } - #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn create_and_get_full_user_test() { let db = create_db().await; diff --git a/src/logic.rs b/src/logic.rs index a0386ca..c9e5c5d 100644 --- a/src/logic.rs +++ b/src/logic.rs @@ -1,7 +1,10 @@ -use crate::context::Context; -use crate::db::Variables; use crate::error::{BotError, DiceRollingError}; use crate::parser::dice::{Amount, Element}; +use crate::{context::Context, models::Account}; +use crate::{ + db::{sqlite::Database, Users, Variables}, + models::TransientUser, +}; use argon2::{self, Config, Error as ArgonError}; use futures::stream::{self, StreamExt, TryStreamExt}; use rand::Rng; @@ -50,3 +53,71 @@ pub(crate) fn hash_password(raw_password: &str) -> Result { let config = Config::default(); argon2::hash_encoded(raw_password.as_bytes(), &salt, &config) } + +pub(crate) async fn get_account(db: &Database, username: &str) -> Result { + Ok(db + .get_user(username) + .await? + .map(|user| Account::Registered(user)) + .unwrap_or_else(|| { + Account::Transient(TransientUser { + username: username.to_owned(), + }) + })) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::db::Users; + use crate::models::{AccountStatus, User}; + + async fn create_db() -> Database { + let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); + crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) + .await + .unwrap(); + + Database::new(db_path.path().to_str().unwrap()) + .await + .unwrap() + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn get_account_no_user_exists() { + let db = create_db().await; + + let account = get_account(&db, "@test:example.com") + .await + .expect("Account retrieval didn't work"); + + assert!(matches!(account, Account::Transient(_))); + + let user = account.transient_user().unwrap(); + assert_eq!(user.username, "@test:example.com"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn get_or_create_user_when_user_exists() { + let db = create_db().await; + + let user = User { + username: "myuser".to_string(), + password: Some("abc".to_string()), + account_status: AccountStatus::Registered, + active_room: Some("myroom".to_string()), + }; + + let insert_result = db.upsert_user(&user).await; + assert!(insert_result.is_ok()); + + let account = get_account(&db, "myuser") + .await + .expect("Account retrieval did not work"); + + assert!(matches!(account, Account::Registered(_))); + + let user_again = account.registered_user().unwrap(); + assert_eq!(user, user_again); + } +} diff --git a/src/models.rs b/src/models.rs index a528e27..2da71dc 100644 --- a/src/models.rs +++ b/src/models.rs @@ -10,9 +10,9 @@ pub struct RoomInfo { #[derive(Eq, PartialEq, Clone, Copy, Debug, sqlx::Type)] #[sqlx(rename_all = "snake_case")] pub enum AccountStatus { - /// User is not registered, which means the "account" only exists - /// for state management in the bot. No privileged actions - /// possible. + /// Account is not registered, which means a transient "account" + /// with limited information exists only for the duration of the + /// command request. NotRegistered, /// User account is fully registered, either via Matrix directly, @@ -30,6 +30,62 @@ impl Default for AccountStatus { } } +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Account { + /// A registered user account, stored in the database. + Registered(User), + + /// A transient account. Not stored in the database. Represents a + /// user in a public channel that has not registered directly with + /// the bot yet. + Transient(TransientUser), +} + +impl Account { + /// Gets the account status. For registered users, this is their + /// actual account status (fully registered or awaiting + /// activation). For transient users, this is + /// AccountStatus::NotRegistered. + pub fn account_status(&self) -> AccountStatus { + match self { + Self::Registered(user) => user.account_status, + Self::Transient(_) => AccountStatus::NotRegistered, + } + } + + /// Consume self into an Option instance, which will be Some + /// if this account has a registered user, and None otherwise. + pub fn registered_user(self) -> Option { + match self { + Self::Registered(user) => Some(user), + _ => None, + } + } + + /// Consume self into an Option instance, which + /// will be Some if this account has a non-registered user, and + /// None otherwise. + pub fn transient_user(self) -> Option { + match self { + Self::Transient(user) => Some(user), + _ => None, + } + } +} + +impl Default for Account { + fn default() -> Self { + Account::Transient(TransientUser { + username: "".to_string(), + }) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TransientUser { + pub username: String, +} + #[derive(Eq, PartialEq, Clone, Debug, Default, sqlx::FromRow)] pub struct User { pub username: String,