diff --git a/.gitignore b/.gitignore index 5030d40..6cb4892 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ bot-db* bigboy .#* *.sqlite +.tmp* diff --git a/Cargo.lock b/Cargo.lock index 7e0aea0..3c6318c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2512,6 +2512,7 @@ dependencies = [ "serde", "sled", "sqlx", + "tempfile", "thiserror", "tokio", "toml", diff --git a/Cargo.toml b/Cargo.toml index acb3a22..93eb562 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,3 +50,6 @@ features = ['derive'] [dependencies.tokio] version = "1" features = [ "full" ] + +[dev-dependencies] +tempfile = "3" \ No newline at end of file diff --git a/src/bin/dicebot-cmd.rs b/src/bin/dicebot-cmd.rs index 79ce2b9..b80fc12 100644 --- a/src/bin/dicebot-cmd.rs +++ b/src/bin/dicebot-cmd.rs @@ -2,7 +2,7 @@ use matrix_sdk::identifiers::room_id; use tenebrous_dicebot::commands; use tenebrous_dicebot::commands::ResponseExtractor; use tenebrous_dicebot::context::{Context, RoomContext}; -use tenebrous_dicebot::db::Database; +use tenebrous_dicebot::db::sqlite::Database; use tenebrous_dicebot::error::BotError; use url::Url; @@ -17,7 +17,7 @@ async fn main() -> Result<(), BotError> { let homeserver = Url::parse("http://example.com")?; let context = Context { - db: Database::new_temp()?, + db: Database::new_temp().await?, matrix_client: &matrix_sdk::Client::new(homeserver) .expect("Could not create matrix client"), room: RoomContext { diff --git a/src/bin/dicebot.rs b/src/bin/dicebot.rs index 8ed69ff..d4ee65f 100644 --- a/src/bin/dicebot.rs +++ b/src/bin/dicebot.rs @@ -5,7 +5,7 @@ use log::error; use std::sync::{Arc, RwLock}; use tenebrous_dicebot::bot::DiceBot; use tenebrous_dicebot::config::*; -use tenebrous_dicebot::db::Database; +use tenebrous_dicebot::db::sqlite::Database; use tenebrous_dicebot::error::BotError; use tenebrous_dicebot::migrator; use tenebrous_dicebot::state::DiceBotState; @@ -29,12 +29,10 @@ async fn run() -> Result<(), BotError> { .expect("Need a config as an argument"); let cfg = Arc::new(read_config(config_path)?); - let db = Database::new(&cfg.database_path())?; + let sqlite_path = format!("{}/dicebot.sqlite", cfg.database_path()); + let db = Database::new(&sqlite_path).await?; let state = Arc::new(RwLock::new(DiceBotState::new(&cfg))); - db.migrate(cfg.migration_version())?; - - let sqlite_path = format!("{}/dicebot.sqlite", cfg.database_path()); migrator::migrate(&sqlite_path).await?; match DiceBot::new(&cfg, &state, &db) { diff --git a/src/bot.rs b/src/bot.rs index b3d8531..d7e11d1 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -1,7 +1,8 @@ use crate::commands::{execute_command, ExecutionError, ExecutionResult, ResponseExtractor}; use crate::config::*; use crate::context::{Context, RoomContext}; -use crate::db::Database; +use crate::db::sqlite::Database; +use crate::db::sqlite::DbState; use crate::error::BotError; use crate::matrix; use crate::state::DiceBotState; @@ -134,7 +135,7 @@ impl DiceBot { // Pull device ID from database, if it exists. Then write it // to DB if the library generated one for us. - let device_id: Option = self.db.state.get_device_id()?; + let device_id: Option = self.db.get_device_id().await?; let device_id: Option<&str> = device_id.as_deref(); client @@ -143,7 +144,7 @@ impl DiceBot { if device_id.is_none() { let device_id = client.device_id().await.ok_or(BotError::NoDeviceIdFound)?; - self.db.state.set_device_id(device_id.as_str())?; + self.db.set_device_id(device_id.as_str()).await?; info!("Recorded new device ID: {}", device_id.as_str()); } else { info!("Using existing device ID: {}", device_id.unwrap()); diff --git a/src/bot/event_handlers.rs b/src/bot/event_handlers.rs index 98f4072..7ceb1b5 100644 --- a/src/bot/event_handlers.rs +++ b/src/bot/event_handlers.rs @@ -4,7 +4,8 @@ * SDK example code. */ use super::DiceBot; -use crate::db::Database; +use crate::db::sqlite::Database; +use crate::db::sqlite::Rooms; use crate::error::BotError; use crate::logic::record_room_information; use async_trait::async_trait; @@ -90,9 +91,9 @@ async fn should_process_message<'a>( Ok((msg_body, sender_username)) } -fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool { - db.rooms - .should_process(room_id, event_id) +async fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool { + db.should_process(room_id, event_id) + .await .unwrap_or_else(|e| { error!( "Database error when checking if we should process an event: {}", @@ -116,7 +117,7 @@ impl EventHandler for DiceBot { let room_id_str = room_id.as_str(); let username = &event.state_key; - if !should_process_event(&self.db, room_id_str, event.event_id.as_str()) { + if !should_process_event(&self.db, room_id_str, event.event_id.as_str()).await { return; } @@ -135,7 +136,7 @@ impl EventHandler for DiceBot { let result = if event_affects_us && !adding_user { info!("Clearing all information for room ID {}", room_id); - self.db.rooms.clear_info(room_id_str) + self.db.clear_info(room_id_str).await } else if event_affects_us && adding_user { info!("Joined room {}; recording room information", room_id); record_room_information( @@ -148,10 +149,10 @@ impl EventHandler for DiceBot { .await } else if !event_affects_us && adding_user { info!("Adding user {} to room ID {}", username, room_id); - self.db.rooms.add_user_to_room(username, room_id_str) + self.db.add_user_to_room(username, room_id_str).await } else if !event_affects_us && !adding_user { info!("Removing user {} from room ID {}", username, room_id); - self.db.rooms.remove_user_from_room(username, room_id_str) + self.db.remove_user_from_room(username, room_id_str).await } else { debug!("Ignoring a room member event: {:#?}", event); Ok(()) @@ -196,7 +197,7 @@ impl EventHandler for DiceBot { }; let room_id = room.room_id().as_str(); - if !should_process_event(&self.db, room_id, event.event_id.as_str()) { + if !should_process_event(&self.db, room_id, event.event_id.as_str()).await { return; } diff --git a/src/cofd/dice.rs b/src/cofd/dice.rs index 5cf86d2..f269b1c 100644 --- a/src/cofd/dice.rs +++ b/src/cofd/dice.rs @@ -325,7 +325,8 @@ pub async fn roll_pool(pool: &DicePoolWithContext<'_>) -> Result for ExecutionError { fn from(error: crate::db::errors::DataError) -> Self { - Self(BotError::DataError(error)) + Self(DataError(error)) + } +} + +impl From for ExecutionError { + fn from(error: crate::db::sqlite::errors::DataError) -> Self { + Self(SqliteDataError(error)) } } @@ -129,9 +136,9 @@ mod tests { )); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn unrecognized_command() { - let db = crate::db::Database::new_temp().unwrap(); + let db = crate::db::sqlite::Database::new_temp().await.unwrap(); let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { db: db, diff --git a/src/commands/variables.rs b/src/commands/variables.rs index 326041d..c03ca70 100644 --- a/src/commands/variables.rs +++ b/src/commands/variables.rs @@ -1,6 +1,7 @@ use super::{Command, Execution, ExecutionResult}; use crate::context::Context; -use crate::db::errors::DataError; +use crate::db::sqlite::errors::DataError; +use crate::db::sqlite::Variables; use crate::db::variables::UserAndRoom; use async_trait::async_trait; @@ -13,8 +14,10 @@ impl Command for GetAllVariablesCommand { } async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { - let key = UserAndRoom(&ctx.username, &ctx.room_id().as_str()); - let variables = ctx.db.variables.get_user_variables(&key)?; + let variables = ctx + .db + .get_user_variables(&ctx.username, ctx.room_id().as_str()) + .await?; let mut variable_list: Vec = variables .into_iter() @@ -43,8 +46,10 @@ impl Command for GetVariableCommand { async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { let name = &self.0; - let key = UserAndRoom(&ctx.username, &ctx.room_id().as_str()); - let result = ctx.db.variables.get_user_variable(&key, name); + let result = ctx + .db + .get_user_variable(&ctx.username, ctx.room_id().as_str(), name) + .await; let value = match result { Ok(num) => format!("{} = {}", name, num), @@ -68,9 +73,10 @@ impl Command for SetVariableCommand { async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { let name = &self.0; let value = self.1; - let key = UserAndRoom(&ctx.username, ctx.room_id().as_str()); - ctx.db.variables.set_user_variable(&key, name, value)?; + ctx.db + .set_user_variable(&ctx.username, ctx.room_id().as_str(), name, value) + .await?; let content = format!("{} = {}", name, value); let html = format!("Set Variable: {}", content); @@ -88,8 +94,10 @@ impl Command for DeleteVariableCommand { async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { let name = &self.0; - let key = UserAndRoom(&ctx.username, ctx.room_id().as_str()); - let result = ctx.db.variables.delete_user_variable(&key, name); + let result = ctx + .db + .delete_user_variable(&ctx.username, ctx.room_id().as_str(), name) + .await; let value = match result { Ok(()) => format!("{} now unset", name), diff --git a/src/context.rs b/src/context.rs index e971b39..04f2b40 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,4 +1,4 @@ -use crate::db::Database; +use crate::db::sqlite::Database; use matrix_sdk::identifiers::RoomId; use matrix_sdk::room::Joined; use matrix_sdk::Client; diff --git a/src/cthulhu/dice.rs b/src/cthulhu/dice.rs index e13cf84..95a0469 100644 --- a/src/cthulhu/dice.rs +++ b/src/cthulhu/dice.rs @@ -1,7 +1,11 @@ +use crate::db::sqlite::Variables; use crate::error::{BotError, DiceRollingError}; use crate::parser::{Amount, Element}; use crate::{context::Context, db::variables::UserAndRoom}; use crate::{dice::calculate_single_die_amount, parser::DiceParsingError}; +use rand::rngs::StdRng; +use rand::Rng; +use rand::SeedableRng; use std::convert::TryFrom; use std::fmt; @@ -270,10 +274,11 @@ macro_rules! is_variable { }; } -///A version of DieRoller that uses a rand::Rng to roll numbers. -struct RngDieRoller(R); +/// A die roller than can have an RNG implementation injected, but +/// must be thread-safe. Required for the async dice rolling code. +struct RngDieRoller(R); -impl DieRoller for RngDieRoller { +impl DieRoller for RngDieRoller { fn roll(&mut self) -> u32 { self.0.gen_range(0..=9) } @@ -361,7 +366,7 @@ pub async fn regular_roll( let target = calculate_single_die_amount(&roll_with_ctx.0.amount, roll_with_ctx.1).await?; let target = u32::try_from(target).map_err(|_| DiceRollingError::InvalidAmount)?; - let mut roller = RngDieRoller(rand::thread_rng()); + let mut roller = RngDieRoller::(SeedableRng::from_entropy()); let rolled_dice = roll_regular_dice(&roll_with_ctx.0.modifier, target, &mut roller); Ok(ExecutedDiceRoll { @@ -371,11 +376,12 @@ pub async fn regular_roll( }) } -fn update_skill(ctx: &Context, variable: &str, value: u32) -> Result<(), BotError> { +async fn update_skill(ctx: &Context<'_>, variable: &str, value: u32) -> Result<(), BotError> { use std::convert::TryInto; let value: i32 = value.try_into()?; - let key = UserAndRoom(ctx.username, ctx.room_id().as_str()); - ctx.db.variables.set_user_variable(&key, variable, value)?; + ctx.db + .set_user_variable(&ctx.username, &ctx.room_id().as_str(), variable, value) + .await?; Ok(()) } @@ -397,12 +403,14 @@ pub async fn advancement_roll( return Err(DiceRollingError::InvalidAmount.into()); } - let mut roller = RngDieRoller(rand::thread_rng()); + let mut roller = RngDieRoller::(SeedableRng::from_entropy()); let roll = roll_advancement_dice(target, &mut roller); + drop(roller); + if roll.successful && is_variable!(existing_skill) { let variable_name: &str = extract_variable(existing_skill)?; - update_skill(roll_with_ctx.1, variable_name, roll.new_skill_amount())?; + update_skill(roll_with_ctx.1, variable_name, roll.new_skill_amount()).await?; } Ok(ExecutedAdvancementRoll { target, roll }) @@ -411,7 +419,7 @@ pub async fn advancement_roll( #[cfg(test)] mod tests { use super::*; - use crate::db::Database; + use crate::db::sqlite::Database; use crate::parser::{Amount, Element, Operator}; use url::Url; @@ -474,7 +482,7 @@ mod tests { assert!(matches!(result, Err(DiceParsingError::WrongElementType))); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn regular_roll_rejects_negative_numbers() { let roll = DiceRoll { amount: Amount { @@ -484,7 +492,15 @@ mod tests { modifier: DiceRollModifier::Normal, }; - let db = Database::new_temp().unwrap(); + let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); + crate::migrator::migrate(db_path.path().to_str().unwrap()) + .await + .unwrap(); + + let db = Database::new(db_path.path().to_str().unwrap()) + .await + .unwrap(); + let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { db: db, @@ -503,7 +519,7 @@ mod tests { )); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn advancement_roll_rejects_negative_numbers() { let roll = AdvancementRoll { existing_skill: Amount { @@ -512,7 +528,15 @@ mod tests { }, }; - let db = Database::new_temp().unwrap(); + let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); + crate::migrator::migrate(db_path.path().to_str().unwrap()) + .await + .unwrap(); + + let db = Database::new(db_path.path().to_str().unwrap()) + .await + .unwrap(); + let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { db: db, @@ -531,7 +555,7 @@ mod tests { )); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn advancement_roll_rejects_big_numbers() { let roll = AdvancementRoll { existing_skill: Amount { @@ -540,7 +564,15 @@ mod tests { }, }; - let db = Database::new_temp().unwrap(); + let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); + crate::migrator::migrate(db_path.path().to_str().unwrap()) + .await + .unwrap(); + + let db = Database::new(db_path.path().to_str().unwrap()) + .await + .unwrap(); + let homeserver = Url::parse("http://example.com").unwrap(); let ctx = Context { db: db, diff --git a/src/db/sqlite/mod.rs b/src/db/sqlite/mod.rs index ac49130..426e2eb 100644 --- a/src/db/sqlite/mod.rs +++ b/src/db/sqlite/mod.rs @@ -2,14 +2,43 @@ use async_trait::async_trait; use errors::DataError; use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; use sqlx::ConnectOptions; -use sqlx::Connection; -use std::collections::HashMap; -use std::path::Path; +use std::clone::Clone; +use std::collections::{HashMap, HashSet}; use std::str::FromStr; +use crate::models::RoomInfo; + pub mod errors; +pub mod rooms; +pub mod state; pub mod variables; +#[async_trait] +pub(crate) trait DbState { + async fn get_device_id(&self) -> Result, DataError>; + + async fn set_device_id(&self, device_id: &str) -> Result<(), DataError>; +} + +#[async_trait] +pub(crate) trait Rooms { + async fn should_process(&self, room_id: &str, event_id: &str) -> Result; + + async fn insert_room_info(&self, info: &RoomInfo) -> Result<(), DataError>; + + async fn get_room_info(&self, room_id: &str) -> Result, DataError>; + + async fn get_rooms_for_user(&self, user_id: &str) -> Result, DataError>; + + async fn get_users_in_room(&self, room_id: &str) -> Result, DataError>; + + async fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError>; + + async fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError>; + + async fn clear_info(&self, room_id: &str) -> Result<(), DataError>; +} + // TODO move this up to the top once we delete sled. Traits will be the // main API, then we can have different impls for different DBs. #[async_trait] @@ -52,7 +81,6 @@ pub struct Database { impl Database { fn new_db(conn: SqlitePool) -> Result { let database = Database { conn: conn.clone() }; - Ok(database) } @@ -78,3 +106,11 @@ impl Database { Self::new("sqlite::memory:").await } } + +impl Clone for Database { + fn clone(&self) -> Self { + Database { + conn: self.conn.clone(), + } + } +} diff --git a/src/db/sqlite/rooms.rs b/src/db/sqlite/rooms.rs new file mode 100644 index 0000000..797c0d3 --- /dev/null +++ b/src/db/sqlite/rooms.rs @@ -0,0 +1,43 @@ +use super::errors::DataError; +use super::{Database, Rooms}; +use crate::models::RoomInfo; +use async_trait::async_trait; +use std::collections::{HashMap, HashSet}; + +#[async_trait] +impl Rooms for Database { + async fn should_process(&self, room_id: &str, event_id: &str) -> Result { + Ok(true) + } + + async fn insert_room_info(&self, info: &RoomInfo) -> Result<(), DataError> { + Ok(()) + } + + async fn get_room_info(&self, room_id: &str) -> Result, DataError> { + Ok(Some(RoomInfo { + room_id: "".to_string(), + room_name: "".to_string(), + })) + } + + async fn get_rooms_for_user(&self, user_id: &str) -> Result, DataError> { + Ok(HashSet::new()) + } + + async fn get_users_in_room(&self, room_id: &str) -> Result, DataError> { + Ok(HashSet::new()) + } + + async fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError> { + Ok(()) + } + + async fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError> { + Ok(()) + } + + async fn clear_info(&self, room_id: &str) -> Result<(), DataError> { + Ok(()) + } +} diff --git a/src/db/sqlite/state.rs b/src/db/sqlite/state.rs new file mode 100644 index 0000000..74a2a3f --- /dev/null +++ b/src/db/sqlite/state.rs @@ -0,0 +1,14 @@ +use super::errors::DataError; +use super::{Database, DbState}; +use async_trait::async_trait; + +#[async_trait] +impl DbState for Database { + async fn get_device_id(&self) -> Result, DataError> { + Ok(None) + } + + async fn set_device_id(&self, device_id: &str) -> Result<(), DataError> { + Ok(()) + } +} diff --git a/src/db/sqlite/variables.rs b/src/db/sqlite/variables.rs index 6e30edd..0aa023d 100644 --- a/src/db/sqlite/variables.rs +++ b/src/db/sqlite/variables.rs @@ -58,7 +58,7 @@ impl Variables for Database { ) -> Result<(), DataError> { sqlx::query( "INSERT INTO user_variables - (user_id, room_id, variable_name, value) + (user_id, room_id, key, value) values (?, ?, ?, ?)", ) .bind(user) diff --git a/src/dice.rs b/src/dice.rs index 5b18867..897225d 100644 --- a/src/dice.rs +++ b/src/dice.rs @@ -1,4 +1,5 @@ use crate::context::Context; +use crate::db::sqlite::Variables; use crate::db::variables::UserAndRoom; use crate::error::BotError; use crate::error::DiceRollingError; @@ -22,8 +23,10 @@ pub async fn calculate_single_die_amount( /// it cannot find a variable defined, or if the database errors. pub async fn calculate_dice_amount(amounts: &[Amount], ctx: &Context<'_>) -> Result { let stream = stream::iter(amounts); - let key = UserAndRoom(&ctx.username, ctx.room_id().as_str()); - let variables = &ctx.db.variables.get_user_variables(&key)?; + let variables = &ctx + .db + .get_user_variables(&ctx.username, ctx.room_id().as_str()) + .await?; use DiceRollingError::VariableNotFound; let dice_amount: i32 = stream diff --git a/src/error.rs b/src/error.rs index ca7bd93..3efd786 100644 --- a/src/error.rs +++ b/src/error.rs @@ -21,6 +21,9 @@ pub enum BotError { #[error("database error: {0}")] DataError(#[from] DataError), + #[error("sqlite database error: {0}")] + SqliteDataError(#[from] crate::db::sqlite::errors::DataError), + #[error("the message should not be processed because it failed validation")] ShouldNotProcessError, diff --git a/src/logic.rs b/src/logic.rs index 3b03190..8375754 100644 --- a/src/logic.rs +++ b/src/logic.rs @@ -1,12 +1,14 @@ -use crate::db::errors::DataError; +use crate::db::sqlite::errors::DataError; +use crate::db::sqlite::Rooms; use crate::matrix; use crate::models::RoomInfo; +use futures::stream::{self, StreamExt, TryStreamExt}; use matrix_sdk::{self, identifiers::RoomId, Client}; /// Record the information about a room, including users in it. pub async fn record_room_information( client: &Client, - db: &crate::db::Database, + db: &crate::db::sqlite::Database, room_id: &RoomId, room_display_name: &str, our_username: &str, @@ -21,11 +23,19 @@ pub async fn record_room_information( // TODO this and the username adding should be one whole // transaction in the db. - db.rooms.insert_room_info(&info)?; + db.insert_room_info(&info).await?; - usernames + let filtered_usernames = usernames .into_iter() - .filter(|username| username != our_username) - .map(|username| db.rooms.add_user_to_room(&username, room_id_str)) - .collect() //Make use of collect impl on Result. + .filter(|username| username != our_username); + + // Async collect into vec of results, then use into_iter of result + // to go to from Result> to just Result<()>. Easier than + // attempting to async-collect our way to a single Result<()>. + stream::iter(filtered_usernames) + .then(|username| async move { db.add_user_to_room(&username, &room_id_str).await }) + .collect::>>() + .await + .into_iter() + .collect() }