diff --git a/.gitignore b/.gitignore index cbd44e8..95c619a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,6 @@ cache *.tar *.tar.gz test-db/ + +# We store a disabled async test in this file +bigboy diff --git a/src/cofd/dice.rs b/src/cofd/dice.rs index b522569..008d45e 100644 --- a/src/cofd/dice.rs +++ b/src/cofd/dice.rs @@ -122,12 +122,24 @@ pub struct DicePool { async fn calculate_dice_amount<'a>(pool: &'a DicePoolWithContext<'a>) -> Result { let stream = stream::iter(&pool.0.amounts); + let variables = pool + .1 + .db + .get_user_variables(&pool.1.room_id, &pool.1.username) + .await?; + let variables = &variables; + + use DiceRollingError::VariableNotFound; let dice_amount: Result = stream .then(|amount| async move { match &amount.element { Element::Number(num_dice) => Ok(*num_dice * amount.operator.mult()), - Element::Variable(variable) => handle_variable(&pool.1, &variable).await, + Element::Variable(variable) => variables + .get(variable) + .ok_or(VariableNotFound(variable.clone().to_string())) + .map(|i| *i) + .map_err(|e| e.into()), } }) .try_fold(0, |total, num_dice| async move { Ok(total + num_dice) }) @@ -352,16 +364,6 @@ fn roll_die(roller: &mut R, pool: &DicePool) -> Vec { results } -async fn handle_variable(ctx: &Context, variable: &str) -> Result { - ctx.db - .get_user_variable(&ctx.room_id, &ctx.username, variable) - .await - .map_err(|e| match e { - KeyDoesNotExist(_) => DiceRollingError::VariableNotFound(variable.to_owned()).into(), - _ => e.into(), - }) -} - fn roll_dice<'a, R: DieRoller>(pool: &DicePool, num_dice: i32, roller: &mut R) -> Vec { (0..num_dice) .flat_map(|_| roll_die(roller, &pool)) diff --git a/src/db.rs b/src/db.rs index d0a091b..9d9f687 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,5 +1,6 @@ use byteorder::LittleEndian; use sled::{Db, IVec}; +use std::collections::HashMap; use thiserror::Error; use zerocopy::byteorder::I32; use zerocopy::{AsBytes, LayoutVerified}; @@ -14,13 +15,18 @@ pub struct Database { db: Db, } +//TODO better combining of key and value in certain errors (namely +//I32SchemaViolation). #[derive(Error, Debug)] pub enum DataError { #[error("value does not exist for key: {0}")] KeyDoesNotExist(String), - #[error("key violates expected schema: {0}")] - SchemaViolation(String), + #[error("expected i32, but i32 schema was violated")] + I32SchemaViolation, + + #[error("expected string, but utf8 schema was violated: {0}")] + Utf8chemaViolation(#[from] std::str::Utf8Error), #[error("internal database error: {0}")] InternalError(#[from] sled::Error), @@ -34,11 +40,56 @@ fn to_key(room_id: &str, username: &str, variable_name: &str) -> Vec { key } +fn to_prefix(room_id: &str, username: &str) -> Vec { + let mut prefix = vec![]; + prefix.extend_from_slice(room_id.as_bytes()); + prefix.extend_from_slice(username.as_bytes()); + prefix +} + +fn convert(raw_value: &[u8]) -> Result { + let layout = LittleEndianI32Layout::new_unaligned(raw_value.as_ref()); + + if let Some(layout) = layout { + let value: I32 = *layout; + Ok(value.get()) + } else { + Err(DataError::I32SchemaViolation) + } +} + impl Database { pub fn new(db: &Db) -> Database { Database { db: db.clone() } } + pub async fn get_user_variables( + &self, + room_id: &str, + username: &str, + ) -> Result, DataError> { + let prefix = to_prefix(&room_id, &username); + let prefix_len: usize = prefix.len(); + + let variables: Result, DataError> = self + .db + .scan_prefix(prefix) + .map(|entry| match entry { + Ok((key, raw_value)) => { + //Strips room and username from key, leaving + //behind name. + let variable_name = std::str::from_utf8(&key[prefix_len..])?; + Ok((variable_name.to_owned(), convert(&raw_value)?)) + } + Err(e) => Err(e.into()), + }) + .collect(); + + //Convert to hash map. Can we do this in the first mapping + //step instead? + variables.map(|entries| entries.into_iter().collect()) + } + pub async fn get_user_variable( &self, room_id: &str, @@ -48,14 +99,7 @@ impl Database { let key = to_key(room_id, username, variable_name); if let Some(raw_value) = self.db.get(&key)? { - let layout = LittleEndianI32Layout::new_unaligned(raw_value.as_ref()); - - if let Some(layout) = layout { - let value: I32 = *layout; - Ok(value.get()) - } else { - Err(DataError::SchemaViolation(String::from_utf8(key).unwrap())) - } + convert(&raw_value) } else { Err(DataError::KeyDoesNotExist(String::from_utf8(key).unwrap())) }