diff --git a/src/cofd/dice.rs b/src/cofd/dice.rs index 960dbb3..b522569 100644 --- a/src/cofd/dice.rs +++ b/src/cofd/dice.rs @@ -1,7 +1,8 @@ use crate::context::Context; use crate::db::DataError::KeyDoesNotExist; use crate::error::BotError; -use crate::roll::{Roll, Rolled}; +use crate::roll::Rolled; +use futures::stream::{self, StreamExt, TryStreamExt}; use itertools::Itertools; use std::convert::TryFrom; use std::fmt; @@ -119,17 +120,18 @@ pub struct DicePool { pub(crate) modifiers: DicePoolModifiers, } -fn calculate_dice_amount(pool: &DicePoolWithContext) -> Result { - let dice_amount: Result = pool - .0 - .amounts - .iter() - .map(|amount| match &amount.element { - Element::Number(num_dice) => Ok(*num_dice * amount.operator.mult()), - Element::Variable(variable) => handle_variable(&pool.1, &variable), +async fn calculate_dice_amount<'a>(pool: &'a DicePoolWithContext<'a>) -> Result { + let stream = stream::iter(&pool.0.amounts); + + let dice_amount: Result = stream + .then(|amount| async move { + match &amount.element { + Element::Number(num_dice) => Ok(*num_dice * amount.operator.mult()), + Element::Variable(variable) => handle_variable(&pool.1, &variable).await, + } }) - .collect::, _>>() - .map(|numbers| numbers.iter().sum()); + .try_fold(0, |total, num_dice| async move { Ok(total + num_dice) }) + .await; dice_amount } @@ -257,14 +259,6 @@ impl DicePoolRoll { /// Attach a Context to a dice pool. Needed for database access. pub struct DicePoolWithContext<'a>(pub &'a DicePool, pub &'a Context); -impl Roll for DicePoolWithContext<'_> { - type Output = Result; - - fn roll(&self) -> Result { - roll_dice(self, &mut RngDieRoller(rand::thread_rng())) - } -} - impl Rolled for DicePoolRoll { fn rolled_value(&self) -> i32 { self.successes() @@ -358,29 +352,32 @@ fn roll_die(roller: &mut R, pool: &DicePool) -> Vec { results } -fn handle_variable(ctx: &Context, variable: &str) -> Result { +async fn handle_variable(ctx: &Context, variable: &str) -> Result { ctx.db .get_user_variable(&ctx.room_id, &ctx.username, variable) + .await .map_err(|e| match e { KeyDoesNotExist(_) => DiceRollingError::VariableNotFound(variable.to_owned()).into(), _ => e.into(), }) } +fn roll_dice<'a, R: DieRoller>(pool: &DicePool, num_dice: i32, roller: &mut R) -> Vec { + (0..num_dice) + .flat_map(|_| roll_die(roller, &pool)) + .collect() +} + ///Roll the dice in a dice pool, according to behavior documented in the various rolling ///methods. -fn roll_dice( - pool: &DicePoolWithContext, - roller: &mut R, -) -> Result { +pub async fn roll_pool(pool: &DicePoolWithContext<'_>) -> Result { if pool.0.amounts.len() > 100 { return Err(DiceRollingError::ExpressionTooLarge.into()); } - let num_dice = calculate_dice_amount(&pool)?; - let rolls: Vec = (0..num_dice) - .flat_map(|_| roll_die(roller, &pool.0)) - .collect(); + let num_dice = calculate_dice_amount(&pool).await?; + let mut roller = RngDieRoller(rand::thread_rng()); + let rolls = roll_dice(&pool.0, num_dice, &mut roller); Ok(RolledDicePool::from(&pool.0, num_dice, rolls)) } @@ -510,36 +507,23 @@ mod tests { #[test] pub fn no_explode_roll_test() { - let db = Database::new(&sled::open(tempdir().unwrap()).unwrap()); - let ctx = Context::new(&db, "roomid", "username", "message"); let pool = DicePool::easy_pool(1, DicePoolQuality::NoExplode); - let pool_with_ctx = DicePoolWithContext(&pool, &ctx); - let mut roller = SequentialDieRoller::new(vec![10, 8]); - let result = roll_dice(&pool_with_ctx, &mut roller); - assert!(result.is_ok()); - - let roll = result.unwrap().roll; - assert_eq!(vec![10], roll.rolls()); + let roll = roll_dice(&pool, 1, &mut roller); + assert_eq!(vec![10], roll); } - #[test] - pub fn number_of_dice_equality_test() { - let db = Database::new(&sled::open(tempdir().unwrap()).unwrap()); - let ctx = Context::new(&db, "roomid", "username", "message"); - let pool = DicePool::easy_pool(5, DicePoolQuality::NoExplode); - let pool_with_ctx = DicePoolWithContext(&pool, &ctx); - - let mut roller = SequentialDieRoller::new(vec![1, 2, 3, 4, 5]); - let result = roll_dice(&pool_with_ctx, &mut roller); - assert!(result.is_ok()); - - let roll = result.unwrap(); - assert_eq!(5, roll.num_dice); + #[tokio::test] + async fn number_of_dice_equality_test() { + let num_dice = 5; + let rolls = vec![1, 2, 3, 4, 5]; + let pool = DicePool::easy_pool(5, DicePoolQuality::TenAgain); + let rolled_pool = RolledDicePool::from(&pool, num_dice, rolls); + assert_eq!(5, rolled_pool.num_dice); } - #[test] - fn rejects_large_expression_test() { + #[tokio::test] + async fn rejects_large_expression_test() { let db = Database::new(&sled::open(tempdir().unwrap()).unwrap()); let ctx = Context::new(&db, "roomid", "username", "message"); @@ -554,9 +538,7 @@ mod tests { let pool = DicePool::new(amounts, DicePoolModifiers::default()); let pool_with_ctx = DicePoolWithContext(&pool, &ctx); - - let mut roller = SequentialDieRoller::new(vec![1, 2, 3, 4, 5]); - let result = roll_dice(&pool_with_ctx, &mut roller); + let result = roll_pool(&pool_with_ctx).await; assert!(matches!( result, Err(BotError::DiceRollingError( @@ -565,12 +547,13 @@ mod tests { )); } - #[test] - fn can_resolve_variables_test() { + #[tokio::test] + async fn can_resolve_variables_test() { let db = Database::new(&sled::open(tempdir().unwrap()).unwrap()); let ctx = Context::new(&db, "roomid", "username", "message"); db.set_user_variable(&ctx.room_id, &ctx.username, "myvariable", 10) + .await .expect("could not set myvariable to 10"); let amounts = vec![Amount { @@ -581,7 +564,7 @@ mod tests { let pool = DicePool::new(amounts, DicePoolModifiers::default()); let pool_with_ctx = DicePoolWithContext(&pool, &ctx); - assert_eq!(calculate_dice_amount(&pool_with_ctx).unwrap(), 10); + assert_eq!(calculate_dice_amount(&pool_with_ctx).await.unwrap(), 10); } //DicePool tests diff --git a/src/commands.rs b/src/commands.rs index 7edfe3d..ffe6d2c 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -1,4 +1,4 @@ -use crate::cofd::dice::{DicePool, DicePoolWithContext}; +use crate::cofd::dice::{roll_pool, DicePool, DicePoolWithContext}; use crate::context::Context; use crate::db::DataError; use crate::dice::ElementExpression; @@ -69,7 +69,7 @@ impl Command for PoolRollCommand { async fn execute(&self, ctx: &Context) -> Execution { let pool_with_ctx = DicePoolWithContext(&self.0, ctx); - let roll_result = pool_with_ctx.roll(); + let roll_result = roll_pool(&pool_with_ctx).await; let (plain, html) = match roll_result { Ok(rolled_pool) => { @@ -121,7 +121,11 @@ impl Command for GetVariableCommand { async fn execute(&self, ctx: &Context) -> Execution { let name = &self.0; - let value = match ctx.db.get_user_variable(&ctx.room_id, &ctx.username, name) { + let value = match ctx + .db + .get_user_variable(&ctx.room_id, &ctx.username, name) + .await + { Ok(num) => format!("{} = {}", name, num), Err(DataError::KeyDoesNotExist(_)) => format!("{} is not set", name), Err(e) => format!("error getting {}: {}", name, e), @@ -146,7 +150,8 @@ impl Command for SetVariableCommand { let value = self.1; let result = ctx .db - .set_user_variable(&ctx.room_id, &ctx.username, name, value); + .set_user_variable(&ctx.room_id, &ctx.username, name, value) + .await; let content = match result { Ok(_) => format!("{} = {}", name, value), @@ -172,6 +177,7 @@ impl Command for DeleteVariableCommand { let value = match ctx .db .delete_user_variable(&ctx.room_id, &ctx.username, name) + .await { Ok(()) => format!("{} now unset", name), Err(DataError::KeyDoesNotExist(_)) => format!("{} is not currently set", name), @@ -235,39 +241,39 @@ mod tests { #[test] fn chance_die_is_not_malformed() { - assert!(Command::parse("!chance").is_ok()); + assert!(parse("!chance").is_ok()); } #[test] fn roll_malformed_expression_test() { - assert!(Command::parse("!roll 1d20asdlfkj").is_err()); - assert!(Command::parse("!roll 1d20asdlfkj ").is_err()); + assert!(parse("!roll 1d20asdlfkj").is_err()); + assert!(parse("!roll 1d20asdlfkj ").is_err()); } #[test] fn roll_dice_pool_malformed_expression_test() { - assert!(Command::parse("!pool 8abc").is_err()); - assert!(Command::parse("!pool 8abc ").is_err()); + assert!(parse("!pool 8abc").is_err()); + assert!(parse("!pool 8abc ").is_err()); } #[test] fn pool_whitespace_test() { - Command::parse("!pool ns3:8 ").expect("was error"); - Command::parse(" !pool ns3:8").expect("was error"); - Command::parse(" !pool ns3:8 ").expect("was error"); + parse("!pool ns3:8 ").expect("was error"); + parse(" !pool ns3:8").expect("was error"); + parse(" !pool ns3:8 ").expect("was error"); } #[test] fn help_whitespace_test() { - Command::parse("!help stuff ").expect("was error"); - Command::parse(" !help stuff").expect("was error"); - Command::parse(" !help stuff ").expect("was error"); + parse("!help stuff ").expect("was error"); + parse(" !help stuff").expect("was error"); + parse(" !help stuff ").expect("was error"); } #[test] fn roll_whitespace_test() { - Command::parse("!roll 1d4 + 5d6 -3 ").expect("was error"); - Command::parse("!roll 1d4 + 5d6 -3 ").expect("was error"); - Command::parse(" !roll 1d4 + 5d6 -3 ").expect("was error"); + parse("!roll 1d4 + 5d6 -3 ").expect("was error"); + parse("!roll 1d4 + 5d6 -3 ").expect("was error"); + parse(" !roll 1d4 + 5d6 -3 ").expect("was error"); } } diff --git a/src/db.rs b/src/db.rs index f3fa959..d0a091b 100644 --- a/src/db.rs +++ b/src/db.rs @@ -39,7 +39,7 @@ impl Database { Database { db: db.clone() } } - pub fn get_user_variable( + pub async fn get_user_variable( &self, room_id: &str, username: &str, @@ -61,7 +61,7 @@ impl Database { } } - pub fn set_user_variable( + pub async fn set_user_variable( &self, room_id: &str, username: &str, @@ -74,7 +74,7 @@ impl Database { Ok(()) } - pub fn delete_user_variable( + pub async fn delete_user_variable( &self, room_id: &str, username: &str,