diff --git a/src/bin/dicebot-cmd.rs b/src/bin/dicebot-cmd.rs index 3f63dd8..ed956f2 100644 --- a/src/bin/dicebot-cmd.rs +++ b/src/bin/dicebot-cmd.rs @@ -1,4 +1,5 @@ use matrix_sdk::identifiers::room_id; +use matrix_sdk::Client; use tenebrous_dicebot::commands; use tenebrous_dicebot::commands::ResponseExtractor; use tenebrous_dicebot::context::{Context, RoomContext}; @@ -26,11 +27,15 @@ async fn main() -> Result<(), BotError> { .await?; let context = Context { - db: db, + db, account: Account::default(), - matrix_client: &matrix_sdk::Client::new(homeserver) - .expect("Could not create matrix client"), - room: RoomContext { + matrix_client: Client::new(homeserver).expect("Could not create matrix client"), + origin_room: RoomContext { + id: &room_id!("!fakeroomid:example.com"), + display_name: "fake room".to_owned(), + secure: false, + }, + active_room: RoomContext { id: &room_id!("!fakeroomid:example.com"), display_name: "fake room".to_owned(), secure: false, diff --git a/src/bot/command_execution.rs b/src/bot/command_execution.rs index 1c72d1c..ad623d8 100644 --- a/src/bot/command_execution.rs +++ b/src/bot/command_execution.rs @@ -1,12 +1,21 @@ -use crate::commands::{execute_command, ExecutionResult, ResponseExtractor}; use crate::context::{Context, RoomContext}; use crate::db::sqlite::Database; use crate::error::BotError; use crate::logic; use crate::matrix; +use crate::{ + commands::{execute_command, ExecutionResult, ResponseExtractor}, + models::Account, +}; use futures::stream::{self, StreamExt}; -use matrix_sdk::{self, identifiers::EventId, room::Joined, Client}; +use matrix_sdk::{ + self, + identifiers::{EventId, RoomId}, + room::Joined, + Client, +}; use std::clone::Clone; +use std::convert::TryFrom; /// Handle responding to a single command being executed. Wil print /// out the full result of that command. @@ -95,24 +104,57 @@ pub(super) async fn handle_multiple_results( matrix::send_message(client, room.room_id(), (&message, &plain), None).await; } -/// Create a context for command execution. Can fai if the room -/// context creation fails. -async fn create_context<'a>( - db: &'a Database, - client: &'a Client, - room: &'a Joined, - sender: &'a str, - command: &'a str, -) -> Result, BotError> { - let room_ctx = RoomContext::new(room, sender).await?; - Ok(Context { +/// Map an account's active room value to an actual matrix room, if +/// the account has an active room. This only retrieves the +/// user-specified active room, and doesn't perform any further +/// filtering. +fn get_account_active_room(client: &Client, account: &Account) -> Result, BotError> { + let active_room = account + .registered_user() + .and_then(|u| u.active_room.as_deref()) + .map(|room_id| RoomId::try_from(room_id)) + .transpose()? + .and_then(|active_room_id| client.get_joined_room(&active_room_id)); + + Ok(active_room) +} + +/// Execute a single command in the list of commands. Can fail if the +/// Account value cannot be created/fetched from the database, or if +/// room display names cannot be calculated. Otherwise, the success or +/// error of command execution itself is returned. +async fn execute_single_command( + command: &str, + db: &Database, + client: &Client, + origin_room: &Joined, + sender: &str, +) -> ExecutionResult { + let origin_ctx = RoomContext::new(origin_room, sender).await?; + let account = logic::get_account(db, sender).await?; + let active_room = get_account_active_room(client, &account)?; + + // Active room is used in secure command-issuing rooms. In + // "public" rooms, where other users are, treat origin as the + // active room. + let active_room = active_room + .as_ref() + .filter(|_| origin_ctx.secure) + .unwrap_or(origin_room); + + let active_ctx = RoomContext::new(active_room, sender).await?; + + let ctx = Context { + account, db: db.clone(), - matrix_client: client, - room: room_ctx, + matrix_client: client.clone(), + origin_room: origin_ctx, username: &sender, - account: logic::get_account(db, &sender).await?, + active_room: active_ctx, message_body: &command, - }) + }; + + execute_command(&ctx).await } /// Attempt to execute all commands sent to the bot in a message. This @@ -127,13 +169,8 @@ pub(super) async fn execute( ) -> Vec<(String, ExecutionResult)> { stream::iter(commands) .then(|command| async move { - match create_context(db, client, room, sender, command).await { - Err(e) => (command.to_owned(), Err(e)), - Ok(ctx) => { - let cmd_result = execute_command(&ctx).await; - (command.to_owned(), cmd_result) - } - } + let result = execute_single_command(command, db, client, room, sender).await; + (command.to_owned(), result) }) .collect() .await diff --git a/src/cofd/dice.rs b/src/cofd/dice.rs index 976ec51..a218375 100644 --- a/src/cofd/dice.rs +++ b/src/cofd/dice.rs @@ -485,8 +485,9 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db, - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: dummy_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: dummy_room!(), + active_room: dummy_room!(), username: "username", message_body: "message", }; @@ -526,8 +527,9 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db, - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: dummy_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: dummy_room!(), + active_room: dummy_room!(), username: "username", message_body: "message", }; @@ -564,15 +566,21 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db.clone(), - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: dummy_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: dummy_room!(), + active_room: dummy_room!(), username: "username", message_body: "message", }; - db.set_user_variable(&ctx.username, &ctx.room.id.as_str(), "myvariable", 10) - .await - .expect("could not set myvariable to 10"); + db.set_user_variable( + &ctx.username, + &ctx.origin_room.id.as_str(), + "myvariable", + 10, + ) + .await + .expect("could not set myvariable to 10"); let amounts = vec![Amount { operator: Operator::Plus, diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 33f9acf..41042c5 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -146,13 +146,13 @@ fn log_command(cmd: &(impl Command + ?Sized), ctx: &Context, result: &ExecutionR Ok(_) => { info!( "[{}] {} <{}{}> - success", - ctx.room.display_name, ctx.username, command, dots + ctx.origin_room.display_name, ctx.username, command, dots ); } Err(e) => { error!( "[{}] {} <{}{}> - {}", - ctx.room.display_name, ctx.username, command, dots, e + ctx.origin_room.display_name, ctx.username, command, dots, e ); } }; @@ -196,8 +196,9 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db, - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: secure_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: secure_room!(), + active_room: secure_room!(), username: "myusername", message_body: "!notacommand", }; @@ -218,8 +219,9 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db, - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: secure_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: secure_room!(), + active_room: secure_room!(), username: "myusername", message_body: "!notacommand", }; @@ -240,8 +242,9 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db, - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: dummy_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: dummy_room!(), + active_room: dummy_room!(), username: "myusername", message_body: "!notacommand", }; @@ -262,8 +265,9 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db, - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: dummy_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: dummy_room!(), + active_room: dummy_room!(), username: "myusername", message_body: "!notacommand", }; @@ -284,8 +288,9 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db, - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: dummy_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: dummy_room!(), + active_room: dummy_room!(), username: "myusername", message_body: "!notacommand", }; diff --git a/src/commands/rooms.rs b/src/commands/rooms.rs index 48b90f6..3cb3761 100644 --- a/src/commands/rooms.rs +++ b/src/commands/rooms.rs @@ -110,7 +110,7 @@ impl Command for ListRoomsCommand { } async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { - let rooms_for_user: Vec = get_rooms_for_user(ctx.matrix_client, ctx.username) + let rooms_for_user: Vec = get_rooms_for_user(&ctx.matrix_client, ctx.username) .await .map(|rooms| { rooms @@ -155,7 +155,7 @@ impl Command for SetRoomCommand { return Err(BotError::AccountDoesNotExist); } - let rooms_for_user = get_rooms_for_user(ctx.matrix_client, ctx.username).await?; + let rooms_for_user = get_rooms_for_user(&ctx.matrix_client, ctx.username).await?; let room = search_for_room(&rooms_for_user, &self.0); if let Some(room) = room { diff --git a/src/commands/variables.rs b/src/commands/variables.rs index 6b1a242..26124d5 100644 --- a/src/commands/variables.rs +++ b/src/commands/variables.rs @@ -35,7 +35,7 @@ impl Command for GetAllVariablesCommand { async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { let variables = ctx .db - .get_user_variables(&ctx.username, ctx.room_id().as_str()) + .get_user_variables(&ctx.username, ctx.active_room_id().as_str()) .await?; let mut variable_list: Vec = variables @@ -85,7 +85,7 @@ impl Command for GetVariableCommand { let name = &self.0; let result = ctx .db - .get_user_variable(&ctx.username, ctx.room_id().as_str(), name) + .get_user_variable(&ctx.username, ctx.active_room_id().as_str(), name) .await; let value = match result { @@ -131,7 +131,7 @@ impl Command for SetVariableCommand { let value = self.1; ctx.db - .set_user_variable(&ctx.username, ctx.room_id().as_str(), name, value) + .set_user_variable(&ctx.username, ctx.active_room_id().as_str(), name, value) .await?; let content = format!("{} = {}", name, value); @@ -170,7 +170,7 @@ impl Command for DeleteVariableCommand { let name = &self.0; let result = ctx .db - .delete_user_variable(&ctx.username, ctx.room_id().as_str(), name) + .delete_user_variable(&ctx.username, ctx.active_room_id().as_str(), name) .await; let value = match result { diff --git a/src/context.rs b/src/context.rs index 4dfe669..20b0f6f 100644 --- a/src/context.rs +++ b/src/context.rs @@ -11,20 +11,25 @@ use std::convert::TryFrom; #[derive(Clone)] pub struct Context<'a> { pub db: Database, - pub matrix_client: &'a Client, - pub room: RoomContext<'a>, + pub matrix_client: Client, + pub origin_room: RoomContext<'a>, + pub active_room: RoomContext<'a>, pub username: &'a str, pub message_body: &'a str, pub account: Account, } impl Context<'_> { + pub fn active_room_id(&self) -> &RoomId { + self.active_room.id + } + pub fn room_id(&self) -> &RoomId { - self.room.id + self.origin_room.id } pub fn is_secure(&self) -> bool { - self.room.secure + self.origin_room.secure } } @@ -38,15 +43,21 @@ pub struct RoomContext<'a> { impl RoomContext<'_> { pub async fn new_with_name<'a>( room: &'a Joined, - display_name: String, sending_user: &str, ) -> Result, BotError> { - // TODO is_direct is a hack; should set rooms to Direct - // Message upon joining, if other contact has requested it. - // Waiting on SDK support. + // TODO is_direct is a hack; the bot should set eligible rooms + // to Direct Message upon joining, if other contact has + // requested it. Waiting on SDK support. + let display_name = room + .display_name() + .await + .ok() + .unwrap_or_default() + .to_string(); + let sending_user = UserId::try_from(sending_user)?; let user_in_room = room.get_member(&sending_user).await.ok().is_some(); - let is_direct = room.joined_members().await?.len() == 2; + let is_direct = room.active_members().await?.len() == 2; Ok(RoomContext { id: room.room_id(), @@ -57,17 +68,8 @@ impl RoomContext<'_> { pub async fn new<'a>( room: &'a Joined, - sending_user: &str, + sending_user: &'a str, ) -> Result, BotError> { - Self::new_with_name( - &room, - room.display_name() - .await - .ok() - .unwrap_or_default() - .to_string(), - sending_user, - ) - .await + Self::new_with_name(room, sending_user).await } } diff --git a/src/cthulhu/dice.rs b/src/cthulhu/dice.rs index 73ad1fb..aa856ce 100644 --- a/src/cthulhu/dice.rs +++ b/src/cthulhu/dice.rs @@ -380,7 +380,12 @@ async fn update_skill(ctx: &Context<'_>, variable: &str, value: u32) -> Result<( use std::convert::TryInto; let value: i32 = value.try_into()?; ctx.db - .set_user_variable(&ctx.username, &ctx.room_id().as_str(), variable, value) + .set_user_variable( + &ctx.username, + &ctx.active_room_id().as_str(), + variable, + value, + ) .await?; Ok(()) } @@ -506,8 +511,9 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db, - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: dummy_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: dummy_room!(), + active_room: dummy_room!(), username: "username", message_body: "message", }; @@ -543,8 +549,9 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db, - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: dummy_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: dummy_room!(), + active_room: dummy_room!(), username: "username", message_body: "message", }; @@ -580,8 +587,9 @@ mod tests { let ctx = Context { account: crate::models::Account::default(), db: db, - matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), - room: dummy_room!(), + matrix_client: matrix_sdk::Client::new(homeserver).unwrap(), + origin_room: dummy_room!(), + active_room: dummy_room!(), username: "username", message_body: "message", }; diff --git a/src/logic.rs b/src/logic.rs index 58dfd06..70fa796 100644 --- a/src/logic.rs +++ b/src/logic.rs @@ -27,7 +27,7 @@ pub async fn calculate_dice_amount(amounts: &[Amount], ctx: &Context<'_>) -> Res let stream = stream::iter(amounts); let variables = &ctx .db - .get_user_variables(&ctx.username, ctx.room_id().as_str()) + .get_user_variables(&ctx.username, ctx.active_room_id().as_str()) .await?; use DiceRollingError::VariableNotFound;