From 0ae92d311febfcbb02d0a7dad1a99fd3cee31934 Mon Sep 17 00:00:00 2001
From: projectmoon <projectmoon@agnos.is>
Date: Sat, 17 Oct 2020 20:24:24 +0000
Subject: [PATCH] Fully async dice rolling. Also remove more unnecessary stuff.

---
 src/cofd/dice.rs | 99 ++++++++++++++++++++----------------------------
 src/commands.rs  | 42 +++++++++++---------
 src/db.rs        |  6 +--
 3 files changed, 68 insertions(+), 79 deletions(-)

diff --git a/src/cofd/dice.rs b/src/cofd/dice.rs
index 960dbb3..b522569 100644
--- a/src/cofd/dice.rs
+++ b/src/cofd/dice.rs
@@ -1,7 +1,8 @@
 use crate::context::Context;
 use crate::db::DataError::KeyDoesNotExist;
 use crate::error::BotError;
-use crate::roll::{Roll, Rolled};
+use crate::roll::Rolled;
+use futures::stream::{self, StreamExt, TryStreamExt};
 use itertools::Itertools;
 use std::convert::TryFrom;
 use std::fmt;
@@ -119,17 +120,18 @@ pub struct DicePool {
     pub(crate) modifiers: DicePoolModifiers,
 }
 
