diff --git a/src/bot.rs b/src/bot.rs index a687551..8928b2a 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -4,27 +4,25 @@ use crate::context::Context; use crate::db::Database; use crate::error::BotError; use crate::state::DiceBotState; -use async_trait::async_trait; use dirs; -use log::{debug, error, info, warn}; +use log::{error, info}; use matrix_sdk::Error as MatrixError; use matrix_sdk::{ self, events::{ - room::member::MemberEventContent, - room::message::{MessageEventContent, NoticeMessageEventContent, TextMessageEventContent}, - AnyMessageEventContent, StrippedStateEvent, SyncMessageEvent, + room::message::{MessageEventContent, NoticeMessageEventContent}, + AnyMessageEventContent, }, - Client, ClientConfig, EventEmitter, JsonStore, Room, SyncRoom, SyncSettings, + Client, ClientConfig, JsonStore, Room, SyncSettings, }; //use matrix_sdk_common_macros::async_trait; use std::clone::Clone; -use std::ops::Sub; use std::path::PathBuf; use std::sync::{Arc, RwLock}; -use std::time::{Duration, SystemTime}; use url::Url; +pub mod event_handlers; + /// The DiceBot struct represents an active dice bot. The bot is not /// connected to Matrix until its run() function is called. pub struct DiceBot { @@ -167,104 +165,3 @@ impl DiceBot { } } } - -/// Check if a message is recent enough to actually process. If the -/// message is within "oldest_message_age" seconds, this function -/// returns true. If it's older than that, it returns false and logs a -/// debug message. -fn check_message_age( - event: &SyncMessageEvent, - oldest_message_age: u64, -) -> bool { - let sending_time = event.origin_server_ts; - let oldest_timestamp = SystemTime::now().sub(Duration::new(oldest_message_age, 0)); - - if sending_time > oldest_timestamp { - true - } else { - let age = match oldest_timestamp.duration_since(sending_time) { - Ok(n) => format!("{} seconds too old", n.as_secs()), - Err(_) => "before the UNIX epoch".to_owned(), - }; - - debug!("Ignoring message because it is {}: {:?}", age, event); - false - } -} - -async fn should_process<'a>( - bot: &DiceBot, - event: &SyncMessageEvent, -) -> Result<(String, String), BotError> { - //Ignore messages that are older than configured duration. - if !check_message_age(event, bot.config.oldest_message_age()) { - let state_check = bot.state.read().unwrap(); - if !((*state_check).logged_skipped_old_messages()) { - drop(state_check); - let mut state = bot.state.write().unwrap(); - (*state).skipped_old_messages(); - } - - return Err(BotError::ShouldNotProcessError); - } - - let (msg_body, sender_username) = if let SyncMessageEvent { - content: MessageEventContent::Text(TextMessageEventContent { body, .. }), - sender, - .. - } = event - { - ( - body.clone(), - format!("@{}:{}", sender.localpart(), sender.server_name()), - ) - } else { - (String::new(), String::new()) - }; - - Ok((msg_body, sender_username)) -} - -/// This event emitter listens for messages with dice rolling commands. -/// Originally adapted from the matrix-rust-sdk examples. -#[async_trait] -impl EventEmitter for DiceBot { - async fn on_stripped_state_member( - &self, - room: SyncRoom, - room_member: &StrippedStateEvent, - _: Option, - ) { - if let SyncRoom::Invited(room) = room { - if let Some(user_id) = self.client.user_id().await { - if room_member.state_key != user_id { - return; - } - } - - let room = room.read().await; - info!("Autojoining room {}", room.display_name()); - - if let Err(e) = self.client.join_room_by_id(&room.room_id).await { - warn!("Could not join room: {}", e.to_string()) - } - } - } - - async fn on_room_message(&self, room: SyncRoom, event: &SyncMessageEvent) { - if let SyncRoom::Joined(room) = room { - let (msg_body, sender_username) = - if let Ok((msg_body, sender_username)) = should_process(self, &event).await { - (msg_body, sender_username) - } else { - return; - }; - - //we clone here to hold the lock for as little time as possible. - let real_room = room.read().await.clone(); - - self.execute_commands(&real_room, &sender_username, &msg_body) - .await; - } - } -} diff --git a/src/bot/event_handlers.rs b/src/bot/event_handlers.rs new file mode 100644 index 0000000..6b1b4c5 --- /dev/null +++ b/src/bot/event_handlers.rs @@ -0,0 +1,210 @@ +use crate::db::Database; +use crate::error::BotError; +use async_trait::async_trait; +use log::{debug, error, info, warn}; +use matrix_sdk::{ + self, + events::{ + room::member::{MemberEventContent, MembershipChange}, + room::message::{MessageEventContent, TextMessageEventContent}, + StrippedStateEvent, SyncMessageEvent, SyncStateEvent, + }, + identifiers::RoomId, + Client, EventEmitter, SyncRoom, +}; +//use matrix_sdk_common_macros::async_trait; +use super::DiceBot; +use std::clone::Clone; +use std::ops::Sub; +use std::time::{Duration, SystemTime}; + +/// Check if a message is recent enough to actually process. If the +/// message is within "oldest_message_age" seconds, this function +/// returns true. If it's older than that, it returns false and logs a +/// debug message. +fn check_message_age( + event: &SyncMessageEvent, + oldest_message_age: u64, +) -> bool { + let sending_time = event.origin_server_ts; + let oldest_timestamp = SystemTime::now().sub(Duration::new(oldest_message_age, 0)); + + if sending_time > oldest_timestamp { + true + } else { + let age = match oldest_timestamp.duration_since(sending_time) { + Ok(n) => format!("{} seconds too old", n.as_secs()), + Err(_) => "before the UNIX epoch".to_owned(), + }; + + debug!("Ignoring message because it is {}: {:?}", age, event); + false + } +} + +/// Determine whether or not to process a received message. This check +/// is necessary in addition to the event processing check because we +/// may receive message events when entering a room for the first +/// time, and we don't want to respond to things before the bot was in +/// the channel, but we do want to respond to things that were sent if +/// the bot left and rejoined quickly. +async fn should_process_message<'a>( + bot: &DiceBot, + event: &SyncMessageEvent, +) -> Result<(String, String), BotError> { + //Ignore messages that are older than configured duration. + if !check_message_age(event, bot.config.oldest_message_age()) { + let state_check = bot.state.read().unwrap(); + if !((*state_check).logged_skipped_old_messages()) { + drop(state_check); + let mut state = bot.state.write().unwrap(); + (*state).skipped_old_messages(); + } + + return Err(BotError::ShouldNotProcessError); + } + + let (msg_body, sender_username) = if let SyncMessageEvent { + content: MessageEventContent::Text(TextMessageEventContent { body, .. }), + sender, + .. + } = event + { + ( + body.clone(), + format!("@{}:{}", sender.localpart(), sender.server_name()), + ) + } else { + (String::new(), String::new()) + }; + + Ok((msg_body, sender_username)) +} + +fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool { + db.rooms + .should_process(room_id, event_id) + .unwrap_or_else(|e| { + error!( + "Database error when checking if we should process an event: {}", + e.to_string() + ); + false + }) +} + +async fn get_users_in_room(client: &Client, room_id: &RoomId) -> Vec { + if let Some(joined_room) = client.get_joined_room(room_id).await { + let joined_room: matrix_sdk::Room = joined_room.read().await.clone(); + joined_room + .joined_members + .keys() + .map(|user_id| format!("@{}:{}", user_id.localpart(), user_id.server_name())) + .collect() + } else { + vec![] + } +} + +/// This event emitter listens for messages with dice rolling commands. +/// Originally adapted from the matrix-rust-sdk examples. +#[async_trait] +impl EventEmitter for DiceBot { + async fn on_room_member(&self, room: SyncRoom, event: &SyncStateEvent) { + if let SyncRoom::Joined(room) | SyncRoom::Left(room) = room { + //Clone to avoid holding lock. + let room = room.read().await.clone(); + let (room_id, username) = (room.room_id.as_str(), &event.state_key); + + if !should_process_event(&self.db, room_id, event.event_id.as_str()) { + return; + } + + let event_affects_us = if let Some(our_user_id) = self.client.user_id().await { + event.state_key == our_user_id + } else { + false + }; + + use MembershipChange::*; + let adding_user = match event.membership_change() { + Joined => true, + Banned | Left | Kicked | KickedAndBanned => false, + _ => return, + }; + + let result = if event_affects_us && !adding_user { + info!("Clearing all information for room ID {}", room_id); + self.db.rooms.clear_info(room_id) + } else if event_affects_us && adding_user { + info!("Joined room {}; recording user information", room_id); + let usernames = get_users_in_room(&self.client, &room.room_id).await; + usernames + .into_iter() + .filter(|username| username != &event.state_key) + .map(|username| self.db.rooms.add_user_to_room(&username, room_id)) + .collect() //Make use of collect impl on Result. + } else if !event_affects_us && adding_user { + info!("Adding user {} to room ID {}", username, room_id); + self.db.rooms.add_user_to_room(username, room_id) + } else if !event_affects_us && !adding_user { + info!("Removing user {} from room ID {}", username, room_id); + self.db.rooms.remove_user_from_room(username, room_id) + } else { + debug!("Ignoring a room member event: {:#?}", event); + Ok(()) + }; + + if let Err(e) = result { + error!("Could not update room information: {}", e.to_string()); + } else { + debug!("Successfully processed room member update."); + } + } + } + + async fn on_stripped_state_member( + &self, + room: SyncRoom, + event: &StrippedStateEvent, + _: Option, + ) { + if let SyncRoom::Invited(room) = room { + if let Some(user_id) = self.client.user_id().await { + if event.state_key != user_id { + return; + } + } + + //Clone to avoid holding lock. + let room = room.read().await.clone(); + info!("Autojoining room {}", room.display_name()); + + if let Err(e) = self.client.join_room_by_id(&room.room_id).await { + warn!("Could not join room: {}", e.to_string()) + } + } + } + + async fn on_room_message(&self, room: SyncRoom, event: &SyncMessageEvent) { + if let SyncRoom::Joined(room) = room { + //Clone to avoid holding lock. + let room = room.read().await.clone(); + let room_id = room.room_id.as_str(); + if !should_process_event(&self.db, room_id, event.event_id.as_str()) { + return; + } + + let (msg_body, sender_username) = if let Ok((msg_body, sender_username)) = + should_process_message(self, &event).await + { + (msg_body, sender_username) + } else { + return; + }; + + self.execute_commands(&room, &sender_username, &msg_body) + .await; + } + } +} diff --git a/src/db.rs b/src/db.rs index 62a94e6..cf8c05b 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,5 +1,6 @@ use crate::db::errors::{DataError, MigrationError}; use crate::db::migrations::{get_migration_version, Migrations}; +use crate::db::rooms::Rooms; use crate::db::variables::Variables; use log::info; use sled::{Config, Db}; @@ -8,6 +9,7 @@ use std::path::Path; pub mod data_migrations; pub mod errors; pub mod migrations; +pub mod rooms; pub mod schema; pub mod variables; @@ -16,17 +18,25 @@ pub struct Database { db: Db, pub(crate) variables: Variables, pub(crate) migrations: Migrations, + pub(crate) rooms: Rooms, } impl Database { fn new_db(db: sled::Db) -> Result { let migrations = db.open_tree("migrations")?; - Ok(Database { + let database = Database { db: db.clone(), variables: Variables::new(&db)?, migrations: Migrations(migrations), - }) + rooms: Rooms::new(&db)?, + }; + + //Start any event handlers. + database.rooms.start_handler(); + + info!("Opened database."); + Ok(database) } pub fn new>(path: P) -> Result { diff --git a/src/db/errors.rs b/src/db/errors.rs index cb25c9b..6cfe7f7 100644 --- a/src/db/errors.rs +++ b/src/db/errors.rs @@ -26,6 +26,9 @@ pub enum DataError { #[error("expected i32, but i32 schema was violated")] I32SchemaViolation, + #[error("unexpected or corruptd data bytes")] + InvalidValue, + #[error("expected string, but utf8 schema was violated: {0}")] Utf8chemaViolation(#[from] std::str::Utf8Error), @@ -62,13 +65,11 @@ impl From> for DataError { } } -// impl From> for DataError { -// fn from(error: ConflictableTransactionError) -> Self { -// match error { -// ConflictableTransactionError::Abort(data_err) => data_err, -// ConflictableTransactionError::Storage(storage_err) => { -// DataError::TransactionError(TransactionError::Storage(storage_err)) -// } -// } -// } -// } +/// Automatically aborts transactions that hit a DataError by using +/// the try (question mark) operator when this trait implementation is +/// in scope. +impl From for sled::transaction::ConflictableTransactionError { + fn from(error: DataError) -> Self { + sled::transaction::ConflictableTransactionError::Abort(error) + } +} diff --git a/src/db/rooms.rs b/src/db/rooms.rs new file mode 100644 index 0000000..02a80ea --- /dev/null +++ b/src/db/rooms.rs @@ -0,0 +1,430 @@ +use crate::db::errors::DataError; +use crate::db::schema::convert_u64; +use byteorder::BigEndian; +use log::{debug, error, log_enabled}; +use sled::transaction::TransactionalTree; +use sled::Transactional; +use sled::{CompareAndSwapError, Tree}; +use std::collections::HashSet; +use std::str; +use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::task::JoinHandle; +use zerocopy::byteorder::U64; +use zerocopy::AsBytes; + +#[derive(Clone)] +pub struct Rooms { + /// Room ID -> RoomInfo struct (single entries) + pub(in crate::db) roomid_roominfo: Tree, + + /// Room ID -> list of usernames in room. + pub(in crate::db) roomid_usernames: Tree, + + /// Username -> list of room IDs user is in. + pub(in crate::db) username_roomids: Tree, + + /// Room ID(str) 0xff event ID(str) -> timestamp. Records event + /// IDs that we have received, so we do not process twice. + pub(in crate::db) roomeventid_timestamp: Tree, + + /// Room ID(str) 0xff timestamp(u64) -> event ID. Records event + /// IDs with timestamp as the primary key instead. Exists to allow + /// easy scanning of old roomeventid_timestamp records for + /// removal. Be careful with u64, it can have 0xff and 0xfe bytes. + /// A simple split on 0xff will not work with this key. Instead, + /// it is meant to be split on the first 0xff byte only, and + /// separated into room ID and timestamp. + pub(in crate::db) roomtimestamp_eventid: Tree, +} + +/// An enum that can hold either a regular sled Tree, or a +/// Transactional tree. +#[derive(Clone, Copy)] +enum TxableTree<'a> { + Tree(&'a Tree), + Tx(&'a TransactionalTree), +} + +impl<'a> Into> for &'a Tree { + fn into(self) -> TxableTree<'a> { + TxableTree::Tree(self) + } +} + +impl<'a> Into> for &'a TransactionalTree { + fn into(self) -> TxableTree<'a> { + TxableTree::Tx(self) + } +} + +/// A set of functions that can be used with a sled::Tree that stores +/// HashSets as its values. Atomicity is partially handled. If the +/// Tree is a transactional tree, operations will be atomic. +/// Otherwise, there is a potential non-atomic step. +mod hashset_tree { + use super::*; + + fn insert_set<'a, T: Into>>( + tree: T, + key: &[u8], + set: HashSet, + ) -> Result<(), DataError> { + let serialized = bincode::serialize(&set)?; + match tree.into() { + TxableTree::Tree(tree) => tree.insert(key, serialized)?, + TxableTree::Tx(tx) => tx.insert(key, serialized)?, + }; + Ok(()) + } + + pub(super) fn get_set<'a, T: Into>>( + tree: T, + key: &[u8], + ) -> Result, DataError> { + let set: HashSet = match tree.into() { + TxableTree::Tree(tree) => tree.get(key)?, + TxableTree::Tx(tx) => tx.get(key)?, + } + .map(|bytes| bincode::deserialize::>(&bytes)) + .unwrap_or(Ok(HashSet::new()))?; + + Ok(set) + } + + pub(super) fn remove_from_set<'a, T: Into> + Copy>( + tree: T, + key: &[u8], + value_to_remove: &str, + ) -> Result<(), DataError> { + let mut set = get_set(tree, key)?; + set.remove(value_to_remove); + insert_set(tree, key, set)?; + Ok(()) + } + + pub(super) fn add_to_set<'a, T: Into> + Copy>( + tree: T, + key: &[u8], + value_to_add: String, + ) -> Result<(), DataError> { + let mut set = get_set(tree, key)?; + set.insert(value_to_add); + insert_set(tree, key, set)?; + Ok(()) + } +} + +/// Functions that specifically relate to the "timestamp index" tree, +/// which is stored on the Rooms instance as a tree called +/// roomtimestamp_eventid. Tightly coupled to the event watcher in the +/// Rooms impl, and only factored out for unit testing. +mod timestamp_index { + use super::*; + + /// Insert an entry from the main roomeventid_timestamp Tree into + /// the timestamp index. Keys in this Tree are stored as room ID + /// 0xff timestamp, with the value being a hashset of event IDs + /// received at the time. The parameters come from an insert to + /// that Tree, where the key is room ID 0xff event ID, and the + /// value is the timestamp. + pub(super) fn insert( + roomtimestamp_eventid: &Tree, + key: &[u8], + timestamp_bytes: &[u8], + ) -> Result<(), DataError> { + let parts: Vec<&[u8]> = key.split(|&b| b == 0xff).collect(); + if let [room_id, event_id] = parts[..] { + let mut ts_key = room_id.to_vec(); + ts_key.push(0xff); + ts_key.extend_from_slice(×tamp_bytes); + log_index_record(room_id, event_id, ×tamp_bytes); + + let event_id = str::from_utf8(event_id)?; + hashset_tree::add_to_set(roomtimestamp_eventid, &ts_key, event_id.to_owned())?; + Ok(()) + } else { + Err(DataError::InvalidValue) + } + } + + /// Log a debug message. + fn log_index_record(room_id: &[u8], event_id: &[u8], timestamp: &[u8]) { + if log_enabled!(log::Level::Debug) { + debug!( + "Recording event {} | {} received at {} in timestamp index.", + str::from_utf8(room_id).unwrap_or("[invalid room id]"), + str::from_utf8(event_id).unwrap_or("[invalid event id]"), + convert_u64(timestamp).unwrap_or(0) + ); + } + } +} + +impl Rooms { + pub(in crate::db) fn new(db: &sled::Db) -> Result { + Ok(Rooms { + roomid_roominfo: db.open_tree("roomid_roominfo")?, + roomid_usernames: db.open_tree("roomid_usernames")?, + username_roomids: db.open_tree("username_roomids")?, + roomeventid_timestamp: db.open_tree("roomeventid_timestamp")?, + roomtimestamp_eventid: db.open_tree("roomtimestamp_eventid")?, + }) + } + + /// Start an event subscriber that listens for inserts made by the + /// `should_process` function. This event handler duplicates the + /// entry by timestamp instead of event ID. + pub(in crate::db) fn start_handler(&self) -> JoinHandle<()> { + //Clone due to lifetime requirements. + let roomeventid_timestamp = self.roomeventid_timestamp.clone(); + let roomtimestamp_eventid = self.roomtimestamp_eventid.clone(); + + tokio::spawn(async move { + let mut subscriber = roomeventid_timestamp.watch_prefix(b""); + + // TODO make this handler receive kill messages somehow so + // we can unit test it and gracefully shut it down. + while let Some(event) = (&mut subscriber).await { + if let sled::Event::Insert { key, value } = event { + match timestamp_index::insert(&roomtimestamp_eventid, &key, &value) { + Err(e) => { + error!("Unable to update the timestamp index: {}", e); + } + _ => (), + } + } + } + }) + } + + /// Determine if an event in a room should be processed. The event + /// is atomically recorded and true returned if the database has + /// not seen tis event yet. If the event already exists in the + /// database, the function returns false. Events are recorded by + /// this function by inserting the (system-local) timestamp in + /// epoch seconds. + pub fn should_process(&self, room_id: &str, event_id: &str) -> Result { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(event_id.as_bytes()); + + let timestamp: U64 = U64::new( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Clock has gone backwards") + .as_secs(), + ); + + match self.roomeventid_timestamp.compare_and_swap( + key, + None as Option<&[u8]>, + Some(timestamp.as_bytes()), + )? { + Ok(()) => Ok(true), + Err(CompareAndSwapError { .. }) => Ok(false), + } + } + + pub fn get_rooms_for_user(&self, username: &str) -> Result, DataError> { + hashset_tree::get_set(&self.username_roomids, username.as_bytes()) + } + + pub fn get_users_in_room(&self, room_id: &str) -> Result, DataError> { + hashset_tree::get_set(&self.roomid_usernames, room_id.as_bytes()) + } + + pub fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError> { + debug!("Adding user {} to room {}", username, room_id); + (&self.username_roomids, &self.roomid_usernames).transaction( + |(tx_username_rooms, tx_room_usernames)| { + let username_key = &username.as_bytes(); + hashset_tree::add_to_set(tx_username_rooms, username_key, room_id.to_owned())?; + + let roomid_key = &room_id.as_bytes(); + hashset_tree::add_to_set(tx_room_usernames, roomid_key, username.to_owned())?; + + Ok(()) + }, + )?; + + Ok(()) + } + + pub fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError> { + debug!("Removing user {} from room {}", username, room_id); + (&self.username_roomids, &self.roomid_usernames).transaction( + |(tx_username_rooms, tx_room_usernames)| { + let username_key = &username.as_bytes(); + hashset_tree::remove_from_set(tx_username_rooms, username_key, room_id)?; + + let roomid_key = &room_id.as_bytes(); + hashset_tree::remove_from_set(tx_room_usernames, roomid_key, username)?; + + Ok(()) + }, + )?; + + Ok(()) + } + + pub fn clear_info(&self, room_id: &str) -> Result<(), DataError> { + debug!("Clearing all information for room {}", room_id); + (&self.username_roomids, &self.roomid_usernames).transaction( + |(tx_username_roomids, tx_roomid_usernames)| { + let roomid_key = room_id.as_bytes(); + let users_in_room = hashset_tree::get_set(tx_roomid_usernames, roomid_key)?; + + //Remove the room ID from every user's room ID list. + for username in users_in_room { + hashset_tree::remove_from_set( + tx_username_roomids, + username.as_bytes(), + room_id, + )?; + } + + //Remove this room entry for the room ID -> username tree. + tx_roomid_usernames.remove(roomid_key)?; + + //TODO: delete roominfo struct from room info tree. + Ok(()) + }, + )?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use sled::Config; + + fn create_test_instance() -> Rooms { + let config = Config::new().temporary(true); + let db = config.open().unwrap(); + Rooms::new(&db).unwrap() + } + + #[test] + fn add_user_to_room() { + let rooms = create_test_instance(); + + rooms + .add_user_to_room("testuser", "myroom") + .expect("Could not add user to room"); + + let users_in_room = rooms + .get_users_in_room("myroom") + .expect("Could not retrieve users in room"); + + let rooms_for_user = rooms + .get_rooms_for_user("testuser") + .expect("Could not get rooms for user"); + + let expected_users_in_room: HashSet = + vec![String::from("testuser")].into_iter().collect(); + + let expected_rooms_for_user: HashSet = + vec![String::from("myroom")].into_iter().collect(); + + assert_eq!(expected_users_in_room, users_in_room); + assert_eq!(expected_rooms_for_user, rooms_for_user); + } + + #[test] + fn remove_user_from_room() { + let rooms = create_test_instance(); + + rooms + .add_user_to_room("testuser", "myroom") + .expect("Could not add user to room"); + + rooms + .remove_user_from_room("testuser", "myroom") + .expect("Could not remove user from room"); + + let users_in_room = rooms + .get_users_in_room("myroom") + .expect("Could not retrieve users in room"); + + let rooms_for_user = rooms + .get_rooms_for_user("testuser") + .expect("Could not get rooms for user"); + + assert_eq!(HashSet::new(), users_in_room); + assert_eq!(HashSet::new(), rooms_for_user); + } + + #[test] + fn clear_info() { + let rooms = create_test_instance(); + + rooms + .add_user_to_room("testuser", "myroom1") + .expect("Could not add user to room1"); + + rooms + .add_user_to_room("testuser", "myroom2") + .expect("Could not add user to room2"); + + rooms + .clear_info("myroom1") + .expect("Could not clear room info"); + + let users_in_room1 = rooms + .get_users_in_room("myroom1") + .expect("Could not retrieve users in room1"); + + let users_in_room2 = rooms + .get_users_in_room("myroom2") + .expect("Could not retrieve users in room2"); + + let rooms_for_user = rooms + .get_rooms_for_user("testuser") + .expect("Could not get rooms for user"); + + let expected_users_in_room2: HashSet = + vec![String::from("testuser")].into_iter().collect(); + + let expected_rooms_for_user: HashSet = + vec![String::from("myroom2")].into_iter().collect(); + + assert_eq!(HashSet::new(), users_in_room1); + assert_eq!(expected_users_in_room2, users_in_room2); + assert_eq!(expected_rooms_for_user, rooms_for_user); + } + + #[test] + fn insert_to_timestamp_index() { + let rooms = create_test_instance(); + + // Insertion into timestamp index based on data that would go + // into main room x eventID -> timestamp tree. + let mut key = b"myroom".to_vec(); + key.push(0xff); + key.extend_from_slice(b"myeventid"); + + let timestamp: U64 = U64::new(1000); + + let result = timestamp_index::insert( + &rooms.roomtimestamp_eventid, + key.as_bytes(), + timestamp.as_bytes(), + ); + + assert!(result.is_ok()); + + // Retrieval of data from the timestamp index tree. + let mut ts_key = b"myroom".to_vec(); + ts_key.push(0xff); + ts_key.extend_from_slice(timestamp.as_bytes()); + + let expected_events: HashSet = + vec![String::from("myeventid")].into_iter().collect(); + + let event_ids = hashset_tree::get_set(&rooms.roomtimestamp_eventid, &ts_key) + .expect("Could not get set out of Tree"); + assert_eq!(expected_events, event_ids); + } +} diff --git a/src/db/schema.rs b/src/db/schema.rs index 61eec0a..35a31ef 100644 --- a/src/db/schema.rs +++ b/src/db/schema.rs @@ -1,6 +1,6 @@ use crate::db::errors::DataError; -use byteorder::LittleEndian; -use zerocopy::byteorder::{I32, U32}; +use byteorder::{BigEndian, LittleEndian}; +use zerocopy::byteorder::{I32, U32, U64}; use zerocopy::LayoutVerified; /// User variables are stored as little-endian 32-bit integers in the @@ -10,6 +10,11 @@ type LittleEndianI32Layout<'a> = LayoutVerified<&'a [u8], I32>; type LittleEndianU32Layout<'a> = LayoutVerified<&'a [u8], U32>; +#[allow(dead_code)] +type LittleEndianU64Layout<'a> = LayoutVerified<&'a [u8], U64>; + +type BigEndianU64Layout<'a> = LayoutVerified<&'a [u8], U64>; + /// Convert bytes to an i32 with zero-copy deserialization. An error /// is returned if the bytes do not represent an i32. pub(super) fn convert_i32(raw_value: &[u8]) -> Result { @@ -33,3 +38,15 @@ pub(super) fn convert_u32(raw_value: &[u8]) -> Result { Err(DataError::I32SchemaViolation) } } + +#[allow(dead_code)] +pub(super) fn convert_u64(raw_value: &[u8]) -> Result { + let layout = BigEndianU64Layout::new_unaligned(raw_value.as_ref()); + + if let Some(layout) = layout { + let value: U64 = *layout; + Ok(value.get()) + } else { + Err(DataError::I32SchemaViolation) + } +} diff --git a/src/lib.rs b/src/lib.rs index 3b78482..50ed009 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod db; pub mod dice; pub mod error; mod help; +pub mod models; mod parser; pub mod state; pub mod variables; diff --git a/src/models.rs b/src/models.rs new file mode 100644 index 0000000..5d76f96 --- /dev/null +++ b/src/models.rs @@ -0,0 +1,5 @@ +/// RoomInfo has basic metadata about a room: its name, ID, etc. +pub struct RoomInfo { + pub room_id: String, + pub room_name: String, +}