Convert to SQLx and SQLite #64
|
@ -0,0 +1,2 @@
|
|||
DATABASE_URL="sqlite://test-db/dicebot.sqlite"
|
||||
SQLX_OFFLINE="true"
|
|
@ -10,3 +10,5 @@ bot-db*
|
|||
# We store a disabled async test in this file
|
||||
bigboy
|
||||
.#*
|
||||
*.sqlite
|
||||
.tmp*
|
||||
|
|
File diff suppressed because it is too large
Load Diff
13
Cargo.toml
13
Cargo.toml
|
@ -10,9 +10,13 @@ repository = 'https://git.agnos.is/projectmoon/matrix-dicebot'
|
|||
keywords = ["games", "dice", "matrix", "bot"]
|
||||
categories = ["games"]
|
||||
|
||||
[[bin]]
|
||||
name = "dicebot-migrate"
|
||||
path = "src/migrate_cli.rs"
|
||||
|
||||
[dependencies]
|
||||
log = "0.4"
|
||||
env_logger = "0.8"
|
||||
tracing-subscriber = "0.2"
|
||||
toml = "0.5"
|
||||
nom = "5"
|
||||
rand = "0.8"
|
||||
|
@ -32,6 +36,13 @@ bincode = "1.3"
|
|||
html2text = "0.2"
|
||||
phf = { version = "0.8", features = ["macros"] }
|
||||
matrix-sdk = { git = "https://github.com/matrix-org/matrix-rust-sdk", branch = "master" }
|
||||
refinery = { version = "0.5", features = ["rusqlite"]}
|
||||
barrel = { version = "0.6", features = ["sqlite3"] }
|
||||
tempfile = "3"
|
||||
|
||||
[dependencies.sqlx]
|
||||
version = "0.5"
|
||||
features = [ "offline", "sqlite", "runtime-tokio-native-tls" ]
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1"
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
{
|
||||
"db": "SQLite",
|
||||
"19d89370cac05c1bc4de0eb3508712da9ca133b1cf9445b5407d238f89c3ab0c": {
|
||||
"query": "SELECT device_id FROM bot_state limit 1",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "device_id",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 0
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"59313c67900a1a9399389720b522e572f181ae503559cd2b49d6305acb9e2207": {
|
||||
"query": "SELECT key, value as \"value: i32\" FROM user_variables\n WHERE room_id = ? AND user_id = ?",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "key",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "value: i32",
|
||||
"ordinal": 1,
|
||||
"type_info": "Int64"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 2
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"636b1b868eaf04cd234fbf17747d94a66e81f7bc1b060ba14151dbfaf40eeefc": {
|
||||
"query": "SELECT value as \"value: i32\" FROM user_variables\n WHERE user_id = ? AND room_id = ? AND key = ?",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "value: i32",
|
||||
"ordinal": 0,
|
||||
"type_info": "Int64"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 3
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"711d222911c1258365a6a0de1fe00eeec4686fd3589e976e225ad599e7cfc75d": {
|
||||
"query": "SELECT count(*) as \"count: i32\" FROM user_variables\n WHERE room_id = ? and user_id = ?",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "count: i32",
|
||||
"ordinal": 0,
|
||||
"type_info": "Int"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 2
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"7248c8ae30bbe4bc5866e80cc277312c7f8cb9af5a8801fd8eaf178fd99eae18": {
|
||||
"query": "SELECT room_id FROM room_users\n WHERE username = ?",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "room_id",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 1
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"97f5d58f62baca51efd8c295ca6737d1240923c69c973621cd0a718ac9eed99f": {
|
||||
"query": "SELECT room_id, room_name FROM room_info\n WHERE room_id = ?",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "room_id",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "room_name",
|
||||
"ordinal": 1,
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 1
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"b302d586e5ac4c72c2970361ea5a5936c0b8c6dad10033c626a0ce0404cadb25": {
|
||||
"query": "SELECT username FROM room_users\n WHERE room_id = ?",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "username",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 1
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"bba0fc255e7c30d1d2d9468c68ba38db6e8a13be035aa1152933ba9247b14f8c": {
|
||||
"query": "SELECT event_id FROM room_events\n WHERE room_id = ? AND event_id = ?",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "event_id",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 2
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,7 +2,7 @@ use matrix_sdk::identifiers::room_id;
|
|||
use tenebrous_dicebot::commands;
|
||||
use tenebrous_dicebot::commands::ResponseExtractor;
|
||||
use tenebrous_dicebot::context::{Context, RoomContext};
|
||||
use tenebrous_dicebot::db::Database;
|
||||
use tenebrous_dicebot::db::sqlite::Database;
|
||||
use tenebrous_dicebot::error::BotError;
|
||||
use url::Url;
|
||||
|
||||
|
@ -15,9 +15,17 @@ async fn main() -> Result<(), BotError> {
|
|||
};
|
||||
|
||||
let homeserver = Url::parse("http://example.com")?;
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
let db = Database::new(
|
||||
db_path
|
||||
.path()
|
||||
.to_str()
|
||||
.expect("Could not get path to temporary db"),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let context = Context {
|
||||
db: Database::new_temp()?,
|
||||
db: db,
|
||||
matrix_client: &matrix_sdk::Client::new(homeserver)
|
||||
.expect("Could not create matrix client"),
|
||||
room: RoomContext {
|
||||
|
|
|
@ -1,20 +1,25 @@
|
|||
//Needed for nested Result handling from tokio. Probably can go away after 1.47.0.
|
||||
#![type_length_limit = "7605144"]
|
||||
use env_logger::Env;
|
||||
use log::error;
|
||||
use std::env;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tenebrous_dicebot::bot::DiceBot;
|
||||
use tenebrous_dicebot::config::*;
|
||||
use tenebrous_dicebot::db::Database;
|
||||
use tenebrous_dicebot::db::sqlite::Database;
|
||||
use tenebrous_dicebot::error::BotError;
|
||||
use tenebrous_dicebot::state::DiceBotState;
|
||||
use tracing_subscriber::filter::EnvFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
env_logger::Builder::from_env(
|
||||
Env::default().default_filter_or("tenebrous_dicebot=info,dicebot=info"),
|
||||
)
|
||||
.init();
|
||||
let filter = if env::var("RUST_LOG").is_ok() {
|
||||
EnvFilter::from_default_env()
|
||||
} else {
|
||||
EnvFilter::new("tenebrous_dicebot=info,dicebot=info,refinery=info")
|
||||
};
|
||||
|
||||
tracing_subscriber::fmt().with_env_filter(filter).init();
|
||||
|
||||
match run().await {
|
||||
Ok(_) => (),
|
||||
Err(e) => error!("Error: {}", e),
|
||||
|
@ -28,11 +33,10 @@ async fn run() -> Result<(), BotError> {
|
|||
.expect("Need a config as an argument");
|
||||
|
||||
let cfg = Arc::new(read_config(config_path)?);
|
||||
let db = Database::new(&cfg.database_path())?;
|
||||
let sqlite_path = format!("{}/dicebot.sqlite", cfg.database_path());
|
||||
let db = Database::new(&sqlite_path).await?;
|
||||
let state = Arc::new(RwLock::new(DiceBotState::new(&cfg)));
|
||||
|
||||
db.migrate(cfg.migration_version())?;
|
||||
|
||||
match DiceBot::new(&cfg, &state, &db) {
|
||||
Ok(bot) => bot.run().await?,
|
||||
Err(e) => println!("Error connecting: {:?}", e),
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
use tenebrous_dicebot::db::sqlite::{Database as SqliteDatabase, Variables};
|
||||
use tenebrous_dicebot::db::Database;
|
||||
use tenebrous_dicebot::error::BotError;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), BotError> {
|
||||
let sled_path = std::env::args()
|
||||
.skip(1)
|
||||
.next()
|
||||
.expect("Need a path to a Sled database as an arument.");
|
||||
|
||||
let sqlite_path = std::env::args()
|
||||
.skip(2)
|
||||
.next()
|
||||
.expect("Need a path to an sqlite database as an arument.");
|
||||
|
||||
let db = Database::new(&sled_path)?;
|
||||
|
||||
let all_variables = db.variables.get_all_variables()?;
|
||||
|
||||
let sql_db = SqliteDatabase::new(&sqlite_path).await?;
|
||||
|
||||
for var in all_variables {
|
||||
if let ((username, room_id, variable_name), value) = var {
|
||||
println!(
|
||||
"Migrating {}::{}::{} = {} to sql",
|
||||
username, room_id, variable_name, value
|
||||
);
|
||||
|
||||
sql_db
|
||||
.set_user_variable(&username, &room_id, &variable_name, value)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
35
src/bot.rs
35
src/bot.rs
|
@ -1,13 +1,14 @@
|
|||
use crate::commands::{execute_command, ExecutionError, ExecutionResult, ResponseExtractor};
|
||||
use crate::config::*;
|
||||
use crate::context::{Context, RoomContext};
|
||||
use crate::db::Database;
|
||||
use crate::db::sqlite::Database;
|
||||
use crate::db::sqlite::DbState;
|
||||
use crate::error::BotError;
|
||||
use crate::matrix;
|
||||
use crate::state::DiceBotState;
|
||||
use dirs;
|
||||
use futures::stream::{self, StreamExt};
|
||||
use log::info;
|
||||
use log::{error, info};
|
||||
use matrix_sdk::{self, identifiers::EventId, room::Joined, Client, ClientConfig, SyncSettings};
|
||||
use std::clone::Clone;
|
||||
use std::path::PathBuf;
|
||||
|
@ -61,6 +62,13 @@ async fn handle_single_result(
|
|||
room: &Joined,
|
||||
event_id: EventId,
|
||||
) {
|
||||
if cmd_result.is_err() {
|
||||
error!(
|
||||
"Command execution error: {}",
|
||||
cmd_result.as_ref().err().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
let html = cmd_result.message_html(respond_to);
|
||||
matrix::send_message(client, room.room_id(), &html, Some(event_id)).await;
|
||||
}
|
||||
|
@ -86,6 +94,10 @@ async fn handle_multiple_results(
|
|||
})
|
||||
.collect();
|
||||
|
||||
for result in errors.iter() {
|
||||
error!("Command execution error: '{}' - {}", result.0, result.1);
|
||||
}
|
||||
|
||||
let message = if errors.len() == 0 {
|
||||
format!("{}: Executed {} commands", respond_to, results.len())
|
||||
} else {
|
||||
|
@ -134,7 +146,7 @@ impl DiceBot {
|
|||
|
||||
// Pull device ID from database, if it exists. Then write it
|
||||
// to DB if the library generated one for us.
|
||||
let device_id: Option<String> = self.db.state.get_device_id()?;
|
||||
let device_id: Option<String> = self.db.get_device_id().await?;
|
||||
let device_id: Option<&str> = device_id.as_deref();
|
||||
|
||||
client
|
||||
|
@ -143,7 +155,7 @@ impl DiceBot {
|
|||
|
||||
if device_id.is_none() {
|
||||
let device_id = client.device_id().await.ok_or(BotError::NoDeviceIdFound)?;
|
||||
self.db.state.set_device_id(device_id.as_str())?;
|
||||
self.db.set_device_id(device_id.as_str()).await?;
|
||||
info!("Recorded new device ID: {}", device_id.as_str());
|
||||
} else {
|
||||
info!("Using existing device ID: {}", device_id.unwrap());
|
||||
|
@ -160,25 +172,12 @@ impl DiceBot {
|
|||
let client = self.client.clone();
|
||||
self.login(&client).await?;
|
||||
|
||||
// Initial sync without event handler prevents responding to
|
||||
// messages received while bot was offline. TODO: selectively
|
||||
// respond to old messages? e.g. comands missed while offline.
|
||||
info!("Performing intial sync (no commands will be responded to)");
|
||||
self.client.sync_once(SyncSettings::default()).await?;
|
||||
|
||||
client.set_event_handler(Box::new(self)).await;
|
||||
info!("Listening for commands");
|
||||
|
||||
let token = client
|
||||
.sync_token()
|
||||
.await
|
||||
.ok_or(BotError::SyncTokenRequired)?;
|
||||
|
||||
let settings = SyncSettings::default().token(token);
|
||||
|
||||
// TODO replace with sync_with_callback for cleaner shutdown
|
||||
// process.
|
||||
client.sync(settings).await;
|
||||
client.sync(SyncSettings::default()).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
* SDK example code.
|
||||
*/
|
||||
use super::DiceBot;
|
||||
use crate::db::Database;
|
||||
use crate::db::sqlite::Database;
|
||||
use crate::db::sqlite::Rooms;
|
||||
use crate::error::BotError;
|
||||
use crate::logic::record_room_information;
|
||||
use async_trait::async_trait;
|
||||
|
@ -19,9 +20,9 @@ use matrix_sdk::{
|
|||
room::Room,
|
||||
EventHandler,
|
||||
};
|
||||
use std::clone::Clone;
|
||||
use std::ops::Sub;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::{clone::Clone, time::UNIX_EPOCH};
|
||||
|
||||
/// Check if a message is recent enough to actually process. If the
|
||||
/// message is within "oldest_message_age" seconds, this function
|
||||
|
@ -31,7 +32,11 @@ fn check_message_age(
|
|||
event: &SyncMessageEvent<MessageEventContent>,
|
||||
oldest_message_age: u64,
|
||||
) -> bool {
|
||||
let sending_time = event.origin_server_ts;
|
||||
let sending_time = event
|
||||
.origin_server_ts
|
||||
.to_system_time()
|
||||
.unwrap_or(UNIX_EPOCH);
|
||||
|
||||
let oldest_timestamp = SystemTime::now().sub(Duration::from_secs(oldest_message_age));
|
||||
|
||||
if sending_time > oldest_timestamp {
|
||||
|
@ -90,9 +95,9 @@ async fn should_process_message<'a>(
|
|||
Ok((msg_body, sender_username))
|
||||
}
|
||||
|
||||
fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool {
|
||||
db.rooms
|
||||
.should_process(room_id, event_id)
|
||||
async fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool {
|
||||
db.should_process(room_id, event_id)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"Database error when checking if we should process an event: {}",
|
||||
|
@ -116,7 +121,7 @@ impl EventHandler for DiceBot {
|
|||
let room_id_str = room_id.as_str();
|
||||
let username = &event.state_key;
|
||||
|
||||
if !should_process_event(&self.db, room_id_str, event.event_id.as_str()) {
|
||||
if !should_process_event(&self.db, room_id_str, event.event_id.as_str()).await {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -126,17 +131,20 @@ impl EventHandler for DiceBot {
|
|||
false
|
||||
};
|
||||
|
||||
// user_joing is true if a user is joining this room, and
|
||||
// false if they have left for some reason. This user may be
|
||||
// us, or another user in the room.
|
||||
use MembershipChange::*;
|
||||
let adding_user = match event.membership_change() {
|
||||
let user_joining = match event.membership_change() {
|
||||
Joined => true,
|
||||
Banned | Left | Kicked | KickedAndBanned => false,
|
||||
_ => return,
|
||||
};
|
||||
|
||||
let result = if event_affects_us && !adding_user {
|
||||
let result = if event_affects_us && !user_joining {
|
||||
info!("Clearing all information for room ID {}", room_id);
|
||||
self.db.rooms.clear_info(room_id_str)
|
||||
} else if event_affects_us && adding_user {
|
||||
self.db.clear_info(room_id_str).await.map_err(|e| e.into())
|
||||
} else if event_affects_us && user_joining {
|
||||
info!("Joined room {}; recording room information", room_id);
|
||||
record_room_information(
|
||||
&self.client,
|
||||
|
@ -146,12 +154,18 @@ impl EventHandler for DiceBot {
|
|||
&event.state_key,
|
||||
)
|
||||
.await
|
||||
} else if !event_affects_us && adding_user {
|
||||
} else if !event_affects_us && user_joining {
|
||||
info!("Adding user {} to room ID {}", username, room_id);
|
||||
self.db.rooms.add_user_to_room(username, room_id_str)
|
||||
} else if !event_affects_us && !adding_user {
|
||||
self.db
|
||||
.add_user_to_room(username, room_id_str)
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
} else if !event_affects_us && !user_joining {
|
||||
info!("Removing user {} from room ID {}", username, room_id);
|
||||
self.db.rooms.remove_user_from_room(username, room_id_str)
|
||||
self.db
|
||||
.remove_user_from_room(username, room_id_str)
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
} else {
|
||||
debug!("Ignoring a room member event: {:#?}", event);
|
||||
Ok(())
|
||||
|
@ -196,7 +210,7 @@ impl EventHandler for DiceBot {
|
|||
};
|
||||
|
||||
let room_id = room.room_id().as_str();
|
||||
if !should_process_event(&self.db, room_id, event.event_id.as_str()) {
|
||||
if !should_process_event(&self.db, room_id, event.event_id.as_str()).await {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -325,7 +325,8 @@ pub async fn roll_pool(pool: &DicePoolWithContext<'_>) -> Result<RolledDicePool,
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::db::Database;
|
||||
use crate::db::sqlite::Database;
|
||||
use crate::db::sqlite::Variables;
|
||||
use url::Url;
|
||||
|
||||
macro_rules! dummy_room {
|
||||
|
@ -463,8 +464,8 @@ mod tests {
|
|||
assert_eq!(vec![10], roll);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn number_of_dice_equality_test() {
|
||||
#[test]
|
||||
fn number_of_dice_equality_test() {
|
||||
let num_dice = 5;
|
||||
let rolls = vec![1, 2, 3, 4, 5];
|
||||
let pool = DicePool::easy_pool(5, DicePoolQuality::TenAgain);
|
||||
|
@ -472,10 +473,14 @@ mod tests {
|
|||
assert_eq!(5, rolled_pool.num_dice);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn rejects_large_expression_test() {
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
let homeserver = Url::parse("http://example.com").unwrap();
|
||||
let db = Database::new_temp().unwrap();
|
||||
let db = Database::new(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ctx = Context {
|
||||
db: db,
|
||||
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
|
||||
|
@ -504,9 +509,17 @@ mod tests {
|
|||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn converts_to_chance_die_test() {
|
||||
let db = Database::new_temp().unwrap();
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let db = Database::new(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let homeserver = Url::parse("http://example.com").unwrap();
|
||||
let ctx = Context {
|
||||
db: db,
|
||||
|
@ -533,11 +546,17 @@ mod tests {
|
|||
assert_eq!(1, roll.num_dice);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn can_resolve_variables_test() {
|
||||
use crate::db::variables::UserAndRoom;
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let db = Database::new(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let db = Database::new_temp().unwrap();
|
||||
let homeserver = Url::parse("http://example.com").unwrap();
|
||||
let ctx = Context {
|
||||
db: db.clone(),
|
||||
|
@ -547,10 +566,8 @@ mod tests {
|
|||
message_body: "message",
|
||||
};
|
||||
|
||||
let user_and_room = UserAndRoom(&ctx.username, &ctx.room.id.as_str());
|
||||
|
||||
db.variables
|
||||
.set_user_variable(&user_and_room, "myvariable", 10)
|
||||
db.set_user_variable(&ctx.username, &ctx.room.id.as_str(), "myvariable", 10)
|
||||
.await
|
||||
.expect("could not set myvariable to 10");
|
||||
|
||||
let amounts = vec![Amount {
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate::context::Context;
|
|||
use crate::error::BotError;
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
use BotError::{DataError, SqliteDataError};
|
||||
|
||||
pub mod basic_rolling;
|
||||
pub mod cofd;
|
||||
|
@ -50,7 +51,13 @@ pub struct ExecutionError(#[from] pub BotError);
|
|||
|
||||
impl From<crate::db::errors::DataError> for ExecutionError {
|
||||
fn from(error: crate::db::errors::DataError) -> Self {
|
||||
Self(BotError::DataError(error))
|
||||
Self(DataError(error))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::db::sqlite::errors::DataError> for ExecutionError {
|
||||
fn from(error: crate::db::sqlite::errors::DataError) -> Self {
|
||||
Self(SqliteDataError(error))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,10 +136,15 @@ mod tests {
|
|||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn unrecognized_command() {
|
||||
let db = crate::db::Database::new_temp().unwrap();
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
let db = crate::db::sqlite::Database::new(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let homeserver = Url::parse("http://example.com").unwrap();
|
||||
|
||||
let ctx = Context {
|
||||
db: db,
|
||||
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use super::{Command, Execution, ExecutionResult};
|
||||
use crate::context::Context;
|
||||
use crate::db::errors::DataError;
|
||||
use crate::db::sqlite::errors::DataError;
|
||||
use crate::db::sqlite::Variables;
|
||||
use crate::db::variables::UserAndRoom;
|
||||
use async_trait::async_trait;
|
||||
|
||||
|
@ -13,8 +14,10 @@ impl Command for GetAllVariablesCommand {
|
|||
}
|
||||
|
||||
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
|
||||
let key = UserAndRoom(&ctx.username, &ctx.room_id().as_str());
|
||||
let variables = ctx.db.variables.get_user_variables(&key)?;
|
||||
let variables = ctx
|
||||
.db
|
||||
.get_user_variables(&ctx.username, ctx.room_id().as_str())
|
||||
.await?;
|
||||
|
||||
let mut variable_list: Vec<String> = variables
|
||||
.into_iter()
|
||||
|
@ -43,8 +46,10 @@ impl Command for GetVariableCommand {
|
|||
|
||||
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
|
||||
let name = &self.0;
|
||||
let key = UserAndRoom(&ctx.username, &ctx.room_id().as_str());
|
||||
let result = ctx.db.variables.get_user_variable(&key, name);
|
||||
let result = ctx
|
||||
.db
|
||||
.get_user_variable(&ctx.username, ctx.room_id().as_str(), name)
|
||||
.await;
|
||||
|
||||
let value = match result {
|
||||
Ok(num) => format!("{} = {}", name, num),
|
||||
|
@ -68,9 +73,10 @@ impl Command for SetVariableCommand {
|
|||
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
|
||||
let name = &self.0;
|
||||
let value = self.1;
|
||||
let key = UserAndRoom(&ctx.username, ctx.room_id().as_str());
|
||||
|
||||
ctx.db.variables.set_user_variable(&key, name, value)?;
|
||||
ctx.db
|
||||
.set_user_variable(&ctx.username, ctx.room_id().as_str(), name, value)
|
||||
.await?;
|
||||
|
||||
let content = format!("{} = {}", name, value);
|
||||
let html = format!("<strong>Set Variable:</strong> {}", content);
|
||||
|
@ -88,8 +94,10 @@ impl Command for DeleteVariableCommand {
|
|||
|
||||
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
|
||||
let name = &self.0;
|
||||
let key = UserAndRoom(&ctx.username, ctx.room_id().as_str());
|
||||
let result = ctx.db.variables.delete_user_variable(&key, name);
|
||||
let result = ctx
|
||||
.db
|
||||
.delete_user_variable(&ctx.username, ctx.room_id().as_str(), name)
|
||||
.await;
|
||||
|
||||
let value = match result {
|
||||
Ok(()) => format!("{} now unset", name),
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::db::Database;
|
||||
use crate::db::sqlite::Database;
|
||||
use matrix_sdk::identifiers::RoomId;
|
||||
use matrix_sdk::room::Joined;
|
||||
use matrix_sdk::Client;
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
use crate::db::sqlite::Variables;
|
||||
use crate::error::{BotError, DiceRollingError};
|
||||
use crate::parser::{Amount, Element};
|
||||
use crate::{context::Context, db::variables::UserAndRoom};
|
||||
use crate::{dice::calculate_single_die_amount, parser::DiceParsingError};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use rand::SeedableRng;
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt;
|
||||
|
||||
|
@ -270,10 +274,11 @@ macro_rules! is_variable {
|
|||
};
|
||||
}
|
||||
|
||||
///A version of DieRoller that uses a rand::Rng to roll numbers.
|
||||
struct RngDieRoller<R: rand::Rng>(R);
|
||||
/// A die roller than can have an RNG implementation injected, but
|
||||
/// must be thread-safe. Required for the async dice rolling code.
|
||||
struct RngDieRoller<R: Rng + ?Sized + Send>(R);
|
||||
|
||||
impl<R: rand::Rng> DieRoller for RngDieRoller<R> {
|
||||
impl<R: Rng + ?Sized + Send> DieRoller for RngDieRoller<R> {
|
||||
fn roll(&mut self) -> u32 {
|
||||
self.0.gen_range(0..=9)
|
||||
}
|
||||
|
@ -361,7 +366,7 @@ pub async fn regular_roll(
|
|||
let target = calculate_single_die_amount(&roll_with_ctx.0.amount, roll_with_ctx.1).await?;
|
||||
let target = u32::try_from(target).map_err(|_| DiceRollingError::InvalidAmount)?;
|
||||
|
||||
let mut roller = RngDieRoller(rand::thread_rng());
|
||||
let mut roller = RngDieRoller::<StdRng>(SeedableRng::from_entropy());
|
||||
let rolled_dice = roll_regular_dice(&roll_with_ctx.0.modifier, target, &mut roller);
|
||||
|
||||
Ok(ExecutedDiceRoll {
|
||||
|
@ -371,11 +376,12 @@ pub async fn regular_roll(
|
|||
})
|
||||
}
|
||||
|
||||
fn update_skill(ctx: &Context, variable: &str, value: u32) -> Result<(), BotError> {
|
||||
async fn update_skill(ctx: &Context<'_>, variable: &str, value: u32) -> Result<(), BotError> {
|
||||
use std::convert::TryInto;
|
||||
let value: i32 = value.try_into()?;
|
||||
let key = UserAndRoom(ctx.username, ctx.room_id().as_str());
|
||||
ctx.db.variables.set_user_variable(&key, variable, value)?;
|
||||
ctx.db
|
||||
.set_user_variable(&ctx.username, &ctx.room_id().as_str(), variable, value)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -397,12 +403,14 @@ pub async fn advancement_roll(
|
|||
return Err(DiceRollingError::InvalidAmount.into());
|
||||
}
|
||||
|
||||
let mut roller = RngDieRoller(rand::thread_rng());
|
||||
let mut roller = RngDieRoller::<StdRng>(SeedableRng::from_entropy());
|
||||
let roll = roll_advancement_dice(target, &mut roller);
|
||||
|
||||
drop(roller);
|
||||
|
||||
if roll.successful && is_variable!(existing_skill) {
|
||||
let variable_name: &str = extract_variable(existing_skill)?;
|
||||
update_skill(roll_with_ctx.1, variable_name, roll.new_skill_amount())?;
|
||||
update_skill(roll_with_ctx.1, variable_name, roll.new_skill_amount()).await?;
|
||||
}
|
||||
|
||||
Ok(ExecutedAdvancementRoll { target, roll })
|
||||
|
@ -411,7 +419,7 @@ pub async fn advancement_roll(
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::db::Database;
|
||||
use crate::db::sqlite::Database;
|
||||
use crate::parser::{Amount, Element, Operator};
|
||||
use url::Url;
|
||||
|
||||
|
@ -474,7 +482,7 @@ mod tests {
|
|||
assert!(matches!(result, Err(DiceParsingError::WrongElementType)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn regular_roll_rejects_negative_numbers() {
|
||||
let roll = DiceRoll {
|
||||
amount: Amount {
|
||||
|
@ -484,7 +492,15 @@ mod tests {
|
|||
modifier: DiceRollModifier::Normal,
|
||||
};
|
||||
|
||||
let db = Database::new_temp().unwrap();
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let db = Database::new(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let homeserver = Url::parse("http://example.com").unwrap();
|
||||
let ctx = Context {
|
||||
db: db,
|
||||
|
@ -503,7 +519,7 @@ mod tests {
|
|||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn advancement_roll_rejects_negative_numbers() {
|
||||
let roll = AdvancementRoll {
|
||||
existing_skill: Amount {
|
||||
|
@ -512,7 +528,15 @@ mod tests {
|
|||
},
|
||||
};
|
||||
|
||||
let db = Database::new_temp().unwrap();
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let db = Database::new(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let homeserver = Url::parse("http://example.com").unwrap();
|
||||
let ctx = Context {
|
||||
db: db,
|
||||
|
@ -531,7 +555,7 @@ mod tests {
|
|||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn advancement_roll_rejects_big_numbers() {
|
||||
let roll = AdvancementRoll {
|
||||
existing_skill: Amount {
|
||||
|
@ -540,7 +564,15 @@ mod tests {
|
|||
},
|
||||
};
|
||||
|
||||
let db = Database::new_temp().unwrap();
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let db = Database::new(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let homeserver = Url::parse("http://example.com").unwrap();
|
||||
let ctx = Context {
|
||||
db: db,
|
||||
|
|
|
@ -12,13 +12,14 @@ pub mod errors;
|
|||
pub mod migrations;
|
||||
pub mod rooms;
|
||||
pub mod schema;
|
||||
pub mod sqlite;
|
||||
pub mod state;
|
||||
pub mod variables;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Database {
|
||||
db: Db,
|
||||
pub(crate) variables: Variables,
|
||||
pub variables: Variables,
|
||||
pub(crate) migrations: Migrations,
|
||||
pub(crate) rooms: Rooms,
|
||||
pub(crate) state: DbState,
|
||||
|
|
|
@ -26,6 +26,9 @@ pub enum DataError {
|
|||
#[error("expected i32, but i32 schema was violated")]
|
||||
I32SchemaViolation,
|
||||
|
||||
#[error("parse error")]
|
||||
ParseError(#[from] std::num::ParseIntError),
|
||||
|
||||
#[error("unexpected or corruptd data bytes")]
|
||||
InvalidValue,
|
||||
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
use std::num::TryFromIntError;
|
||||
|
||||
use sled::transaction::{TransactionError, UnabortableTransactionError};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum DataError {
|
||||
#[error("value does not exist for key: {0}")]
|
||||
KeyDoesNotExist(String),
|
||||
|
||||
#[error("too many entries")]
|
||||
TooManyEntries,
|
||||
|
||||
#[error("expected i32, but i32 schema was violated")]
|
||||
I32SchemaViolation,
|
||||
|
||||
#[error("unexpected or corruptd data bytes")]
|
||||
InvalidValue,
|
||||
|
||||
#[error("expected string ref, but utf8 schema was violated: {0}")]
|
||||
Utf8RefSchemaViolation(#[from] std::str::Utf8Error),
|
||||
|
||||
#[error("expected string, but utf8 schema was violated: {0}")]
|
||||
Utf8SchemaViolation(#[from] std::string::FromUtf8Error),
|
||||
|
||||
#[error("internal database error: {0}")]
|
||||
InternalError(#[from] sled::Error),
|
||||
|
||||
#[error("transaction error: {0}")]
|
||||
TransactionError(#[from] sled::transaction::TransactionError),
|
||||
|
||||
#[error("unabortable transaction error: {0}")]
|
||||
UnabortableTransactionError(#[from] UnabortableTransactionError),
|
||||
|
||||
#[error("data migration error: {0}")]
|
||||
MigrationError(#[from] super::migrator::MigrationError),
|
||||
|
||||
#[error("deserialization error: {0}")]
|
||||
DeserializationError(#[from] bincode::Error),
|
||||
|
||||
#[error("sqlx error: {0}")]
|
||||
SqlxError(#[from] sqlx::Error),
|
||||
|
||||
#[error("numeric conversion error")]
|
||||
NumericConversionError(#[from] TryFromIntError),
|
||||
}
|
||||
|
||||
/// This From implementation is necessary to deal with the recursive
|
||||
/// error type in the error enum. We defined a transaction error, but
|
||||
/// the only place we use it is when converting from
|
||||
/// sled::transaction::TransactionError<DataError>. This converter
|
||||
/// extracts the inner data error from transaction aborted errors, and
|
||||
/// forwards anything else onward as-is, but wrapped in DataError.
|
||||
impl From<TransactionError<DataError>> for DataError {
|
||||
fn from(error: TransactionError<DataError>) -> Self {
|
||||
match error {
|
||||
TransactionError::Abort(data_err) => data_err,
|
||||
TransactionError::Storage(storage_err) => {
|
||||
DataError::TransactionError(TransactionError::Storage(storage_err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Automatically aborts transactions that hit a DataError by using
|
||||
/// the try (question mark) operator when this trait implementation is
|
||||
/// in scope.
|
||||
impl From<DataError> for sled::transaction::ConflictableTransactionError<DataError> {
|
||||
fn from(error: DataError) -> Self {
|
||||
sled::transaction::ConflictableTransactionError::Abort(error)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
use barrel::backend::Sqlite;
|
||||
use barrel::{types, Migration};
|
||||
|
||||
pub fn migration() -> String {
|
||||
let mut m = Migration::new();
|
||||
|
||||
m.create_table("user_variables", |t| {
|
||||
t.add_column("room_id", types::text());
|
||||
t.add_column("user_id", types::text());
|
||||
t.add_column("key", types::text());
|
||||
t.add_column("value", types::integer());
|
||||
});
|
||||
|
||||
m.make::<Sqlite>()
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
use barrel::backend::Sqlite;
|
||||
use barrel::{types, types::Type, Migration};
|
||||
|
||||
fn primary_uuid() -> Type {
|
||||
types::text().unique(true).primary(true).nullable(false)
|
||||
}
|
||||
|
||||
pub fn migration() -> String {
|
||||
let mut m = Migration::new();
|
||||
|
||||
//Table for basic room information: room ID, room name
|
||||
m.create_table("room_info", move |t| {
|
||||
t.add_column("room_id", primary_uuid());
|
||||
t.add_column("room_name", types::text());
|
||||
});
|
||||
|
||||
m.make::<Sqlite>()
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
use barrel::backend::Sqlite;
|
||||
use barrel::{types, Migration};
|
||||
|
||||
pub fn migration() -> String {
|
||||
let mut m = Migration::new();
|
||||
|
||||
//Basic state table with only device ID for now. Uses only one row.
|
||||
m.create_table("bot_state", move |t| {
|
||||
t.add_column("device_id", types::text());
|
||||
});
|
||||
|
||||
m.make::<Sqlite>()
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
use barrel::backend::Sqlite;
|
||||
use barrel::{types, types::Type, Migration};
|
||||
pub fn migration() -> String {
|
||||
let mut m = Migration::new();
|
||||
|
||||
//Table of room ID, event ID, event timestamp
|
||||
m.create_table("room_events", move |t| {
|
||||
t.add_column("room_id", types::text().nullable(false));
|
||||
t.add_column("event_id", types::text().nullable(false));
|
||||
t.add_column("event_timestamp", types::integer());
|
||||
});
|
||||
|
||||
let mut res = m.make::<Sqlite>();
|
||||
|
||||
//This is a hack that gives us a composite primary key.
|
||||
if res.ends_with(");") {
|
||||
res.pop();
|
||||
res.pop();
|
||||
}
|
||||
|
||||
format!("{}, PRIMARY KEY (room_id, event_id));", res)
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
use barrel::backend::Sqlite;
|
||||
use barrel::{types, types::Type, Migration};
|
||||
|
||||
pub fn migration() -> String {
|
||||
let mut m = Migration::new();
|
||||
|
||||
//Table of users in rooms.
|
||||
m.create_table("room_users", move |t| {
|
||||
t.add_column("room_id", types::text());
|
||||
t.add_column("username", types::text());
|
||||
});
|
||||
|
||||
let mut res = m.make::<Sqlite>();
|
||||
|
||||
//This is a hack that gives us a composite primary key.
|
||||
if res.ends_with(");") {
|
||||
res.pop();
|
||||
res.pop();
|
||||
}
|
||||
|
||||
format!("{}, PRIMARY KEY (room_id, username));", res)
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
use refinery::include_migration_mods;
|
||||
include_migration_mods!("src/db/sqlite/migrator/migrations");
|
|
@ -0,0 +1,33 @@
|
|||
use log::info;
|
||||
use refinery::config::{Config, ConfigDbType};
|
||||
use sqlx::sqlite::SqliteConnectOptions;
|
||||
use sqlx::ConnectOptions;
|
||||
use std::str::FromStr;
|
||||
use thiserror::Error;
|
||||
|
||||
pub mod migrations;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum MigrationError {
|
||||
#[error("sqlite connection error: {0}")]
|
||||
SqlxError(#[from] sqlx::Error),
|
||||
|
||||
#[error("refinery migration error: {0}")]
|
||||
RefineryError(#[from] refinery::Error),
|
||||
}
|
||||
|
||||
/// Run database migrations against the sqlite database.
|
||||
pub async fn migrate(db_path: &str) -> Result<(), MigrationError> {
|
||||
//Create database if missing.
|
||||
let conn = SqliteConnectOptions::from_str(&format!("sqlite://{}", db_path))?
|
||||
.create_if_missing(true)
|
||||
.connect()
|
||||
.await?;
|
||||
|
||||
drop(conn);
|
||||
|
||||
let mut conn = Config::new(ConfigDbType::Sqlite).set_db_path(db_path);
|
||||
info!("Running migrations");
|
||||
migrations::runner().run(&mut conn)?;
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
use async_trait::async_trait;
|
||||
use errors::DataError;
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
|
||||
use sqlx::ConnectOptions;
|
||||
use std::clone::Clone;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::models::RoomInfo;
|
||||
|
||||
pub mod errors;
|
||||
pub mod migrator;
|
||||
pub mod rooms;
|
||||
pub mod state;
|
||||
pub mod variables;
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait DbState {
|
||||
async fn get_device_id(&self) -> Result<Option<String>, DataError>;
|
||||
|
||||
async fn set_device_id(&self, device_id: &str) -> Result<(), DataError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait Rooms {
|
||||
async fn should_process(&self, room_id: &str, event_id: &str) -> Result<bool, DataError>;
|
||||
|
||||
async fn insert_room_info(&self, info: &RoomInfo) -> Result<(), DataError>;
|
||||
|
||||
async fn get_room_info(&self, room_id: &str) -> Result<Option<RoomInfo>, DataError>;
|
||||
|
||||
async fn get_rooms_for_user(&self, user_id: &str) -> Result<HashSet<String>, DataError>;
|
||||
|
||||
async fn get_users_in_room(&self, room_id: &str) -> Result<HashSet<String>, DataError>;
|
||||
|
||||
async fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError>;
|
||||
|
||||
async fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError>;
|
||||
|
||||
async fn clear_info(&self, room_id: &str) -> Result<(), DataError>;
|
||||
}
|
||||
|
||||
// TODO move this up to the top once we delete sled. Traits will be the
|
||||
// main API, then we can have different impls for different DBs.
|
||||
#[async_trait]
|
||||
pub trait Variables {
|
||||
async fn get_user_variables(
|
||||
&self,
|
||||
user: &str,
|
||||
room_id: &str,
|
||||
) -> Result<HashMap<String, i32>, DataError>;
|
||||
|
||||
async fn get_variable_count(&self, user: &str, room_id: &str) -> Result<i32, DataError>;
|
||||
|
||||
async fn get_user_variable(
|
||||
&self,
|
||||
user: &str,
|
||||
room_id: &str,
|
||||
variable_name: &str,
|
||||
) -> Result<i32, DataError>;
|
||||
|
||||
async fn set_user_variable(
|
||||
&self,
|
||||
user: &str,
|
||||
room_id: &str,
|
||||
variable_name: &str,
|
||||
value: i32,
|
||||
) -> Result<(), DataError>;
|
||||
|
||||
async fn delete_user_variable(
|
||||
&self,
|
||||
user: &str,
|
||||
room_id: &str,
|
||||
variable_name: &str,
|
||||
) -> Result<(), DataError>;
|
||||
}
|
||||
|
||||
pub struct Database {
|
||||
conn: SqlitePool,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
fn new_db(conn: SqlitePool) -> Result<Database, DataError> {
|
||||
let database = Database { conn: conn.clone() };
|
||||
Ok(database)
|
||||
}
|
||||
|
||||
pub async fn new(path: &str) -> Result<Database, DataError> {
|
||||
//Create database if missing.
|
||||
let conn = SqliteConnectOptions::from_str(path)?
|
||||
.create_if_missing(true)
|
||||
.connect()
|
||||
.await?;
|
||||
|
||||
drop(conn);
|
||||
|
||||
//Migrate database.
|
||||
migrator::migrate(&path).await?;
|
||||
|
||||
//Return actual conncetion pool.
|
||||
let conn = SqlitePoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(path)
|
||||
.await?;
|
||||
|
||||
Self::new_db(conn)
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Database {
|
||||
fn clone(&self) -> Self {
|
||||
Database {
|
||||
conn: self.conn.clone(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,379 @@
|
|||
use super::errors::DataError;
|
||||
use super::{Database, Rooms};
|
||||
use crate::models::RoomInfo;
|
||||
use async_trait::async_trait;
|
||||
use sqlx::SqlitePool;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
async fn record_event(conn: &SqlitePool, room_id: &str, event_id: &str) -> Result<(), DataError> {
|
||||
use std::convert::TryFrom;
|
||||
let now: i64 = i64::try_from(
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("Clock has gone backwards")
|
||||
.as_secs(),
|
||||
)?;
|
||||
|
||||
sqlx::query(
|
||||
r#"INSERT INTO room_events
|
||||
(room_id, event_id, event_timestamp)
|
||||
VALUES (?, ?, ?)"#,
|
||||
)
|
||||
.bind(room_id)
|
||||
.bind(event_id)
|
||||
.bind(now)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Rooms for Database {
|
||||
async fn should_process(&self, room_id: &str, event_id: &str) -> Result<bool, DataError> {
|
||||
let row = sqlx::query!(
|
||||
r#"SELECT event_id FROM room_events
|
||||
WHERE room_id = ? AND event_id = ?"#,
|
||||
room_id,
|
||||
event_id
|
||||
)
|
||||
.fetch_optional(&self.conn)
|
||||
.await?;
|
||||
|
||||
match row {
|
||||
Some(_) => Ok(false),
|
||||
None => {
|
||||
record_event(&self.conn, room_id, event_id).await?;
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn insert_room_info(&self, info: &RoomInfo) -> Result<(), DataError> {
|
||||
sqlx::query(
|
||||
r#"INSERT INTO room_info (room_id, room_name) VALUES (?, ?)
|
||||
ON CONFLICT(room_id) DO UPDATE SET room_name = ?"#,
|
||||
)
|
||||
.bind(&info.room_id)
|
||||
.bind(&info.room_name)
|
||||
.bind(&info.room_name)
|
||||
.execute(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_room_info(&self, room_id: &str) -> Result<Option<RoomInfo>, DataError> {
|
||||
let info = sqlx::query!(
|
||||
r#"SELECT room_id, room_name FROM room_info
|
||||
WHERE room_id = ?"#,
|
||||
room_id
|
||||
)
|
||||
.fetch_optional(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(info.map(|i| RoomInfo {
|
||||
room_id: i.room_id,
|
||||
room_name: i.room_name,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn get_rooms_for_user(&self, user_id: &str) -> Result<HashSet<String>, DataError> {
|
||||
let room_ids = sqlx::query!(
|
||||
r#"SELECT room_id FROM room_users
|
||||
WHERE username = ?"#,
|
||||
user_id
|
||||
)
|
||||
.fetch_all(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(room_ids.into_iter().map(|row| row.room_id).collect())
|
||||
}
|
||||
|
||||
async fn get_users_in_room(&self, room_id: &str) -> Result<HashSet<String>, DataError> {
|
||||
let usernames = sqlx::query!(
|
||||
r#"SELECT username FROM room_users
|
||||
WHERE room_id = ?"#,
|
||||
room_id
|
||||
)
|
||||
.fetch_all(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(usernames.into_iter().map(|row| row.username).collect())
|
||||
}
|
||||
|
||||
async fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError> {
|
||||
sqlx::query(
|
||||
"INSERT INTO room_users (room_id, username) VALUES (?, ?)
|
||||
ON CONFLICT DO NOTHING",
|
||||
)
|
||||
.bind(room_id)
|
||||
.bind(username)
|
||||
.execute(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError> {
|
||||
sqlx::query("DELETE FROM room_users where username = ? AND room_id = ?")
|
||||
.bind(username)
|
||||
.bind(room_id)
|
||||
.execute(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn clear_info(&self, room_id: &str) -> Result<(), DataError> {
|
||||
// We do not clear event history here, because if we rejoin a
|
||||
// room, we would re-process events we've already seen.
|
||||
let mut tx = self.conn.begin().await?;
|
||||
|
||||
sqlx::query("DELETE FROM room_info where room_id = ?")
|
||||
.bind(room_id)
|
||||
.execute(&mut tx)
|
||||
.await?;
|
||||
|
||||
sqlx::query("DELETE FROM room_users where room_id = ?")
|
||||
.bind(room_id)
|
||||
.execute(&mut tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::Rooms;
|
||||
use super::*;
|
||||
|
||||
async fn create_db() -> Database {
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
Database::new(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn should_process_test() {
|
||||
let db = create_db().await;
|
||||
|
||||
let first_check = db
|
||||
.should_process("myroom", "myeventid")
|
||||
.await
|
||||
.expect("should_process failed in first insert");
|
||||
|
||||
assert_eq!(first_check, true);
|
||||
|
||||
let second_check = db
|
||||
.should_process("myroom", "myeventid")
|
||||
.await
|
||||
.expect("should_process failed in first insert");
|
||||
|
||||
assert_eq!(second_check, false);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn insert_and_get_room_info_test() {
|
||||
let db = create_db().await;
|
||||
|
||||
let info = RoomInfo {
|
||||
room_id: "myroomid".to_string(),
|
||||
room_name: "myroomname".to_string(),
|
||||
};
|
||||
|
||||
db.insert_room_info(&info)
|
||||
.await
|
||||
.expect("Could not insert room info.");
|
||||
|
||||
let retrieved_info = db
|
||||
.get_room_info("myroomid")
|
||||
.await
|
||||
.expect("Could not retrieve room info.");
|
||||
|
||||
assert!(retrieved_info.is_some());
|
||||
assert_eq!(info, retrieved_info.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn insert_room_info_updates_existing() {
|
||||
let db = create_db().await;
|
||||
|
||||
let info1 = RoomInfo {
|
||||
room_id: "myroomid".to_string(),
|
||||
room_name: "myroomname".to_string(),
|
||||
};
|
||||
|
||||
db.insert_room_info(&info1)
|
||||
.await
|
||||
.expect("Could not insert room info1.");
|
||||
|
||||
let info2 = RoomInfo {
|
||||
room_id: "myroomid".to_string(),
|
||||
room_name: "myroomname2".to_string(),
|
||||
};
|
||||
|
||||
db.insert_room_info(&info2)
|
||||
.await
|
||||
.expect("Could not update room info after first insert");
|
||||
|
||||
let retrieved_info = db
|
||||
.get_room_info("myroomid")
|
||||
.await
|
||||
.expect("Could not get room info");
|
||||
|
||||
assert!(retrieved_info.is_some());
|
||||
let retrieved_info = retrieved_info.unwrap();
|
||||
|
||||
assert_eq!(retrieved_info.room_id, "myroomid");
|
||||
assert_eq!(retrieved_info.room_name, "myroomname2");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn add_user_to_room_test() {
|
||||
let db = create_db().await;
|
||||
|
||||
db.add_user_to_room("myuser", "myroom")
|
||||
.await
|
||||
.expect("Could not add user to room.");
|
||||
|
||||
let users_in_room = db
|
||||
.get_users_in_room("myroom")
|
||||
.await
|
||||
.expect("Could not get users in room.");
|
||||
|
||||
assert_eq!(users_in_room.len(), 1);
|
||||
assert!(users_in_room.contains("myuser"));
|
||||
|
||||
let rooms_for_user = db
|
||||
.get_rooms_for_user("myuser")
|
||||
.await
|
||||
.expect("Could not get rooms for user");
|
||||
|
||||
assert_eq!(rooms_for_user.len(), 1);
|
||||
assert!(rooms_for_user.contains("myroom"));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn add_user_to_room_does_not_have_constraint_violation() {
|
||||
let db = create_db().await;
|
||||
|
||||
db.add_user_to_room("myuser", "myroom")
|
||||
.await
|
||||
.expect("Could not add user to room.");
|
||||
|
||||
let second_attempt = db.add_user_to_room("myuser", "myroom").await;
|
||||
|
||||
assert!(second_attempt.is_ok());
|
||||
|
||||
let users_in_room = db
|
||||
.get_users_in_room("myroom")
|
||||
.await
|
||||
.expect("Could not get users in room.");
|
||||
|
||||
assert_eq!(users_in_room.len(), 1);
|
||||
assert!(users_in_room.contains("myuser"));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn remove_user_from_room_test() {
|
||||
let db = create_db().await;
|
||||
|
||||
db.add_user_to_room("myuser", "myroom")
|
||||
.await
|
||||
.expect("Could not add user to room.");
|
||||
|
||||
let remove_attempt = db.remove_user_from_room("myuser", "myroom").await;
|
||||
|
||||
assert!(remove_attempt.is_ok());
|
||||
|
||||
let users_in_room = db
|
||||
.get_users_in_room("myroom")
|
||||
.await
|
||||
.expect("Could not get users in room.");
|
||||
|
||||
assert_eq!(users_in_room.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn clear_info_does_not_delete_other_rooms() {
|
||||
let db = create_db().await;
|
||||
|
||||
let info1 = RoomInfo {
|
||||
room_id: "myroomid".to_string(),
|
||||
room_name: "myroomname".to_string(),
|
||||
};
|
||||
|
||||
let info2 = RoomInfo {
|
||||
room_id: "myroomid2".to_string(),
|
||||
room_name: "myroomname2".to_string(),
|
||||
};
|
||||
|
||||
db.insert_room_info(&info1)
|
||||
.await
|
||||
.expect("Could not insert room info1.");
|
||||
|
||||
db.insert_room_info(&info2)
|
||||
.await
|
||||
.expect("Could not insert room info2.");
|
||||
|
||||
db.add_user_to_room("myuser", &info1.room_id)
|
||||
.await
|
||||
.expect("Could not add user to room.");
|
||||
|
||||
db.clear_info(&info1.room_id)
|
||||
.await
|
||||
.expect("Could not clear room info1");
|
||||
|
||||
let room_info2 = db
|
||||
.get_room_info(&info2.room_id)
|
||||
.await
|
||||
.expect("Could not get room info2.");
|
||||
|
||||
assert!(room_info2.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn clear_info_test() {
|
||||
let db = create_db().await;
|
||||
|
||||
let info = RoomInfo {
|
||||
room_id: "myroomid".to_string(),
|
||||
room_name: "myroomname".to_string(),
|
||||
};
|
||||
|
||||
db.insert_room_info(&info)
|
||||
.await
|
||||
.expect("Could not insert room info.");
|
||||
|
||||
db.add_user_to_room("myuser", &info.room_id)
|
||||
.await
|
||||
.expect("Could not add user to room.");
|
||||
|
||||
db.clear_info(&info.room_id)
|
||||
.await
|
||||
.expect("Could not clear room info");
|
||||
|
||||
let users_in_room = db
|
||||
.get_users_in_room(&info.room_id)
|
||||
.await
|
||||
.expect("Could not get users in room.");
|
||||
|
||||
assert_eq!(users_in_room.len(), 0);
|
||||
|
||||
let room_info = db
|
||||
.get_room_info(&info.room_id)
|
||||
.await
|
||||
.expect("Could not get room info.");
|
||||
|
||||
assert!(room_info.is_none());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
use super::errors::DataError;
|
||||
use super::{Database, DbState};
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
impl DbState for Database {
|
||||
async fn get_device_id(&self) -> Result<Option<String>, DataError> {
|
||||
let state = sqlx::query!(r#"SELECT device_id FROM bot_state limit 1"#)
|
||||
.fetch_optional(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(state.map(|s| s.device_id))
|
||||
}
|
||||
|
||||
async fn set_device_id(&self, device_id: &str) -> Result<(), DataError> {
|
||||
// This will have to be updated if we ever add another column
|
||||
// to this table!
|
||||
sqlx::query("DELETE FROM bot_state")
|
||||
.execute(&self.conn)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
sqlx::query(
|
||||
r#"INSERT INTO bot_state
|
||||
(device_id)
|
||||
VALUES (?)"#,
|
||||
)
|
||||
.bind(device_id)
|
||||
.execute(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::DbState;
|
||||
use super::*;
|
||||
|
||||
async fn create_db() -> Database {
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
Database::new(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn set_and_get_device_id() {
|
||||
let db = create_db().await;
|
||||
|
||||
db.set_device_id("device_id")
|
||||
.await
|
||||
.expect("Could not set device ID");
|
||||
|
||||
let device_id = db.get_device_id().await.expect("Could not get device ID");
|
||||
|
||||
assert!(device_id.is_some());
|
||||
assert_eq!(device_id.unwrap(), "device_id");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn no_device_id_set_returns_none() {
|
||||
let db = create_db().await;
|
||||
let device_id = db.get_device_id().await.expect("Could not get device ID");
|
||||
assert!(device_id.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn can_update_device_id() {
|
||||
let db = create_db().await;
|
||||
|
||||
db.set_device_id("device_id")
|
||||
.await
|
||||
.expect("Could not set device ID");
|
||||
|
||||
db.set_device_id("device_id2")
|
||||
.await
|
||||
.expect("Could not set device ID");
|
||||
|
||||
let device_id = db.get_device_id().await.expect("Could not get device ID");
|
||||
|
||||
assert!(device_id.is_some());
|
||||
assert_eq!(device_id.unwrap(), "device_id2");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,251 @@
|
|||
use super::errors::DataError;
|
||||
use super::{Database, Variables};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
|
||||
struct UserVariableRow {
|
||||
key: String,
|
||||
value: i32,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Variables for Database {
|
||||
async fn get_user_variables(
|
||||
&self,
|
||||
user: &str,
|
||||
room_id: &str,
|
||||
) -> Result<HashMap<String, i32>, DataError> {
|
||||
let rows = sqlx::query!(
|
||||
r#"SELECT key, value as "value: i32" FROM user_variables
|
||||
WHERE room_id = ? AND user_id = ?"#,
|
||||
room_id,
|
||||
user
|
||||
)
|
||||
.fetch_all(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(rows.into_iter().map(|row| (row.key, row.value)).collect())
|
||||
}
|
||||
|
||||
async fn get_variable_count(&self, user: &str, room_id: &str) -> Result<i32, DataError> {
|
||||
let row = sqlx::query!(
|
||||
r#"SELECT count(*) as "count: i32" FROM user_variables
|
||||
WHERE room_id = ? and user_id = ?"#,
|
||||
room_id,
|
||||
user
|
||||
)
|
||||
.fetch_optional(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(|r| r.count).unwrap_or(0))
|
||||
}
|
||||
|
||||
async fn get_user_variable(
|
||||
&self,
|
||||
user: &str,
|
||||
room_id: &str,
|
||||
variable_name: &str,
|
||||
) -> Result<i32, DataError> {
|
||||
let row = sqlx::query!(
|
||||
r#"SELECT value as "value: i32" FROM user_variables
|
||||
WHERE user_id = ? AND room_id = ? AND key = ?"#,
|
||||
user,
|
||||
room_id,
|
||||
variable_name
|
||||
)
|
||||
.fetch_optional(&self.conn)
|
||||
.await?;
|
||||
|
||||
row.map(|r| r.value)
|
||||
.ok_or_else(|| DataError::KeyDoesNotExist(variable_name.to_string()))
|
||||
}
|
||||
|
||||
async fn set_user_variable(
|
||||
&self,
|
||||
user: &str,
|
||||
room_id: &str,
|
||||
variable_name: &str,
|
||||
value: i32,
|
||||
) -> Result<(), DataError> {
|
||||
sqlx::query(
|
||||
"INSERT INTO user_variables
|
||||
(user_id, room_id, key, value)
|
||||
values (?, ?, ?, ?)",
|
||||
)
|
||||
.bind(user)
|
||||
.bind(room_id)
|
||||
.bind(variable_name)
|
||||
.bind(value)
|
||||
.execute(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_user_variable(
|
||||
&self,
|
||||
user: &str,
|
||||
room_id: &str,
|
||||
variable_name: &str,
|
||||
) -> Result<(), DataError> {
|
||||
sqlx::query(
|
||||
"DELETE FROM user_variables
|
||||
WHERE user_id = ? AND room_id = ? AND key = ?",
|
||||
)
|
||||
.bind(user)
|
||||
.bind(room_id)
|
||||
.bind(variable_name)
|
||||
.execute(&self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::Variables;
|
||||
use super::*;
|
||||
|
||||
async fn create_db() -> Database {
|
||||
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
|
||||
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
Database::new(db_path.path().to_str().unwrap())
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn set_and_get_variable_test() {
|
||||
use super::super::Variables;
|
||||
let db = create_db().await;
|
||||
|
||||
db.set_user_variable("myuser", "myroom", "myvariable", 1)
|
||||
.await
|
||||
.expect("Could not set variable");
|
||||
|
||||
let value = db
|
||||
.get_user_variable("myuser", "myroom", "myvariable")
|
||||
.await
|
||||
.expect("Could not get variable");
|
||||
|
||||
assert_eq!(value, 1);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn get_missing_variable_test() {
|
||||
use super::super::Variables;
|
||||
let db = create_db().await;
|
||||
|
||||
let value = db.get_user_variable("myuser", "myroom", "myvariable").await;
|
||||
|
||||
assert!(value.is_err());
|
||||
assert!(matches!(
|
||||
value.err().unwrap(),
|
||||
DataError::KeyDoesNotExist(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn get_other_user_variable_test() {
|
||||
use super::super::Variables;
|
||||
let db = create_db().await;
|
||||
|
||||
db.set_user_variable("myuser1", "myroom", "myvariable", 1)
|
||||
.await
|
||||
.expect("Could not set variable");
|
||||
|
||||
let value = db
|
||||
.get_user_variable("myuser2", "myroom", "myvariable")
|
||||
.await;
|
||||
|
||||
assert!(value.is_err());
|
||||
assert!(matches!(
|
||||
value.err().unwrap(),
|
||||
DataError::KeyDoesNotExist(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn count_variables_test() {
|
||||
let db = create_db().await;
|
||||
|
||||
for variable_name in &["var1", "var2", "var3"] {
|
||||
db.set_user_variable("myuser", "myroom", variable_name, 1)
|
||||
.await
|
||||
.expect("Could not set variable");
|
||||
}
|
||||
|
||||
let count = db
|
||||
.get_variable_count("myuser", "myroom")
|
||||
.await
|
||||
.expect("Could not get count.");
|
||||
|
||||
assert_eq!(count, 3);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn count_variables_respects_user_id() {
|
||||
let db = create_db().await;
|
||||
|
||||
for variable_name in &["var1", "var2", "var3"] {
|
||||
db.set_user_variable("different-user", "myroom", variable_name, 1)
|
||||
.await
|
||||
.expect("Could not set variable");
|
||||
}
|
||||
|
||||
let count = db
|
||||
.get_variable_count("myuser", "myroom")
|
||||
.await
|
||||
.expect("Could not get count.");
|
||||
|
||||
assert_eq!(count, 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn count_variables_respects_room_id() {
|
||||
let db = create_db().await;
|
||||
|
||||
for variable_name in &["var1", "var2", "var3"] {
|
||||
db.set_user_variable("myuser", "different-room", variable_name, 1)
|
||||
.await
|
||||
.expect("Could not set variable");
|
||||
}
|
||||
|
||||
let count = db
|
||||
.get_variable_count("myuser", "myroom")
|
||||
.await
|
||||
.expect("Could not get count.");
|
||||
|
||||
assert_eq!(count, 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn delete_variable_test() {
|
||||
let db = create_db().await;
|
||||
|
||||
for variable_name in &["var1", "var2", "var3"] {
|
||||
db.set_user_variable("myuser", "myroom", variable_name, 1)
|
||||
.await
|
||||
.expect("Could not set variable");
|
||||
}
|
||||
|
||||
db.delete_user_variable("myuser", "myroom", "var1")
|
||||
.await
|
||||
.expect("Could not delete variable.");
|
||||
|
||||
let count = db
|
||||
.get_variable_count("myuser", "myroom")
|
||||
.await
|
||||
.expect("Could not get count");
|
||||
|
||||
assert_eq!(count, 2);
|
||||
|
||||
let var1 = db.get_user_variable("myuser", "myroom", "var1").await;
|
||||
assert!(var1.is_err());
|
||||
assert!(matches!(var1.err().unwrap(), DataError::KeyDoesNotExist(_)));
|
||||
}
|
||||
}
|
|
@ -10,6 +10,8 @@ use std::str;
|
|||
use zerocopy::byteorder::I32;
|
||||
use zerocopy::AsBytes;
|
||||
|
||||
use super::errors;
|
||||
|
||||
pub(super) mod migrations;
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -67,6 +69,9 @@ fn alter_room_variable_count(
|
|||
Ok(new_count)
|
||||
}
|
||||
|
||||
/// Room ID, Username, Variable Name
|
||||
pub type AllVariablesKey = (String, String, String);
|
||||
|
||||
impl Variables {
|
||||
pub(in crate::db) fn new(db: &sled::Db) -> Result<Variables, sled::Error> {
|
||||
Ok(Variables {
|
||||
|
@ -75,6 +80,40 @@ impl Variables {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn get_all_variables(&self) -> Result<HashMap<AllVariablesKey, i32>, DataError> {
|
||||
use std::convert::TryFrom;
|
||||
let variables: Result<Vec<(AllVariablesKey, i32)>, DataError> = self
|
||||
.room_user_variables
|
||||
.scan_prefix("")
|
||||
.map(|entry| match entry {
|
||||
Ok((key, raw_value)) => {
|
||||
let keys: Vec<_> = key
|
||||
.split(|&b| b == 0xfe || b == 0xff)
|
||||
.map(|b| str::from_utf8(b))
|
||||
.collect();
|
||||
|
||||
if let &[Ok(room_id), Ok(username), Ok(variable_name), ..] = keys.as_slice() {
|
||||
Ok((
|
||||
(
|
||||
room_id.to_owned(),
|
||||
username.to_owned(),
|
||||
variable_name.to_owned(),
|
||||
),
|
||||
convert_i32(&raw_value)?,
|
||||
))
|
||||
} else {
|
||||
Err(errors::DataError::InvalidValue)
|
||||
}
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Convert tuples to hash map with collect(), inferred via
|
||||
// return type.
|
||||
variables.map(|entries| entries.into_iter().collect())
|
||||
}
|
||||
|
||||
pub fn get_user_variables(
|
||||
&self,
|
||||
key: &UserAndRoom<'_>,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::context::Context;
|
||||
use crate::db::sqlite::Variables;
|
||||
use crate::db::variables::UserAndRoom;
|
||||
use crate::error::BotError;
|
||||
use crate::error::DiceRollingError;
|
||||
|
@ -22,8 +23,10 @@ pub async fn calculate_single_die_amount(
|
|||
/// it cannot find a variable defined, or if the database errors.
|
||||
pub async fn calculate_dice_amount(amounts: &[Amount], ctx: &Context<'_>) -> Result<i32, BotError> {
|
||||
let stream = stream::iter(amounts);
|
||||
let key = UserAndRoom(&ctx.username, ctx.room_id().as_str());
|
||||
let variables = &ctx.db.variables.get_user_variables(&key)?;
|
||||
let variables = &ctx
|
||||
.db
|
||||
.get_user_variables(&ctx.username, ctx.room_id().as_str())
|
||||
.await?;
|
||||
|
||||
use DiceRollingError::VariableNotFound;
|
||||
let dice_amount: i32 = stream
|
||||
|
|
12
src/error.rs
12
src/error.rs
|
@ -1,6 +1,6 @@
|
|||
use crate::commands::CommandError;
|
||||
use crate::config::ConfigError;
|
||||
use crate::db::errors::DataError;
|
||||
use crate::{commands::CommandError, db::sqlite::migrator};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
@ -21,6 +21,9 @@ pub enum BotError {
|
|||
#[error("database error: {0}")]
|
||||
DataError(#[from] DataError),
|
||||
|
||||
#[error("sqlite database error: {0}")]
|
||||
SqliteDataError(#[from] crate::db::sqlite::errors::DataError),
|
||||
|
||||
#[error("the message should not be processed because it failed validation")]
|
||||
ShouldNotProcessError,
|
||||
|
||||
|
@ -33,10 +36,10 @@ pub enum BotError {
|
|||
#[error("error in matrix state store: {0}")]
|
||||
MatrixStateStoreError(#[from] matrix_sdk::StoreError),
|
||||
|
||||
#[error("uncategorized matrix SDK error")]
|
||||
#[error("uncategorized matrix SDK error: {0}")]
|
||||
MatrixError(#[from] matrix_sdk::Error),
|
||||
|
||||
#[error("uncategorized matrix SDK base error")]
|
||||
#[error("uncategorized matrix SDK base error: {0}")]
|
||||
MatrixBaseError(#[from] matrix_sdk::BaseError),
|
||||
|
||||
#[error("future canceled")]
|
||||
|
@ -73,6 +76,9 @@ pub enum BotError {
|
|||
#[error("database error")]
|
||||
DatabaseError(#[from] sled::Error),
|
||||
|
||||
#[error("database migration error: {0}")]
|
||||
SqliteError(#[from] migrator::MigrationError),
|
||||
|
||||
#[error("too many commands or message was too large")]
|
||||
MessageTooLarge,
|
||||
|
||||
|
|
36
src/logic.rs
36
src/logic.rs
|
@ -1,18 +1,24 @@
|
|||
use crate::db::errors::DataError;
|
||||
use crate::db::sqlite::errors::DataError;
|
||||
use crate::db::sqlite::Rooms;
|
||||
use crate::error::BotError;
|
||||
use crate::matrix;
|
||||
use crate::models::RoomInfo;
|
||||
use futures::stream::{self, StreamExt, TryStreamExt};
|
||||
use matrix_sdk::{self, identifiers::RoomId, Client};
|
||||
|
||||
/// Record the information about a room, including users in it.
|
||||
pub async fn record_room_information(
|
||||
client: &Client,
|
||||
db: &crate::db::Database,
|
||||
db: &crate::db::sqlite::Database,
|
||||
room_id: &RoomId,
|
||||
room_display_name: &str,
|
||||
our_username: &str,
|
||||
) -> Result<(), DataError> {
|
||||
) -> Result<(), BotError> {
|
||||
//Clear out any old room info first.
|
||||
db.clear_info(room_id.as_str()).await?;
|
||||
|
||||
let room_id_str = room_id.as_str();
|
||||
let usernames = matrix::get_users_in_room(&client, &room_id).await;
|
||||
let usernames = matrix::get_users_in_room(&client, &room_id).await?;
|
||||
|
||||
let info = RoomInfo {
|
||||
room_id: room_id_str.to_owned(),
|
||||
|
@ -21,11 +27,23 @@ pub async fn record_room_information(
|
|||
|
||||
// TODO this and the username adding should be one whole
|
||||
// transaction in the db.
|
||||
db.rooms.insert_room_info(&info)?;
|
||||
db.insert_room_info(&info).await?;
|
||||
|
||||
usernames
|
||||
let filtered_usernames = usernames
|
||||
.into_iter()
|
||||
.filter(|username| username != our_username)
|
||||
.map(|username| db.rooms.add_user_to_room(&username, room_id_str))
|
||||
.collect() //Make use of collect impl on Result.
|
||||
.filter(|username| username != our_username);
|
||||
|
||||
// Async collect into vec of results, then use into_iter of result
|
||||
// to go to from Result<Vec<()>> to just Result<()>. Easier than
|
||||
// attempting to async-collect our way to a single Result<()>.
|
||||
stream::iter(filtered_usernames)
|
||||
.then(|username| async move {
|
||||
db.add_user_to_room(&username, &room_id_str)
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
})
|
||||
.collect::<Vec<Result<(), BotError>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
|
|
|
@ -20,16 +20,19 @@ fn extract_error_message(error: MatrixError) -> String {
|
|||
}
|
||||
|
||||
/// Retrieve a list of users in a given room.
|
||||
pub async fn get_users_in_room(client: &Client, room_id: &RoomId) -> Vec<String> {
|
||||
pub async fn get_users_in_room(
|
||||
client: &Client,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Vec<String>, MatrixError> {
|
||||
if let Some(joined_room) = client.get_joined_room(room_id) {
|
||||
let members = joined_room.joined_members().await.ok().unwrap_or_default();
|
||||
let members = joined_room.joined_members().await?;
|
||||
|
||||
members
|
||||
Ok(members
|
||||
.into_iter()
|
||||
.map(|member| member.user_id().to_string())
|
||||
.collect()
|
||||
.collect())
|
||||
} else {
|
||||
vec![]
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,7 +53,7 @@ pub async fn send_message(
|
|||
));
|
||||
|
||||
content.relates_to = reply_to.map(|event_id| Relation::Reply {
|
||||
in_reply_to: InReplyTo { event_id },
|
||||
in_reply_to: InReplyTo::new(event_id),
|
||||
});
|
||||
|
||||
let content = AnyMessageEventContent::RoomMessage(content);
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
use std::env;
|
||||
use tenebrous_dicebot::db::sqlite::migrator;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), migrator::MigrationError> {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let db_path: &str = match &args[..] {
|
||||
[_, path] => path.as_ref(),
|
||||
[_, _, ..] => panic!("Expected exactly 0 or 1 argument"),
|
||||
_ => "dicebot.sqlite",
|
||||
};
|
||||
|
||||
println!("Using database: {}", db_path);
|
||||
|
||||
migrator::migrate(db_path).await
|
||||
}
|
Loading…
Reference in New Issue