diff --git a/src/bot.rs b/src/bot.rs index 7e26a8c..48506c5 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -120,20 +120,40 @@ impl DiceBot { }) } - /// Logs the bot into Matrix and listens for events until program + /// Logs in to matrix and potentially records a new device ID. If + /// no device ID is found in the database, a new one will be + /// generated by the matrix SDK, and we will store it. + async fn login(&self, client: &Client) -> Result<(), BotError> { + let username = self.config.matrix_username(); + let password = self.config.matrix_password(); + + // Pull device ID from database, if it exists. Then write it + // to DB if the library generated one for us. + let device_id: Option = self.db.state.get_device_id()?; + let device_id: Option<&str> = device_id.as_deref(); + + client + .login(username, password, device_id, Some("matrix dice bot")) + .await?; + + if device_id.is_none() { + let device_id = client.device_id().await.ok_or(BotError::NoDeviceIdFound)?; + self.db.state.set_device_id(device_id.as_str())?; + info!("Recorded new device ID: {}", device_id.as_str()); + } else { + info!("Using existing device ID: {}", device_id.unwrap()); + } + + info!("Logged in as {}", username); + Ok(()) + } + + /// Logs the bot in to Matrix and listens for events until program /// terminated, or a panic occurs. Originally adapted from the /// matrix-rust-sdk command bot example. pub async fn run(self) -> Result<(), BotError> { - let username = &self.config.matrix_username(); - let password = &self.config.matrix_password(); - - //TODO provide a device id from config. let client = self.client.clone(); - client - .login(username, password, None, Some("matrix dice bot")) - .await?; - - info!("Logged in as {}", username); + self.login(&client).await?; // Initial sync without event handler prevents responding to // messages received while bot was offline. TODO: selectively diff --git a/src/db.rs b/src/db.rs index cf8c05b..b50e92a 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,6 +1,7 @@ use crate::db::errors::{DataError, MigrationError}; use crate::db::migrations::{get_migration_version, Migrations}; use crate::db::rooms::Rooms; +use crate::db::state::DbState; use crate::db::variables::Variables; use log::info; use sled::{Config, Db}; @@ -11,6 +12,7 @@ pub mod errors; pub mod migrations; pub mod rooms; pub mod schema; +pub mod state; pub mod variables; #[derive(Clone)] @@ -19,6 +21,7 @@ pub struct Database { pub(crate) variables: Variables, pub(crate) migrations: Migrations, pub(crate) rooms: Rooms, + pub(crate) state: DbState, } impl Database { @@ -30,6 +33,7 @@ impl Database { variables: Variables::new(&db)?, migrations: Migrations(migrations), rooms: Rooms::new(&db)?, + state: DbState::new(&db)?, }; //Start any event handlers. diff --git a/src/db/errors.rs b/src/db/errors.rs index 6cfe7f7..8b1757f 100644 --- a/src/db/errors.rs +++ b/src/db/errors.rs @@ -29,8 +29,11 @@ pub enum DataError { #[error("unexpected or corruptd data bytes")] InvalidValue, + #[error("expected string ref, but utf8 schema was violated: {0}")] + Utf8RefSchemaViolation(#[from] std::str::Utf8Error), + #[error("expected string, but utf8 schema was violated: {0}")] - Utf8chemaViolation(#[from] std::str::Utf8Error), + Utf8SchemaViolation(#[from] std::string::FromUtf8Error), #[error("internal database error: {0}")] InternalError(#[from] sled::Error), diff --git a/src/db/state.rs b/src/db/state.rs new file mode 100644 index 0000000..cf0f84a --- /dev/null +++ b/src/db/state.rs @@ -0,0 +1,88 @@ +use crate::db::errors::DataError; +use sled::Tree; + +#[derive(Clone)] +pub struct DbState { + /// Tree of simple key-values for global state values that persist + /// between restarts (e.g. device ID). + pub(in crate::db) global_metadata: Tree, +} + +const DEVICE_ID_KEY: &'static [u8] = b"device_id"; + +impl DbState { + pub(in crate::db) fn new(db: &sled::Db) -> Result { + Ok(DbState { + global_metadata: db.open_tree("global_metadata")?, + }) + } + + pub fn get_device_id(&self) -> Result, DataError> { + self.global_metadata + .get(DEVICE_ID_KEY)? + .map(|v| String::from_utf8(v.to_vec())) + .transpose() + .map_err(|e| e.into()) + } + + pub fn set_device_id(&self, device_id: &str) -> Result<(), DataError> { + self.global_metadata + .insert(DEVICE_ID_KEY, device_id.as_bytes())?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use sled::Config; + + fn create_test_instance() -> DbState { + let config = Config::new().temporary(true); + let db = config.open().unwrap(); + DbState::new(&db).unwrap() + } + + #[test] + fn set_device_id_works() { + let state = create_test_instance(); + let result = state.set_device_id("test-device"); + assert!(result.is_ok()); + } + + #[test] + fn set_device_id_can_overwrite() { + let state = create_test_instance(); + state.set_device_id("test-device").expect("insert 1 failed"); + let result = state.set_device_id("test-device2"); + assert!(result.is_ok()); + } + + #[test] + fn get_device_id_returns_some_when_set() { + let state = create_test_instance(); + + state + .set_device_id("test-device") + .expect("could not store device id properly"); + + let device_id = state.get_device_id(); + + assert!(device_id.is_ok()); + + let device_id = device_id.unwrap(); + assert!(device_id.is_some()); + assert_eq!("test-device", device_id.unwrap()); + } + + #[test] + fn get_device_id_returns_none_when_unset() { + let state = create_test_instance(); + let device_id = state.get_device_id(); + assert!(device_id.is_ok()); + + let device_id = device_id.unwrap(); + assert!(device_id.is_none()); + } +} diff --git a/src/error.rs b/src/error.rs index a2f5540..39ae7de 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,9 @@ pub enum BotError { #[error("the sync token could not be retrieved")] SyncTokenRequired, + #[error("could not retrieve device id")] + NoDeviceIdFound, + #[error("command error: {0}")] CommandError(#[from] CommandError),