diff --git a/src/bot.rs b/src/bot.rs index 0aaf6f4..8928b2a 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -4,27 +4,25 @@ use crate::context::Context; use crate::db::Database; use crate::error::BotError; use crate::state::DiceBotState; -use async_trait::async_trait; use dirs; -use log::{debug, error, info, warn}; +use log::{error, info}; use matrix_sdk::Error as MatrixError; use matrix_sdk::{ self, events::{ - room::member::{MemberEventContent, MembershipState}, - room::message::{MessageEventContent, NoticeMessageEventContent, TextMessageEventContent}, - AnyMessageEventContent, StrippedStateEvent, SyncMessageEvent, SyncStateEvent, + room::message::{MessageEventContent, NoticeMessageEventContent}, + AnyMessageEventContent, }, - Client, ClientConfig, EventEmitter, JsonStore, Room, SyncRoom, SyncSettings, + Client, ClientConfig, JsonStore, Room, SyncSettings, }; //use matrix_sdk_common_macros::async_trait; use std::clone::Clone; -use std::ops::Sub; use std::path::PathBuf; use std::sync::{Arc, RwLock}; -use std::time::{Duration, SystemTime}; use url::Url; +pub mod event_handlers; + /// The DiceBot struct represents an active dice bot. The bot is not /// connected to Matrix until its run() function is called. pub struct DiceBot { @@ -167,135 +165,3 @@ impl DiceBot { } } } - -/// Check if a message is recent enough to actually process. If the -/// message is within "oldest_message_age" seconds, this function -/// returns true. If it's older than that, it returns false and logs a -/// debug message. -fn check_message_age( - event: &SyncMessageEvent, - oldest_message_age: u64, -) -> bool { - let sending_time = event.origin_server_ts; - let oldest_timestamp = SystemTime::now().sub(Duration::new(oldest_message_age, 0)); - - if sending_time > oldest_timestamp { - true - } else { - let age = match oldest_timestamp.duration_since(sending_time) { - Ok(n) => format!("{} seconds too old", n.as_secs()), - Err(_) => "before the UNIX epoch".to_owned(), - }; - - debug!("Ignoring message because it is {}: {:?}", age, event); - false - } -} - -async fn should_process<'a>( - bot: &DiceBot, - event: &SyncMessageEvent, -) -> Result<(String, String), BotError> { - //Ignore messages that are older than configured duration. - if !check_message_age(event, bot.config.oldest_message_age()) { - let state_check = bot.state.read().unwrap(); - if !((*state_check).logged_skipped_old_messages()) { - drop(state_check); - let mut state = bot.state.write().unwrap(); - (*state).skipped_old_messages(); - } - - return Err(BotError::ShouldNotProcessError); - } - - let (msg_body, sender_username) = if let SyncMessageEvent { - content: MessageEventContent::Text(TextMessageEventContent { body, .. }), - sender, - .. - } = event - { - ( - body.clone(), - format!("@{}:{}", sender.localpart(), sender.server_name()), - ) - } else { - (String::new(), String::new()) - }; - - Ok((msg_body, sender_username)) -} - -/// This event emitter listens for messages with dice rolling commands. -/// Originally adapted from the matrix-rust-sdk examples. -#[async_trait] -impl EventEmitter for DiceBot { - async fn on_room_member( - &self, - room: SyncRoom, - room_member: &SyncStateEvent, - ) { - //When joining a channel, we get join events from other users. - //content is MemberContent, and it has a membership type. - - //Ignore if state_key is our username, because we only care about other users. - let event_affects_us = if let Some(our_user_id) = self.client.user_id().await { - room_member.state_key == our_user_id - } else { - false - }; - - let should_add = match room_member.content.membership { - MembershipState::Join => true, - MembershipState::Leave | MembershipState::Ban => false, - _ => return, - }; - - //if event affects us and is leave/ban, delete all our info. - //if event does not affect us, delete info only for that user. - - //TODO replace with call to new db.rooms thing. - println!( - "member {} recorded with action {:?} to/from db.", - room_member.state_key, should_add - ); - } - - async fn on_stripped_state_member( - &self, - room: SyncRoom, - room_member: &StrippedStateEvent, - _: Option, - ) { - if let SyncRoom::Invited(room) = room { - if let Some(user_id) = self.client.user_id().await { - if room_member.state_key != user_id { - return; - } - } - - let room = room.read().await; - info!("Autojoining room {}", room.display_name()); - - if let Err(e) = self.client.join_room_by_id(&room.room_id).await { - warn!("Could not join room: {}", e.to_string()) - } - } - } - - async fn on_room_message(&self, room: SyncRoom, event: &SyncMessageEvent) { - if let SyncRoom::Joined(room) = room { - let (msg_body, sender_username) = - if let Ok((msg_body, sender_username)) = should_process(self, &event).await { - (msg_body, sender_username) - } else { - return; - }; - - //we clone here to hold the lock for as little time as possible. - let real_room = room.read().await.clone(); - - self.execute_commands(&real_room, &sender_username, &msg_body) - .await; - } - } -} diff --git a/src/bot/event_handlers.rs b/src/bot/event_handlers.rs new file mode 100644 index 0000000..526bb64 --- /dev/null +++ b/src/bot/event_handlers.rs @@ -0,0 +1,163 @@ +use crate::error::BotError; +use async_trait::async_trait; +use log::{debug, error, info, warn}; +use matrix_sdk::{ + self, + events::{ + room::member::{MemberEventContent, MembershipState}, + room::message::{MessageEventContent, TextMessageEventContent}, + StrippedStateEvent, SyncMessageEvent, SyncStateEvent, + }, + EventEmitter, SyncRoom, +}; +//use matrix_sdk_common_macros::async_trait; +use super::DiceBot; +use std::clone::Clone; +use std::ops::Sub; +use std::time::{Duration, SystemTime}; + +/// Check if a message is recent enough to actually process. If the +/// message is within "oldest_message_age" seconds, this function +/// returns true. If it's older than that, it returns false and logs a +/// debug message. +fn check_message_age( + event: &SyncMessageEvent, + oldest_message_age: u64, +) -> bool { + let sending_time = event.origin_server_ts; + let oldest_timestamp = SystemTime::now().sub(Duration::new(oldest_message_age, 0)); + + if sending_time > oldest_timestamp { + true + } else { + let age = match oldest_timestamp.duration_since(sending_time) { + Ok(n) => format!("{} seconds too old", n.as_secs()), + Err(_) => "before the UNIX epoch".to_owned(), + }; + + debug!("Ignoring message because it is {}: {:?}", age, event); + false + } +} + +async fn should_process<'a>( + bot: &DiceBot, + event: &SyncMessageEvent, +) -> Result<(String, String), BotError> { + //Ignore messages that are older than configured duration. + if !check_message_age(event, bot.config.oldest_message_age()) { + let state_check = bot.state.read().unwrap(); + if !((*state_check).logged_skipped_old_messages()) { + drop(state_check); + let mut state = bot.state.write().unwrap(); + (*state).skipped_old_messages(); + } + + return Err(BotError::ShouldNotProcessError); + } + + let (msg_body, sender_username) = if let SyncMessageEvent { + content: MessageEventContent::Text(TextMessageEventContent { body, .. }), + sender, + .. + } = event + { + ( + body.clone(), + format!("@{}:{}", sender.localpart(), sender.server_name()), + ) + } else { + (String::new(), String::new()) + }; + + Ok((msg_body, sender_username)) +} + +/// This event emitter listens for messages with dice rolling commands. +/// Originally adapted from the matrix-rust-sdk examples. +#[async_trait] +impl EventEmitter for DiceBot { + async fn on_room_member( + &self, + room: SyncRoom, + room_member: &SyncStateEvent, + ) { + if let SyncRoom::Joined(room) = room { + let event_affects_us = if let Some(our_user_id) = self.client.user_id().await { + room_member.state_key == our_user_id + } else { + false + }; + + let adding_user = match room_member.content.membership { + MembershipState::Join => true, + MembershipState::Leave | MembershipState::Ban => false, + _ => return, + }; + + //Clone to avoid holding lock. + let room = room.read().await.clone(); + let (room_id, username) = (room.room_id.as_str(), &room_member.state_key); + + let result = if event_affects_us && !adding_user { + debug!("Clearing all information for room ID {}", room_id); + self.db.rooms.clear_info(room_id) + } else if !event_affects_us && adding_user { + debug!("Adding {} to room ID {}", username, room_id); + self.db.rooms.add_user_to_room(username, room_id) + } else if !event_affects_us && !adding_user { + debug!("Removing {} from room ID {}", username, room_id); + self.db.rooms.remove_user_from_room(username, room_id) + } else { + debug!("Ignoring a room member event: {:#?}", room_member); + Ok(()) + }; + + if let Err(e) = result { + error!("Could not update room information: {}", e.to_string()); + } else { + debug!("Successfully processed room member update."); + } + } + } + + async fn on_stripped_state_member( + &self, + room: SyncRoom, + room_member: &StrippedStateEvent, + _: Option, + ) { + if let SyncRoom::Invited(room) = room { + if let Some(user_id) = self.client.user_id().await { + if room_member.state_key != user_id { + return; + } + } + + //Clone to avoid holding lock. + let room = room.read().await.clone(); + info!("Autojoining room {}", room.display_name()); + + if let Err(e) = self.client.join_room_by_id(&room.room_id).await { + warn!("Could not join room: {}", e.to_string()) + } + } + } + + async fn on_room_message(&self, room: SyncRoom, event: &SyncMessageEvent) { + if let SyncRoom::Joined(room) = room { + let (msg_body, sender_username) = + if let Ok((msg_body, sender_username)) = should_process(self, &event).await { + (msg_body, sender_username) + } else { + return; + }; + + //we clone here to hold the lock for as little time as possible. + let real_room = room.read().await.clone(); + + self.execute_commands(&real_room, &sender_username, &msg_body) + .await; + } + } +} diff --git a/src/db/errors.rs b/src/db/errors.rs index cb25c9b..63fd2a5 100644 --- a/src/db/errors.rs +++ b/src/db/errors.rs @@ -62,13 +62,11 @@ impl From> for DataError { } } -// impl From> for DataError { -// fn from(error: ConflictableTransactionError) -> Self { -// match error { -// ConflictableTransactionError::Abort(data_err) => data_err, -// ConflictableTransactionError::Storage(storage_err) => { -// DataError::TransactionError(TransactionError::Storage(storage_err)) -// } -// } -// } -// } +/// Automatically aborts transactions that hit a DataError by using +/// the try (question mark) operator when this trait implementation is +/// in scope. +impl From for sled::transaction::ConflictableTransactionError { + fn from(error: DataError) -> Self { + sled::transaction::ConflictableTransactionError::Abort(error) + } +} diff --git a/src/db/rooms.rs b/src/db/rooms.rs index 3d507fb..61ccd57 100644 --- a/src/db/rooms.rs +++ b/src/db/rooms.rs @@ -1,15 +1,9 @@ use crate::db::errors::DataError; -use crate::db::schema::convert_i32; -use byteorder::LittleEndian; -use sled::transaction::{abort, TransactionalTree}; +use sled::transaction::TransactionalTree; use sled::Transactional; use sled::Tree; -use std::collections::HashMap; use std::collections::HashSet; -use std::convert::From; use std::str; -use zerocopy::byteorder::I32; -use zerocopy::AsBytes; #[derive(Clone)] pub struct Rooms { @@ -45,6 +39,46 @@ pub struct Rooms { // to_vec(value) // } // } +enum TxableTree<'a> { + Tree(&'a Tree), + Tx(&'a TransactionalTree), +} + +impl<'a> Into> for &'a Tree { + fn into(self) -> TxableTree<'a> { + TxableTree::Tree(self) + } +} + +impl<'a> Into> for &'a TransactionalTree { + fn into(self) -> TxableTree<'a> { + TxableTree::Tx(self) + } +} + +fn get_set<'a, T: Into>>(tree: T, key: &[u8]) -> Result, DataError> { + let set: HashSet = match tree.into() { + TxableTree::Tree(tree) => tree.get(key)?, + TxableTree::Tx(tx) => tx.get(key)?, + } + .map(|bytes| bincode::deserialize::>(&bytes)) + .unwrap_or(Ok(HashSet::new()))?; + + Ok(set) +} + +fn insert_set<'a, T: Into>>( + tree: T, + key: &[u8], + set: HashSet, +) -> Result<(), DataError> { + let serialized = bincode::serialize(&set)?; + match tree.into() { + TxableTree::Tree(tree) => tree.insert(key, serialized)?, + TxableTree::Tx(tx) => tx.insert(key, serialized)?, + }; + Ok(()) +} impl Rooms { pub(in crate::db) fn new(db: &sled::Db) -> Result { @@ -55,28 +89,121 @@ impl Rooms { }) } - pub fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError> { - //in txn: - //get or create list of users in room - //get or create list of rooms user is in - //deserialize/create set and add username to set for roomid - //deserialize/create set and add roomid to set for username - //store both again - let user_to_rooms: HashSet = self - .username_roomids - .get(username.as_bytes())? - .map(|bytes| bincode::deserialize::>(&bytes)) - .unwrap_or(Ok(HashSet::new()))?; + pub fn get_rooms_for_user(&self, username: &str) -> Result, DataError> { + get_set(&self.username_roomids, username.as_bytes()) + } + + pub fn get_users_in_room(&self, room_id: &str) -> Result, DataError> { + get_set(&self.roomid_usernames, room_id.as_bytes()) + } + + pub fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError> { + (&self.username_roomids, &self.roomid_usernames).transaction( + |(tx_username_rooms, tx_room_usernames)| { + let username_key = &username.as_bytes(); + let mut user_to_rooms = get_set(tx_username_rooms, username_key)?; + user_to_rooms.insert(room_id.to_string()); + insert_set(tx_username_rooms, username_key, user_to_rooms)?; + + let roomid_key = &room_id.as_bytes(); + let mut room_to_users = get_set(tx_room_usernames, roomid_key)?; + room_to_users.insert(username.to_string()); + insert_set(tx_room_usernames, roomid_key, room_to_users)?; + + Ok(()) + }, + )?; - let room_to_users: HashSet = self - .roomid_usernames - .get(room_id.as_bytes())? - .map(|bytes| bincode::deserialize::>(&bytes)) - .unwrap_or(Ok(HashSet::new()))?; Ok(()) } pub fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError> { + (&self.username_roomids, &self.roomid_usernames).transaction( + |(tx_username_rooms, tx_room_usernames)| { + let username_key = &username.as_bytes(); + let mut user_to_rooms = get_set(tx_username_rooms, username_key)?; + user_to_rooms.remove(room_id); + insert_set(tx_username_rooms, username_key, user_to_rooms)?; + + let roomid_key = &room_id.as_bytes(); + let mut room_to_users = get_set(tx_room_usernames, roomid_key)?; + room_to_users.remove(username); + insert_set(tx_room_usernames, roomid_key, room_to_users)?; + + Ok(()) + }, + )?; + + Ok(()) + } + + pub fn clear_info(&self, _room_id: &str) -> Result<(), DataError> { + //TODO implement me + //when bot leaves a room, it must, atomically: + // - delete roominfo struct from room info tree. + // - load list of users it knows about in room. + // - remove room id from every user's list. (cannot reuse existing fn because atomicity) + // - delete list of users in room from tree. Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use sled::Config; + + fn create_test_instance() -> Rooms { + let config = Config::new().temporary(true); + let db = config.open().unwrap(); + Rooms::new(&db).unwrap() + } + + #[test] + fn add_user_to_room() { + let rooms = create_test_instance(); + rooms + .add_user_to_room("testuser", "myroom") + .expect("Could not add user to room"); + + let users_in_room = rooms + .get_users_in_room("myroom") + .expect("Could not retrieve users in room"); + + let rooms_for_user = rooms + .get_rooms_for_user("testuser") + .expect("Could not get rooms for user"); + + let expected_users_in_room: HashSet = + vec![String::from("testuser")].into_iter().collect(); + + let expected_rooms_for_user: HashSet = + vec![String::from("myroom")].into_iter().collect(); + + assert_eq!(expected_users_in_room, users_in_room); + assert_eq!(expected_rooms_for_user, rooms_for_user); + } + + #[test] + fn remove_user_from_room() { + let rooms = create_test_instance(); + rooms + .add_user_to_room("testuser", "myroom") + .expect("Could not add user to room"); + + rooms + .remove_user_from_room("testuser", "myroom") + .expect("Could not remove user from room"); + + let users_in_room = rooms + .get_users_in_room("myroom") + .expect("Could not retrieve users in room"); + + let rooms_for_user = rooms + .get_rooms_for_user("testuser") + .expect("Could not get rooms for user"); + + assert_eq!(HashSet::new(), users_in_room); + assert_eq!(HashSet::new(), rooms_for_user); + } +} diff --git a/src/lib.rs b/src/lib.rs index 3b78482..50ed009 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod db; pub mod dice; pub mod error; mod help; +pub mod models; mod parser; pub mod state; pub mod variables; diff --git a/src/models.rs b/src/models.rs new file mode 100644 index 0000000..5d76f96 --- /dev/null +++ b/src/models.rs @@ -0,0 +1,5 @@ +/// RoomInfo has basic metadata about a room: its name, ID, etc. +pub struct RoomInfo { + pub room_id: String, + pub room_name: String, +}