Compare commits

..

No commits in common. "402f236ba79cf5582635cf2ca5a5e64f061be85b" and "a33367fadac3f4525fb67f6fd1f6c5bd7aedb6fd" have entirely different histories.

29 changed files with 1928 additions and 105 deletions

41
Cargo.lock generated
View File

@ -180,6 +180,15 @@ version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
[[package]]
name = "bincode"
version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.2.1" version = "1.2.1"
@ -1194,6 +1203,12 @@ version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc"
[[package]]
name = "memmem"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15"
[[package]] [[package]]
name = "memoffset" name = "memoffset"
version = "0.6.3" version = "0.6.3"
@ -2498,6 +2513,8 @@ version = "0.10.0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"barrel", "barrel",
"bincode",
"byteorder",
"combine", "combine",
"dirs", "dirs",
"futures", "futures",
@ -2506,11 +2523,13 @@ dependencies = [
"itertools", "itertools",
"log", "log",
"matrix-sdk", "matrix-sdk",
"memmem",
"nom 5.1.2", "nom 5.1.2",
"phf", "phf",
"rand 0.8.3", "rand 0.8.3",
"refinery", "refinery",
"serde", "serde",
"sled",
"sqlx", "sqlx",
"tempfile", "tempfile",
"thiserror", "thiserror",
@ -2518,6 +2537,7 @@ dependencies = [
"toml", "toml",
"tracing-subscriber", "tracing-subscriber",
"url", "url",
"zerocopy",
] ]
[[package]] [[package]]
@ -3073,6 +3093,27 @@ dependencies = [
"time 0.1.43", "time 0.1.43",
] ]
[[package]]
name = "zerocopy"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e59ec1d2457bd6c0dd89b50e7d9d6b0b647809bf3f0a59ac85557046950b7b2"
dependencies = [
"byteorder",
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0af017aca1fa6181f5dd7a802456fe6f7666ecdcc18d0910431f0fc89d474e51"
dependencies = [
"proc-macro2",
"syn",
"synstructure",
]
[[package]] [[package]]
name = "zeroize" name = "zeroize"
version = "1.3.0" version = "1.3.0"

View File

@ -27,7 +27,12 @@ url = "2.1"
dirs = "3.0" dirs = "3.0"
indoc = "1.0" indoc = "1.0"
combine = "4.5" combine = "4.5"
sled = "0.34"
zerocopy = "0.5"
byteorder = "1.3"
futures = "0.3" futures = "0.3"
memmem = "0.1"
bincode = "1.3"
html2text = "0.2" html2text = "0.2"
phf = { version = "0.8", features = ["macros"] } phf = { version = "0.8", features = ["macros"] }
matrix-sdk = { git = "https://github.com/matrix-org/matrix-rust-sdk", branch = "master" } matrix-sdk = { git = "https://github.com/matrix-org/matrix-rust-sdk", branch = "master" }

37
src/bin/migrate_sled.rs Normal file
View File

@ -0,0 +1,37 @@
use tenebrous_dicebot::db::sqlite::{Database as SqliteDatabase, Variables};
use tenebrous_dicebot::db::Database;
use tenebrous_dicebot::error::BotError;
#[tokio::main]
async fn main() -> Result<(), BotError> {
let sled_path = std::env::args()
.skip(1)
.next()
.expect("Need a path to a Sled database as an arument.");
let sqlite_path = std::env::args()
.skip(2)
.next()
.expect("Need a path to an sqlite database as an arument.");
let db = Database::new(&sled_path)?;
let all_variables = db.variables.get_all_variables()?;
let sql_db = SqliteDatabase::new(&sqlite_path).await?;
for var in all_variables {
if let ((username, room_id, variable_name), value) = var {
println!(
"Migrating {}::{}::{} = {} to sql",
username, room_id, variable_name, value
);
sql_db
.set_user_variable(&username, &room_id, &variable_name, value)
.await;
}
}
Ok(())
}

View File

@ -2,7 +2,7 @@ use crate::commands::{execute_command, ExecutionError, ExecutionResult, Response
use crate::config::*; use crate::config::*;
use crate::context::{Context, RoomContext}; use crate::context::{Context, RoomContext};
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::DbState; use crate::db::sqlite::DbState;
use crate::error::BotError; use crate::error::BotError;
use crate::matrix; use crate::matrix;
use crate::state::DiceBotState; use crate::state::DiceBotState;

View File

@ -5,7 +5,7 @@
*/ */
use super::DiceBot; use super::DiceBot;
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::Rooms; use crate::db::sqlite::Rooms;
use crate::error::BotError; use crate::error::BotError;
use crate::logic::record_room_information; use crate::logic::record_room_information;
use async_trait::async_trait; use async_trait::async_trait;

View File

@ -326,7 +326,7 @@ pub async fn roll_pool(pool: &DicePoolWithContext<'_>) -> Result<RolledDicePool,
mod tests { mod tests {
use super::*; use super::*;
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::Variables; use crate::db::sqlite::Variables;
use url::Url; use url::Url;
macro_rules! dummy_room { macro_rules! dummy_room {

View File

@ -2,7 +2,7 @@ use crate::context::Context;
use crate::error::BotError; use crate::error::BotError;
use async_trait::async_trait; use async_trait::async_trait;
use thiserror::Error; use thiserror::Error;
use BotError::DataError; use BotError::{DataError, SqliteDataError};
pub mod basic_rolling; pub mod basic_rolling;
pub mod cofd; pub mod cofd;
@ -55,6 +55,12 @@ impl From<crate::db::errors::DataError> for ExecutionError {
} }
} }
impl From<crate::db::sqlite::errors::DataError> for ExecutionError {
fn from(error: crate::db::sqlite::errors::DataError) -> Self {
Self(SqliteDataError(error))
}
}
impl ExecutionError { impl ExecutionError {
/// Error message in bolded HTML. /// Error message in bolded HTML.
pub fn html(&self) -> String { pub fn html(&self) -> String {

View File

@ -1,7 +1,8 @@
use super::{Command, Execution, ExecutionResult}; use super::{Command, Execution, ExecutionResult};
use crate::context::Context; use crate::context::Context;
use crate::db::errors::DataError; use crate::db::sqlite::errors::DataError;
use crate::db::Variables; use crate::db::sqlite::Variables;
use crate::db::variables::UserAndRoom;
use async_trait::async_trait; use async_trait::async_trait;
pub struct GetAllVariablesCommand; pub struct GetAllVariablesCommand;

View File

@ -1,7 +1,7 @@
use crate::context::Context; use crate::db::sqlite::Variables;
use crate::db::Variables;
use crate::error::{BotError, DiceRollingError}; use crate::error::{BotError, DiceRollingError};
use crate::parser::{Amount, Element}; use crate::parser::{Amount, Element};
use crate::{context::Context, db::variables::UserAndRoom};
use crate::{dice::calculate_single_die_amount, parser::DiceParsingError}; use crate::{dice::calculate_single_die_amount, parser::DiceParsingError};
use rand::rngs::StdRng; use rand::rngs::StdRng;
use rand::Rng; use rand::Rng;

98
src/db.rs Normal file
View File

@ -0,0 +1,98 @@
use crate::db::errors::{DataError, MigrationError};
use crate::db::migrations::{get_migration_version, Migrations};
use crate::db::rooms::Rooms;
use crate::db::state::DbState;
use crate::db::variables::Variables;
use log::info;
use sled::{Config, Db};
use std::path::Path;
pub mod data_migrations;
pub mod errors;
pub mod migrations;
pub mod rooms;
pub mod schema;
pub mod sqlite;
pub mod state;
pub mod variables;
#[derive(Clone)]
pub struct Database {
db: Db,
pub variables: Variables,
pub(crate) migrations: Migrations,
pub(crate) rooms: Rooms,
pub(crate) state: DbState,
}
impl Database {
fn new_db(db: sled::Db) -> Result<Database, DataError> {
let migrations = db.open_tree("migrations")?;
let database = Database {
db: db.clone(),
variables: Variables::new(&db)?,
migrations: Migrations(migrations),
rooms: Rooms::new(&db)?,
state: DbState::new(&db)?,
};
//Start any event handlers.
database.rooms.start_handler();
info!("Opened database.");
Ok(database)
}
pub fn new<P: AsRef<Path>>(path: P) -> Result<Database, DataError> {
let db = sled::open(path)?;
Self::new_db(db)
}
pub fn new_temp() -> Result<Database, DataError> {
let config = Config::new().temporary(true);
let db = config.open()?;
Self::new_db(db)
}
pub fn migrate(&self, to_version: u32) -> Result<(), DataError> {
//get version from db
let db_version = get_migration_version(&self)?;
if db_version < to_version {
info!(
"Migrating database from version {} to version {}",
db_version, to_version
);
//if db version < to_version, proceed
//produce range of db_version+1 .. to_version (inclusive)
let versions_to_run: Vec<u32> = ((db_version + 1)..=to_version).collect();
let migrations = data_migrations::get_migrations(&versions_to_run)?;
//execute each closure.
for (version, migration) in versions_to_run.iter().zip(migrations) {
let (migration_func, name) = migration;
//This needs to be transactional on migrations
//keyspace. abort on migration func error.
info!("Applying migration {} :: {}", version, name);
match migration_func(&self) {
Ok(_) => Ok(()),
Err(e) => Err(e),
}?;
self.migrations.set_migration_version(*version)?;
}
info!("Done applying migrations.");
Ok(())
} else if db_version > to_version {
//if db version > to_version, cannot downgrade error
Err(MigrationError::CannotDowngrade.into())
} else {
//if db version == to_version, do nothing
info!("No database migrations needed.");
Ok(())
}
}
}

28
src/db/data_migrations.rs Normal file
View File

@ -0,0 +1,28 @@
use crate::db::errors::{DataError, MigrationError};
use crate::db::variables::migrations::*;
use crate::db::Database;
use phf::phf_map;
pub(super) type DataMigration = (fn(&Database) -> Result<(), DataError>, &'static str);
static MIGRATIONS: phf::Map<u32, DataMigration> = phf_map! {
1u32 => (add_room_user_variable_count::migrate, "add_room_user_variable_count"),
2u32 => (delete_v0_schema, "delete_v0_schema"),
3u32 => (delete_variable_count, "delete_variable_count"),
4u32 => (change_delineator_delimiter::migrate, "change_delineator_delimiter"),
5u32 => (change_tree_structure::migrate, "change_tree_structure"),
};
pub fn get_migrations(versions: &[u32]) -> Result<Vec<DataMigration>, MigrationError> {
let mut migrations: Vec<DataMigration> = vec![];
for version in versions {
match MIGRATIONS.get(version) {
Some(func) => migrations.push(*func),
None => return Err(MigrationError::MigrationNotFound(*version)),
}
}
Ok(migrations)
}

View File

@ -1,6 +1,20 @@
use std::num::TryFromIntError; use sled::transaction::{TransactionError, UnabortableTransactionError};
use thiserror::Error; use thiserror::Error;
#[derive(Error, Debug)]
pub enum MigrationError {
#[error("cannot downgrade to an older database version")]
CannotDowngrade,
#[error("migration for version {0} not defined")]
MigrationNotFound(u32),
#[error("migration failed: {0}")]
MigrationFailed(String),
}
//TODO better combining of key and value in certain errors (namely
//I32SchemaViolation).
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum DataError { pub enum DataError {
#[error("value does not exist for key: {0}")] #[error("value does not exist for key: {0}")]
@ -12,6 +26,9 @@ pub enum DataError {
#[error("expected i32, but i32 schema was violated")] #[error("expected i32, but i32 schema was violated")]
I32SchemaViolation, I32SchemaViolation,
#[error("parse error")]
ParseError(#[from] std::num::ParseIntError),
#[error("unexpected or corruptd data bytes")] #[error("unexpected or corruptd data bytes")]
InvalidValue, InvalidValue,
@ -21,12 +38,44 @@ pub enum DataError {
#[error("expected string, but utf8 schema was violated: {0}")] #[error("expected string, but utf8 schema was violated: {0}")]
Utf8SchemaViolation(#[from] std::string::FromUtf8Error), Utf8SchemaViolation(#[from] std::string::FromUtf8Error),
#[error("data migration error: {0}")]
MigrationError(#[from] crate::db::sqlite::migrator::MigrationError),
#[error("internal database error: {0}")] #[error("internal database error: {0}")]
SqlxError(#[from] sqlx::Error), InternalError(#[from] sled::Error),
#[error("numeric conversion error")] #[error("transaction error: {0}")]
NumericConversionError(#[from] TryFromIntError), TransactionError(#[from] sled::transaction::TransactionError),
#[error("unabortable transaction error: {0}")]
UnabortableTransactionError(#[from] UnabortableTransactionError),
#[error("data migration error: {0}")]
MigrationError(#[from] MigrationError),
#[error("deserialization error: {0}")]
DeserializationError(#[from] bincode::Error),
}
/// This From implementation is necessary to deal with the recursive
/// error type in the error enum. We defined a transaction error, but
/// the only place we use it is when converting from
/// sled::transaction::TransactionError<DataError>. This converter
/// extracts the inner data error from transaction aborted errors, and
/// forwards anything else onward as-is, but wrapped in DataError.
impl From<TransactionError<DataError>> for DataError {
fn from(error: TransactionError<DataError>) -> Self {
match error {
TransactionError::Abort(data_err) => data_err,
TransactionError::Storage(storage_err) => {
DataError::TransactionError(TransactionError::Storage(storage_err))
}
}
}
}
/// Automatically aborts transactions that hit a DataError by using
/// the try (question mark) operator when this trait implementation is
/// in scope.
impl From<DataError> for sled::transaction::ConflictableTransactionError<DataError> {
fn from(error: DataError) -> Self {
sled::transaction::ConflictableTransactionError::Abort(error)
}
} }

54
src/db/migrations.rs Normal file
View File

@ -0,0 +1,54 @@
use crate::db::errors::DataError;
use crate::db::schema::convert_u32;
use crate::db::Database;
use byteorder::LittleEndian;
use sled::Tree;
use zerocopy::byteorder::U32;
use zerocopy::AsBytes;
//This file is for controlling the migration info stored in the
//database, not actually running migrations.
#[derive(Clone)]
pub struct Migrations(pub(super) Tree);
const COLON: &'static [u8] = b":";
const METADATA_SPACE: &'static str = "metadata";
const MIGRATION_KEY: &'static str = "migration_version";
fn to_key(keyspace: &str, key_name: &str) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(keyspace.as_bytes());
key.extend_from_slice(COLON);
key.extend_from_slice(key_name.as_bytes());
key
}
fn metadata_key(key_name: &str) -> Vec<u8> {
to_key(METADATA_SPACE, key_name)
}
impl Migrations {
pub(super) fn set_migration_version(&self, version: u32) -> Result<(), DataError> {
//Rust cannot type infer this transaction
let result: Result<_, sled::transaction::TransactionError<DataError>> =
self.0.transaction(|tx| {
let key = metadata_key(MIGRATION_KEY);
let db_value: U32<LittleEndian> = U32::new(version);
tx.insert(key, db_value.as_bytes())?;
Ok(())
});
result?;
Ok(())
}
}
pub(super) fn get_migration_version(db: &Database) -> Result<u32, DataError> {
let key = metadata_key(MIGRATION_KEY);
match db.migrations.0.get(key)? {
Some(bytes) => convert_u32(&bytes),
None => Ok(0),
}
}

View File

@ -1,69 +0,0 @@
use async_trait::async_trait;
use errors::DataError;
use std::collections::{HashMap, HashSet};
use crate::models::RoomInfo;
pub mod errors;
pub mod sqlite;
#[async_trait]
pub(crate) trait DbState {
async fn get_device_id(&self) -> Result<Option<String>, DataError>;
async fn set_device_id(&self, device_id: &str) -> Result<(), DataError>;
}
#[async_trait]
pub(crate) trait Rooms {
async fn should_process(&self, room_id: &str, event_id: &str) -> Result<bool, DataError>;
async fn insert_room_info(&self, info: &RoomInfo) -> Result<(), DataError>;
async fn get_room_info(&self, room_id: &str) -> Result<Option<RoomInfo>, DataError>;
async fn get_rooms_for_user(&self, user_id: &str) -> Result<HashSet<String>, DataError>;
async fn get_users_in_room(&self, room_id: &str) -> Result<HashSet<String>, DataError>;
async fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError>;
async fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError>;
async fn clear_info(&self, room_id: &str) -> Result<(), DataError>;
}
// TODO move this up to the top once we delete sled. Traits will be the
// main API, then we can have different impls for different DBs.
#[async_trait]
pub trait Variables {
async fn get_user_variables(
&self,
user: &str,
room_id: &str,
) -> Result<HashMap<String, i32>, DataError>;
async fn get_variable_count(&self, user: &str, room_id: &str) -> Result<i32, DataError>;
async fn get_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
) -> Result<i32, DataError>;
async fn set_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
value: i32,
) -> Result<(), DataError>;
async fn delete_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
) -> Result<(), DataError>;
}

515
src/db/rooms.rs Normal file
View File

@ -0,0 +1,515 @@
use crate::db::errors::DataError;
use crate::db::schema::convert_u64;
use crate::models::RoomInfo;
use byteorder::BigEndian;
use log::{debug, error, log_enabled};
use sled::transaction::TransactionalTree;
use sled::Transactional;
use sled::{CompareAndSwapError, Tree};
use std::collections::HashSet;
use std::str;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::task::JoinHandle;
use zerocopy::byteorder::U64;
use zerocopy::AsBytes;
#[derive(Clone)]
pub struct Rooms {
/// Room ID -> RoomInfo struct (single entries).
/// Key is just room ID as bytes.
pub(in crate::db) roomid_roominfo: Tree,
/// Room ID -> list of usernames in room.
pub(in crate::db) roomid_usernames: Tree,
/// Username -> list of room IDs user is in.
pub(in crate::db) username_roomids: Tree,
/// Room ID(str) 0xff event ID(str) -> timestamp. Records event
/// IDs that we have received, so we do not process twice.
pub(in crate::db) roomeventid_timestamp: Tree,
/// Room ID(str) 0xff timestamp(u64) -> event ID. Records event
/// IDs with timestamp as the primary key instead. Exists to allow
/// easy scanning of old roomeventid_timestamp records for
/// removal. Be careful with u64, it can have 0xff and 0xfe bytes.
/// A simple split on 0xff will not work with this key. Instead,
/// it is meant to be split on the first 0xff byte only, and
/// separated into room ID and timestamp.
pub(in crate::db) roomtimestamp_eventid: Tree,
}
/// An enum that can hold either a regular sled Tree, or a
/// Transactional tree.
#[derive(Clone, Copy)]
enum TxableTree<'a> {
Tree(&'a Tree),
Tx(&'a TransactionalTree),
}
impl<'a> Into<TxableTree<'a>> for &'a Tree {
fn into(self) -> TxableTree<'a> {
TxableTree::Tree(self)
}
}
impl<'a> Into<TxableTree<'a>> for &'a TransactionalTree {
fn into(self) -> TxableTree<'a> {
TxableTree::Tx(self)
}
}
/// A set of functions that can be used with a sled::Tree that stores
/// HashSets as its values. Atomicity is partially handled. If the
/// Tree is a transactional tree, operations will be atomic.
/// Otherwise, there is a potential non-atomic step.
mod hashset_tree {
use super::*;
fn insert_set<'a, T: Into<TxableTree<'a>>>(
tree: T,
key: &[u8],
set: HashSet<String>,
) -> Result<(), DataError> {
let serialized = bincode::serialize(&set)?;
match tree.into() {
TxableTree::Tree(tree) => tree.insert(key, serialized)?,
TxableTree::Tx(tx) => tx.insert(key, serialized)?,
};
Ok(())
}
pub(super) fn get_set<'a, T: Into<TxableTree<'a>>>(
tree: T,
key: &[u8],
) -> Result<HashSet<String>, DataError> {
let set: HashSet<String> = match tree.into() {
TxableTree::Tree(tree) => tree.get(key)?,
TxableTree::Tx(tx) => tx.get(key)?,
}
.map(|bytes| bincode::deserialize::<HashSet<String>>(&bytes))
.unwrap_or(Ok(HashSet::new()))?;
Ok(set)
}
pub(super) fn remove_from_set<'a, T: Into<TxableTree<'a>> + Copy>(
tree: T,
key: &[u8],
value_to_remove: &str,
) -> Result<(), DataError> {
let mut set = get_set(tree, key)?;
set.remove(value_to_remove);
insert_set(tree, key, set)?;
Ok(())
}
pub(super) fn add_to_set<'a, T: Into<TxableTree<'a>> + Copy>(
tree: T,
key: &[u8],
value_to_add: String,
) -> Result<(), DataError> {
let mut set = get_set(tree, key)?;
set.insert(value_to_add);
insert_set(tree, key, set)?;
Ok(())
}
}
/// Functions that specifically relate to the "timestamp index" tree,
/// which is stored on the Rooms instance as a tree called
/// roomtimestamp_eventid. Tightly coupled to the event watcher in the
/// Rooms impl, and only factored out for unit testing.
mod timestamp_index {
use super::*;
/// Insert an entry from the main roomeventid_timestamp Tree into
/// the timestamp index. Keys in this Tree are stored as room ID
/// 0xff timestamp, with the value being a hashset of event IDs
/// received at the time. The parameters come from an insert to
/// that Tree, where the key is room ID 0xff event ID, and the
/// value is the timestamp.
pub(super) fn insert(
roomtimestamp_eventid: &Tree,
key: &[u8],
timestamp_bytes: &[u8],
) -> Result<(), DataError> {
let parts: Vec<&[u8]> = key.split(|&b| b == 0xff).collect();
if let [room_id, event_id] = parts[..] {
let mut ts_key = room_id.to_vec();
ts_key.push(0xff);
ts_key.extend_from_slice(&timestamp_bytes);
log_index_record(room_id, event_id, &timestamp_bytes);
let event_id = str::from_utf8(event_id)?;
hashset_tree::add_to_set(roomtimestamp_eventid, &ts_key, event_id.to_owned())?;
Ok(())
} else {
Err(DataError::InvalidValue)
}
}
/// Log a debug message.
fn log_index_record(room_id: &[u8], event_id: &[u8], timestamp: &[u8]) {
if log_enabled!(log::Level::Debug) {
debug!(
"Recording event {} | {} received at {} in timestamp index.",
str::from_utf8(room_id).unwrap_or("[invalid room id]"),
str::from_utf8(event_id).unwrap_or("[invalid event id]"),
convert_u64(timestamp).unwrap_or(0)
);
}
}
}
impl Rooms {
pub(in crate::db) fn new(db: &sled::Db) -> Result<Rooms, sled::Error> {
Ok(Rooms {
roomid_roominfo: db.open_tree("roomid_roominfo")?,
roomid_usernames: db.open_tree("roomid_usernames")?,
username_roomids: db.open_tree("username_roomids")?,
roomeventid_timestamp: db.open_tree("roomeventid_timestamp")?,
roomtimestamp_eventid: db.open_tree("roomtimestamp_eventid")?,
})
}
/// Start an event subscriber that listens for inserts made by the
/// `should_process` function. This event handler duplicates the
/// entry by timestamp instead of event ID.
pub(in crate::db) fn start_handler(&self) -> JoinHandle<()> {
//Clone due to lifetime requirements.
let roomeventid_timestamp = self.roomeventid_timestamp.clone();
let roomtimestamp_eventid = self.roomtimestamp_eventid.clone();
tokio::spawn(async move {
let mut subscriber = roomeventid_timestamp.watch_prefix(b"");
// TODO make this handler receive kill messages somehow so
// we can unit test it and gracefully shut it down.
while let Some(event) = (&mut subscriber).await {
if let sled::Event::Insert { key, value } = event {
match timestamp_index::insert(&roomtimestamp_eventid, &key, &value) {
Err(e) => {
error!("Unable to update the timestamp index: {}", e);
}
_ => (),
}
}
}
})
}
/// Determine if an event in a room should be processed. The event
/// is atomically recorded and true returned if the database has
/// not seen tis event yet. If the event already exists in the
/// database, the function returns false. Events are recorded by
/// this function by inserting the (system-local) timestamp in
/// epoch seconds.
pub fn should_process(&self, room_id: &str, event_id: &str) -> Result<bool, DataError> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(event_id.as_bytes());
let timestamp: U64<BigEndian> = U64::new(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Clock has gone backwards")
.as_secs(),
);
match self.roomeventid_timestamp.compare_and_swap(
key,
None as Option<&[u8]>,
Some(timestamp.as_bytes()),
)? {
Ok(()) => Ok(true),
Err(CompareAndSwapError { .. }) => Ok(false),
}
}
pub fn insert_room_info(&self, info: &RoomInfo) -> Result<(), DataError> {
let key = info.room_id.as_bytes();
let serialized = bincode::serialize(&info)?;
self.roomid_roominfo.insert(key, serialized)?;
Ok(())
}
pub fn get_room_info(&self, room_id: &str) -> Result<Option<RoomInfo>, DataError> {
let key = room_id.as_bytes();
let room_info: Option<RoomInfo> = self
.roomid_roominfo
.get(key)?
.map(|bytes| bincode::deserialize(&bytes))
.transpose()?;
Ok(room_info)
}
pub fn get_rooms_for_user(&self, username: &str) -> Result<HashSet<String>, DataError> {
hashset_tree::get_set(&self.username_roomids, username.as_bytes())
}
pub fn get_users_in_room(&self, room_id: &str) -> Result<HashSet<String>, DataError> {
hashset_tree::get_set(&self.roomid_usernames, room_id.as_bytes())
}
pub fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError> {
debug!("Adding user {} to room {}", username, room_id);
(&self.username_roomids, &self.roomid_usernames).transaction(
|(tx_username_rooms, tx_room_usernames)| {
let username_key = &username.as_bytes();
hashset_tree::add_to_set(tx_username_rooms, username_key, room_id.to_owned())?;
let roomid_key = &room_id.as_bytes();
hashset_tree::add_to_set(tx_room_usernames, roomid_key, username.to_owned())?;
Ok(())
},
)?;
Ok(())
}
pub fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError> {
debug!("Removing user {} from room {}", username, room_id);
(&self.username_roomids, &self.roomid_usernames).transaction(
|(tx_username_rooms, tx_room_usernames)| {
let username_key = &username.as_bytes();
hashset_tree::remove_from_set(tx_username_rooms, username_key, room_id)?;
let roomid_key = &room_id.as_bytes();
hashset_tree::remove_from_set(tx_room_usernames, roomid_key, username)?;
Ok(())
},
)?;
Ok(())
}
pub fn clear_info(&self, room_id: &str) -> Result<(), DataError> {
debug!("Clearing all information for room {}", room_id);
(&self.username_roomids, &self.roomid_usernames).transaction(
|(tx_username_roomids, tx_roomid_usernames)| {
let roomid_key = room_id.as_bytes();
let users_in_room = hashset_tree::get_set(tx_roomid_usernames, roomid_key)?;
//Remove the room ID from every user's room ID list.
for username in users_in_room {
hashset_tree::remove_from_set(
tx_username_roomids,
username.as_bytes(),
room_id,
)?;
}
//Remove this room entry for the room ID -> username tree.
tx_roomid_usernames.remove(roomid_key)?;
//TODO: delete roominfo struct from room info tree.
Ok(())
},
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use sled::Config;
fn create_test_instance() -> Rooms {
let config = Config::new().temporary(true);
let db = config.open().unwrap();
Rooms::new(&db).unwrap()
}
#[test]
fn add_user_to_room() {
let rooms = create_test_instance();
rooms
.add_user_to_room("testuser", "myroom")
.expect("Could not add user to room");
let users_in_room = rooms
.get_users_in_room("myroom")
.expect("Could not retrieve users in room");
let rooms_for_user = rooms
.get_rooms_for_user("testuser")
.expect("Could not get rooms for user");
let expected_users_in_room: HashSet<String> =
vec![String::from("testuser")].into_iter().collect();
let expected_rooms_for_user: HashSet<String> =
vec![String::from("myroom")].into_iter().collect();
assert_eq!(expected_users_in_room, users_in_room);
assert_eq!(expected_rooms_for_user, rooms_for_user);
}
#[test]
fn remove_user_from_room() {
let rooms = create_test_instance();
rooms
.add_user_to_room("testuser", "myroom")
.expect("Could not add user to room");
rooms
.remove_user_from_room("testuser", "myroom")
.expect("Could not remove user from room");
let users_in_room = rooms
.get_users_in_room("myroom")
.expect("Could not retrieve users in room");
let rooms_for_user = rooms
.get_rooms_for_user("testuser")
.expect("Could not get rooms for user");
assert_eq!(HashSet::new(), users_in_room);
assert_eq!(HashSet::new(), rooms_for_user);
}
#[test]
fn insert_room_info_works() {
let rooms = create_test_instance();
let info = RoomInfo {
room_id: matrix_sdk::identifiers::room_id!("!fakeroom:example.com")
.as_str()
.to_owned(),
room_name: "fake room name".to_owned(),
};
rooms
.insert_room_info(&info)
.expect("Could insert room info");
let found_info = rooms
.get_room_info("!fakeroom:example.com")
.expect("Error loading room info");
assert!(found_info.is_some());
assert_eq!(info, found_info.unwrap());
}
#[test]
fn insert_room_info_updates_data() {
let rooms = create_test_instance();
let mut info = RoomInfo {
room_id: matrix_sdk::identifiers::room_id!("!fakeroom:example.com")
.as_str()
.to_owned(),
room_name: "fake room name".to_owned(),
};
rooms
.insert_room_info(&info)
.expect("Could insert room info");
//Update info to have a new room name before inserting again.
info.room_name = "new fake room name".to_owned();
rooms
.insert_room_info(&info)
.expect("Could insert room info");
let found_info = rooms
.get_room_info("!fakeroom:example.com")
.expect("Error loading room info");
assert!(found_info.is_some());
assert_eq!(info, found_info.unwrap());
}
#[test]
fn get_room_info_none_when_room_does_not_exist() {
let rooms = create_test_instance();
let found_info = rooms
.get_room_info("!fakeroom:example.com")
.expect("Error loading room info");
assert!(found_info.is_none());
}
#[test]
fn clear_info_modifies_removes_requested_room() {
let rooms = create_test_instance();
rooms
.add_user_to_room("testuser", "myroom1")
.expect("Could not add user to room1");
rooms
.add_user_to_room("testuser", "myroom2")
.expect("Could not add user to room2");
rooms
.clear_info("myroom1")
.expect("Could not clear room info");
let users_in_room1 = rooms
.get_users_in_room("myroom1")
.expect("Could not retrieve users in room1");
let users_in_room2 = rooms
.get_users_in_room("myroom2")
.expect("Could not retrieve users in room2");
let rooms_for_user = rooms
.get_rooms_for_user("testuser")
.expect("Could not get rooms for user");
let expected_users_in_room2: HashSet<String> =
vec![String::from("testuser")].into_iter().collect();
let expected_rooms_for_user: HashSet<String> =
vec![String::from("myroom2")].into_iter().collect();
assert_eq!(HashSet::new(), users_in_room1);
assert_eq!(expected_users_in_room2, users_in_room2);
assert_eq!(expected_rooms_for_user, rooms_for_user);
}
#[test]
fn insert_to_timestamp_index() {
let rooms = create_test_instance();
// Insertion into timestamp index based on data that would go
// into main room x eventID -> timestamp tree.
let mut key = b"myroom".to_vec();
key.push(0xff);
key.extend_from_slice(b"myeventid");
let timestamp: U64<BigEndian> = U64::new(1000);
let result = timestamp_index::insert(
&rooms.roomtimestamp_eventid,
key.as_bytes(),
timestamp.as_bytes(),
);
assert!(result.is_ok());
// Retrieval of data from the timestamp index tree.
let mut ts_key = b"myroom".to_vec();
ts_key.push(0xff);
ts_key.extend_from_slice(timestamp.as_bytes());
let expected_events: HashSet<String> =
vec![String::from("myeventid")].into_iter().collect();
let event_ids = hashset_tree::get_set(&rooms.roomtimestamp_eventid, &ts_key)
.expect("Could not get set out of Tree");
assert_eq!(expected_events, event_ids);
}
}

52
src/db/schema.rs Normal file
View File

@ -0,0 +1,52 @@
use crate::db::errors::DataError;
use byteorder::{BigEndian, LittleEndian};
use zerocopy::byteorder::{I32, U32, U64};
use zerocopy::LayoutVerified;
/// User variables are stored as little-endian 32-bit integers in the
/// database. This type alias makes the database code more pleasant to
/// read.
type LittleEndianI32Layout<'a> = LayoutVerified<&'a [u8], I32<LittleEndian>>;
type LittleEndianU32Layout<'a> = LayoutVerified<&'a [u8], U32<LittleEndian>>;
#[allow(dead_code)]
type LittleEndianU64Layout<'a> = LayoutVerified<&'a [u8], U64<LittleEndian>>;
type BigEndianU64Layout<'a> = LayoutVerified<&'a [u8], U64<BigEndian>>;
/// Convert bytes to an i32 with zero-copy deserialization. An error
/// is returned if the bytes do not represent an i32.
pub(super) fn convert_i32(raw_value: &[u8]) -> Result<i32, DataError> {
let layout = LittleEndianI32Layout::new_unaligned(raw_value.as_ref());
if let Some(layout) = layout {
let value: I32<LittleEndian> = *layout;
Ok(value.get())
} else {
Err(DataError::I32SchemaViolation)
}
}
pub(super) fn convert_u32(raw_value: &[u8]) -> Result<u32, DataError> {
let layout = LittleEndianU32Layout::new_unaligned(raw_value.as_ref());
if let Some(layout) = layout {
let value: U32<LittleEndian> = *layout;
Ok(value.get())
} else {
Err(DataError::I32SchemaViolation)
}
}
#[allow(dead_code)]
pub(super) fn convert_u64(raw_value: &[u8]) -> Result<u64, DataError> {
let layout = BigEndianU64Layout::new_unaligned(raw_value.as_ref());
if let Some(layout) = layout {
let value: U64<BigEndian> = *layout;
Ok(value.get())
} else {
Err(DataError::I32SchemaViolation)
}
}

72
src/db/sqlite/errors.rs Normal file
View File

@ -0,0 +1,72 @@
use std::num::TryFromIntError;
use sled::transaction::{TransactionError, UnabortableTransactionError};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum DataError {
#[error("value does not exist for key: {0}")]
KeyDoesNotExist(String),
#[error("too many entries")]
TooManyEntries,
#[error("expected i32, but i32 schema was violated")]
I32SchemaViolation,
#[error("unexpected or corruptd data bytes")]
InvalidValue,
#[error("expected string ref, but utf8 schema was violated: {0}")]
Utf8RefSchemaViolation(#[from] std::str::Utf8Error),
#[error("expected string, but utf8 schema was violated: {0}")]
Utf8SchemaViolation(#[from] std::string::FromUtf8Error),
#[error("internal database error: {0}")]
InternalError(#[from] sled::Error),
#[error("transaction error: {0}")]
TransactionError(#[from] sled::transaction::TransactionError),
#[error("unabortable transaction error: {0}")]
UnabortableTransactionError(#[from] UnabortableTransactionError),
#[error("data migration error: {0}")]
MigrationError(#[from] super::migrator::MigrationError),
#[error("deserialization error: {0}")]
DeserializationError(#[from] bincode::Error),
#[error("sqlx error: {0}")]
SqlxError(#[from] sqlx::Error),
#[error("numeric conversion error")]
NumericConversionError(#[from] TryFromIntError),
}
/// This From implementation is necessary to deal with the recursive
/// error type in the error enum. We defined a transaction error, but
/// the only place we use it is when converting from
/// sled::transaction::TransactionError<DataError>. This converter
/// extracts the inner data error from transaction aborted errors, and
/// forwards anything else onward as-is, but wrapped in DataError.
impl From<TransactionError<DataError>> for DataError {
fn from(error: TransactionError<DataError>) -> Self {
match error {
TransactionError::Abort(data_err) => data_err,
TransactionError::Storage(storage_err) => {
DataError::TransactionError(TransactionError::Storage(storage_err))
}
}
}
}
/// Automatically aborts transactions that hit a DataError by using
/// the try (question mark) operator when this trait implementation is
/// in scope.
impl From<DataError> for sled::transaction::ConflictableTransactionError<DataError> {
fn from(error: DataError) -> Self {
sled::transaction::ConflictableTransactionError::Abort(error)
}
}

View File

@ -1,6 +1,5 @@
use barrel::backend::Sqlite; use barrel::backend::Sqlite;
use barrel::{types, Migration}; use barrel::{types, types::Type, Migration};
pub fn migration() -> String { pub fn migration() -> String {
let mut m = Migration::new(); let mut m = Migration::new();

View File

@ -1,5 +1,5 @@
use barrel::backend::Sqlite; use barrel::backend::Sqlite;
use barrel::{types, Migration}; use barrel::{types, types::Type, Migration};
pub fn migration() -> String { pub fn migration() -> String {
let mut m = Migration::new(); let mut m = Migration::new();

View File

@ -1,14 +1,80 @@
use crate::db::errors::DataError; use async_trait::async_trait;
use errors::DataError;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use sqlx::ConnectOptions; use sqlx::ConnectOptions;
use std::clone::Clone; use std::clone::Clone;
use std::collections::{HashMap, HashSet};
use std::str::FromStr; use std::str::FromStr;
use crate::models::RoomInfo;
pub mod errors;
pub mod migrator; pub mod migrator;
pub mod rooms; pub mod rooms;
pub mod state; pub mod state;
pub mod variables; pub mod variables;
#[async_trait]
pub(crate) trait DbState {
async fn get_device_id(&self) -> Result<Option<String>, DataError>;
async fn set_device_id(&self, device_id: &str) -> Result<(), DataError>;
}
#[async_trait]
pub(crate) trait Rooms {
async fn should_process(&self, room_id: &str, event_id: &str) -> Result<bool, DataError>;
async fn insert_room_info(&self, info: &RoomInfo) -> Result<(), DataError>;
async fn get_room_info(&self, room_id: &str) -> Result<Option<RoomInfo>, DataError>;
async fn get_rooms_for_user(&self, user_id: &str) -> Result<HashSet<String>, DataError>;
async fn get_users_in_room(&self, room_id: &str) -> Result<HashSet<String>, DataError>;
async fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError>;
async fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError>;
async fn clear_info(&self, room_id: &str) -> Result<(), DataError>;
}
// TODO move this up to the top once we delete sled. Traits will be the
// main API, then we can have different impls for different DBs.
#[async_trait]
pub trait Variables {
async fn get_user_variables(
&self,
user: &str,
room_id: &str,
) -> Result<HashMap<String, i32>, DataError>;
async fn get_variable_count(&self, user: &str, room_id: &str) -> Result<i32, DataError>;
async fn get_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
) -> Result<i32, DataError>;
async fn set_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
value: i32,
) -> Result<(), DataError>;
async fn delete_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
) -> Result<(), DataError>;
}
pub struct Database { pub struct Database {
conn: SqlitePool, conn: SqlitePool,
} }

View File

@ -1,9 +1,9 @@
use super::Database; use super::errors::DataError;
use crate::db::{errors::DataError, Rooms}; use super::{Database, Rooms};
use crate::models::RoomInfo; use crate::models::RoomInfo;
use async_trait::async_trait; use async_trait::async_trait;
use sqlx::SqlitePool; use sqlx::SqlitePool;
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
async fn record_event(conn: &SqlitePool, room_id: &str, event_id: &str) -> Result<(), DataError> { async fn record_event(conn: &SqlitePool, room_id: &str, event_id: &str) -> Result<(), DataError> {
@ -148,9 +148,8 @@ impl Rooms for Database {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::Rooms;
use super::*; use super::*;
use crate::db::sqlite::Database;
use crate::db::Rooms;
async fn create_db() -> Database { async fn create_db() -> Database {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();

View File

@ -1,5 +1,5 @@
use super::Database; use super::errors::DataError;
use crate::db::{errors::DataError, DbState}; use super::{Database, DbState};
use async_trait::async_trait; use async_trait::async_trait;
#[async_trait] #[async_trait]
@ -35,8 +35,8 @@ impl DbState for Database {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::db::sqlite::Database; use super::super::DbState;
use crate::db::DbState; use super::*;
async fn create_db() -> Database { async fn create_db() -> Database {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();

View File

@ -1,8 +1,13 @@
use super::Database; use super::errors::DataError;
use crate::db::{errors::DataError, Variables}; use super::{Database, Variables};
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::HashMap; use std::collections::HashMap;
struct UserVariableRow {
key: String,
value: i32,
}
#[async_trait] #[async_trait]
impl Variables for Database { impl Variables for Database {
async fn get_user_variables( async fn get_user_variables(
@ -99,9 +104,8 @@ impl Variables for Database {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::Variables;
use super::*; use super::*;
use crate::db::sqlite::Database;
use crate::db::Variables;
async fn create_db() -> Database { async fn create_db() -> Database {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
@ -116,6 +120,7 @@ mod tests {
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn set_and_get_variable_test() { async fn set_and_get_variable_test() {
use super::super::Variables;
let db = create_db().await; let db = create_db().await;
db.set_user_variable("myuser", "myroom", "myvariable", 1) db.set_user_variable("myuser", "myroom", "myvariable", 1)
@ -132,6 +137,7 @@ mod tests {
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_missing_variable_test() { async fn get_missing_variable_test() {
use super::super::Variables;
let db = create_db().await; let db = create_db().await;
let value = db.get_user_variable("myuser", "myroom", "myvariable").await; let value = db.get_user_variable("myuser", "myroom", "myvariable").await;
@ -145,6 +151,7 @@ mod tests {
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_other_user_variable_test() { async fn get_other_user_variable_test() {
use super::super::Variables;
let db = create_db().await; let db = create_db().await;
db.set_user_variable("myuser1", "myroom", "myvariable", 1) db.set_user_variable("myuser1", "myroom", "myvariable", 1)

88
src/db/state.rs Normal file
View File

@ -0,0 +1,88 @@
use crate::db::errors::DataError;
use sled::Tree;
#[derive(Clone)]
pub struct DbState {
/// Tree of simple key-values for global state values that persist
/// between restarts (e.g. device ID).
pub(in crate::db) global_metadata: Tree,
}
const DEVICE_ID_KEY: &'static [u8] = b"device_id";
impl DbState {
pub(in crate::db) fn new(db: &sled::Db) -> Result<DbState, sled::Error> {
Ok(DbState {
global_metadata: db.open_tree("global_metadata")?,
})
}
pub fn get_device_id(&self) -> Result<Option<String>, DataError> {
self.global_metadata
.get(DEVICE_ID_KEY)?
.map(|v| String::from_utf8(v.to_vec()))
.transpose()
.map_err(|e| e.into())
}
pub fn set_device_id(&self, device_id: &str) -> Result<(), DataError> {
self.global_metadata
.insert(DEVICE_ID_KEY, device_id.as_bytes())?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use sled::Config;
fn create_test_instance() -> DbState {
let config = Config::new().temporary(true);
let db = config.open().unwrap();
DbState::new(&db).unwrap()
}
#[test]
fn set_device_id_works() {
let state = create_test_instance();
let result = state.set_device_id("test-device");
assert!(result.is_ok());
}
#[test]
fn set_device_id_can_overwrite() {
let state = create_test_instance();
state.set_device_id("test-device").expect("insert 1 failed");
let result = state.set_device_id("test-device2");
assert!(result.is_ok());
}
#[test]
fn get_device_id_returns_some_when_set() {
let state = create_test_instance();
state
.set_device_id("test-device")
.expect("could not store device id properly");
let device_id = state.get_device_id();
assert!(device_id.is_ok());
let device_id = device_id.unwrap();
assert!(device_id.is_some());
assert_eq!("test-device", device_id.unwrap());
}
#[test]
fn get_device_id_returns_none_when_unset() {
let state = create_test_instance();
let device_id = state.get_device_id();
assert!(device_id.is_ok());
let device_id = device_id.unwrap();
assert!(device_id.is_none());
}
}

410
src/db/variables.rs Normal file
View File

@ -0,0 +1,410 @@
use crate::db::errors::DataError;
use crate::db::schema::convert_i32;
use byteorder::LittleEndian;
use sled::transaction::{abort, TransactionalTree};
use sled::Transactional;
use sled::Tree;
use std::collections::HashMap;
use std::convert::From;
use std::str;
use zerocopy::byteorder::I32;
use zerocopy::AsBytes;
use super::errors;
pub(super) mod migrations;
#[derive(Clone)]
pub struct Variables {
//room id + username + variable = i32
pub(in crate::db) room_user_variables: Tree,
//room id + username = i32
pub(in crate::db) room_user_variable_count: Tree,
}
/// Request something by a username and room ID.
pub struct UserAndRoom<'a>(pub &'a str, pub &'a str);
fn to_vec(value: &UserAndRoom<'_>) -> Vec<u8> {
let mut bytes = vec![];
bytes.extend_from_slice(value.0.as_bytes());
bytes.push(0xfe);
bytes.extend_from_slice(value.1.as_bytes());
bytes
}
impl From<UserAndRoom<'_>> for Vec<u8> {
fn from(value: UserAndRoom) -> Vec<u8> {
to_vec(&value)
}
}
impl From<&UserAndRoom<'_>> for Vec<u8> {
fn from(value: &UserAndRoom) -> Vec<u8> {
to_vec(value)
}
}
/// Use a transaction to atomically alter the count of variables in
/// the database by the given amount. Count cannot go below 0.
fn alter_room_variable_count(
room_variable_count: &TransactionalTree,
user_and_room: &UserAndRoom<'_>,
amount: i32,
) -> Result<i32, DataError> {
let key: Vec<u8> = user_and_room.into();
let mut new_count = match room_variable_count.get(&key)? {
Some(bytes) => convert_i32(&bytes)? + amount,
None => amount,
};
if new_count < 0 {
new_count = 0;
}
let db_value: I32<LittleEndian> = I32::new(new_count);
room_variable_count.insert(key, db_value.as_bytes())?;
Ok(new_count)
}
/// Room ID, Username, Variable Name
pub type AllVariablesKey = (String, String, String);
impl Variables {
pub(in crate::db) fn new(db: &sled::Db) -> Result<Variables, sled::Error> {
Ok(Variables {
room_user_variables: db.open_tree("variables")?,
room_user_variable_count: db.open_tree("room_user_variable_count")?,
})
}
pub fn get_all_variables(&self) -> Result<HashMap<AllVariablesKey, i32>, DataError> {
use std::convert::TryFrom;
let variables: Result<Vec<(AllVariablesKey, i32)>, DataError> = self
.room_user_variables
.scan_prefix("")
.map(|entry| match entry {
Ok((key, raw_value)) => {
let keys: Vec<_> = key
.split(|&b| b == 0xfe || b == 0xff)
.map(|b| str::from_utf8(b))
.collect();
if let &[Ok(room_id), Ok(username), Ok(variable_name), ..] = keys.as_slice() {
Ok((
(
room_id.to_owned(),
username.to_owned(),
variable_name.to_owned(),
),
convert_i32(&raw_value)?,
))
} else {
Err(errors::DataError::InvalidValue)
}
}
Err(e) => Err(e.into()),
})
.collect();
// Convert tuples to hash map with collect(), inferred via
// return type.
variables.map(|entries| entries.into_iter().collect())
}
pub fn get_user_variables(
&self,
key: &UserAndRoom<'_>,
) -> Result<HashMap<String, i32>, DataError> {
let mut prefix: Vec<u8> = key.into();
prefix.push(0xff);
let prefix_len: usize = prefix.len();
let variables: Result<Vec<(String, i32)>, DataError> = self
.room_user_variables
.scan_prefix(prefix)
.map(|entry| match entry {
Ok((key, raw_value)) => {
//Strips room and username from key, leaving behind name.
let variable_name = str::from_utf8(&key[prefix_len..])?;
Ok((variable_name.to_owned(), convert_i32(&raw_value)?))
}
Err(e) => Err(e.into()),
})
.collect();
// Convert tuples to hash map with collect(), inferred via
// return type.
variables.map(|entries| entries.into_iter().collect())
}
pub fn get_variable_count(&self, user_and_room: &UserAndRoom<'_>) -> Result<i32, DataError> {
let key: Vec<u8> = user_and_room.into();
match self.room_user_variable_count.get(&key)? {
Some(raw_value) => convert_i32(&raw_value),
None => Ok(0),
}
}
pub fn get_user_variable(
&self,
user_and_room: &UserAndRoom<'_>,
variable_name: &str,
) -> Result<i32, DataError> {
let mut key: Vec<u8> = user_and_room.into();
key.push(0xff);
key.extend_from_slice(variable_name.as_bytes());
match self.room_user_variables.get(&key)? {
Some(raw_value) => convert_i32(&raw_value),
_ => Err(DataError::KeyDoesNotExist(variable_name.to_owned())),
}
}
pub fn set_user_variable(
&self,
user_and_room: &UserAndRoom<'_>,
variable_name: &str,
value: i32,
) -> Result<(), DataError> {
if self.get_variable_count(user_and_room)? >= 100 {
return Err(DataError::TooManyEntries);
}
(&self.room_user_variables, &self.room_user_variable_count).transaction(
|(tx_vars, tx_counts)| {
let mut key: Vec<u8> = user_and_room.into();
key.push(0xff);
key.extend_from_slice(variable_name.as_bytes());
let db_value: I32<LittleEndian> = I32::new(value);
let old_value = tx_vars.insert(key, db_value.as_bytes())?;
//Only increment variable count on new keys.
if let None = old_value {
if let Err(e) = alter_room_variable_count(&tx_counts, &user_and_room, 1) {
return abort(e);
}
}
Ok(())
},
)?;
Ok(())
}
pub fn delete_user_variable(
&self,
user_and_room: &UserAndRoom<'_>,
variable_name: &str,
) -> Result<(), DataError> {
(&self.room_user_variables, &self.room_user_variable_count).transaction(
|(tx_vars, tx_counts)| {
let mut key: Vec<u8> = user_and_room.into();
key.push(0xff);
key.extend_from_slice(variable_name.as_bytes());
if let Some(_) = tx_vars.remove(key)? {
if let Err(e) = alter_room_variable_count(&tx_counts, user_and_room, -1) {
return abort(e);
}
} else {
return abort(DataError::KeyDoesNotExist(variable_name.to_owned()));
}
Ok(())
},
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use sled::Config;
fn create_test_instance() -> Variables {
let config = Config::new().temporary(true);
let db = config.open().unwrap();
Variables::new(&db).unwrap()
}
//Room Variable count tests
#[test]
fn alter_room_variable_count_test() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
let alter_count = |amount: i32| {
variables
.room_user_variable_count
.transaction(|tx| match alter_room_variable_count(&tx, &key, amount) {
Err(e) => abort(e),
_ => Ok(()),
})
.expect("got transaction failure");
};
let get_count = |variables: &Variables| -> i32 {
variables
.get_variable_count(&key)
.expect("could not get variable count")
};
//addition
alter_count(5);
assert_eq!(5, get_count(&variables));
//subtraction
alter_count(-3);
assert_eq!(2, get_count(&variables));
}
#[test]
fn alter_room_variable_count_cannot_go_below_0_test() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
variables
.room_user_variable_count
.transaction(|tx| match alter_room_variable_count(&tx, &key, -1000) {
Err(e) => abort(e),
_ => Ok(()),
})
.expect("got transaction failure");
let count = variables
.get_variable_count(&key)
.expect("could not get variable count");
assert_eq!(0, count);
}
#[test]
fn empty_db_reports_0_room_variable_count_test() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
let count = variables
.get_variable_count(&key)
.expect("could not get variable count");
assert_eq!(0, count);
}
#[test]
fn set_user_variable_increments_count() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
variables
.set_user_variable(&key, "myvariable", 5)
.expect("could not insert variable");
let count = variables
.get_variable_count(&key)
.expect("could not get variable count");
assert_eq!(1, count);
}
#[test]
fn update_user_variable_does_not_increment_count() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
variables
.set_user_variable(&key, "myvariable", 5)
.expect("could not insert variable");
variables
.set_user_variable(&key, "myvariable", 10)
.expect("could not update variable");
let count = variables
.get_variable_count(&key)
.expect("could not get variable count");
assert_eq!(1, count);
}
// Set/get/delete variable tests
#[test]
fn set_and_get_variable_test() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
variables
.set_user_variable(&key, "myvariable", 5)
.expect("could not insert variable");
let value = variables
.get_user_variable(&key, "myvariable")
.expect("could not get value");
assert_eq!(5, value);
}
#[test]
fn cannot_set_more_than_100_variables_per_room() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
for c in 0..100 {
variables
.set_user_variable(&key, &format!("myvariable{}", c), 5)
.expect("could not insert variable");
}
let result = variables.set_user_variable(&key, "myvariable101", 5);
assert!(result.is_err());
assert!(matches!(result, Err(DataError::TooManyEntries)));
}
#[test]
fn delete_variable_test() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
variables
.set_user_variable(&key, "myvariable", 5)
.expect("could not insert variable");
variables
.delete_user_variable(&key, "myvariable")
.expect("could not delete value");
let result = variables.get_user_variable(&key, "myvariable");
assert!(result.is_err());
assert!(matches!(result, Err(DataError::KeyDoesNotExist(_))));
}
#[test]
fn get_missing_variable_returns_key_does_not_exist() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
let result = variables.get_user_variable(&key, "myvariable");
assert!(result.is_err());
assert!(matches!(result, Err(DataError::KeyDoesNotExist(_))));
}
#[test]
fn remove_missing_variable_returns_key_does_not_exist() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
let result = variables.delete_user_variable(&key, "myvariable");
assert!(result.is_err());
assert!(matches!(result, Err(DataError::KeyDoesNotExist(_))));
}
}

View File

@ -0,0 +1,354 @@
use super::*;
use crate::db::errors::{DataError, MigrationError};
use crate::db::Database;
use byteorder::LittleEndian;
use memmem::{Searcher, TwoWaySearcher};
use sled::transaction::TransactionError;
use sled::{Batch, IVec};
use std::collections::HashMap;
use zerocopy::byteorder::{I32, U32};
use zerocopy::AsBytes;
pub(in crate::db) mod add_room_user_variable_count {
use super::*;
//Not to be confused with the super::RoomAndUser delineator.
#[derive(PartialEq, Eq, std::hash::Hash)]
struct RoomAndUser {
room_id: String,
username: String,
}
/// Create a version 0 user variable key.
fn v0_variable_key(info: &RoomAndUser, variable_name: &str) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(info.room_id.as_bytes());
key.extend_from_slice(info.username.as_bytes());
key.extend_from_slice(variable_name.as_bytes());
key
}
fn map_value_to_room_and_user(
entry: sled::Result<(IVec, IVec)>,
) -> Result<RoomAndUser, MigrationError> {
if let Ok((key, _)) = entry {
let keys: Vec<Result<&str, _>> = key
.split(|&b| b == 0xff)
.map(|b| str::from_utf8(b))
.collect();
if let &[_, Ok(room_id), Ok(username), Ok(_variable)] = keys.as_slice() {
Ok(RoomAndUser {
room_id: room_id.to_owned(),
username: username.to_owned(),
})
} else {
Err(MigrationError::MigrationFailed(
"a key violates utf8 schema".to_string(),
))
}
} else {
Err(MigrationError::MigrationFailed(
"encountered unexpected key".to_string(),
))
}
}
fn create_key(room_id: &str, username: &str) -> Vec<u8> {
let mut key = b"variables".to_vec();
key.push(0xff);
key.extend_from_slice(room_id.as_bytes());
key.push(0xff);
key.extend_from_slice(username.as_bytes());
key.push(0xff);
key.extend_from_slice(b"variable_count");
key
}
pub(in crate::db) fn migrate(db: &Database) -> Result<(), DataError> {
let tree = &db.variables.room_user_variables;
let prefix = b"variables";
//Extract a vec of tuples, consisting of room id + username.
let results: Vec<RoomAndUser> = tree
.scan_prefix(prefix)
.map(map_value_to_room_and_user)
.collect::<Result<Vec<_>, MigrationError>>()?;
let counts: HashMap<RoomAndUser, u32> =
results
.into_iter()
.fold(HashMap::new(), |mut count_map, room_and_user| {
let count = count_map.entry(room_and_user).or_insert(0);
*count += 1;
count_map
});
//Start a transaction on the variables tree.
let tx_result: Result<_, TransactionError<DataError>> =
db.variables.room_user_variables.transaction(|tx_vars| {
let batch = counts.iter().fold(Batch::default(), |mut batch, entry| {
let (info, count) = entry;
//Add variable count according to new schema.
let key = create_key(&info.room_id, &info.username);
let db_value: U32<LittleEndian> = U32::new(*count);
batch.insert(key, db_value.as_bytes());
//Delete the old variable_count variable if exists.
let old_key = v0_variable_key(&info, "variable_count");
batch.remove(old_key);
batch
});
tx_vars.apply_batch(&batch)?;
Ok(())
});
tx_result?; //For some reason, it cannot infer the type
Ok(())
}
}
pub(in crate::db) fn delete_v0_schema(db: &Database) -> Result<(), DataError> {
let mut vars = db.variables.room_user_variables.scan_prefix("");
let mut batch = Batch::default();
while let Some(Ok((key, _))) = vars.next() {
let key = key.to_vec();
if !key.contains(&0xff) {
batch.remove(key);
}
}
db.variables.room_user_variables.apply_batch(batch)?;
Ok(())
}
pub(in crate::db) fn delete_variable_count(db: &Database) -> Result<(), DataError> {
let prefix = b"variables";
let mut vars = db.variables.room_user_variables.scan_prefix(prefix);
let mut batch = Batch::default();
while let Some(Ok((key, _))) = vars.next() {
let search = TwoWaySearcher::new(b"variable_count");
let ends_with = {
match search.search_in(&key) {
Some(index) => key.len() - index == b"variable_count".len(),
None => false,
}
};
if ends_with {
batch.remove(key);
}
}
db.variables.room_user_variables.apply_batch(batch)?;
Ok(())
}
pub(in crate::db) mod change_delineator_delimiter {
use super::*;
/// An entry in the room user variables keyspace.
struct UserVariableEntry {
room_id: String,
username: String,
variable_name: String,
value: IVec,
}
/// Extract keys and values from the variables keyspace according
/// to the v1 schema.
fn extract_v1_entries(
entry: sled::Result<(IVec, IVec)>,
) -> Result<UserVariableEntry, MigrationError> {
if let Ok((key, value)) = entry {
let keys: Vec<Result<&str, _>> = key
.split(|&b| b == 0xff)
.map(|b| str::from_utf8(b))
.collect();
if let &[_, Ok(room_id), Ok(username), Ok(variable)] = keys.as_slice() {
Ok(UserVariableEntry {
room_id: room_id.to_owned(),
username: username.to_owned(),
variable_name: variable.to_owned(),
value: value,
})
} else {
Err(MigrationError::MigrationFailed(
"a key violates utf8 schema".to_string(),
))
}
} else {
Err(MigrationError::MigrationFailed(
"encountered unexpected key".to_string(),
))
}
}
/// Create an old key, where delineator is separated by 0xff.
fn create_old_key(prefix: &[u8], insert: &UserVariableEntry) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(&prefix); //prefix already has 0xff.
key.extend_from_slice(&insert.room_id.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.username.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.variable_name.as_bytes());
key
}
/// Create an old key, where delineator is separated by 0xfe.
fn create_new_key(prefix: &[u8], insert: &UserVariableEntry) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(&prefix); //prefix already has 0xff.
key.extend_from_slice(&insert.room_id.as_bytes());
key.push(0xfe);
key.extend_from_slice(&insert.username.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.variable_name.as_bytes());
key
}
pub fn migrate(db: &Database) -> Result<(), DataError> {
let tree = &db.variables.room_user_variables;
let prefix = b"variables";
let results: Vec<UserVariableEntry> = tree
.scan_prefix(&prefix)
.map(extract_v1_entries)
.collect::<Result<Vec<_>, MigrationError>>()?;
let mut batch = Batch::default();
for insert in results {
let old = create_old_key(prefix, &insert);
let new = create_new_key(prefix, &insert);
batch.remove(old);
batch.insert(new, insert.value);
}
tree.apply_batch(batch)?;
Ok(())
}
}
/// Move the user variable entries into two tree structures, with yet
/// another key format change. Now there is one tree for variable
/// counts, and one tree for actual user variables. Keys in the user
/// variable tree were changed to be username-first, then room ID.
/// They are still separated by 0xfe, while the variable name is
/// separated by 0xff. Variable count now stores just
/// USERNAME0xfeROOM_ID and a count in its own tree. This enables
/// public use of a strongly typed UserAndRoom struct for getting
/// variables.
pub(in crate::db) mod change_tree_structure {
use super::*;
/// An entry in the room user variables keyspace.
struct UserVariableEntry {
room_id: String,
username: String,
variable_name: String,
value: IVec,
}
/// Extract keys and values from the variables keyspace according
/// to the v1 schema.
fn extract_v1_entries(
entry: sled::Result<(IVec, IVec)>,
) -> Result<UserVariableEntry, MigrationError> {
if let Ok((key, value)) = entry {
let keys: Vec<Result<&str, _>> = key
.split(|&b| b == 0xff || b == 0xfe)
.map(|b| str::from_utf8(b))
.collect();
if let &[_, Ok(room_id), Ok(username), Ok(variable)] = keys.as_slice() {
Ok(UserVariableEntry {
room_id: room_id.to_owned(),
username: username.to_owned(),
variable_name: variable.to_owned(),
value: value,
})
} else {
Err(MigrationError::MigrationFailed(
"a key violates utf8 schema".to_string(),
))
}
} else {
Err(MigrationError::MigrationFailed(
"encountered unexpected key".to_string(),
))
}
}
/// Create an old key, of "variables" 0xff "room id" 0xfe "username" 0xff "variablename".
fn create_old_key(prefix: &[u8], insert: &UserVariableEntry) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(&prefix); //prefix already has 0xff.
key.extend_from_slice(&insert.room_id.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.username.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.variable_name.as_bytes());
key
}
/// Create a new key, of "username" 0xfe "room id" 0xff "variablename".
fn create_new_key(insert: &UserVariableEntry) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(&insert.username.as_bytes());
key.push(0xfe);
key.extend_from_slice(&insert.room_id.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.variable_name.as_bytes());
key
}
pub fn migrate(db: &Database) -> Result<(), DataError> {
let variables_tree = &db.variables.room_user_variables;
let count_tree = &db.variables.room_user_variable_count;
let prefix = b"variables";
let results: Vec<UserVariableEntry> = variables_tree
.scan_prefix(&prefix)
.map(extract_v1_entries)
.collect::<Result<Vec<_>, MigrationError>>()?;
let mut counts: HashMap<(String, String), i32> = HashMap::new();
let mut batch = Batch::default();
for insert in results {
let count = counts
.entry((insert.username.clone(), insert.room_id.clone()))
.or_insert(0);
*count += 1;
let old = create_old_key(prefix, &insert);
let new = create_new_key(&insert);
batch.remove(old);
batch.insert(new, insert.value);
}
let mut count_batch = Batch::default();
counts.into_iter().for_each(|((username, room_id), count)| {
let mut key = username.as_bytes().to_vec();
key.push(0xfe);
key.extend_from_slice(room_id.as_bytes());
let db_value: I32<LittleEndian> = I32::new(count);
count_batch.insert(key, db_value.as_bytes());
});
variables_tree.apply_batch(batch)?;
count_tree.apply_batch(count_batch)?;
Ok(())
}
}

View File

@ -1,5 +1,6 @@
use crate::context::Context; use crate::context::Context;
use crate::db::Variables; use crate::db::sqlite::Variables;
use crate::db::variables::UserAndRoom;
use crate::error::BotError; use crate::error::BotError;
use crate::error::DiceRollingError; use crate::error::DiceRollingError;
use crate::parser::Amount; use crate::parser::Amount;

View File

@ -1,6 +1,6 @@
use crate::commands::CommandError;
use crate::config::ConfigError; use crate::config::ConfigError;
use crate::db::errors::DataError; use crate::db::errors::DataError;
use crate::{commands::CommandError, db::sqlite::migrator};
use thiserror::Error; use thiserror::Error;
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -21,6 +21,9 @@ pub enum BotError {
#[error("database error: {0}")] #[error("database error: {0}")]
DataError(#[from] DataError), DataError(#[from] DataError),
#[error("sqlite database error: {0}")]
SqliteDataError(#[from] crate::db::sqlite::errors::DataError),
#[error("the message should not be processed because it failed validation")] #[error("the message should not be processed because it failed validation")]
ShouldNotProcessError, ShouldNotProcessError,
@ -70,6 +73,12 @@ pub enum BotError {
#[error("variables not yet supported")] #[error("variables not yet supported")]
VariablesNotSupported, VariablesNotSupported,
#[error("database error")]
DatabaseError(#[from] sled::Error),
#[error("database migration error: {0}")]
SqliteError(#[from] migrator::MigrationError),
#[error("too many commands or message was too large")] #[error("too many commands or message was too large")]
MessageTooLarge, MessageTooLarge,

View File

@ -1,8 +1,9 @@
use crate::db::Rooms; use crate::db::sqlite::errors::DataError;
use crate::db::sqlite::Rooms;
use crate::error::BotError; use crate::error::BotError;
use crate::matrix; use crate::matrix;
use crate::models::RoomInfo; use crate::models::RoomInfo;
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt, TryStreamExt};
use matrix_sdk::{self, identifiers::RoomId, Client}; use matrix_sdk::{self, identifiers::RoomId, Client};
/// Record the information about a room, including users in it. /// Record the information about a room, including users in it.