-fn calculate_dice_amount(pool: &DicePoolWithContext) -> Result<i32, BotError> {
-    let dice_amount: Result<i32, BotError> = pool
-        .0
-        .amounts
-        .iter()
-        .map(|amount| match &amount.element {
-            Element::Number(num_dice) => Ok(*num_dice * amount.operator.mult()),
-            Element::Variable(variable) => handle_variable(&pool.1, &variable),
+async fn calculate_dice_amount<'a>(pool: &'a DicePoolWithContext<'a>) -> Result<i32, BotError> {
+    let stream = stream::iter(&pool.0.amounts);
+
+    let dice_amount: Result<i32, BotError> = stream
+        .then(|amount| async move {
+            match &amount.element {
+                Element::Number(num_dice) => Ok(*num_dice * amount.operator.mult()),
+                Element::Variable(variable) => handle_variable(&pool.1, &variable).await,
+            }
         })
-        .collect::<Result<Vec<i32>, _>>()
-        .map(|numbers| numbers.iter().sum());
+        .try_fold(0, |total, num_dice| async move { Ok(total + num_dice) })
+        .await;
 
     dice_amount
 }
@@ -257,14 +259,6 @@ impl DicePoolRoll {
 /// Attach a Context to a dice pool. Needed for database access.
 pub struct DicePoolWithContext<'a>(pub &'a DicePool, pub &'a Context);
 
-impl Roll for DicePoolWithContext<'_> {
-    type Output = Result<RolledDicePool, BotError>;
-
-    fn roll(&self) -> Result<RolledDicePool, BotError> {
-        roll_dice(self, &mut RngDieRoller(rand::thread_rng()))
-    }
-}
-
 impl Rolled for DicePoolRoll {
     fn rolled_value(&self) -> i32 {
         self.successes()
@@ -358,29 +352,32 @@ fn roll_die<R: DieRoller>(roller: &mut R, pool: &DicePool) -> Vec<i32> {
     results
 }
 
-fn handle_variable(ctx: &Context, variable: &str) -> Result<i32, BotError> {
+async fn handle_variable(ctx: &Context, variable: &str) -> Result<i32, BotError> {
     ctx.db
         .get_user_variable(&ctx.room_id, &ctx.username, variable)
+        .await
         .map_err(|e| match e {
             KeyDoesNotExist(_) => DiceRollingError::VariableNotFound(variable.to_owned()).into(),
             _ => e.into(),
         })
 }
 
+fn roll_dice<'a, R: DieRoller>(pool: &DicePool, num_dice: i32, roller: &mut R) -> Vec<i32> {
+    (0..num_dice)
+        .flat_map(|_| roll_die(roller, &pool))
+        .collect()
+}
+
 ///Roll the dice in a dice pool, according to behavior documented in the various rolling
 ///methods.
-fn roll_dice<R: DieRoller>(
-    pool: &DicePoolWithContext,
-    roller: &mut R,
-) -> Result<RolledDicePool, BotError> {
+pub async fn roll_pool(pool: &DicePoolWithContext<'_>) -> Result<RolledDicePool, BotError> {
     if pool.0.amounts.len() > 100 {
         return Err(DiceRollingError::ExpressionTooLarge.into());
     }
 
-    let num_dice = calculate_dice_amount(&pool)?;
-    let rolls: Vec<i32> = (0..num_dice)
-        .flat_map(|_| roll_die(roller, &pool.0))
-        .collect();
+    let num_dice = calculate_dice_amount(&pool).await?;
+    let mut roller = RngDieRoller(rand::thread_rng());
+    let rolls = roll_dice(&pool.0, num_dice, &mut roller);
     Ok(RolledDicePool::from(&pool.0, num_dice, rolls))
 }
 
@@ -510,36 +507,23 @@ mod tests {
 
     #[test]
     pub fn no_explode_roll_test() {
-        let db = Database::new(&sled::open(tempdir().unwrap()).unwrap());
-        let ctx = Context::new(&db, "roomid", "username", "message");
         let pool = DicePool::easy_pool(1, DicePoolQuality::NoExplode);
-        let pool_with_ctx = DicePoolWithContext(&pool, &ctx);
-
         let mut roller = SequentialDieRoller::new(vec![10, 8]);
-        let result = roll_dice(&pool_with_ctx, &mut roller);
-        assert!(result.is_ok());
-
-        let roll = result.unwrap().roll;
-        assert_eq!(vec![10], roll.rolls());
+        let roll = roll_dice(&pool, 1, &mut roller);
+        assert_eq!(vec![10], roll);
     }
 
-    #[test]
-    pub fn number_of_dice_equality_test() {
-        let db = Database::new(&sled::open(tempdir().unwrap()).unwrap());
-        let ctx = Context::new(&db, "roomid", "username", "message");
-        let pool = DicePool::easy_pool(5, DicePoolQuality::NoExplode);
-        let pool_with_ctx = DicePoolWithContext(&pool, &ctx);
-
-        let mut roller = SequentialDieRoller::new(vec![1, 2, 3, 4, 5]);
-        let result = roll_dice(&pool_with_ctx, &mut roller);
-        assert!(result.is_ok());
-
-        let roll = result.unwrap();
-        assert_eq!(5, roll.num_dice);
+    #[tokio::test]
+    async fn number_of_dice_equality_test() {
+        let num_dice = 5;
+        let rolls = vec![1, 2, 3, 4, 5];
+        let pool = DicePool::easy_pool(5, DicePoolQuality::TenAgain);
+        let rolled_pool = RolledDicePool::from(&pool, num_dice, rolls);
+        assert_eq!(5, rolled_pool.num_dice);
     }
 
-    #[test]
-    fn rejects_large_expression_test() {
+    #[tokio::test]
+    async fn rejects_large_expression_test() {
         let db = Database::new(&sled::open(tempdir().unwrap()).unwrap());
         let ctx = Context::new(&db, "roomid", "username", "message");
 
@@ -554,9 +538,7 @@ mod tests {
 
         let pool = DicePool::new(amounts, DicePoolModifiers::default());
         let pool_with_ctx = DicePoolWithContext(&pool, &ctx);
-
-        let mut roller = SequentialDieRoller::new(vec![1, 2, 3, 4, 5]);
-        let result = roll_dice(&pool_with_ctx, &mut roller);
+        let result = roll_pool(&pool_with_ctx).await;
         assert!(matches!(
             result,
             Err(BotError::DiceRollingError(
@@ -565,12 +547,13 @@ mod tests {
         ));
     }
 
-    #[test]
-    fn can_resolve_variables_test() {
+    #[tokio::test]
+    async fn can_resolve_variables_test() {
         let db = Database::new(&sled::open(tempdir().unwrap()).unwrap());
         let ctx = Context::new(&db, "roomid", "username", "message");
 
         db.set_user_variable(&ctx.room_id, &ctx.username, "myvariable", 10)
+            .await
             .expect("could not set myvariable to 10");
 
         let amounts = vec![Amount {
@@ -581,7 +564,7 @@ mod tests {
         let pool = DicePool::new(amounts, DicePoolModifiers::default());
         let pool_with_ctx = DicePoolWithContext(&pool, &ctx);
 
-        assert_eq!(calculate_dice_amount(&pool_with_ctx).unwrap(), 10);
+        assert_eq!(calculate_dice_amount(&pool_with_ctx).await.unwrap(), 10);
     }
 
     //DicePool tests
diff --git a/src/commands.rs b/src/commands.rs
index 7edfe3d..ffe6d2c 100644
--- a/src/commands.rs
+++ b/src/commands.rs
@@ -1,4 +1,4 @@
-use crate::cofd::dice::{DicePool, DicePoolWithContext};
+use crate::cofd::dice::{roll_pool, DicePool, DicePoolWithContext};
 use crate::context::Context;
 use crate::db::DataError;
 use crate::dice::ElementExpression;
@@ -69,7 +69,7 @@ impl Command for PoolRollCommand {
 
     async fn execute(&self, ctx: &Context) -> Execution {
         let pool_with_ctx = DicePoolWithContext(&self.0, ctx);
-        let roll_result = pool_with_ctx.roll();
+        let roll_result = roll_pool(&pool_with_ctx).await;
 
         let (plain, html) = match roll_result {
             Ok(rolled_pool) => {
@@ -121,7 +121,11 @@ impl Command for GetVariableCommand {
 
     async fn execute(&self, ctx: &Context) -> Execution {
         let name = &self.0;
-        let value = match ctx.db.get_user_variable(&ctx.room_id, &ctx.username, name) {
+        let value = match ctx
+            .db
+            .get_user_variable(&ctx.room_id, &ctx.username, name)
+            .await
+        {
             Ok(num) => format!("{} = {}", name, num),
             Err(DataError::KeyDoesNotExist(_)) => format!("{} is not set", name),
             Err(e) => format!("error getting {}: {}", name, e),
@@ -146,7 +150,8 @@ impl Command for SetVariableCommand {
         let value = self.1;
         let result = ctx
             .db
-            .set_user_variable(&ctx.room_id, &ctx.username, name, value);
+            .set_user_variable(&ctx.room_id, &ctx.username, name, value)
+            .await;
 
         let content = match result {
             Ok(_) => format!("{} = {}", name, value),
@@ -172,6 +177,7 @@ impl Command for DeleteVariableCommand {
         let value = match ctx
             .db
             .delete_user_variable(&ctx.room_id, &ctx.username, name)
+            .await
         {
             Ok(()) => format!("{} now unset", name),
             Err(DataError::KeyDoesNotExist(_)) => format!("{} is not currently set", name),
@@ -235,39 +241,39 @@ mod tests {
 
     #[test]
     fn chance_die_is_not_malformed() {
-        assert!(Command::parse("!chance").is_ok());
+        assert!(parse("!chance").is_ok());
     }
 
     #[test]
     fn roll_malformed_expression_test() {
-        assert!(Command::parse("!roll 1d20asdlfkj").is_err());
-        assert!(Command::parse("!roll 1d20asdlfkj   ").is_err());
+        assert!(parse("!roll 1d20asdlfkj").is_err());
+        assert!(parse("!roll 1d20asdlfkj   ").is_err());
     }
 
     #[test]
     fn roll_dice_pool_malformed_expression_test() {
-        assert!(Command::parse("!pool 8abc").is_err());
-        assert!(Command::parse("!pool 8abc    ").is_err());
+        assert!(parse("!pool 8abc").is_err());
+        assert!(parse("!pool 8abc    ").is_err());
     }
 
     #[test]
     fn pool_whitespace_test() {
-        Command::parse("!pool ns3:8   ").expect("was error");
-        Command::parse("   !pool ns3:8").expect("was error");
-        Command::parse("   !pool ns3:8   ").expect("was error");
+        parse("!pool ns3:8   ").expect("was error");
+        parse("   !pool ns3:8").expect("was error");
+        parse("   !pool ns3:8   ").expect("was error");
     }
 
     #[test]
     fn help_whitespace_test() {
-        Command::parse("!help stuff   ").expect("was error");
-        Command::parse("   !help stuff").expect("was error");
-        Command::parse("   !help stuff   ").expect("was error");
+        parse("!help stuff   ").expect("was error");
+        parse("   !help stuff").expect("was error");
+        parse("   !help stuff   ").expect("was error");
     }
 
     #[test]
     fn roll_whitespace_test() {
-        Command::parse("!roll 1d4 + 5d6 -3   ").expect("was error");
-        Command::parse("!roll 1d4 + 5d6 -3   ").expect("was error");
-        Command::parse("   !roll 1d4 + 5d6 -3   ").expect("was error");
+        parse("!roll 1d4 + 5d6 -3   ").expect("was error");
+        parse("!roll 1d4 + 5d6 -3   ").expect("was error");
+        parse("   !roll 1d4 + 5d6 -3   ").expect("was error");
     }
 }
diff --git a/src/db.rs b/src/db.rs
index f3fa959..d0a091b 100644
--- a/src/db.rs
+++ b/src/db.rs
@@ -39,7 +39,7 @@ impl Database {
         Database { db: db.clone() }
     }
 
-    pub fn get_user_variable(
+    pub async fn get_user_variable(
         &self,
         room_id: &str,
         username: &str,
@@ -61,7 +61,7 @@ impl Database {
         }
     }
 
-    pub fn set_user_variable(
+    pub async fn set_user_variable(
         &self,
         room_id: &str,
         username: &str,
@@ -74,7 +74,7 @@ impl Database {
         Ok(())
     }
 
-    pub fn delete_user_variable(
+    pub async fn delete_user_variable(
         &self,
         room_id: &str,
         username: &str,