Compare commits

...

101 Commits

Author SHA1 Message Date
projectmoon 86df3c5d1f Do not process commands coming from ourselves (help text)
continuous-integration/drone/push Build is passing Details
2024-09-26 09:18:42 +02:00
projectmoon 38a7e50c5c Don't forget to update xbps on final stage too
continuous-integration/drone/push Build is passing Details
2024-09-25 23:06:20 +02:00
projectmoon e309fd1fc6 Sync xbps and update it before everything else.
continuous-integration/drone/push Build is failing Details
2024-09-25 22:56:02 +02:00
projectmoon 9262fe2cac move xbps update after sync
continuous-integration/drone/push Build is failing Details
2024-09-25 22:41:54 +02:00
projectmoon 724a781e7c Attempt to correct error in docker image
continuous-integration/drone/push Build is failing Details
2024-09-25 22:30:30 +02:00
projectmoon ef074beb96 Drone: Update to Rust 1.80 builder
continuous-integration/drone/push Build is failing Details
continuous-integration/drone Build is failing Details
2024-09-25 21:56:03 +02:00
projectmoon 81a69f329a Update for Rust 1.80.x
continuous-integration/drone/push Build is failing Details
2024-09-25 20:57:19 +02:00
projectmoon c9e7efa61d update to sqlx 0.6
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
2023-04-13 21:12:04 +02:00
projectmoon f295f2b7b6 Update to Matrix SDK 0.6 (#98)
continuous-integration/drone/push Build is passing Details
Quite a few changes involved. Mostly variable renames and a few changes to `await`s.

Not ready yet because bot cannot login due to some arcane error of `expected value at line 1 column 1`.

Co-authored-by: projectmoon <projectmoon@noreply.git.agnos.is>
Reviewed-on: projectmoon/tenebrous-dicebot#98
2023-04-13 19:04:48 +00:00
projectmoon 090ce9be45 Add help topic for variables
continuous-integration/drone/push Build is passing Details
Fixes #60
2023-04-05 07:59:13 +02:00
projectmoon 2a6dff3e07 Update cargo deps
continuous-integration/drone/push Build is failing Details
2023-04-05 07:58:26 +02:00
projectmoon 952f35d53a Rust 1.68 (#99)
continuous-integration/drone/push Build is failing Details
Update to Rust 1.68

Co-authored-by: projectmoon <projectmoon@noreply.git.agnos.is>
Reviewed-on: projectmoon/tenebrous-dicebot#99
2023-04-05 05:57:16 +00:00
projectmoon 552daa4746 Add a game system column to room info (#95)
continuous-integration/drone/push Build is passing Details
Adds a new enum and table in preparation for storing game information about a specific room.

Reviewed-on: projectmoon/tenebrous-dicebot#95
2022-02-02 20:56:50 +00:00
projectmoon c514b85510 Change modifier order in Cthulhu
continuous-integration/drone/push Build is passing Details
2021-11-06 21:23:51 +00:00
projectmoon 6eb81f43d5 Change CofD modifiers to come after dice pool 2021-11-06 21:23:51 +00:00
projectmoon 44b1e0f649 Switch to working (but somewhat bigger) Void docker image
continuous-integration/drone/push Build is passing Details
2021-11-06 13:47:23 +00:00
projectmoon a8ccdc9cce Update rust test image version for CI.
continuous-integration/drone/push Build is passing Details
2021-11-05 19:37:59 +00:00
projectmoon 13ce7b3ee6 Readme update (aka force build)
continuous-integration/drone/push Build is failing Details
2021-11-05 17:53:53 +00:00
projectmoon 6f09a11586 Upgrade to matrix SDK 0.4. 2021-11-05 15:34:16 +00:00
projectmoon ee3ec18e06 Refactor keep-drop parsing into function, better error handling. (#93)
continuous-integration/drone/push Build is passing Details
This commit refactors the keep-drop parsing into two separate
functions: one for extracting keep-drop text, and one for actually
doing something with the extracted values. An intermediate enum is
introduced to contain extracted text, instead of relying on Ok/Err
values directly for figuring out what to do with the values.

This allows us to express "this behavior is correct, and all others
are not" instead of using a "fall back to secondary functionality"
approach.

Reviewed-on: projectmoon/tenebrous-dicebot#93
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-09-30 21:16:00 +00:00
projectmoon 126548d868 Do not panic on invalid dice/sides amount for keep/drop.
continuous-integration/drone/push Build is passing Details
Insted of unwrap(), map error to a nom parser error. Not the best-est
solution, but it is functional. The TooLarge value seems appropriate.
2021-09-26 14:15:12 +00:00
Matthew Sparks 7e7e9e534e Adding None enum to keep/drop, cleaning up matches
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
2021-09-24 23:03:20 -04:00
Matthew Sparks 2d9853fbf0 Updating README for new drop command
continuous-integration/drone/pr Build is passing Details
2021-09-17 23:15:55 -04:00
Matthew Sparks 3d6210b32d Adding enum for exclusive drop/keep
continuous-integration/drone/pr Build is passing Details
2021-09-17 23:11:13 -04:00
Matthew Sparks 8b5973475f Forgot to fix tests, fixing keep/drop Err case
continuous-integration/drone/pr Build is passing Details
2021-09-17 22:18:23 -04:00
Matthew Sparks 1992ef4e08 Updating roll doc
continuous-integration/drone/pr Build is failing Details
2021-09-17 22:08:51 -04:00
Matthew Sparks f904e3a948 Updating match blocks for keep/drop
continuous-integration/drone/pr Build is failing Details
2021-09-17 21:45:30 -04:00
Matthew Sparks 8317f40f61 Updating README for keep/drop
continuous-integration/drone/pr Build is passing Details
2021-09-16 23:25:26 -04:00
Matthew Sparks 069ee47364 Adding drop function 2021-09-16 22:55:11 -04:00
Matthew Sparks dc242182f4 Fix string comparison in keep/count check, and add test cases 2021-09-07 23:59:49 -04:00
Matthew Sparks 15163ac11d Adding calculations for keep, and adding validation on keep input 2021-09-07 22:10:14 -04:00
Matthew Sparks 1860eaf378 Adding parsing for keeping highest dice 2021-09-06 21:43:46 -04:00
Matthew Sparks 2654887d8c Initial commit to add keep to dice struct and preserve parser test cases 2021-09-06 21:43:46 -04:00
projectmoon 125f3d0cee Fix drone yml to produce docker images again.
continuous-integration/drone/push Build is passing Details
2021-09-06 23:58:05 +00:00
projectmoon a4c3d34a97 Version 0.13.1
continuous-integration/drone/push Build is passing Details
2021-09-06 22:21:24 +00:00
projectmoon 86fbb05e54 Run Drone CI on tags
continuous-integration/drone/push Build is passing Details
2021-09-06 22:18:06 +00:00
projectmoon 661a943672 Readme Updates (#91)
continuous-integration/drone/push Build was killed Details
Add contributing information.

Add support/community section.

Add matrix room badge
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-09-06 22:15:20 +00:00
projectmoon d65715dee6 Remove example room ID from tonic_client
continuous-integration/drone/push Build is passing Details
2021-09-05 20:38:45 +00:00
projectmoon 55a3bfb861 Update readme for crates.io installation. 2021-09-05 20:38:09 +00:00
projectmoon 0050810182 Fix dicebot readme link
continuous-integration/drone/push Build is passing Details
2021-09-05 20:22:42 +00:00
projectmoon 3ba546d4a4 Add metadata to rpc package.
continuous-integration/drone/push Build is passing Details
2021-09-05 20:14:56 +00:00
projectmoon ffded7b572 Add metadata to rpc package. 2021-09-05 20:14:13 +00:00
projectmoon cf93d14913 Version 0.13.0
continuous-integration/drone/push Build is passing Details
2021-09-05 19:08:27 +00:00
projectmoon cf6dd96b34 Update sqlx and refinery to newer versions (#88)
continuous-integration/drone/push Build is passing Details
For some reason, also required rewriting database tests to deal with
tempfile deleting files after scope drop. This never used to occur,
but now it does! So now the unit tests are in a closure where the temp
file is dropped at the end of the test. Really should just use sqlx
migrations, then we can get an in-memory database.

Co-authored-by: projectmoon <projectmoon@agnos.is>
Reviewed-on: projectmoon/tenebrous-dicebot#88
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-09-05 07:56:41 +00:00
projectmoon c8c6f4d6f0 Fix dependency specification for rpc crate in dicebot.
continuous-integration/drone/push Build is passing Details
2021-09-04 23:24:52 +00:00
projectmoon 2488429edb Version 0.12.0
continuous-integration/drone/push Build is passing Details
2021-09-04 22:23:36 +00:00
projectmoon f68d5ffcc1 Update to versioned matrix SDK.
continuous-integration/drone/push Build is passing Details
2021-09-04 21:37:49 +00:00
projectmoon 473e899275 Merge branch 'kg333-master'
continuous-integration/drone/push Build is passing Details
Merge PR #43 from github to fix docker build.
2021-09-03 09:33:02 +00:00
projectmoon 1f03837bfe Merge branch 'master' of https://github.com/kg333/matrix-dicebot into kg333-master 2021-09-03 09:32:48 +00:00
projectmoon 0059e3d133 Revert "Initial prototype of web UI and web API."
continuous-integration/drone/push Build is failing Details
This reverts commit cab856241d.
2021-09-03 09:29:52 +00:00
matthew 915b82d0aa Updating GPG key server; sks-keyservers.net is offline permanently 2021-08-28 00:12:12 +00:00
projectmoon cab856241d Initial prototype of web UI and web API.
continuous-integration/drone/push Build is failing Details
This commit shuffles the entire repository around into multiple crates, bringing with it an in-progress web UI and web AI. It was merged prematurely to allow for dependency upgrades of the matrix SDK.

The build should still only produce the dicebot image.
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-07-15 15:04:50 +00:00
projectmoon 764426382a Convert project to workspace with Tonic for gRPC. (#84)
continuous-integration/drone/push Build is passing Details
Convert project to workspace with Tonic for gRPC.

This commit adds an RPC service to the dicebot, allowing external
applications to control it. The project was converted to a cargo
workspace to house the protobuf definitions in a common crate
(tenebrous-rpc), so that clients and servers can make use of these
protobuf definitions.
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-06-02 21:09:58 +00:00
projectmoon b4321721c4 Minor documentation update.
continuous-integration/drone/push Build is passing Details
2021-05-30 22:53:56 +00:00
projectmoon 494d28486e Remove Box<dyn Command> conversion impls for map in macro.
continuous-integration/drone/push Build is passing Details
2021-05-30 22:49:28 +00:00
projectmoon b7393c1907 Use active room in relevant commands.
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
2021-05-30 14:19:13 +00:00
projectmoon 3d2eb14cd3 Change room in context to origin_room, add active_room.
The context now knows about origin room (the room where the command
was executed), and the "active room," which is the room that the user
wants the command to apply to. If no active room is defined, then the
origin room acts as the active room. In a public room with the bot,
the active room is also the same as the origin room.
2021-05-30 14:18:56 +00:00
projectmoon 53339282e0 Actually set room when running SetRoomCommand (#79)
continuous-integration/drone/push Build is passing Details
Also sort rooms in get_rooms_for_user for consistency.
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-05-29 20:26:20 +00:00
projectmoon 7050cf037a Remove return statements in Fuseable impl for room search.
continuous-integration/drone/push Build is passing Details
2021-05-29 14:49:24 +00:00
projectmoon 0c0ddafd03 Search for rooms closure as a separate variable.
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
2021-05-28 21:19:26 +00:00
projectmoon 7f0bdc1e82 Unit test for search_rooms
continuous-integration/drone/push Build is passing Details
continuous-integration/drone/pr Build is passing Details
2021-05-28 21:13:19 +00:00
projectmoon 0ca7ad4db0 Minor fix to command logging.
continuous-integration/drone/push Build is passing Details
continuous-integration/drone/pr Build is passing Details
2021-05-28 15:08:00 +00:00
projectmoon 59be127430 Implement set room command; common code for list and set rooms.
Adds fuzzy room search that can also set by exact ID, and refactors
the code to get room list for user into a common function and struct
for use by both commands.
2021-05-28 15:08:00 +00:00
projectmoon e9c0a184bd Show room list with preformatted text.
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
2021-05-27 20:47:54 +00:00
projectmoon 589d0e0dbf From<String> for ListRoomsCommand
continuous-integration/drone/push Build is passing Details
continuous-integration/drone/pr Build is passing Details
2021-05-27 15:56:15 +00:00
projectmoon 892ccf73e3 Basic list rooms command. Needs formatting.
continuous-integration/drone/push Build is failing Details
continuous-integration/drone/pr Build is failing Details
2021-05-27 15:52:16 +00:00
projectmoon 896acee5ba Avoid cloned command input with From<String> instead of From<&str>.
continuous-integration/drone/push Build is passing Details
2021-05-27 15:50:43 +00:00
projectmoon d70df44d2a Remove MIT notice from bot event handlers
continuous-integration/drone/push Build is passing Details
2021-05-26 22:40:15 +00:00
projectmoon 5f15e62c6d Remove 'project' from intial informational text in license.
continuous-integration/drone/push Build is failing Details
2021-05-26 22:39:09 +00:00
projectmoon ed3b582aad Matrix SDK isn't MIT anymore.
continuous-integration/drone/push Build is passing Details
2021-05-26 22:35:12 +00:00
projectmoon 49db0062a3 Various improvements to bot responses.
continuous-integration/drone/push Build is passing Details
- Do not display username pill with quoted HTML replies.
 - Do not attempt to create matrix.to link in plain text replies.
 - Move plain text formatting responsibility outside of matrix
   send_message function.
2021-05-26 22:20:53 +00:00
projectmoon 4ae871224a Remove ExecutionError, as it is unnecessary.
continuous-integration/drone/push Build is passing Details
2021-05-26 21:25:32 +00:00
projectmoon 1ebd13e912 Change execution_allowed to a match for shorter reading.
continuous-integration/drone/push Build is passing Details
2021-05-26 21:12:21 +00:00
projectmoon 8f5b6f0636 Replace db query with simple in-memory check of if account already exists.
continuous-integration/drone/push Build is passing Details
2021-05-26 21:04:53 +00:00
projectmoon 5b3d174edc Separate registering and linking accounts.
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
Can register an account with the bot to manage variables and stuff in
private room, and then separately "link" it with a password, which
makes it available to anything using the bot API (aka web app). Can
also unlink and unregister. Check command no longer validates
password. It just checks and reports your account status.
2021-05-26 15:28:59 +00:00
projectmoon 495df13fe6 Do not automatically create accounts; use enum to show this instead.
continuous-integration/drone/push Build is passing Details
continuous-integration/drone/pr Build is passing Details
Instead of automatically creating a user account entry for any user
executing a command, we use an Account enum which covers both
registered and "transient" unregistered users. If a user registers,
the context has the actual user instance available, with state and
everything. If a user is unregistered, then the account is considered
transient for the request, with only the username available.
2021-05-26 14:20:18 +00:00
projectmoon de92fc8488 Remove nested <p> tags in error messages.
continuous-integration/drone/push Build is passing Details
2021-05-26 07:06:00 +00:00
projectmoon b05129ad9f Localize all command parsing code into trait impls.
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
This cleans up the command parser a lot, as all of the one or two line
functions and associated imports have been removed. Unfortunately it
does make the command files larger, as two trait impls are required:
one for converting to Box<dyn Command>, and one for converting from
&str to the command type.

Fixes #66.
2021-05-25 23:55:50 +00:00
projectmoon 5d002e5063 Add ability to store user active room, with skeleton accounts.
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
- Adds a user_state table, currently only with active_room.
 - A user must have an account to take advantage of state.
 - Now, all users will get an 'account' even if they don't explicitly register.
 - Bonus: converts user queries to compile-time checked macros.

To support these automatically created "accounts," the accounts table
now also has an account_status column, indicating if the user is
registered or not (or pending activation--future use).

The User model has been updated with extra properties from the state,
and the user is now carrried in the Context during command execution.
A user is ensured to be created before executing the command.
2021-05-25 22:29:01 +00:00
projectmoon 849a1b6a14 Remove most of Room DB API
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
2021-05-24 22:25:20 +00:00
projectmoon 97be5d5ccb Add migration to remove room state management tables.
continuous-integration/drone/push Build is failing Details
continuous-integration/drone/pr Build is failing Details
2021-05-24 22:10:41 +00:00
projectmoon 395753e8a9 Remove room state mgmt; let matrix SDK do it on-demand instead.
continuous-integration/drone/push Build is passing Details
continuous-integration/drone/pr Build is passing Details
Fixes #71.

Fixes #20.
2021-05-24 21:45:51 +00:00
projectmoon df0248d99a More useful account registration message.
continuous-integration/drone/push Build is passing Details
2021-05-23 13:58:58 +00:00
projectmoon 76214bc790 Add an account deletion command.
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
2021-05-22 23:12:17 +00:00
projectmoon 921c4cd644 Update sqlx offline json for user query.
continuous-integration/drone/push Build is passing Details
continuous-integration/drone/pr Build is passing Details
2021-05-22 22:53:01 +00:00
projectmoon 8c2a90e86b Tests for secure commands and user DB API.
continuous-integration/drone/pr Build was killed Details
continuous-integration/drone/push Build is failing Details
2021-05-22 22:48:47 +00:00
projectmoon 926dae57fb Add check password command.
continuous-integration/drone/push Build is failing Details
continuous-integration/drone/pr Build is failing Details
2021-05-22 22:25:00 +00:00
projectmoon 4557498ac6 Improved command logging, sensitive to secure commands.
continuous-integration/drone/push Build is failing Details
continuous-integration/drone/pr Build is failing Details
2021-05-22 22:17:33 +00:00
projectmoon ca34841d86 Functional user account registration.
continuous-integration/drone/push Build is failing Details
continuous-integration/drone/pr Build is failing Details
2021-05-22 14:52:32 +00:00
projectmoon c1ec7366e4 Add user accounts, registration command, secure command valiation. 2021-05-22 14:01:16 +00:00
projectmoon a84d4fd869 Make command parsing case insensitive.
continuous-integration/drone/push Build is passing Details
2021-05-21 22:40:03 +00:00
projectmoon 34ee2c6e5d Consider command execution secure when proper conditions are met.
continuous-integration/drone/push Build is failing Details
- If the room is end-to-end encrypted.
 - If only the sending user and the bot are present in the room.

This lays groundwork for sensitive commands like registering a user
account with the bot.
2021-05-21 22:28:45 +00:00
projectmoon 9de74d05a9 Add an is_secure attribute for commands. 2021-05-21 15:32:08 +00:00
projectmoon 5643677627 Consolidate dice and variable parsers under parser module.
continuous-integration/drone/push Build is passing Details
2021-05-21 14:44:03 +00:00
projectmoon de63fd914e Move commands.rs to commands/mod.rs; move migrate_cli.rs. 2021-05-21 14:35:56 +00:00
projectmoon e73ad118b2 Move some declaration-only modules to mod.rs files in folders. 2021-05-21 14:30:46 +00:00
projectmoon 3d5cda39c8 Consolidate dice module into logic module. 2021-05-21 14:26:58 +00:00
projectmoon 402f236ba7 Remove sled and all related crates from dependencies.
continuous-integration/drone/push Build is passing Details
2021-05-21 14:21:22 +00:00
projectmoon 059538b95d Remove remaining warnings. 2021-05-21 14:14:03 +00:00
projectmoon 4de273db4a Remove sled code; promote sql to top level 2021-05-21 14:05:25 +00:00
projectmoon a33367fada Update dependencies to fix matrix SDK list users bug.
continuous-integration/drone/push Build is passing Details
2021-05-20 15:40:52 +00:00
97 changed files with 6165 additions and 5073 deletions

View File

@ -3,18 +3,20 @@ name: build-and-test
steps:
- name: test
image: rust:1.51
image: rust:1.80
commands:
- apt-get update
- apt-get install -y cmake
- rustup component add rustfmt
- cargo build --verbose --all
- cargo test --verbose --all
- name: docker
image: plugins/docker
when:
branch:
- master
ref:
- refs/tags/v*
- refs/heads/master
settings:
auto_tag: true
username:

3277
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,53 +1,6 @@
[package]
name = "tenebrous-dicebot"
version = "0.10.0"
authors = ["Taylor C. Richberger <taywee@gmx.com>", "projectmoon <projectmoon@agnos.is>"]
edition = "2018"
license = 'AGPL-3.0-or-later'
description = 'An async Matrix dice bot for role-playing games'
readme = 'README.md'
repository = 'https://git.agnos.is/projectmoon/matrix-dicebot'
keywords = ["games", "dice", "matrix", "bot"]
categories = ["games"]
[workspace]
[[bin]]
name = "dicebot-migrate"
path = "src/migrate_cli.rs"
[dependencies]
log = "0.4"
tracing-subscriber = "0.2"
toml = "0.5"
nom = "5"
rand = "0.8"
thiserror = "1.0"
itertools = "0.10"
async-trait = "0.1"
url = "2.1"
dirs = "3.0"
indoc = "1.0"
combine = "4.5"
sled = "0.34"
zerocopy = "0.5"
byteorder = "1.3"
futures = "0.3"
memmem = "0.1"
bincode = "1.3"
html2text = "0.2"
phf = { version = "0.8", features = ["macros"] }
matrix-sdk = { git = "https://github.com/matrix-org/matrix-rust-sdk", branch = "master" }
refinery = { version = "0.5", features = ["rusqlite"]}
barrel = { version = "0.6", features = ["sqlite3"] }
tempfile = "3"
[dependencies.sqlx]
version = "0.5"
features = [ "offline", "sqlite", "runtime-tokio-native-tls" ]
[dependencies.serde]
version = "1"
features = ['derive']
[dependencies.tokio]
version = "1"
features = [ "full" ]
members = [
"dicebot",
"rpc"
]

View File

@ -1,16 +1,15 @@
# Builder image with development dependencies.
FROM bougyman/voidlinux:glibc as builder
FROM ghcr.io/void-linux/void-linux:latest-mini-x86_64 as builder
RUN xbps-install -S
RUN xbps-install -yu xbps
RUN xbps-install -Syu
RUN xbps-install -Sy base-devel rustup cargo cmake wget gnupg
RUN xbps-install -Sy base-devel rustup cmake wget gnupg
RUN xbps-install -Sy openssl-devel libstdc++-devel
RUN rustup-init -qy
# Install tini for signal processing and zombie killing
ENV TINI_VERSION v0.19.0
ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /usr/local/bin/tini
ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini.asc /tini.asc
RUN gpg --batch --keyserver hkp://p80.pool.sks-keyservers.net:80 --recv-keys 595E85A6B1B4779EA4DAAEC70B588DFF0527A9B7 \
&& gpg --batch --verify /tini.asc /usr/local/bin/tini
RUN chmod +x /usr/local/bin/tini
# Build dicebot
@ -20,7 +19,10 @@ ADD . ./
RUN . /root/.cargo/env && cargo build --release
# Final image
FROM bougyman/voidlinux:tiny
FROM ghcr.io/void-linux/void-linux:latest-mini-x86_64
RUN xbps-install -S
RUN xbps-install -yu xbps
RUN xbps-install -Syu
RUN xbps-install -Sy ca-certificates libstdc++
COPY --from=builder \
/root/src/target/release/dicebot \

14
LICENSE
View File

@ -1,16 +1,12 @@
This software project is governed by the terms of the Affero GNU
General Public License. Portions of the code come from the original
This software is governed by the terms of the Affero GNU General
Public License. Portions of the code come from the original
MIT-licensed project, and the terms of the MIT license also apply to
those portions. Some portions of the code are also taken from the Rust
Matrix SDK examples, which are governed by the MIT license. In files
that are partially or wholly subject to the MIT license in addition to
the Affero GNU General Public License, this is noted with a header at
the top of the file.
those portions. In files that are partially or wholly subject to the
MIT license in addition to the Affero GNU General Public License, this
is noted with a header at the top of the file.
Original upstream project: https://gitlab.com/Taywee/axfive-matrix-dicebot
Rust Matrix SDK: https://github.com/matrix-org/matrix-rust-sdk
For code from the original project that is governed by the MIT license
in addition to the Affero GNU General Public License, the following
terms apply:

View File

@ -1,6 +1,7 @@
# Tenebrous Dicebot
[![Build Status](https://drone.agnos.is/api/badges/projectmoon/tenebrous-dicebot/status.svg)](https://drone.agnos.is/projectmoon/tenebrous-dicebot)
[![Matrix Chat](https://img.shields.io/matrix/tenebrous:agnos.is?label=matrix&server_fqdn=matrix.org)][matrix-room]
_This repository is hosted on [Agnos.is Git][main-repo] and mirrored
to [GitHub][github-repo]._
@ -24,6 +25,23 @@ System.
* Works in encrypted or unencrypted Matrix rooms.
* Storing variables created by the user.
## Support and Community
The project has a Matrix room at [#tenebrous:agnos.is][matrix-room].
It is also possible to make a post in [GitHub
Discussions][github-discussions].
For reporting bugs, we prefer that you open an issue on
[git.agnos.is][agnosis-git-issues]. However, you may also open an
issue on [GitHub][github-issues].
### Development and Contributions
All development occurs on [git.agnos.is][main-repo]. If you wish to
contribute, please open a pull request there. In some cases, pull
requests from GitHub may be accepted. All contributions must be
licensed under [AGPL 3.0 or later][agpl] to be accepted.
## Building and Installation
### Docker Image
@ -46,6 +64,17 @@ root of the repository.
After pulling or building the image, see [instructions on how to use
the Docker image](#running-the-bot).
### Install from crates.io
The project can be from [crates.io][crates-io]. To install it, execute
`cargo install tenebrous-dicebot`. This will make the following
executables available on your system:
* `dicebot`: Main dicebot executable.
* `dicebot-cmd`: Run dicebot commands from the command line.
* `dicebot_migrate`: Standalone database migrator (not required).
* `tonic_client`: Test client for the gRPC connection (not required).
### Build from Source
Precompiled executables are not yet available. Clone this repository
@ -89,8 +118,16 @@ expressions.
!r 3d12 - 5d2 + 3 - 7d3 + 20d20
```
This system does not yet have the capability to handle things like D&D
5e advantage or disadvantage.
#### Keep/Drop Dice
The bot supports either keeping the highest dice in a roll, or
dropping the highest dice in a roll. This allows the bot to handle
things like D&D 5e advantage or disadvantage.
```
!roll 2d20k1
!r 2d20dh1 + 5
!r 10d10k5 + 10d10dh5 - 2
```
### Storytelling System
@ -241,6 +278,7 @@ The most basic plans are:
* Perhaps some sort of character sheet integration. But for that, we
would need a sheet service.
* Use environment variables instead of config file in Docker image.
* Per-system game rules.
## Credits
@ -254,3 +292,9 @@ support added for Chronicles of Darkness and Call of Cthulhu.
[main-repo]: https://git.agnos.is/projectmoon/tenebrous-dicebot
[github-repo]: https://github.com/ProjectMoon/matrix-dicebot
[roadmap]: https://git.agnos.is/projectmoon/tenebrous-dicebot/wiki/Roadmap
[crates-io]: https://crates.io/crates/tenebrous-dicebot
[matrix-room]: https://matrix.to/#/#tenebrous:agnos.is
[agnosis-git-issues]: https://git.agnos.is/projectmoon/tenebrous-dicebot/issues
[github-discussions]: https://github.com/ProjectMoon/matrix-dicebot/discussions
[github-issues]: https://github.com/ProjectMoon/matrix-dicebot/issues
[agpl]: https://www.gnu.org/licenses/agpl-3.0.en.html

57
dicebot/Cargo.toml Normal file
View File

@ -0,0 +1,57 @@
[package]
name = "tenebrous-dicebot"
version = "0.13.2"
rust-version = "1.68"
authors = ["projectmoon <projectmoon@agnos.is>", "Taylor C. Richberger <taywee@gmx.com>"]
edition = "2018"
license = 'AGPL-3.0-or-later'
description = 'An async Matrix dice bot for role-playing games'
readme = '../README.md'
repository = 'https://git.agnos.is/projectmoon/matrix-dicebot'
keywords = ["games", "dice", "matrix", "bot"]
categories = ["games"]
[build-dependencies]
tonic-build = "0.4"
[dependencies]
# indexmap version locked fixes a dependency cycle.
# indexmap = "=1.6.2"
log = "0.4"
tracing-subscriber = "0.2"
toml = "0.5"
nom = "5"
rand = "0.8"
rust-argon2 = "0.8"
thiserror = "1.0"
itertools = "0.10"
async-trait = "0.1"
url = "2.1"
dirs = "3.0"
indoc = "1.0"
combine = "4.5"
futures = "0.3"
html2text = "0.2"
phf = { version = "0.8", features = ["macros"] }
matrix-sdk = { version = "0.6" }
refinery = { version = "0.8", features = ["rusqlite"]}
barrel = { version = "0.7", features = ["sqlite3"] }
strum = { version = "0.22", features = ["derive"] }
tempfile = "3"
substring = "1.4"
fuse-rust = "0.2"
tonic = "0.4"
prost = "0.7"
tenebrous-rpc = { path = "../rpc", version = "0.1.0" }
[dependencies.sqlx]
version = "0.6"
features = [ "offline", "sqlite", "runtime-tokio-native-tls" ]
[dependencies.serde]
version = "1"
features = ['derive']
[dependencies.tokio]
version = "1"
features = [ "full" ]

View File

@ -18,6 +18,26 @@
]
}
},
"26903a92a7de34df3e227fe599e41ae1bb61612eb80befad398383af36df0ce4": {
"query": "DELETE FROM accounts WHERE user_id = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 1
},
"nullable": []
}
},
"2d4a32735da04509c2e3c4f99bef79ef699964f58ae332b0611f3de088596e1e": {
"query": "INSERT INTO accounts (user_id, password, account_status)\n VALUES (?, ?, ?)\n ON CONFLICT(user_id) DO\n UPDATE SET password = ?, account_status = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 5
},
"nullable": []
}
},
"59313c67900a1a9399389720b522e572f181ae503559cd2b49d6305acb9e2207": {
"query": "SELECT key, value as \"value: i32\" FROM user_variables\n WHERE room_id = ? AND user_id = ?",
"describe": {
@ -60,6 +80,16 @@
]
}
},
"667b26343ce44e1c48ac689ce887ef6a0558a2ce199f7372a5dce58672499c5a": {
"query": "INSERT INTO user_state (user_id, active_room)\n VALUES (?, ?)\n ON CONFLICT(user_id) DO\n UPDATE SET active_room = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 3
},
"nullable": []
}
},
"711d222911c1258365a6a0de1fe00eeec4686fd3589e976e225ad599e7cfc75d": {
"query": "SELECT count(*) as \"count: i32\" FROM user_variables\n WHERE room_id = ? and user_id = ?",
"describe": {
@ -78,66 +108,6 @@
]
}
},
"7248c8ae30bbe4bc5866e80cc277312c7f8cb9af5a8801fd8eaf178fd99eae18": {
"query": "SELECT room_id FROM room_users\n WHERE username = ?",
"describe": {
"columns": [
{
"name": "room_id",
"ordinal": 0,
"type_info": "Text"
}
],
"parameters": {
"Right": 1
},
"nullable": [
false
]
}
},
"97f5d58f62baca51efd8c295ca6737d1240923c69c973621cd0a718ac9eed99f": {
"query": "SELECT room_id, room_name FROM room_info\n WHERE room_id = ?",
"describe": {
"columns": [
{
"name": "room_id",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "room_name",
"ordinal": 1,
"type_info": "Text"
}
],
"parameters": {
"Right": 1
},
"nullable": [
false,
false
]
}
},
"b302d586e5ac4c72c2970361ea5a5936c0b8c6dad10033c626a0ce0404cadb25": {
"query": "SELECT username FROM room_users\n WHERE room_id = ?",
"describe": {
"columns": [
{
"name": "username",
"ordinal": 0,
"type_info": "Text"
}
],
"parameters": {
"Right": 1
},
"nullable": [
false
]
}
},
"bba0fc255e7c30d1d2d9468c68ba38db6e8a13be035aa1152933ba9247b14f8c": {
"query": "SELECT event_id FROM room_events\n WHERE room_id = ? AND event_id = ?",
"describe": {
@ -155,5 +125,15 @@
false
]
}
},
"dce9bb45cf954054a920ee8b53852c6d562e3588d76bbfaa1433d8309d4e4921": {
"query": "DELETE FROM user_state WHERE user_id = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 1
},
"nullable": []
}
}
}

View File

@ -6,23 +6,52 @@
use std::fmt;
use std::ops::{Deref, DerefMut};
//Old stuff, for regular dice rolling. To be moved elsewhere.
/// A basic dice roll, in XdY notation, like "1d4" or "3d6".
/// Optionally supports D&D advantage/disadvantge keep-or-drop
/// functionality.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct Dice {
pub(crate) count: u32,
pub(crate) sides: u32,
pub(crate) keep_drop: KeepOrDrop,
}
/// Enum indicating how to handle bonuses or penalties using extra
/// dice. If set to Keep, the roll will keep the highest X number of
/// dice in the roll, and add those together. If set to Drop, the
/// opposite is performed, and the lowest X number of dice are added
/// instead. If set to None, then all dice in the roll are added up as
/// normal.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum KeepOrDrop {
/// Keep only the X highest dice for adding up to the total.
Keep(u32),
/// Keep only the X lowest dice (i.e. drop the highest) for adding
/// up to the total.
Drop(u32),
/// Add up all dice in the roll for the total.
None,
}
impl fmt::Display for Dice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}d{}", self.count, self.sides)
match self.keep_drop {
KeepOrDrop::Keep(keep) => write!(f, "{}d{}k{}", self.count, self.sides, keep),
KeepOrDrop::Drop(drop) => write!(f, "{}d{}dh{}", self.count, self.sides, drop),
KeepOrDrop::None => write!(f, "{}d{}", self.count, self.sides),
}
}
}
impl Dice {
pub fn new(count: u32, sides: u32) -> Dice {
Dice { count, sides }
pub fn new(count: u32, sides: u32, keep_drop: KeepOrDrop) -> Dice {
Dice {
count,
sides,
keep_drop,
}
}
}

360
dicebot/src/basic/parser.rs Normal file
View File

@ -0,0 +1,360 @@
/**
* In addition to the terms of the AGPL, this file is governed by the
* terms of the MIT license, from the original axfive-matrix-dicebot
* project.
*/
use nom::bytes::complete::take_while;
use nom::error::ErrorKind as NomErrorKind;
use nom::Err as NomErr;
use nom::{
alt, bytes::complete::tag, character::complete::digit1, complete, many0, named,
sequence::tuple, tag, IResult,
};
use super::dice::*;
//******************************
//Legacy Code
//******************************
fn is_whitespace(input: char) -> bool {
input == ' ' || input == '\n' || input == '\t' || input == '\r'
}
/// Eat whitespace, returning it
pub fn eat_whitespace(input: &str) -> IResult<&str, &str> {
let (input, whitespace) = take_while(is_whitespace)(input)?;
Ok((input, whitespace))
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Sign {
Plus,
Minus,
}
/// Intermediate parsed value for a keep-drop expression to indicate
/// which one it is.
enum ParsedKeepOrDrop<'a> {
Keep(&'a str),
Drop(&'a str),
NotPresent,
}
macro_rules! too_big {
($input: expr) => {
NomErr::Error(($input, NomErrorKind::TooLarge))
};
}
/// Parse a dice expression. Does not eat whitespace
fn parse_dice(input: &str) -> IResult<&str, Dice> {
let (input, (count, _, sides)) = tuple((digit1, tag("d"), digit1))(input)?;
let count: u32 = count.parse().map_err(|_| too_big!(count))?;
let sides = sides.parse().map_err(|_| too_big!(sides))?;
let (input, keep_drop) = parse_keep_or_drop(input, count)?;
Ok((input, Dice::new(count, sides, keep_drop)))
}
/// Extract keep/drop number as a string. Fails if the value is not a
/// string.
fn parse_keep_or_drop_text<'a>(
symbol: &'a str,
input: &'a str,
) -> IResult<&'a str, ParsedKeepOrDrop<'a>> {
let (parsed_kd, input) = match tuple::<&str, _, (_, _), _>((tag(symbol), digit1))(input) {
// if ok, one of the expressions is present
Ok((rest, (_, kd_expr))) => match symbol {
"k" => (ParsedKeepOrDrop::Keep(kd_expr), rest),
"dh" => (ParsedKeepOrDrop::Drop(kd_expr), rest),
_ => panic!("Unrecogized keep-drop symbol: {}", symbol),
},
// otherwise absent (attempt to keep all dice)
Err(_) => (ParsedKeepOrDrop::NotPresent, input),
};
Ok((input, parsed_kd))
}
/// Parse keep/drop expression, which consits of "k" or "dh" following
/// a dice expression. For example, "1d4h3" or "1d4dh2".
fn parse_keep_or_drop<'a>(input: &'a str, count: u32) -> IResult<&'a str, KeepOrDrop> {
let (input, keep) = parse_keep_or_drop_text("k", input)?;
let (input, drop) = parse_keep_or_drop_text("dh", input)?;
use ParsedKeepOrDrop::*;
let keep_drop: KeepOrDrop = match (keep, drop) {
//Potential valid Keep expression.
(Keep(keep), NotPresent) => match keep.parse().map_err(|_| too_big!(input))? {
_i if _i > count || _i == 0 => Ok(KeepOrDrop::None),
i => Ok(KeepOrDrop::Keep(i)),
},
//Potential valid Drop expression.
(NotPresent, Drop(drop)) => match drop.parse().map_err(|_| too_big!(input))? {
_i if _i >= count => Ok(KeepOrDrop::None),
i => Ok(KeepOrDrop::Drop(i)),
},
//No Keep or Drop specified; regular behavior.
(NotPresent, NotPresent) => Ok(KeepOrDrop::None),
//Anything else is an error.
_ => Err(NomErr::Error((input, NomErrorKind::Many1))),
}?;
Ok((input, keep_drop))
}
// Parse a single digit expression. Does not eat whitespace
fn parse_bonus(input: &str) -> IResult<&str, u32> {
let (input, bonus) = digit1(input)?;
Ok((input, bonus.parse().unwrap()))
}
// Parse a sign expression. Eats whitespace.
fn parse_sign(input: &str) -> IResult<&str, Sign> {
let (input, _) = eat_whitespace(input)?;
named!(sign(&str) -> Sign, alt!(
complete!(tag!("+")) => { |_| Sign::Plus } |
complete!(tag!("-")) => { |_| Sign::Minus }
));
let (input, sign) = sign(input)?;
Ok((input, sign))
}
// Parse an element expression. Eats whitespace.
fn parse_element(input: &str) -> IResult<&str, Element> {
let (input, _) = eat_whitespace(input)?;
named!(element(&str) -> Element, alt!(
parse_dice => { |d| Element::Dice(d) } |
parse_bonus => { |b| Element::Bonus(b) }
));
let (input, element) = element(input)?;
Ok((input, element))
}
// Parse a signed element expression. Eats whitespace.
fn parse_signed_element(input: &str) -> IResult<&str, SignedElement> {
let (input, _) = eat_whitespace(input)?;
let (input, sign) = parse_sign(input)?;
let (input, _) = eat_whitespace(input)?;
let (input, element) = parse_element(input)?;
let element = match sign {
Sign::Plus => SignedElement::Positive(element),
Sign::Minus => SignedElement::Negative(element),
};
Ok((input, element))
}
// Parse a full element expression. Eats whitespace.
pub fn parse_element_expression(input: &str) -> IResult<&str, ElementExpression> {
named!(first_element(&str) -> SignedElement, alt!(
parse_signed_element => { |e| e } |
parse_element => { |e| SignedElement::Positive(e) }
));
let (input, first) = first_element(input)?;
let (input, rest) = if input.trim().is_empty() {
(input, vec![first])
} else {
named!(rest_elements(&str) -> Vec<SignedElement>, many0!(parse_signed_element));
let (input, mut rest) = rest_elements(input)?;
rest.insert(0, first);
(input, rest)
};
Ok((input, ElementExpression(rest)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dice_test() {
assert_eq!(
parse_dice("2d4"),
Ok(("", Dice::new(2, 4, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("20d40"),
Ok(("", Dice::new(20, 40, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("8d7"),
Ok(("", Dice::new(8, 7, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("2d20k1"),
Ok(("", Dice::new(2, 20, KeepOrDrop::Keep(1))))
);
assert_eq!(
parse_dice("100d10k90"),
Ok(("", Dice::new(100, 10, KeepOrDrop::Keep(90))))
);
assert_eq!(
parse_dice("11d10k10"),
Ok(("", Dice::new(11, 10, KeepOrDrop::Keep(10))))
);
assert_eq!(
parse_dice("12d10k11"),
Ok(("", Dice::new(12, 10, KeepOrDrop::Keep(11))))
);
assert_eq!(
parse_dice("12d10k13"),
Ok(("", Dice::new(12, 10, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("12d10k0"),
Ok(("", Dice::new(12, 10, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("20d40dh5"),
Ok(("", Dice::new(20, 40, KeepOrDrop::Drop(5))))
);
assert_eq!(
parse_dice("8d7dh9"),
Ok(("", Dice::new(8, 7, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("8d7dh8"),
Ok(("", Dice::new(8, 7, KeepOrDrop::None)))
);
}
#[test]
fn cant_have_both_keep_and_drop_test() {
let res = parse_dice("1d4k3dh2");
assert!(res.is_err());
match res {
Err(NomErr::Error((_, kind))) => {
assert_eq!(kind, NomErrorKind::Many1);
}
_ => panic!("Got success, expected error"),
}
}
#[test]
fn big_number_of_dice_doesnt_crash_test() {
let res = parse_dice("64378631476346123874527551481376547657868536d4");
assert!(res.is_err());
match res {
Err(NomErr::Error((input, kind))) => {
assert_eq!(kind, NomErrorKind::TooLarge);
assert_eq!(input, "64378631476346123874527551481376547657868536");
}
_ => panic!("Got success, expected error"),
}
}
#[test]
fn big_number_of_sides_doesnt_crash_test() {
let res = parse_dice("1d423562312587425472658956278456298376234876");
assert!(res.is_err());
match res {
Err(NomErr::Error((input, kind))) => {
assert_eq!(kind, NomErrorKind::TooLarge);
assert_eq!(input, "423562312587425472658956278456298376234876");
}
_ => panic!("Got success, expected error"),
}
}
#[test]
fn element_test() {
assert_eq!(
parse_element(" \t\n\r\n 8d7 \n"),
Ok((" \n", Element::Dice(Dice::new(8, 7, KeepOrDrop::None))))
);
assert_eq!(
parse_element(" \t\n\r\n 3d20k2 \n"),
Ok((" \n", Element::Dice(Dice::new(3, 20, KeepOrDrop::Keep(2)))))
);
assert_eq!(
parse_element(" \t\n\r\n 8 \n"),
Ok((" \n", Element::Bonus(8)))
);
}
#[test]
fn signed_element_test() {
assert_eq!(
parse_signed_element("+ 7"),
Ok(("", SignedElement::Positive(Element::Bonus(7))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8 \n"),
Ok((" \n", SignedElement::Negative(Element::Bonus(8))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8d4 \n"),
Ok((
" \n",
SignedElement::Negative(Element::Dice(Dice::new(8, 4, KeepOrDrop::None)))
))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8d4k4 \n"),
Ok((
" \n",
SignedElement::Negative(Element::Dice(Dice::new(8, 4, KeepOrDrop::Keep(4))))
))
);
assert_eq!(
parse_signed_element(" \t\n\r\n+ 8d4 \n"),
Ok((
" \n",
SignedElement::Positive(Element::Dice(Dice::new(8, 4, KeepOrDrop::None)))
))
);
}
#[test]
fn element_expression_test() {
assert_eq!(
parse_element_expression("8d4"),
Ok((
"",
ElementExpression(vec![SignedElement::Positive(Element::Dice(Dice::new(
8,
4,
KeepOrDrop::None
)))])
))
);
assert_eq!(
parse_element_expression("\t2d20k1 + 5"),
Ok((
"",
ElementExpression(vec![
SignedElement::Positive(Element::Dice(Dice::new(2, 20, KeepOrDrop::Keep(1)))),
SignedElement::Positive(Element::Bonus(5)),
])
))
);
assert_eq!(
parse_element_expression(" - 8d4 \n "),
Ok((
" \n ",
ElementExpression(vec![SignedElement::Negative(Element::Dice(Dice::new(
8,
4,
KeepOrDrop::None
)))])
))
);
assert_eq!(
parse_element_expression("\t3d4k2 + 7 - 5 - 6d12dh3 + 1d1 + 53 1d5 "),
Ok((
" 1d5 ",
ElementExpression(vec![
SignedElement::Positive(Element::Dice(Dice::new(3, 4, KeepOrDrop::Keep(2)))),
SignedElement::Positive(Element::Bonus(7)),
SignedElement::Negative(Element::Bonus(5)),
SignedElement::Negative(Element::Dice(Dice::new(6, 12, KeepOrDrop::Drop(3)))),
SignedElement::Positive(Element::Dice(Dice::new(1, 1, KeepOrDrop::None))),
SignedElement::Positive(Element::Bonus(53)),
])
))
);
}
}

View File

@ -4,6 +4,7 @@
* project.
*/
use crate::basic::dice;
use crate::basic::dice::KeepOrDrop;
use rand::prelude::*;
use std::fmt;
use std::ops::{Deref, DerefMut};
@ -19,15 +20,27 @@ pub trait Rolled {
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct DiceRoll(pub Vec<u32>);
/// array of rolls in order, how many dice to keep, and how many to drop
/// keep indicates how many of the highest dice to keep
/// drop indicates how many of the highest dice to drop
pub struct DiceRoll (pub Vec<u32>, usize, usize);
impl DiceRoll {
pub fn rolls(&self) -> &[u32] {
&self.0
}
pub fn keep(&self) -> usize {
self.1
}
pub fn drop(&self) -> usize {
self.2
}
// only count kept dice in total
pub fn total(&self) -> u32 {
self.0.iter().sum()
self.0[self.2..self.1].iter().sum()
}
}
@ -41,11 +54,21 @@ impl fmt::Display for DiceRoll {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.rolled_value())?;
let rolls = self.rolls();
let mut iter = rolls.iter();
let keep = self.keep();
let drop = self.drop();
let mut iter = rolls.iter().enumerate();
if let Some(first) = iter.next() {
write!(f, " ({}", first)?;
if drop != 0 {
write!(f, " ([{}]", first.1)?;
} else {
write!(f, " ({}", first.1)?;
}
for roll in iter {
write!(f, " + {}", roll)?;
if roll.0 >= keep || roll.0 < drop {
write!(f, " + [{}]", roll.1)?;
} else {
write!(f, " + {}", roll.1)?;
}
}
write!(f, ")")?;
}
@ -58,11 +81,17 @@ impl Roll for dice::Dice {
fn roll(&self) -> DiceRoll {
let mut rng = rand::thread_rng();
let rolls: Vec<_> = (0..self.count)
let mut rolls: Vec<_> = (0..self.count)
.map(|_| rng.gen_range(1..=self.sides))
.collect();
// sort rolls in descending order
rolls.sort_by(|a, b| b.cmp(a));
DiceRoll(rolls)
match self.keep_drop {
KeepOrDrop::Keep(k) => DiceRoll(rolls,k as usize, 0),
KeepOrDrop::Drop(dh) => DiceRoll(rolls,self.count as usize, dh as usize),
KeepOrDrop::None => DiceRoll(rolls,self.count as usize, 0),
}
}
}
@ -198,18 +227,26 @@ mod tests {
use super::*;
#[test]
fn dice_roll_display_test() {
assert_eq!(DiceRoll(vec![1, 3, 4]).to_string(), "8 (1 + 3 + 4)");
assert_eq!(DiceRoll(vec![]).to_string(), "0");
assert_eq!(DiceRoll(vec![1, 3, 4], 3, 0).to_string(), "8 (1 + 3 + 4)");
assert_eq!(DiceRoll(vec![], 0, 0).to_string(), "0");
assert_eq!(
DiceRoll(vec![4, 7, 2, 10]).to_string(),
DiceRoll(vec![4, 7, 2, 10], 4, 0).to_string(),
"23 (4 + 7 + 2 + 10)"
);
assert_eq!(
DiceRoll(vec![20, 13, 11, 10], 3, 0).to_string(),
"44 (20 + 13 + 11 + [10])"
);
assert_eq!(
DiceRoll(vec![20, 13, 11, 10], 4, 1).to_string(),
"34 ([20] + 13 + 11 + 10)"
);
}
#[test]
fn element_roll_display_test() {
assert_eq!(
ElementRoll::Dice(DiceRoll(vec![1, 3, 4])).to_string(),
ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0)).to_string(),
"8 (1 + 3 + 4)"
);
assert_eq!(ElementRoll::Bonus(7).to_string(), "7");
@ -218,11 +255,11 @@ mod tests {
#[test]
fn signed_element_roll_display_test() {
assert_eq!(
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))).to_string(),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))).to_string(),
"8 (1 + 3 + 4)"
);
assert_eq!(
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))).to_string(),
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))).to_string(),
"-8 (1 + 3 + 4)"
);
assert_eq!(
@ -239,14 +276,14 @@ mod tests {
fn element_expression_roll_display_test() {
assert_eq!(
ElementExpressionRoll(vec![SignedElementRoll::Positive(ElementRoll::Dice(
DiceRoll(vec![1, 3, 4])
DiceRoll(vec![1, 3, 4], 3, 0)
)),])
.to_string(),
"8 (1 + 3 + 4)"
);
assert_eq!(
ElementExpressionRoll(vec![SignedElementRoll::Negative(ElementRoll::Dice(
DiceRoll(vec![1, 3, 4])
DiceRoll(vec![1, 3, 4], 3, 0)
)),])
.to_string(),
"-8 (1 + 3 + 4)"
@ -263,8 +300,8 @@ mod tests {
);
assert_eq!(
ElementExpressionRoll(vec![
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))),
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 2]))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))),
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 2], 2, 0))),
SignedElementRoll::Positive(ElementRoll::Bonus(4)),
SignedElementRoll::Negative(ElementRoll::Bonus(7)),
])
@ -273,13 +310,33 @@ mod tests {
);
assert_eq!(
ElementExpressionRoll(vec![
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 2]))),
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 2], 2, 0))),
SignedElementRoll::Negative(ElementRoll::Bonus(4)),
SignedElementRoll::Positive(ElementRoll::Bonus(7)),
])
.to_string(),
"-2 (-8 (1 + 3 + 4) + 3 (1 + 2) - 4 + 7)"
);
assert_eq!(
ElementExpressionRoll(vec![
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![4, 3, 1], 3, 0))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![12, 2], 1, 0))),
SignedElementRoll::Negative(ElementRoll::Bonus(4)),
SignedElementRoll::Positive(ElementRoll::Bonus(7)),
])
.to_string(),
"7 (-8 (4 + 3 + 1) + 12 (12 + [2]) - 4 + 7)"
);
assert_eq!(
ElementExpressionRoll(vec![
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![4, 3, 1], 3, 1))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![12, 2], 2, 0))),
SignedElementRoll::Negative(ElementRoll::Bonus(4)),
SignedElementRoll::Positive(ElementRoll::Bonus(7)),
])
.to_string(),
"13 (-4 ([4] + 3 + 1) + 14 (12 + 2) - 4 + 7)"
);
}
}

View File

@ -1,9 +1,11 @@
use matrix_sdk::identifiers::room_id;
use matrix_sdk::ruma::room_id;
use matrix_sdk::Client;
use tenebrous_dicebot::commands;
use tenebrous_dicebot::commands::ResponseExtractor;
use tenebrous_dicebot::context::{Context, RoomContext};
use tenebrous_dicebot::db::sqlite::Database;
use tenebrous_dicebot::error::BotError;
use tenebrous_dicebot::models::Account;
use url::Url;
#[tokio::main]
@ -25,12 +27,18 @@ async fn main() -> Result<(), BotError> {
.await?;
let context = Context {
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver)
.expect("Could not create matrix client"),
room: RoomContext {
db,
account: Account::default(),
matrix_client: Client::new(homeserver).await.expect("Could not create matrix client"),
origin_room: RoomContext {
id: &room_id!("!fakeroomid:example.com"),
display_name: "fake room",
display_name: "fake room".to_owned(),
secure: false,
},
active_room: RoomContext {
id: &room_id!("!fakeroomid:example.com"),
display_name: "fake room".to_owned(),
secure: false,
},
username: "@localuser:example.com",
message_body: &input,

View File

@ -1,21 +1,36 @@
//Needed for nested Result handling from tokio. Probably can go away after 1.47.0.
#![type_length_limit = "7605144"]
use futures::try_join;
use log::error;
use matrix_sdk::Client;
use std::env;
use std::sync::{Arc, RwLock};
use tenebrous_dicebot::bot::DiceBot;
use tenebrous_dicebot::config::*;
use tenebrous_dicebot::db::sqlite::Database;
use tenebrous_dicebot::error::BotError;
use tenebrous_dicebot::rpc;
use tenebrous_dicebot::state::DiceBotState;
use tracing_subscriber::filter::EnvFilter;
/// Attempt to create config object and ddatabase connection pool from
/// the given config path. An error is returned if config creation or
/// database pool creation fails for some reason.
async fn init(config_path: &str) -> Result<(Arc<Config>, Database, Client), BotError> {
let cfg = read_config(config_path)?;
let cfg = Arc::new(cfg);
let sqlite_path = format!("{}/dicebot.sqlite", cfg.database_path());
let db = Database::new(&sqlite_path).await?;
let client = tenebrous_dicebot::matrix::create_client(&cfg).await?;
Ok((cfg, db, client))
}
#[tokio::main]
async fn main() {
async fn main() -> Result<(), BotError> {
let filter = if env::var("RUST_LOG").is_ok() {
EnvFilter::from_default_env()
} else {
EnvFilter::new("tenebrous_dicebot=info,dicebot=info,refinery=info")
EnvFilter::new("tonic=info,tenebrous_dicebot=info,dicebot=info,refinery=info")
};
tracing_subscriber::fmt().with_env_filter(filter).init();
@ -23,7 +38,9 @@ async fn main() {
match run().await {
Ok(_) => (),
Err(e) => error!("Error: {}", e),
};
}
Ok(())
}
async fn run() -> Result<(), BotError> {
@ -32,12 +49,22 @@ async fn run() -> Result<(), BotError> {
.next()
.expect("Need a config as an argument");
let cfg = Arc::new(read_config(config_path)?);
let sqlite_path = format!("{}/dicebot.sqlite", cfg.database_path());
let db = Database::new(&sqlite_path).await?;
let (cfg, db, client) = init(&config_path).await?;
let grpc = rpc::serve_grpc(&cfg, &db, &client);
let bot = run_bot(&cfg, &db, &client);
match try_join!(bot, grpc) {
Ok(_) => (),
Err(e) => error!("Error: {:?}", e),
};
Ok(())
}
async fn run_bot(cfg: &Arc<Config>, db: &Database, client: &Client) -> Result<(), BotError> {
let state = Arc::new(RwLock::new(DiceBotState::new(&cfg)));
match DiceBot::new(&cfg, &state, &db) {
match DiceBot::new(cfg, &state, db, client) {
Ok(bot) => bot.run().await?,
Err(e) => println!("Error connecting: {:?}", e),
};

View File

@ -0,0 +1,33 @@
use tenebrous_rpc::protos::dicebot::UserIdRequest;
use tenebrous_rpc::protos::dicebot::{dicebot_client::DicebotClient};
use tonic::{metadata::MetadataValue, transport::Channel, Request};
async fn create_client(
shared_secret: &str,
) -> Result<DicebotClient<Channel>, Box<dyn std::error::Error>> {
let channel = Channel::from_static("http://0.0.0.0:9090")
.connect()
.await?;
let bearer = MetadataValue::from_str(&format!("Bearer {}", shared_secret))?;
let client = DicebotClient::with_interceptor(channel, move |mut req: Request<()>| {
req.metadata_mut().insert("authorization", bearer.clone());
Ok(req)
});
Ok(client)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut client = create_client("example-key").await?;
let request = tonic::Request::new(UserIdRequest {
user_id: "@projectmoon:agnos.is".into(),
});
let response = client.rooms_for_user(request).await?.into_inner();
println!("Rooms: {:?}", response.rooms);
Ok(())
}

View File

@ -0,0 +1,173 @@
use crate::context::{Context, RoomContext};
use crate::db::sqlite::Database;
use crate::error::BotError;
use crate::logic;
use crate::matrix;
use crate::{
commands::{execute_command, ExecutionResult, ResponseExtractor},
models::Account,
};
use futures::stream::{self, StreamExt};
use matrix_sdk::ruma::{OwnedEventId, RoomId};
use matrix_sdk::{self, room::Joined, Client};
use std::clone::Clone;
use std::convert::TryFrom;
/// Handle responding to a single command being executed. Wil print
/// out the full result of that command.
pub(super) async fn handle_single_result(
client: &Client,
cmd_result: &ExecutionResult,
respond_to: &str,
room: &Joined,
event_id: OwnedEventId,
) {
let html = cmd_result.message_html(respond_to);
let plain = cmd_result.message_plain(respond_to);
matrix::send_message(client, room.room_id(), (&html, &plain), Some(event_id)).await;
}
/// Format failure messages nicely in either HTML or plain text. If
/// plain is true, plain-text will be returned. Otherwise, formatted
/// HTML.
fn format_failures(
errors: &[(&str, &BotError)],
commands_executed: usize,
respond_to: &str,
plain: bool,
) -> String {
let respond_to = match plain {
true => respond_to.to_owned(),
false => format!(
"<a href=\"https://matrix.to/#/{}\">{}</a>",
respond_to, respond_to
),
};
let failures: Vec<String> = errors
.iter()
.map(|&(cmd, err)| format!("<strong>{}:</strong> {}", cmd, err))
.collect();
let message = format!(
"{}: Executed {} commands ({} failed)\n\nFailures:\n{}",
respond_to,
commands_executed,
errors.len(),
failures.join("\n")
)
.replace("\n", "<br/>");
match plain {
true => html2text::from_read(message.as_bytes(), message.len()),
false => message,
}
}
/// Handle responding to multiple commands being executed. Will print
/// out how many commands succeeded and failed (if any failed).
pub(super) async fn handle_multiple_results(
client: &Client,
results: &[(String, ExecutionResult)],
respond_to: &str,
room: &Joined,
) {
let user_pill = format!(
"<a href=\"https://matrix.to/#/{}\">{}</a>",
respond_to, respond_to
);
let errors: Vec<(&str, &BotError)> = results
.into_iter()
.filter_map(|(cmd, result)| match result {
Err(e) => Some((cmd.as_ref(), e)),
_ => None,
})
.collect();
let (message, plain) = if errors.len() == 0 {
(
format!("{}: Executed {} commands", user_pill, results.len()),
format!("{}: Executed {} commands", respond_to, results.len()),
)
} else {
(
format_failures(&errors, results.len(), respond_to, false),
format_failures(&errors, results.len(), respond_to, true),
)
};
matrix::send_message(client, room.room_id(), (&message, &plain), None).await;
}
/// Map an account's active room value to an actual matrix room, if
/// the account has an active room. This only retrieves the
/// user-specified active room, and doesn't perform any further
/// filtering.
fn get_account_active_room(client: &Client, account: &Account) -> Result<Option<Joined>, BotError> {
let active_room = account
.registered_user()
.and_then(|u| u.active_room.as_deref())
.map(|room_id| <&RoomId>::try_from(room_id))
.transpose()?
.and_then(|active_room_id| client.get_joined_room(active_room_id));
Ok(active_room)
}
/// Execute a single command in the list of commands. Can fail if the
/// Account value cannot be created/fetched from the database, or if
/// room display names cannot be calculated. Otherwise, the success or
/// error of command execution itself is returned.
async fn execute_single_command(
command: &str,
db: &Database,
client: &Client,
origin_room: &Joined,
sender: &str,
) -> ExecutionResult {
let origin_ctx = RoomContext::new(origin_room, sender).await?;
let account = logic::get_account(db, sender).await?;
let active_room = get_account_active_room(client, &account)?;
// Active room is used in secure command-issuing rooms. In
// "public" rooms, where other users are, treat origin as the
// active room.
let active_room = active_room
.as_ref()
.filter(|_| origin_ctx.secure)
.unwrap_or(origin_room);
let active_ctx = RoomContext::new(active_room, sender).await?;
let ctx = Context {
account,
db: db.clone(),
matrix_client: client.clone(),
origin_room: origin_ctx,
username: &sender,
active_room: active_ctx,
message_body: &command,
};
execute_command(&ctx).await
}
/// Attempt to execute all commands sent to the bot in a message. This
/// asynchronously executes all commands given to it. A Vec of all
/// commands and their execution results are returned.
pub(super) async fn execute(
commands: Vec<&str>,
db: &Database,
client: &Client,
room: &Joined,
sender: &str,
) -> Vec<(String, ExecutionResult)> {
stream::iter(commands)
.then(|command| async move {
let result = execute_single_command(command, db, client, room, sender).await;
(command.to_owned(), result)
})
.collect()
.await
}

View File

@ -0,0 +1,163 @@
use super::DiceBot;
use crate::db::sqlite::Database;
use crate::db::Rooms;
use crate::error::BotError;
use log::{debug, error, info, warn};
use matrix_sdk::ruma::events::room::member::RoomMemberEventContent;
use matrix_sdk::ruma::events::{StrippedStateEvent, SyncMessageLikeEvent};
use matrix_sdk::{self, room::Room, ruma::events::room::message::RoomMessageEventContent};
use matrix_sdk::{Client, DisplayName};
use std::ops::Sub;
use std::time::UNIX_EPOCH;
use std::time::{Duration, SystemTime};
/// Check if a message is recent enough to actually process. If the
/// message is within "oldest_message_age" seconds, this function
/// returns true. If it's older than that, it returns false and logs a
/// debug message.
fn check_message_age(
event: &SyncMessageLikeEvent<RoomMessageEventContent>,
oldest_message_age: u64,
) -> bool {
let sending_time = event
.origin_server_ts()
.to_system_time()
.unwrap_or(UNIX_EPOCH);
let oldest_timestamp = SystemTime::now().sub(Duration::from_secs(oldest_message_age));
if sending_time > oldest_timestamp {
true
} else {
let age = match oldest_timestamp.duration_since(sending_time) {
Ok(n) => format!("{} seconds too old", n.as_secs()),
Err(_) => "before the UNIX epoch".to_owned(),
};
debug!("Ignoring message because it is {}: {:?}", age, event);
false
}
}
/// Determine whether or not to process a received message. This check
/// is necessary in addition to the event processing check because we
/// may receive message events when entering a room for the first
/// time, and we don't want to respond to things before the bot was in
/// the channel, but we do want to respond to things that were sent if
/// the bot left and rejoined quickly.
async fn should_process_message<'a>(
bot: &DiceBot,
event: &SyncMessageLikeEvent<RoomMessageEventContent>,
) -> Result<(String, String), BotError> {
//Ignore messages that are older than configured duration.
if !check_message_age(event, bot.config.oldest_message_age()) {
let state_check = bot.state.read().unwrap();
if !((*state_check).logged_skipped_old_messages()) {
drop(state_check);
let mut state = bot.state.write().unwrap();
(*state).skipped_old_messages();
}
return Err(BotError::ShouldNotProcessError);
}
let msg_body: String = event
.as_original()
.map(|e| e.content.body())
.map(str::to_string)
.unwrap_or_else(|| String::new());
let sender_username: String = format!(
"@{}:{}",
event.sender().localpart(),
event.sender().server_name()
);
// Do not process messages from the bot itself. Otherwise it might
// try to execute its own commands.
let bot_username = bot
.client
.user_id()
.map(|u| format!("@{}:{}", u.localpart(), u.server_name()))
.unwrap_or_default();
if sender_username == bot_username {
return Err(BotError::ShouldNotProcessError);
}
Ok((msg_body, sender_username))
}
async fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool {
db.should_process(room_id, event_id)
.await
.unwrap_or_else(|e| {
error!(
"Database error when checking if we should process an event: {}",
e.to_string()
);
false
})
}
pub(super) async fn on_stripped_state_member(
event: StrippedStateEvent<RoomMemberEventContent>,
client: Client,
room: Room,
) {
let room = match room {
Room::Invited(invited_room) => invited_room,
_ => return,
};
if room.own_user_id().as_str() != event.state_key {
return;
}
info!(
"Autojoining room {}",
room.display_name()
.await
.ok()
.unwrap_or_else(|| DisplayName::Named("[error]".to_string()))
);
if let Err(e) = client.join_room_by_id(&room.room_id()).await {
warn!("Could not join room: {}", e.to_string())
}
}
pub(super) async fn on_room_message(
event: SyncMessageLikeEvent<RoomMessageEventContent>,
room: Room,
bot: DiceBot,
) {
let room = match room {
Room::Joined(joined_room) => joined_room,
_ => return,
};
let room_id = room.room_id().as_str();
if !should_process_event(&bot.db, room_id, event.event_id().as_str()).await {
return;
}
let (msg_body, sender_username) =
if let Ok((msg_body, sender_username)) = should_process_message(&bot, &event).await {
(msg_body, sender_username)
} else {
return;
};
let results = bot
.execute_commands(&room, &sender_username, &msg_body)
.await;
bot.handle_results(
&room,
&sender_username,
event.event_id().to_owned(),
results,
)
.await;
}

172
dicebot/src/bot/mod.rs Normal file
View File

@ -0,0 +1,172 @@
use crate::commands::ExecutionResult;
use crate::config::*;
use crate::db::sqlite::Database;
use crate::db::DbState;
use crate::error::BotError;
use crate::state::DiceBotState;
use log::info;
use matrix_sdk::room::Room;
use matrix_sdk::ruma::events::room::message::RoomMessageEventContent;
use matrix_sdk::ruma::events::SyncMessageLikeEvent;
use matrix_sdk::ruma::OwnedEventId;
use matrix_sdk::{self, room::Joined, Client};
use matrix_sdk::config::SyncSettings;
use std::clone::Clone;
use std::sync::{Arc, RwLock};
mod command_execution;
pub mod event_handlers;
/// How many commands can be in one message. If the amount is higher
/// than this, we reject execution.
const MAX_COMMANDS_PER_MESSAGE: usize = 50;
/// The DiceBot struct represents an active dice bot. The bot is not
/// connected to Matrix until its run() function is called.
#[derive(Clone)]
pub struct DiceBot {
/// A reference to the configuration read in on application start.
config: Arc<Config>,
/// The matrix client.
client: Client,
/// State of the dicebot
state: Arc<RwLock<DiceBotState>>,
/// Active database layer
db: Database,
}
impl DiceBot {
/// Create a new dicebot with the given configuration and state
/// actor. This function returns a Result because it is possible
/// for client creation to fail for some reason (e.g. invalid
/// homeserver URL).
pub fn new(
config: &Arc<Config>,
state: &Arc<RwLock<DiceBotState>>,
db: &Database,
client: &Client,
) -> Result<Self, BotError> {
Ok(DiceBot {
client: client.clone(),
config: config.clone(),
state: state.clone(),
db: db.clone(),
})
}
/// Logs in to matrix and potentially records a new device ID. If
/// no device ID is found in the database, a new one will be
/// generated by the matrix SDK, and we will store it.
async fn login(&self, client: &Client) -> Result<(), BotError> {
let username = self.config.matrix_username();
let password = self.config.matrix_password();
// Pull device ID from database, if it exists. Then write it
// to DB if the library generated one for us.
let device_id: Option<String> = self.db.get_device_id().await?;
let device_id: Option<&str> = device_id.as_deref();
let no_device_ld_login = || client.login_username(username, password);
let device_id_login = |id| client.login_username(username, password).device_id(id);
let login = device_id.map_or_else(no_device_ld_login, device_id_login);
login.send().await?;
if device_id.is_none() {
let device_id = client.device_id().ok_or(BotError::NoDeviceIdFound)?;
self.db.set_device_id(device_id.as_str()).await?;
info!("Recorded new device ID: {}", device_id.as_str());
} else {
info!("Using existing device ID: {}", device_id.unwrap());
}
info!("Logged in as {}", username);
Ok(())
}
async fn bind_events(&self) {
//on room message: need closure to pass bot ref in.
self.client
.add_event_handler({
let bot: DiceBot = self.clone();
move |event: SyncMessageLikeEvent<RoomMessageEventContent>, room: Room| {
let bot = bot.clone();
async move { event_handlers::on_room_message(event, room, bot).await }
}
});
//auto-join handler
self.client
.add_event_handler(event_handlers::on_stripped_state_member);
}
/// Logs the bot in to Matrix and listens for events until program
/// terminated, or a panic occurs. Originally adapted from the
/// matrix-rust-sdk command bot example.
pub async fn run(self) -> Result<(), BotError> {
let client = self.client.clone();
self.login(&client).await?;
self.bind_events().await;
info!("Listening for commands");
// TODO replace with sync_with_callback for cleaner shutdown
// process.
client.sync(SyncSettings::default()).await?;
Ok(())
}
async fn execute_commands(
&self,
room: &Joined,
sender: &str,
msg_body: &str,
) -> Vec<(String, ExecutionResult)> {
let commands: Vec<&str> = msg_body
.lines()
.filter(|line| line.starts_with("!"))
.take(MAX_COMMANDS_PER_MESSAGE + 1)
.collect();
//Up to 50 commands allowed, otherwise we send back an error.
let results: Vec<(String, ExecutionResult)> = if commands.len() < MAX_COMMANDS_PER_MESSAGE {
command_execution::execute(commands, &self.db, &self.client, room, sender).await
} else {
vec![("".to_owned(), Err(BotError::MessageTooLarge))]
};
results
}
pub async fn handle_results(
&self,
room: &Joined,
sender_username: &str,
event_id: OwnedEventId,
results: Vec<(String, ExecutionResult)>,
) {
if results.len() >= 1 {
if results.len() == 1 {
command_execution::handle_single_result(
&self.client,
&results[0].1,
sender_username,
&room,
event_id,
)
.await;
} else if results.len() > 1 {
command_execution::handle_multiple_results(
&self.client,
&results,
sender_username,
&room,
)
.await;
}
}
}
}

View File

@ -1,6 +1,6 @@
use crate::context::Context;
use crate::error::{BotError, DiceRollingError};
use crate::parser::{Amount, Element, Operator};
use crate::parser::dice::{Amount, Element, Operator};
use itertools::Itertools;
use std::convert::TryFrom;
use std::fmt;
@ -308,7 +308,7 @@ pub async fn roll_pool(pool: &DicePoolWithContext<'_>) -> Result<RolledDicePool,
return Err(DiceRollingError::ExpressionTooLarge.into());
}
let num_dice = crate::dice::calculate_dice_amount(&pool.0.amounts, &pool.1).await?;
let num_dice = crate::logic::calculate_dice_amount(&pool.0.amounts, &pool.1).await?;
let mut roller = RngDieRoller(rand::thread_rng());
if num_dice > 0 {
@ -326,14 +326,15 @@ pub async fn roll_pool(pool: &DicePoolWithContext<'_>) -> Result<RolledDicePool,
mod tests {
use super::*;
use crate::db::sqlite::Database;
use crate::db::sqlite::Variables;
use crate::db::Variables;
use url::Url;
macro_rules! dummy_room {
() => {
crate::context::RoomContext {
id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"),
display_name: "displayname",
id: &matrix_sdk::ruma::room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(),
secure: false,
}
};
}
@ -482,9 +483,11 @@ mod tests {
.unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username",
message_body: "message",
};
@ -522,9 +525,11 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username",
message_body: "message",
};
@ -559,16 +564,23 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db.clone(),
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username",
message_body: "message",
};
db.set_user_variable(&ctx.username, &ctx.room.id.as_str(), "myvariable", 10)
.await
.expect("could not set myvariable to 10");
db.set_user_variable(
&ctx.username,
&ctx.origin_room.id.as_str(),
"myvariable",
10,
)
.await
.expect("could not set myvariable to 10");
let amounts = vec![Amount {
operator: Operator::Plus,
@ -578,7 +590,7 @@ mod tests {
let pool = DicePool::new(amounts, DicePoolModifiers::default());
assert_eq!(
crate::dice::calculate_dice_amount(&pool.amounts, &ctx)
crate::logic::calculate_dice_amount(&pool.amounts, &ctx)
.await
.unwrap(),
10

View File

@ -1,6 +1,6 @@
use crate::cofd::dice::{DicePool, DicePoolModifiers, DicePoolQuality};
use crate::error::BotError;
use crate::parser::{parse_amounts, DiceParsingError};
use crate::parser::dice::{parse_amounts, DiceParsingError};
use combine::parser::char::{digit, spaces, string};
use combine::{choice, count, many1, one_of, Parser};
@ -45,13 +45,13 @@ pub fn parse_modifiers(input: &str) -> Result<DicePoolModifiers, DiceParsingErro
let (result, rest) = parser.parse(input)?;
if rest.len() == 0 {
convert_to_info(&result)
convert_to_modifiers(&result)
} else {
Err(DiceParsingError::UnconsumedInput)
}
}
fn convert_to_info(parsed: &Vec<ParsedInfo>) -> Result<DicePoolModifiers, DiceParsingError> {
fn convert_to_modifiers(parsed: &Vec<ParsedInfo>) -> Result<DicePoolModifiers, DiceParsingError> {
use ParsedInfo::*;
if parsed.len() == 0 {
Ok(DicePoolModifiers::default())
@ -79,19 +79,8 @@ fn convert_to_info(parsed: &Vec<ParsedInfo>) -> Result<DicePoolModifiers, DicePa
}
pub fn parse_dice_pool(input: &str) -> Result<DicePool, BotError> {
//The "modifiers:" part is optional. Assume amounts if no modifier
//section found.
let split = input.split(":").collect::<Vec<_>>();
let (modifiers_str, amounts_str) = (match split[..] {
[amounts] => Ok(("", amounts)),
[modifiers, amounts] => Ok((modifiers, amounts)),
_ => Err(BotError::DiceParsingError(
DiceParsingError::UnconsumedInput,
)),
})?;
let (amounts, modifiers_str) = parse_amounts(input)?;
let modifiers = parse_modifiers(modifiers_str)?;
let amounts = parse_amounts(&amounts_str)?;
Ok(DicePool::new(amounts, modifiers))
}
@ -175,7 +164,7 @@ mod tests {
#[test]
fn dice_pool_number_with_quality() {
let result = parse_dice_pool("n:8");
let result = parse_dice_pool("8 n");
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
@ -186,7 +175,7 @@ mod tests {
#[test]
fn dice_pool_number_with_success_change() {
let modifiers = DicePoolModifiers::custom_exceptional_on(3);
let result = parse_dice_pool("s3:8");
let result = parse_dice_pool("8 s3");
assert!(result.is_ok());
assert_eq!(result.unwrap(), DicePool::easy_with_modifiers(8, modifiers));
}
@ -194,14 +183,14 @@ mod tests {
#[test]
fn dice_pool_with_quality_and_success_change() {
let modifiers = DicePoolModifiers::custom(DicePoolQuality::Rote, 3);
let result = parse_dice_pool("rs3:8");
let result = parse_dice_pool("8 rs3");
assert!(result.is_ok());
assert_eq!(result.unwrap(), DicePool::easy_with_modifiers(8, modifiers));
}
#[test]
fn dice_pool_complex_expression_test() {
use crate::parser::*;
use crate::parser::dice::*;
let modifiers = DicePoolModifiers::custom(DicePoolQuality::Rote, 3);
let amounts = vec![
Amount {
@ -224,20 +213,20 @@ mod tests {
let expected = DicePool::new(amounts, modifiers);
let result = parse_dice_pool("rs3:8+10-2+varname");
let result = parse_dice_pool("8+10-2+varname rs3");
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected);
let result = parse_dice_pool("rs3:8+10- 2 + varname");
let result = parse_dice_pool("8+10- 2 + varname rs3");
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected);
let result = parse_dice_pool("rs3 : 8+ 10 -2 + varname");
let result = parse_dice_pool("8+ 10 -2 + varname rs3");
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected);
//This one has tabs in it.
let result = parse_dice_pool(" r s3 : 8 + 10 -2 + varname");
let result = parse_dice_pool(" 8 + 10 -2 + varname r s3");
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected);
}

View File

@ -0,0 +1,48 @@
use super::{Command, Execution, ExecutionResult};
use crate::basic::dice::ElementExpression;
use crate::basic::parser::parse_element_expression;
use crate::basic::roll::Roll;
use crate::context::Context;
use crate::error::BotError;
use async_trait::async_trait;
use nom::Err as NomErr;
use std::convert::TryFrom;
pub struct RollCommand(pub ElementExpression);
impl TryFrom<String> for RollCommand {
type Error = BotError;
fn try_from(input: String) -> Result<Self, Self::Error> {
let result = parse_element_expression(&input);
match result {
Ok((rest, expression)) if rest.len() == 0 => Ok(RollCommand(expression)),
//"Legacy code boundary": translates Nom errors into BotErrors.
Ok(_) => Err(BotError::NomParserIncomplete),
Err(NomErr::Error(e)) => Err(BotError::NomParserError(e.1)),
Err(NomErr::Failure(e)) => Err(BotError::NomParserError(e.1)),
Err(NomErr::Incomplete(_)) => Err(BotError::NomParserIncomplete),
}
}
}
#[async_trait]
impl Command for RollCommand {
fn name(&self) -> &'static str {
"roll regular dice"
}
fn is_secure(&self) -> bool {
false
}
async fn execute(&self, _ctx: &Context<'_>) -> ExecutionResult {
let roll = self.0.roll();
let html = format!(
"<strong>Dice:</strong> {}</p><p><strong>Result</strong>: {}",
self.0, roll
);
Execution::success(html)
}
}

View File

@ -1,16 +1,39 @@
use super::{Command, Execution, ExecutionResult};
use crate::cofd::dice::{roll_pool, DicePool, DicePoolWithContext};
use crate::cofd::parser::{create_chance_die, parse_dice_pool};
use crate::context::Context;
use crate::error::BotError;
use async_trait::async_trait;
use std::convert::TryFrom;
pub struct PoolRollCommand(pub DicePool);
impl PoolRollCommand {
pub fn chance_die() -> Result<PoolRollCommand, BotError> {
let pool = create_chance_die()?;
Ok(PoolRollCommand(pool))
}
}
impl TryFrom<String> for PoolRollCommand {
type Error = BotError;
fn try_from(input: String) -> Result<Self, Self::Error> {
let pool = parse_dice_pool(&input)?;
Ok(PoolRollCommand(pool))
}
}
#[async_trait]
impl Command for PoolRollCommand {
fn name(&self) -> &'static str {
"roll dice pool"
}
fn is_secure(&self) -> bool {
false
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let pool_with_ctx = DicePoolWithContext(&self.0, ctx);
let rolled_pool = roll_pool(&pool_with_ctx).await?;

View File

@ -4,16 +4,32 @@ use crate::cthulhu::dice::{
advancement_roll, regular_roll, AdvancementRoll, AdvancementRollWithContext, DiceRoll,
DiceRollWithContext,
};
use crate::cthulhu::parser::{parse_advancement_roll, parse_regular_roll};
use crate::error::BotError;
use async_trait::async_trait;
use std::convert::TryFrom;
pub struct CthRoll(pub DiceRoll);
impl TryFrom<String> for CthRoll {
type Error = BotError;
fn try_from(input: String) -> Result<Self, Self::Error> {
let roll = parse_regular_roll(&input)?;
Ok(CthRoll(roll))
}
}
#[async_trait]
impl Command for CthRoll {
fn name(&self) -> &'static str {
"roll percentile dice"
}
fn is_secure(&self) -> bool {
false
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let roll_with_ctx = DiceRollWithContext(&self.0, ctx);
let executed_roll = regular_roll(&roll_with_ctx).await?;
@ -29,12 +45,25 @@ impl Command for CthRoll {
pub struct CthAdvanceRoll(pub AdvancementRoll);
impl TryFrom<String> for CthAdvanceRoll {
type Error = BotError;
fn try_from(input: String) -> Result<Self, Self::Error> {
let roll = parse_advancement_roll(&input)?;
Ok(CthAdvanceRoll(roll))
}
}
#[async_trait]
impl Command for CthAdvanceRoll {
fn name(&self) -> &'static str {
"roll skill advancement dice"
}
fn is_secure(&self) -> bool {
false
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let roll_with_ctx = AdvancementRollWithContext(&self.0, ctx);
let executed_roll = advancement_roll(&roll_with_ctx).await?;

View File

@ -0,0 +1,200 @@
use super::{Command, Execution, ExecutionResult};
use crate::db::Users;
use crate::error::BotError::{AccountDoesNotExist, PasswordCreationError};
use crate::logic::hash_password;
use crate::models::{AccountStatus, User};
use crate::{context::Context, error::BotError};
use async_trait::async_trait;
use std::convert::{Into, TryFrom};
pub struct RegisterCommand;
impl TryFrom<String> for RegisterCommand {
type Error = BotError;
fn try_from(_: String) -> Result<Self, Self::Error> {
Ok(RegisterCommand)
}
}
#[async_trait]
impl Command for RegisterCommand {
fn name(&self) -> &'static str {
"register user account"
}
fn is_secure(&self) -> bool {
true
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
if ctx.account.is_registered() {
return Err(BotError::AccountAlreadyExists);
}
let user = User {
username: ctx.username.to_owned(),
password: None,
account_status: AccountStatus::Registered,
..Default::default()
};
ctx.db.upsert_user(&user).await?;
Execution::success(format!(
"User account {} registered for bot commands.",
ctx.username
))
}
}
pub struct UnlinkCommand(pub String);
impl TryFrom<String> for UnlinkCommand {
type Error = BotError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Ok(UnlinkCommand(value))
}
}
#[async_trait]
impl Command for UnlinkCommand {
fn name(&self) -> &'static str {
"unlink user accountx from external applications"
}
fn is_secure(&self) -> bool {
true
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let mut user = ctx
.db
.get_user(&ctx.username)
.await?
.ok_or(BotError::AccountDoesNotExist)?;
user.password = None;
ctx.db.upsert_user(&user).await?;
Execution::success(format!(
"Accounted {} is now inaccessible to external applications.",
ctx.username
))
}
}
pub struct LinkCommand(pub String);
impl TryFrom<String> for LinkCommand {
type Error = BotError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Ok(LinkCommand(value))
}
}
#[async_trait]
impl Command for LinkCommand {
fn name(&self) -> &'static str {
"link user account to external applications"
}
fn is_secure(&self) -> bool {
true
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let mut user = ctx
.db
.get_user(&ctx.username)
.await?
.ok_or(BotError::AccountDoesNotExist)?;
let pw_hash = hash_password(&self.0).map_err(|e| PasswordCreationError(e))?;
user.password = Some(pw_hash);
ctx.db.upsert_user(&user).await?;
Execution::success(format!(
"Accounted now available for external use. Please log in to \
external applications with username {} and the password you set.",
ctx.username
))
}
}
pub struct CheckCommand;
impl TryFrom<String> for CheckCommand {
type Error = BotError;
fn try_from(_: String) -> Result<Self, Self::Error> {
Ok(CheckCommand)
}
}
#[async_trait]
impl Command for CheckCommand {
fn name(&self) -> &'static str {
"check user account status"
}
fn is_secure(&self) -> bool {
true
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let user = ctx.db.get_user(&ctx.username).await?;
match user {
Some(user) => match user.password {
Some(_) => Execution::success(
"Account exists, and is available to external applications with a password. \
If you forgot your password, change it with !link."
.to_string(),
),
None => Execution::success(
"Account exists, but is not available to external applications.".to_string(),
),
},
None => Execution::success(
"No account registered. Only simple commands in public rooms are available."
.to_string(),
),
}
}
}
pub struct UnregisterCommand;
impl TryFrom<String> for UnregisterCommand {
type Error = BotError;
fn try_from(_: String) -> Result<Self, Self::Error> {
Ok(UnregisterCommand)
}
}
#[async_trait]
impl Command for UnregisterCommand {
fn name(&self) -> &'static str {
"unregister user account"
}
fn is_secure(&self) -> bool {
true
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let user = ctx.db.get_user(&ctx.username).await?;
match user {
Some(_) => {
ctx.db.delete_user(&ctx.username).await?;
Execution::success("Your user account has been removed.".to_string())
}
None => Err(AccountDoesNotExist.into()),
}
}
}

View File

@ -1,16 +1,31 @@
use super::{Command, Execution, ExecutionResult};
use crate::context::Context;
use crate::help::HelpTopic;
use crate::error::BotError;
use crate::help::{parse_help_topic, HelpTopic};
use async_trait::async_trait;
use std::convert::TryFrom;
pub struct HelpCommand(pub Option<HelpTopic>);
impl TryFrom<String> for HelpCommand {
type Error = BotError;
fn try_from(input: String) -> Result<Self, Self::Error> {
let topic = parse_help_topic(&input);
Ok(HelpCommand(topic))
}
}
#[async_trait]
impl Command for HelpCommand {
fn name(&self) -> &'static str {
"help information"
}
fn is_secure(&self) -> bool {
false
}
async fn execute(&self, _ctx: &Context<'_>) -> ExecutionResult {
let help = match &self.0 {
Some(topic) => topic.message(),

301
dicebot/src/commands/mod.rs Normal file
View File

@ -0,0 +1,301 @@
use crate::context::Context;
use crate::error::BotError;
use async_trait::async_trait;
use log::{error, info};
use thiserror::Error;
pub mod basic_rolling;
pub mod cofd;
pub mod cthulhu;
pub mod management;
pub mod misc;
pub mod parser;
pub mod rooms;
pub mod variables;
/// A custom error type specifically related to parsing command text.
/// Does not wrap an execution failure.
#[derive(Error, Debug)]
pub enum CommandError {
#[error("invalid command: {0}")]
InvalidCommand(String),
#[error("command can only be executed from encrypted direct message")]
InsecureExecution,
#[error("ignored command")]
IgnoredCommand,
}
/// A successfully executed command returns a message to be sent back
/// to the user in HTML (plain text used as a fallback by message
/// formatter).
#[derive(Debug)]
pub struct Execution {
html: String,
}
impl Execution {
pub fn success(html: String) -> ExecutionResult {
Ok(Execution { html })
}
/// Response message in HTML.
pub fn html(&self) -> String {
self.html.clone()
}
}
/// Wraps either a successful command execution response, or an error
/// that occurred.
pub type ExecutionResult = Result<Execution, BotError>;
/// Extract response messages out of a type, whether it is success or
/// failure.
pub trait ResponseExtractor {
/// HTML representation of the message, directly mentioning the
/// username.
fn message_html(&self, username: &str) -> String;
fn message_plain(&self, username: &str) -> String;
}
impl ResponseExtractor for ExecutionResult {
/// Error message in bolded HTML.
fn message_html(&self, username: &str) -> String {
// TODO use user display name too (element seems to render this
// without display name)
let username = format!(
"<a href=\"https://matrix.to/#/{}\">{}</a>",
username, username
);
match self {
Ok(resp) => format!("<p>{}</p>", resp.html).replace("\n", "<br/>"),
Err(e) => format!("<p>{}: <strong>{}</strong></p>", username, e).replace("\n", "<br/>"),
}
}
fn message_plain(&self, username: &str) -> String {
let message = match self {
Ok(resp) => format!("{}", resp.html),
Err(e) => format!("{}", e),
};
format!(
"{}:\n{}",
username,
html2text::from_read(message.as_bytes(), message.len())
)
}
}
/// The trait that any command that can be executed must implement.
#[async_trait]
pub trait Command: Send + Sync {
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult;
fn name(&self) -> &'static str;
fn is_secure(&self) -> bool;
}
/// Determine if we are allowed to execute this command. Currently the
/// rules are that secure commands must be executed in secure rooms
/// (encrypted + direct), and anything else can be executed where
/// ever. Later, we can add stuff like admin/regular user power
/// separation, etc.
fn execution_allowed(cmd: &(impl Command + ?Sized), ctx: &Context<'_>) -> Result<(), CommandError> {
match cmd {
cmd if cmd.is_secure() && ctx.is_secure() => Ok(()),
cmd if cmd.is_secure() && !ctx.is_secure() => Err(CommandError::InsecureExecution),
_ => Ok(()),
}
}
/// Attempt to execute a command, and return the content that should
/// go back to Matrix, if the command was executed, whether or not the
/// command was successful.
pub async fn execute_command(ctx: &Context<'_>) -> ExecutionResult {
let cmd = parser::parse_command(&ctx.message_body)?;
let result = match execution_allowed(cmd.as_ref(), ctx) {
Ok(_) => cmd.execute(ctx).await,
Err(e) => Err(e.into()),
};
log_command(cmd.as_ref(), ctx, &result);
result
}
/// Log result of an executed command.
fn log_command(cmd: &(impl Command + ?Sized), ctx: &Context, result: &ExecutionResult) {
use substring::Substring;
let command = match cmd.is_secure() {
true => cmd.name(),
false => ctx.message_body,
};
let dots = match command.len() {
_len if _len > 30 => "[...]",
_ => "",
};
let command = command.substring(0, 30);
match result {
Ok(_) => {
info!(
"[{}] {} <{}{}> - success",
ctx.origin_room.display_name, ctx.username, command, dots
);
}
Err(e) => {
error!(
"[{}] {} <{}{}> - {}",
ctx.origin_room.display_name, ctx.username, command, dots, e
);
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use management::RegisterCommand;
use url::Url;
use matrix_sdk::ruma::room_id;
macro_rules! dummy_room {
() => {
crate::context::RoomContext {
id: &room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(),
secure: false,
}
};
}
macro_rules! secure_room {
() => {
crate::context::RoomContext {
id: &room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(),
secure: true,
}
};
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn secure_context_secure_command_allows_execution() {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
let db = crate::db::sqlite::Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: secure_room!(),
active_room: secure_room!(),
username: "myusername",
message_body: "!notacommand",
};
let cmd = RegisterCommand;
assert_eq!(execution_allowed(&cmd, &ctx).is_ok(), true);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn secure_context_insecure_command_allows_execution() {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
let db = crate::db::sqlite::Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: secure_room!(),
active_room: secure_room!(),
username: "myusername",
message_body: "!notacommand",
};
let cmd = variables::GetVariableCommand("".to_owned());
assert_eq!(execution_allowed(&cmd, &ctx).is_ok(), true);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn insecure_context_insecure_command_allows_execution() {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
let db = crate::db::sqlite::Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "myusername",
message_body: "!notacommand",
};
let cmd = variables::GetVariableCommand("".to_owned());
assert_eq!(execution_allowed(&cmd, &ctx).is_ok(), true);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn insecure_context_secure_command_denies_execution() {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
let db = crate::db::sqlite::Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "myusername",
message_body: "!notacommand",
};
let cmd = RegisterCommand;
assert_eq!(execution_allowed(&cmd, &ctx).is_err(), true);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn unrecognized_command() {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
let db = crate::db::sqlite::Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "myusername",
message_body: "!notacommand",
};
let result = execute_command(&ctx).await;
assert!(result.is_err());
}
}

View File

@ -3,26 +3,22 @@
* governed by the terms of the MIT license, from the original
* axfive-matrix-dicebot project.
*/
use crate::basic::parser::parse_element_expression;
use crate::cofd::parser::{create_chance_die, parse_dice_pool};
use crate::commands::{
basic_rolling::RollCommand,
cofd::PoolRollCommand,
cthulhu::{CthAdvanceRoll, CthRoll},
management::ResyncCommand,
management::{CheckCommand, LinkCommand, RegisterCommand, UnlinkCommand, UnregisterCommand},
misc::HelpCommand,
rooms::{ListRoomsCommand, SetRoomCommand},
variables::{
DeleteVariableCommand, GetAllVariablesCommand, GetVariableCommand, SetVariableCommand,
},
Command,
};
use crate::cthulhu::parser::{parse_advancement_roll, parse_regular_roll};
use crate::error::BotError;
use crate::help::parse_help_topic;
use crate::variables::parse_set_variable;
use combine::parser::char::{char, letter, space};
use combine::{any, many1, optional, Parser};
use nom::Err as NomErr;
use std::convert::TryFrom;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Error)]
@ -34,65 +30,6 @@ pub enum CommandParsingError {
InternalParseError(#[from] combine::error::StringStreamError),
}
// Parse a roll expression.
fn parse_roll(input: &str) -> Result<Box<dyn Command>, BotError> {
let result = parse_element_expression(input);
match result {
Ok((rest, expression)) if rest.len() == 0 => Ok(Box::new(RollCommand(expression))),
//Legacy code boundary translates nom errors into BotErrors.
Ok(_) => Err(BotError::NomParserIncomplete),
Err(NomErr::Error(e)) => Err(BotError::NomParserError(e.1)),
Err(NomErr::Failure(e)) => Err(BotError::NomParserError(e.1)),
Err(NomErr::Incomplete(_)) => Err(BotError::NomParserIncomplete),
}
}
fn parse_get_variable_command(input: &str) -> Result<Box<dyn Command>, BotError> {
Ok(Box::new(GetVariableCommand(input.to_owned())))
}
fn parse_set_variable_command(input: &str) -> Result<Box<dyn Command>, BotError> {
let (variable_name, value) = parse_set_variable(input)?;
Ok(Box::new(SetVariableCommand(variable_name, value)))
}
fn parse_delete_variable_command(input: &str) -> Result<Box<dyn Command>, BotError> {
Ok(Box::new(DeleteVariableCommand(input.to_owned())))
}
fn parse_pool_roll(input: &str) -> Result<Box<dyn Command>, BotError> {
let pool = parse_dice_pool(input)?;
Ok(Box::new(PoolRollCommand(pool)))
}
fn parse_cth_roll(input: &str) -> Result<Box<dyn Command>, BotError> {
let roll = parse_regular_roll(input)?;
Ok(Box::new(CthRoll(roll)))
}
fn parse_cth_advancement_roll(input: &str) -> Result<Box<dyn Command>, BotError> {
let roll = parse_advancement_roll(input)?;
Ok(Box::new(CthAdvanceRoll(roll)))
}
fn chance_die() -> Result<Box<dyn Command>, BotError> {
let pool = create_chance_die()?;
Ok(Box::new(PoolRollCommand(pool)))
}
fn get_all_variables() -> Result<Box<dyn Command>, BotError> {
Ok(Box::new(GetAllVariablesCommand))
}
fn parse_resync() -> Result<Box<dyn Command>, BotError> {
Ok(Box::new(ResyncCommand))
}
fn help(topic: &str) -> Result<Box<dyn Command>, BotError> {
let topic = parse_help_topic(topic);
Ok(Box::new(HelpCommand(topic)))
}
/// Split an input string into its constituent command and "everything
/// else" parts. Extracts the command separately from its input (i.e.
/// rest of the line) and returns a tuple of (command_input, command).
@ -124,25 +61,37 @@ fn split_command(input: &str) -> Result<(String, String), CommandParsingError> {
Ok((command, command_input))
}
/// Atempt to convert text input to a Boxed command type. Shortens
/// boilerplate.
macro_rules! convert_to {
($type:ident, $input: expr) => {
$type::try_from($input).map(|cmd| Box::new(cmd) as Box<dyn Command>)
};
}
/// Potentially parse a command expression. If we recognize the
/// command, an error should be raised if the command is misparsed. If
/// we don't recognize the command, return an error.
pub fn parse_command(input: &str) -> Result<Box<dyn Command>, BotError> {
match split_command(input) {
Ok((cmd, cmd_input)) => match cmd.as_ref() {
"variables" => get_all_variables(),
"get" => parse_get_variable_command(&cmd_input),
"set" => parse_set_variable_command(&cmd_input),
"del" => parse_delete_variable_command(&cmd_input),
"resync" => parse_resync(),
"r" | "roll" => parse_roll(&cmd_input),
"rp" | "pool" => parse_pool_roll(&cmd_input),
"cthroll" | "cthRoll" => parse_cth_roll(&cmd_input),
"cthadv" | "ctharoll" | "cthAroll" | "cthARoll" => {
parse_cth_advancement_roll(&cmd_input)
}
"chance" => chance_die(),
"help" => help(&cmd_input),
Ok((cmd, cmd_input)) => match cmd.to_lowercase().as_ref() {
"variables" => convert_to!(GetAllVariablesCommand, cmd_input),
"get" => convert_to!(GetVariableCommand, cmd_input),
"set" => convert_to!(SetVariableCommand, cmd_input),
"del" => convert_to!(DeleteVariableCommand, cmd_input),
"r" | "roll" => convert_to!(RollCommand, cmd_input),
"rp" | "pool" => convert_to!(PoolRollCommand, cmd_input),
"chance" => PoolRollCommand::chance_die().map(|cmd| Box::new(cmd) as Box<dyn Command>),
"cthroll" => convert_to!(CthRoll, cmd_input),
"cthadv" | "ctharoll" => convert_to!(CthAdvanceRoll, cmd_input),
"help" => convert_to!(HelpCommand, cmd_input),
"register" => convert_to!(RegisterCommand, cmd_input),
"link" => convert_to!(LinkCommand, cmd_input),
"unlink" => convert_to!(UnlinkCommand, cmd_input),
"check" => convert_to!(CheckCommand, cmd_input),
"unregister" => convert_to!(UnregisterCommand, cmd_input),
"rooms" => convert_to!(ListRoomsCommand, cmd_input),
"room" => convert_to!(SetRoomCommand, cmd_input),
_ => Err(CommandParsingError::UnrecognizedCommand(cmd).into()),
},
//All other errors passed up.
@ -272,9 +221,9 @@ mod tests {
#[test]
fn pool_whitespace_test() {
parse_command("!pool ns3:8 ").expect("was error");
parse_command(" !pool ns3:8").expect("was error");
parse_command(" !pool ns3:8 ").expect("was error");
parse_command("!pool 8 ns3 ").expect("was error");
parse_command(" !pool 8 ns3").expect("was error");
parse_command(" !pool 8 ns3 ").expect("was error");
}
#[test]
@ -290,4 +239,9 @@ mod tests {
parse_command("!roll 1d4 + 5d6 -3 ").expect("was error");
parse_command(" !roll 1d4 + 5d6 -3 ").expect("was error");
}
#[test]
fn case_insensitive_test() {
parse_command("!CTHROLL 40").expect("command parsing is not case sensitive.");
}
}

View File

@ -0,0 +1,187 @@
use super::{Command, Execution, ExecutionResult};
use crate::context::Context;
use crate::db::Users;
use crate::error::BotError;
use crate::matrix;
use async_trait::async_trait;
use fuse_rust::{Fuse, FuseProperty, Fuseable};
use futures::stream::{self, StreamExt, TryStreamExt};
use matrix_sdk::{ruma::OwnedUserId, Client};
use std::convert::TryFrom;
/// Holds matrix room ID and display name as strings, for use with
/// searching. See search_for_room.
#[derive(Clone, Debug, Eq, PartialEq)]
struct RoomNameAndId {
id: String,
name: String,
}
/// Allows searching for a room name and ID struct, instead of just
/// searching room display names directly.
impl Fuseable for RoomNameAndId {
fn properties(&self) -> Vec<FuseProperty> {
vec![FuseProperty {
value: String::from("name"),
weight: 1.0,
}]
}
fn lookup(&self, key: &str) -> Option<&str> {
match key {
"name" => Some(&self.name),
_ => None,
}
}
}
/// Attempt to find a room by either name or Matrix Room ID query
/// string. It prefers the exact room ID first, and then falls back to
/// fuzzy searching based on room display name. The best match is
/// returned, or None if no matches were found.
fn search_for_room<'a>(
rooms_for_user: &'a [RoomNameAndId],
search_for: &str,
) -> Option<&'a RoomNameAndId> {
//Lowest score is the best match.
let best_fuzzy_match = || -> Option<&RoomNameAndId> {
Fuse::default()
.search_text_in_fuse_list(search_for, &rooms_for_user)
.into_iter()
.min_by(|r1, r2| r1.score.partial_cmp(&r2.score).unwrap())
.and_then(|result| rooms_for_user.get(result.index))
};
rooms_for_user
.iter()
.find(|room| room.id == search_for)
.or_else(best_fuzzy_match)
}
async fn get_rooms_for_user(
client: &Client,
user_id: &str,
) -> Result<Vec<RoomNameAndId>, BotError> {
let user_id = OwnedUserId::try_from(user_id)?;
let rooms_for_user = matrix::get_rooms_for_user(client, &user_id).await?;
let mut rooms_for_user: Vec<RoomNameAndId> = stream::iter(rooms_for_user)
.filter_map(|room| async move {
Some(room.display_name().await.map(|room_name| RoomNameAndId {
id: room.room_id().to_string(),
name: room_name.to_string(),
}))
})
.try_collect()
.await?;
//Alphabetically descending, symbols first, ignore case.
let sort = |r1: &RoomNameAndId, r2: &RoomNameAndId| {
r1.name.to_lowercase().cmp(&r2.name.to_lowercase())
};
rooms_for_user.sort_by(sort);
Ok(rooms_for_user)
}
pub struct ListRoomsCommand;
impl TryFrom<String> for ListRoomsCommand {
type Error = BotError;
fn try_from(_: String) -> Result<Self, Self::Error> {
Ok(ListRoomsCommand)
}
}
#[async_trait]
impl Command for ListRoomsCommand {
fn name(&self) -> &'static str {
"list rooms"
}
fn is_secure(&self) -> bool {
true
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let rooms_for_user: Vec<String> = get_rooms_for_user(&ctx.matrix_client, ctx.username)
.await
.map(|rooms| {
rooms
.into_iter()
.map(|room| format!(" {} | {}", room.id, room.name))
.collect()
})?;
let html = format!("<pre>{}</pre>", rooms_for_user.join("\n"));
Execution::success(html)
}
}
pub struct SetRoomCommand(String);
impl TryFrom<String> for SetRoomCommand {
type Error = BotError;
fn try_from(input: String) -> Result<Self, Self::Error> {
Ok(SetRoomCommand(input))
}
}
#[async_trait]
impl Command for SetRoomCommand {
fn name(&self) -> &'static str {
"set active room"
}
fn is_secure(&self) -> bool {
true
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
if !ctx.account.is_registered() {
return Err(BotError::AccountDoesNotExist);
}
let rooms_for_user = get_rooms_for_user(&ctx.matrix_client, ctx.username).await?;
let room = search_for_room(&rooms_for_user, &self.0);
if let Some(room) = room {
let mut new_user = ctx
.account
.registered_user()
.cloned()
.ok_or(BotError::AccountDoesNotExist)?;
new_user.active_room = Some(room.id.clone());
ctx.db.upsert_user(&new_user).await?;
Execution::success(format!(r#"Active room set to "{}""#, room.name))
} else {
Err(BotError::RoomDoesNotExist)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn set_room_prefers_room_id_over_name() {
let rooms = vec![
RoomNameAndId {
id: "roomid".to_string(),
name: "room_name".to_string(),
},
RoomNameAndId {
id: "anotherone".to_string(),
name: "roomid".to_string(),
},
];
let found_room = search_for_room(&rooms, "roomid");
assert!(found_room.is_some());
assert_eq!(found_room.unwrap(), &rooms[0]);
}
}

View File

@ -1,22 +1,35 @@
use super::{Command, Execution, ExecutionResult};
use crate::context::Context;
use crate::db::sqlite::errors::DataError;
use crate::db::sqlite::Variables;
use crate::db::variables::UserAndRoom;
use crate::db::errors::DataError;
use crate::db::Variables;
use crate::error::BotError;
use async_trait::async_trait;
use std::convert::TryFrom;
pub struct GetAllVariablesCommand;
impl TryFrom<String> for GetAllVariablesCommand {
type Error = BotError;
fn try_from(_: String) -> Result<Self, Self::Error> {
Ok(GetAllVariablesCommand)
}
}
#[async_trait]
impl Command for GetAllVariablesCommand {
fn name(&self) -> &'static str {
"get all variables"
}
fn is_secure(&self) -> bool {
false
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let variables = ctx
.db
.get_user_variables(&ctx.username, ctx.room_id().as_str())
.get_user_variables(&ctx.username, ctx.active_room_id().as_str())
.await?;
let mut variable_list: Vec<String> = variables
@ -38,17 +51,29 @@ impl Command for GetAllVariablesCommand {
pub struct GetVariableCommand(pub String);
impl TryFrom<String> for GetVariableCommand {
type Error = BotError;
fn try_from(input: String) -> Result<Self, Self::Error> {
Ok(GetVariableCommand(input))
}
}
#[async_trait]
impl Command for GetVariableCommand {
fn name(&self) -> &'static str {
"retrieve variable value"
}
fn is_secure(&self) -> bool {
false
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let name = &self.0;
let result = ctx
.db
.get_user_variable(&ctx.username, ctx.room_id().as_str(), name)
.get_user_variable(&ctx.username, ctx.active_room_id().as_str(), name)
.await;
let value = match result {
@ -64,18 +89,31 @@ impl Command for GetVariableCommand {
pub struct SetVariableCommand(pub String, pub i32);
impl TryFrom<String> for SetVariableCommand {
type Error = BotError;
fn try_from(input: String) -> Result<Self, Self::Error> {
let (variable_name, value) = crate::parser::variables::parse_set_variable(&input)?;
Ok(SetVariableCommand(variable_name, value))
}
}
#[async_trait]
impl Command for SetVariableCommand {
fn name(&self) -> &'static str {
"set variable value"
}
fn is_secure(&self) -> bool {
false
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let name = &self.0;
let value = self.1;
ctx.db
.set_user_variable(&ctx.username, ctx.room_id().as_str(), name, value)
.set_user_variable(&ctx.username, ctx.active_room_id().as_str(), name, value)
.await?;
let content = format!("{} = {}", name, value);
@ -86,17 +124,29 @@ impl Command for SetVariableCommand {
pub struct DeleteVariableCommand(pub String);
impl TryFrom<String> for DeleteVariableCommand {
type Error = BotError;
fn try_from(input: String) -> Result<Self, Self::Error> {
Ok(DeleteVariableCommand(input))
}
}
#[async_trait]
impl Command for DeleteVariableCommand {
fn name(&self) -> &'static str {
"delete variable"
}
fn is_secure(&self) -> bool {
false
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let name = &self.0;
let result = ctx
.db
.delete_user_variable(&ctx.username, ctx.room_id().as_str(), name)
.delete_user_variable(&ctx.username, ctx.active_room_id().as_str(), name)
.await;
let value = match result {

View File

@ -4,10 +4,6 @@ use std::fs;
use std::path::PathBuf;
use thiserror::Error;
/// Shortcut to defining db migration versions. Will probably
/// eventually be moved to a config file.
const MIGRATION_VERSION: u32 = 5;
#[derive(Error, Debug)]
pub enum ConfigError {
#[error("i/o error: {0}")]
@ -53,10 +49,19 @@ fn db_path_from_env() -> String {
}
/// The "bot" section of the config file, for bot settings.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, Default)]
struct BotConfig {
/// How far back from current time should we process a message?
oldest_message_age: Option<u64>,
/// What address and port to run the RPC service on. If not
/// specified, RPC will not be enabled.
rpc_addr: Option<String>,
/// The shared secret key between the bot and any RPC clients that
/// want to connect to it. The RPC server will reject any clients
/// that don't present the shared key.
rpc_key: Option<String>,
}
/// The "database" section of the config file.
@ -84,6 +89,18 @@ impl BotConfig {
self.oldest_message_age
.unwrap_or(DEFAULT_OLDEST_MESSAGE_AGE)
}
#[inline]
#[must_use]
fn rpc_addr(&self) -> Option<String> {
self.rpc_addr.clone()
}
#[inline]
#[must_use]
fn rpc_key(&self) -> Option<String> {
self.rpc_key.clone()
}
}
/// Represents the toml config file for the dicebot. The sections of
@ -128,15 +145,6 @@ impl Config {
.unwrap_or_else(|| db_path_from_env())
}
/// The current migration version we expect of the database. If
/// this number is higher than the one in the database, we will
/// execute migrations to update the data.
#[inline]
#[must_use]
pub fn migration_version(&self) -> u32 {
MIGRATION_VERSION
}
/// Figure out the allowed oldest message age, in seconds. This will
/// be the defined oldest message age in the bot config, if the bot
/// configuration and associated "oldest_message_age" setting are
@ -150,6 +158,18 @@ impl Config {
.map(|bc| bc.oldest_message_age())
.unwrap_or(DEFAULT_OLDEST_MESSAGE_AGE)
}
#[inline]
#[must_use]
pub fn rpc_addr(&self) -> Option<String> {
self.bot.as_ref().and_then(|bc| bc.rpc_addr())
}
#[inline]
#[must_use]
pub fn rpc_key(&self) -> Option<String> {
self.bot.as_ref().and_then(|bc| bc.rpc_key())
}
}
#[cfg(test)]
@ -169,6 +189,7 @@ mod tests {
}),
bot: Some(BotConfig {
oldest_message_age: None,
..Default::default()
}),
};

76
dicebot/src/context.rs Normal file
View File

@ -0,0 +1,76 @@
use crate::db::sqlite::Database;
use crate::error::BotError;
use crate::models::Account;
use matrix_sdk::room::Joined;
use matrix_sdk::ruma::{RoomId, UserId};
use matrix_sdk::Client;
use std::convert::TryFrom;
/// A context carried through the system providing access to things
/// like the database.
#[derive(Clone)]
pub struct Context<'a> {
pub db: Database,
pub matrix_client: Client,
pub origin_room: RoomContext<'a>,
pub active_room: RoomContext<'a>,
pub username: &'a str,
pub message_body: &'a str,
pub account: Account,
}
impl Context<'_> {
pub fn active_room_id(&self) -> &RoomId {
self.active_room.id
}
pub fn room_id(&self) -> &RoomId {
self.origin_room.id
}
pub fn is_secure(&self) -> bool {
self.origin_room.secure
}
}
#[derive(Clone)]
pub struct RoomContext<'a> {
pub id: &'a RoomId,
pub display_name: String,
pub secure: bool,
}
impl RoomContext<'_> {
pub async fn new_with_name<'a>(
room: &'a Joined,
sending_user: &str,
) -> Result<RoomContext<'a>, BotError> {
// TODO is_direct is a hack; the bot should set eligible rooms
// to Direct Message upon joining, if other contact has
// requested it. Waiting on SDK support.
let display_name =
room
.display_name()
.await
.ok()
.map(|d| d.to_string())
.unwrap_or_default();
let sending_user = <&UserId>::try_from(sending_user)?;
let user_in_room = room.get_member(sending_user).await.ok().is_some();
let is_direct = room.active_members().await?.len() == 2;
Ok(RoomContext {
id: room.room_id(),
display_name,
secure: room.is_encrypted() && is_direct && user_in_room,
})
}
pub async fn new<'a>(
room: &'a Joined,
sending_user: &'a str,
) -> Result<RoomContext<'a>, BotError> {
Self::new_with_name(room, sending_user).await
}
}

View File

@ -1,8 +1,8 @@
use crate::db::sqlite::Variables;
use crate::context::Context;
use crate::db::Variables;
use crate::error::{BotError, DiceRollingError};
use crate::parser::{Amount, Element};
use crate::{context::Context, db::variables::UserAndRoom};
use crate::{dice::calculate_single_die_amount, parser::DiceParsingError};
use crate::logic::calculate_single_die_amount;
use crate::parser::dice::{Amount, DiceParsingError, Element};
use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;
@ -270,7 +270,7 @@ macro_rules! is_variable {
element: Element::Variable(_),
..
}
);
)
};
}
@ -380,7 +380,12 @@ async fn update_skill(ctx: &Context<'_>, variable: &str, value: u32) -> Result<(
use std::convert::TryInto;
let value: i32 = value.try_into()?;
ctx.db
.set_user_variable(&ctx.username, &ctx.room_id().as_str(), variable, value)
.set_user_variable(
&ctx.username,
&ctx.active_room_id().as_str(),
variable,
value,
)
.await?;
Ok(())
}
@ -420,14 +425,16 @@ pub async fn advancement_roll(
mod tests {
use super::*;
use crate::db::sqlite::Database;
use crate::parser::{Amount, Element, Operator};
use crate::parser::dice::{Amount, Element, Operator};
use url::Url;
use matrix_sdk::ruma::room_id;
macro_rules! dummy_room {
() => {
crate::context::RoomContext {
id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"),
display_name: "displayname",
id: &room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(),
secure: false,
}
};
}
@ -503,9 +510,11 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username",
message_body: "message",
};
@ -539,9 +548,11 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username",
message_body: "message",
};
@ -575,9 +586,11 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username",
message_body: "message",
};

View File

@ -1,19 +1,16 @@
use super::dice::{AdvancementRoll, DiceRoll, DiceRollModifier};
use crate::parser::DiceParsingError;
use crate::parser::dice::DiceParsingError;
//TOOD convert these to use parse_amounts from the common dice code.
fn parse_modifier(input: &str) -> Result<DiceRollModifier, DiceParsingError> {
if input.ends_with("bb") {
Ok(DiceRollModifier::TwoBonus)
} else if input.ends_with("b") {
Ok(DiceRollModifier::OneBonus)
} else if input.ends_with("pp") {
Ok(DiceRollModifier::TwoPenalty)
} else if input.ends_with("p") {
Ok(DiceRollModifier::OnePenalty)
} else {
Ok(DiceRollModifier::Normal)
match input.trim() {
"bb" => Ok(DiceRollModifier::TwoBonus),
"b" => Ok(DiceRollModifier::OneBonus),
"pp" => Ok(DiceRollModifier::TwoPenalty),
"p" => Ok(DiceRollModifier::OnePenalty),
"" => Ok(DiceRollModifier::Normal),
_ => Err(DiceParsingError::InvalidModifiers),
}
}
@ -21,33 +18,70 @@ fn parse_modifier(input: &str) -> Result<DiceRollModifier, DiceParsingError> {
//Split based on :, send first part to parse_modifier.
//Send second part to parse_amounts
pub fn parse_regular_roll(input: &str) -> Result<DiceRoll, DiceParsingError> {
let input: Vec<&str> = input.trim().split(":").collect();
let (modifiers_str, amounts_str) = match input[..] {
[amounts] => Ok(("", amounts)),
[modifiers, amounts] => Ok((modifiers, amounts)),
_ => Err(DiceParsingError::UnconsumedInput),
}?;
let (amount, modifiers_str) = crate::parser::dice::parse_single_amount(input)?;
let modifier = parse_modifier(modifiers_str)?;
let amount = crate::parser::parse_single_amount(amounts_str)?;
Ok(DiceRoll { modifier, amount })
}
pub fn parse_advancement_roll(input: &str) -> Result<AdvancementRoll, DiceParsingError> {
let input = input.trim();
let amounts = crate::parser::parse_single_amount(input)?;
let (amounts, unconsumed_input) = crate::parser::dice::parse_single_amount(input)?;
Ok(AdvancementRoll {
existing_skill: amounts,
})
if unconsumed_input.len() == 0 {
Ok(AdvancementRoll {
existing_skill: amounts,
})
} else {
Err(DiceParsingError::InvalidAmount)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::{Amount, Element, Operator};
use crate::parser::dice::{Amount, DiceParsingError, Element, Operator};
#[test]
fn parse_modifier_rejects_bad_value() {
let modifier = parse_modifier("qqq");
assert!(matches!(modifier, Err(DiceParsingError::InvalidModifiers)))
}
#[test]
fn parse_modifier_accepts_one_bonus() {
let modifier = parse_modifier("b");
assert!(matches!(modifier, Ok(DiceRollModifier::OneBonus)))
}
#[test]
fn parse_modifier_accepts_two_bonus() {
let modifier = parse_modifier("bb");
assert!(matches!(modifier, Ok(DiceRollModifier::TwoBonus)))
}
#[test]
fn parse_modifier_accepts_two_penalty() {
let modifier = parse_modifier("pp");
assert!(matches!(modifier, Ok(DiceRollModifier::TwoPenalty)))
}
#[test]
fn parse_modifier_accepts_one_penalty() {
let modifier = parse_modifier("p");
assert!(matches!(modifier, Ok(DiceRollModifier::OnePenalty)))
}
#[test]
fn parse_modifier_accepts_normal() {
let modifier = parse_modifier("");
assert!(matches!(modifier, Ok(DiceRollModifier::Normal)))
}
#[test]
fn parse_modifier_accepts_normal_unaffected_by_whitespace() {
let modifier = parse_modifier(" ");
assert!(matches!(modifier, Ok(DiceRollModifier::Normal)))
}
#[test]
fn regular_roll_accepts_single_number() {
@ -73,7 +107,7 @@ mod tests {
#[test]
fn regular_roll_accepts_two_bonus() {
let result = parse_regular_roll("bb:60");
let result = parse_regular_roll("60 bb");
assert!(result.is_ok());
assert_eq!(
DiceRoll {
@ -89,7 +123,7 @@ mod tests {
#[test]
fn regular_roll_accepts_one_bonus() {
let result = parse_regular_roll("b:60");
let result = parse_regular_roll("60 b");
assert!(result.is_ok());
assert_eq!(
DiceRoll {
@ -105,7 +139,7 @@ mod tests {
#[test]
fn regular_roll_accepts_two_penalty() {
let result = parse_regular_roll("pp:60");
let result = parse_regular_roll("60 pp");
assert!(result.is_ok());
assert_eq!(
DiceRoll {
@ -121,7 +155,7 @@ mod tests {
#[test]
fn regular_roll_accepts_one_penalty() {
let result = parse_regular_roll("p:60");
let result = parse_regular_roll("60 p");
assert!(result.is_ok());
assert_eq!(
DiceRoll {
@ -141,21 +175,21 @@ mod tests {
assert!(parse_regular_roll(" 60").is_ok());
assert!(parse_regular_roll(" 60 ").is_ok());
assert!(parse_regular_roll("bb:60 ").is_ok());
assert!(parse_regular_roll(" bb:60").is_ok());
assert!(parse_regular_roll(" bb:60 ").is_ok());
assert!(parse_regular_roll("60bb ").is_ok());
assert!(parse_regular_roll(" 60 bb").is_ok());
assert!(parse_regular_roll(" 60 bb ").is_ok());
assert!(parse_regular_roll("b:60 ").is_ok());
assert!(parse_regular_roll(" b:60").is_ok());
assert!(parse_regular_roll(" b:60 ").is_ok());
assert!(parse_regular_roll("60b ").is_ok());
assert!(parse_regular_roll(" 60 b").is_ok());
assert!(parse_regular_roll(" 60 b ").is_ok());
assert!(parse_regular_roll("pp:60 ").is_ok());
assert!(parse_regular_roll(" pp:60").is_ok());
assert!(parse_regular_roll(" pp:60 ").is_ok());
assert!(parse_regular_roll("60pp ").is_ok());
assert!(parse_regular_roll(" 60 pp").is_ok());
assert!(parse_regular_roll(" 60 pp ").is_ok());
assert!(parse_regular_roll("p:60 ").is_ok());
assert!(parse_regular_roll(" p:60").is_ok());
assert!(parse_regular_roll(" p:60 ").is_ok());
assert!(parse_regular_roll("60p ").is_ok());
assert!(parse_regular_roll(" 60p ").is_ok());
assert!(parse_regular_roll(" 60 p ").is_ok());
}
#[test]

32
dicebot/src/db/errors.rs Normal file
View File

@ -0,0 +1,32 @@
use std::num::TryFromIntError;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum DataError {
#[error("value does not exist for key: {0}")]
KeyDoesNotExist(String),
#[error("too many entries")]
TooManyEntries,
#[error("expected i32, but i32 schema was violated")]
I32SchemaViolation,
#[error("unexpected or corruptd data bytes")]
InvalidValue,
#[error("expected string ref, but utf8 schema was violated: {0}")]
Utf8RefSchemaViolation(#[from] std::str::Utf8Error),
#[error("expected string, but utf8 schema was violated: {0}")]
Utf8SchemaViolation(#[from] std::string::FromUtf8Error),
#[error("data migration error: {0}")]
MigrationError(#[from] crate::db::sqlite::migrator::MigrationError),
#[error("internal database error: {0}")]
SqlxError(#[from] sqlx::Error),
#[error("numeric conversion error")]
NumericConversionError(#[from] TryFromIntError),
}

70
dicebot/src/db/mod.rs Normal file
View File

@ -0,0 +1,70 @@
use crate::error::BotError;
use crate::models::User;
use async_trait::async_trait;
use errors::DataError;
use std::collections::HashMap;
pub mod errors;
pub mod sqlite;
#[async_trait]
pub(crate) trait DbState {
async fn get_device_id(&self) -> Result<Option<String>, DataError>;
async fn set_device_id(&self, device_id: &str) -> Result<(), DataError>;
}
#[async_trait]
pub(crate) trait Users {
async fn upsert_user(&self, user: &User) -> Result<(), DataError>;
async fn get_user(&self, username: &str) -> Result<Option<User>, DataError>;
async fn delete_user(&self, username: &str) -> Result<(), DataError>;
async fn authenticate_user(
&self,
username: &str,
raw_password: &str,
) -> Result<Option<User>, BotError>;
}
#[async_trait]
pub(crate) trait Rooms {
async fn should_process(&self, room_id: &str, event_id: &str) -> Result<bool, DataError>;
}
// TODO move this up to the top once we delete sled. Traits will be the
// main API, then we can have different impls for different DBs.
#[async_trait]
pub trait Variables {
async fn get_user_variables(
&self,
user: &str,
room_id: &str,
) -> Result<HashMap<String, i32>, DataError>;
async fn get_variable_count(&self, user: &str, room_id: &str) -> Result<i32, DataError>;
async fn get_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
) -> Result<i32, DataError>;
async fn set_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
value: i32,
) -> Result<(), DataError>;
async fn delete_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
) -> Result<(), DataError>;
}

View File

@ -0,0 +1,22 @@
use crate::systems::GameSystem;
use barrel::backend::Sqlite;
use barrel::{types, types::Type, Migration};
use itertools::Itertools;
use strum::IntoEnumIterator;
fn primary_id() -> Type {
types::text().unique(true).primary(true).nullable(false)
}
pub fn migration() -> String {
let mut m = Migration::new();
//Normally we would add a CHECK clause here, but types::custom requires a 'static string.
//Which means we can't automagically generate one from the enum.
m.create_table("room_info", move |t| {
t.add_column("room_id", primary_id());
t.add_column("game_system", types::text().nullable(false));
});
m.make::<Sqlite>()
}

View File

@ -1,5 +1,6 @@
use barrel::backend::Sqlite;
use barrel::{types, types::Type, Migration};
use barrel::{types, Migration};
pub fn migration() -> String {
let mut m = Migration::new();

View File

@ -1,5 +1,5 @@
use barrel::backend::Sqlite;
use barrel::{types, types::Type, Migration};
use barrel::{types, Migration};
pub fn migration() -> String {
let mut m = Migration::new();

View File

@ -0,0 +1,18 @@
use barrel::backend::Sqlite;
use barrel::{types, types::Type, Migration};
fn primary_uuid() -> Type {
types::text().unique(true).primary(true).nullable(false)
}
pub fn migration() -> String {
let mut m = Migration::new();
//Table of room ID, event ID, event timestamp
m.create_table("accounts", move |t| {
t.add_column("user_id", primary_uuid());
t.add_column("password", types::text().nullable(false));
});
m.make::<Sqlite>()
}

View File

@ -0,0 +1,10 @@
use barrel::backend::Sqlite;
use barrel::Migration;
pub fn migration() -> String {
let mut m = Migration::new();
m.drop_table_if_exists("room_info");
m.drop_table_if_exists("room_users");
m.make::<Sqlite>()
}

View File

@ -0,0 +1,18 @@
use barrel::backend::Sqlite;
use barrel::{types, types::Type, Migration};
fn primary_uuid() -> Type {
types::text().unique(true).primary(true).nullable(false)
}
pub fn migration() -> String {
let mut m = Migration::new();
// Keep track of contextual user state.
m.create_table("user_state", move |t| {
t.add_column("user_id", primary_uuid());
t.add_column("active_room", types::text().nullable(true));
});
m.make::<Sqlite>()
}

View File

@ -0,0 +1,17 @@
pub fn migration() -> String {
// sqlite does really support alter column, and barrel does not
// implement the required workaround, so we do it ourselves!
r#"
CREATE TABLE IF NOT EXISTS "accounts2" (
"user_id" TEXT PRIMARY KEY NOT NULL UNIQUE,
"password" TEXT NULL,
"account_status" TEXT NOT NULL CHECK(
account_status IN ('not_registered', 'registered', 'awaiting_activation'
))
);
INSERT INTO accounts2 select *, 'registered' FROM accounts;
DROP TABLE accounts;
ALTER TABLE accounts2 RENAME TO accounts;
"#
.to_string()
}

View File

@ -0,0 +1 @@

View File

@ -5,7 +5,7 @@ use sqlx::ConnectOptions;
use std::str::FromStr;
use thiserror::Error;
pub mod migrations;
//pub mod migrations;
#[derive(Error, Debug)]
pub enum MigrationError {
@ -16,6 +16,11 @@ pub enum MigrationError {
RefineryError(#[from] refinery::Error),
}
mod embedded {
use refinery::embed_migrations;
embed_migrations!("src/db/sqlite/migrator/migrations");
}
/// Run database migrations against the sqlite database.
pub async fn migrate(db_path: &str) -> Result<(), MigrationError> {
//Create database if missing.
@ -28,6 +33,6 @@ pub async fn migrate(db_path: &str) -> Result<(), MigrationError> {
let mut conn = Config::new(ConfigDbType::Sqlite).set_db_path(db_path);
info!("Running migrations");
migrations::runner().run(&mut conn)?;
embedded::migrations::runner().run(&mut conn)?;
Ok(())
}

View File

@ -0,0 +1,51 @@
use crate::db::errors::DataError;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use sqlx::ConnectOptions;
use std::clone::Clone;
use std::str::FromStr;
pub mod migrator;
pub mod rooms;
pub mod state;
pub mod users;
pub mod variables;
pub struct Database {
conn: SqlitePool,
}
impl Database {
fn new_db(conn: SqlitePool) -> Result<Database, DataError> {
let database = Database { conn: conn.clone() };
Ok(database)
}
pub async fn new(path: &str) -> Result<Database, DataError> {
//Create database if missing.
let conn = SqliteConnectOptions::from_str(path)?
.create_if_missing(true)
.connect()
.await?;
drop(conn);
//Migrate database.
migrator::migrate(&path).await?;
//Return actual conncetion pool.
let conn = SqlitePoolOptions::new()
.max_connections(5)
.connect(path)
.await?;
Self::new_db(conn)
}
}
impl Clone for Database {
fn clone(&self) -> Self {
Database {
conn: self.conn.clone(),
}
}
}

View File

@ -0,0 +1,93 @@
use super::Database;
use crate::db::{errors::DataError, Rooms};
use async_trait::async_trait;
use sqlx::SqlitePool;
use std::time::{SystemTime, UNIX_EPOCH};
async fn record_event(conn: &SqlitePool, room_id: &str, event_id: &str) -> Result<(), DataError> {
use std::convert::TryFrom;
let now: i64 = i64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Clock has gone backwards")
.as_secs(),
)?;
sqlx::query(
r#"INSERT INTO room_events
(room_id, event_id, event_timestamp)
VALUES (?, ?, ?)"#,
)
.bind(room_id)
.bind(event_id)
.bind(now)
.execute(conn)
.await?;
Ok(())
}
#[async_trait]
impl Rooms for Database {
async fn should_process(&self, room_id: &str, event_id: &str) -> Result<bool, DataError> {
let row = sqlx::query!(
r#"SELECT event_id FROM room_events
WHERE room_id = ? AND event_id = ?"#,
room_id,
event_id
)
.fetch_optional(&self.conn)
.await?;
match row {
Some(_) => Ok(false),
None => {
record_event(&self.conn, room_id, event_id).await?;
Ok(true)
}
}
}
}
#[cfg(test)]
mod tests {
use crate::db::sqlite::Database;
use crate::db::Rooms;
use std::future::Future;
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
let db = Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
f(db).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn should_process_test() {
with_db(|db| async move {
let first_check = db
.should_process("myroom", "myeventid")
.await
.expect("should_process failed in first insert");
assert_eq!(first_check, true);
let second_check = db
.should_process("myroom", "myeventid")
.await
.expect("should_process failed in first insert");
assert_eq!(second_check, false);
})
.await;
}
}

View File

@ -1,5 +1,5 @@
use super::errors::DataError;
use super::{Database, DbState};
use super::Database;
use crate::db::{errors::DataError, DbState};
use async_trait::async_trait;
#[async_trait]
@ -35,56 +35,66 @@ impl DbState for Database {
#[cfg(test)]
mod tests {
use super::super::DbState;
use super::*;
use crate::db::sqlite::Database;
use crate::db::DbState;
use std::future::Future;
async fn create_db() -> Database {
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
Database::new(db_path.path().to_str().unwrap())
let db = Database::new(db_path.path().to_str().unwrap())
.await
.unwrap()
.unwrap();
f(db).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn set_and_get_device_id() {
let db = create_db().await;
with_db(|db| async move {
db.set_device_id("device_id")
.await
.expect("Could not set device ID");
db.set_device_id("device_id")
.await
.expect("Could not set device ID");
let device_id = db.get_device_id().await.expect("Could not get device ID");
let device_id = db.get_device_id().await.expect("Could not get device ID");
assert!(device_id.is_some());
assert_eq!(device_id.unwrap(), "device_id");
assert!(device_id.is_some());
assert_eq!(device_id.unwrap(), "device_id");
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn no_device_id_set_returns_none() {
let db = create_db().await;
let device_id = db.get_device_id().await.expect("Could not get device ID");
assert!(device_id.is_none());
with_db(|db| async move {
let device_id = db.get_device_id().await.expect("Could not get device ID");
assert!(device_id.is_none());
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_update_device_id() {
let db = create_db().await;
with_db(|db| async move {
db.set_device_id("device_id")
.await
.expect("Could not set device ID");
db.set_device_id("device_id")
.await
.expect("Could not set device ID");
db.set_device_id("device_id2")
.await
.expect("Could not set device ID");
db.set_device_id("device_id2")
.await
.expect("Could not set device ID");
let device_id = db.get_device_id().await.expect("Could not get device ID");
let device_id = db.get_device_id().await.expect("Could not get device ID");
assert!(device_id.is_some());
assert_eq!(device_id.unwrap(), "device_id2");
assert!(device_id.is_some());
assert_eq!(device_id.unwrap(), "device_id2");
})
.await;
}
}

View File

@ -0,0 +1,361 @@
use super::Database;
use crate::db::{errors::DataError, Users};
use crate::error::BotError;
use crate::models::User;
use async_trait::async_trait;
#[async_trait]
impl Users for Database {
async fn upsert_user(&self, user: &User) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query!(
r#"INSERT INTO accounts (user_id, password, account_status)
VALUES (?, ?, ?)
ON CONFLICT(user_id) DO
UPDATE SET password = ?, account_status = ?"#,
user.username,
user.password,
user.account_status,
user.password,
user.account_status
)
.execute(&mut tx)
.await?;
sqlx::query!(
r#"INSERT INTO user_state (user_id, active_room)
VALUES (?, ?)
ON CONFLICT(user_id) DO
UPDATE SET active_room = ?"#,
user.username,
user.active_room,
user.active_room
)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn delete_user(&self, username: &str) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query!(r#"DELETE FROM accounts WHERE user_id = ?"#, username)
.execute(&mut tx)
.await?;
sqlx::query!(r#"DELETE FROM user_state WHERE user_id = ?"#, username)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn get_user(&self, username: &str) -> Result<Option<User>, DataError> {
// Should be query_as! macro, but the left join breaks it with a
// non existing error message.
let user_row: Option<User> = sqlx::query_as(
r#"SELECT
a.user_id as "username",
a.password,
s.active_room,
COALESCE(a.account_status, 'not_registered') as "account_status"
FROM accounts a
LEFT JOIN user_state s on a.user_id = s.user_id
WHERE a.user_id = ?"#,
)
.bind(username)
.fetch_optional(&self.conn)
.await?;
Ok(user_row)
}
async fn authenticate_user(
&self,
username: &str,
raw_password: &str,
) -> Result<Option<User>, BotError> {
let user = self.get_user(username).await?;
Ok(user.filter(|u| u.verify_password(raw_password)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::sqlite::Database;
use crate::db::Users;
use crate::models::AccountStatus;
use std::future::Future;
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
let db = Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
f(db).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn create_and_get_full_user_test() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
assert_eq!(user.account_status, AccountStatus::Registered);
assert_eq!(user.active_room, Some("myroom".to_string()));
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_get_user_with_no_state_record() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::AwaitingActivation,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
sqlx::query("DELETE FROM user_state")
.execute(&db.conn)
.await
.expect("Could not delete from user_state table.");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
//These should be default values because the state record is missing.
assert_eq!(user.active_room, None);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_password() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, None);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_active_room() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
active_room: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.active_room, None);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_update_user() {
with_db(|db| async move {
let insert_result1 = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
..Default::default()
})
.await;
assert!(insert_result1.is_ok());
let insert_result2 = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("123".to_string()),
active_room: Some("room".to_string()),
account_status: AccountStatus::AwaitingActivation,
})
.await;
assert!(insert_result2.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
//From second upsert
assert_eq!(user.password, Some("123".to_string()));
assert_eq!(user.active_room, Some("room".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_delete_user() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
db.delete_user("myuser")
.await
.expect("User deletion query failed");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_none());
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn username_not_in_db_returns_none() {
with_db(|db| async move {
let user = db
.get_user("does not exist")
.await
.expect("Get user query failure");
assert!(user.is_none());
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn authenticate_user_is_some_with_valid_password() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some(
crate::logic::hash_password("abc").expect("password hash error!"),
),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.authenticate_user("myuser", "abc")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn authenticate_user_is_none_with_wrong_password() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some(
crate::logic::hash_password("abc").expect("password hash error!"),
),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.authenticate_user("myuser", "wrong-password")
.await
.expect("User retrieval query failed");
assert!(user.is_none());
})
.await;
}
}

View File

@ -1,13 +1,8 @@
use super::errors::DataError;
use super::{Database, Variables};
use super::Database;
use crate::db::{errors::DataError, Variables};
use async_trait::async_trait;
use std::collections::HashMap;
struct UserVariableRow {
key: String,
value: i32,
}
#[async_trait]
impl Variables for Database {
async fn get_user_variables(
@ -104,148 +99,159 @@ impl Variables for Database {
#[cfg(test)]
mod tests {
use super::super::Variables;
use super::*;
use crate::db::sqlite::Database;
use crate::db::Variables;
use std::future::Future;
async fn create_db() -> Database {
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
Database::new(db_path.path().to_str().unwrap())
let db = Database::new(db_path.path().to_str().unwrap())
.await
.unwrap()
.unwrap();
f(db).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn set_and_get_variable_test() {
use super::super::Variables;
let db = create_db().await;
with_db(|db| async move {
db.set_user_variable("myuser", "myroom", "myvariable", 1)
.await
.expect("Could not set variable");
db.set_user_variable("myuser", "myroom", "myvariable", 1)
.await
.expect("Could not set variable");
let value = db
.get_user_variable("myuser", "myroom", "myvariable")
.await
.expect("Could not get variable");
let value = db
.get_user_variable("myuser", "myroom", "myvariable")
.await
.expect("Could not get variable");
assert_eq!(value, 1);
assert_eq!(value, 1);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_missing_variable_test() {
use super::super::Variables;
let db = create_db().await;
with_db(|db| async move {
let value = db.get_user_variable("myuser", "myroom", "myvariable").await;
let value = db.get_user_variable("myuser", "myroom", "myvariable").await;
assert!(value.is_err());
assert!(matches!(
value.err().unwrap(),
DataError::KeyDoesNotExist(_)
));
assert!(value.is_err());
assert!(matches!(
value.err().unwrap(),
DataError::KeyDoesNotExist(_)
));
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_other_user_variable_test() {
use super::super::Variables;
let db = create_db().await;
with_db(|db| async move {
db.set_user_variable("myuser1", "myroom", "myvariable", 1)
.await
.expect("Could not set variable");
db.set_user_variable("myuser1", "myroom", "myvariable", 1)
.await
.expect("Could not set variable");
let value = db
.get_user_variable("myuser2", "myroom", "myvariable")
.await;
let value = db
.get_user_variable("myuser2", "myroom", "myvariable")
.await;
assert!(value.is_err());
assert!(matches!(
value.err().unwrap(),
DataError::KeyDoesNotExist(_)
));
assert!(value.is_err());
assert!(matches!(
value.err().unwrap(),
DataError::KeyDoesNotExist(_)
));
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn count_variables_test() {
let db = create_db().await;
with_db(|db| async move {
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "myroom", variable_name, 1)
.await
.expect("Could not set variable");
}
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "myroom", variable_name, 1)
let count = db
.get_variable_count("myuser", "myroom")
.await
.expect("Could not set variable");
}
.expect("Could not get count.");
let count = db
.get_variable_count("myuser", "myroom")
.await
.expect("Could not get count.");
assert_eq!(count, 3);
assert_eq!(count, 3);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn count_variables_respects_user_id() {
let db = create_db().await;
with_db(|db| async move {
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("different-user", "myroom", variable_name, 1)
.await
.expect("Could not set variable");
}
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("different-user", "myroom", variable_name, 1)
let count = db
.get_variable_count("myuser", "myroom")
.await
.expect("Could not set variable");
}
.expect("Could not get count.");
let count = db
.get_variable_count("myuser", "myroom")
.await
.expect("Could not get count.");
assert_eq!(count, 0);
assert_eq!(count, 0);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn count_variables_respects_room_id() {
let db = create_db().await;
with_db(|db| async move {
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "different-room", variable_name, 1)
.await
.expect("Could not set variable");
}
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "different-room", variable_name, 1)
let count = db
.get_variable_count("myuser", "myroom")
.await
.expect("Could not set variable");
}
.expect("Could not get count.");
let count = db
.get_variable_count("myuser", "myroom")
.await
.expect("Could not get count.");
assert_eq!(count, 0);
assert_eq!(count, 0);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn delete_variable_test() {
let db = create_db().await;
with_db(|db| async move {
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "myroom", variable_name, 1)
.await
.expect("Could not set variable");
}
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "myroom", variable_name, 1)
db.delete_user_variable("myuser", "myroom", "var1")
.await
.expect("Could not set variable");
}
.expect("Could not delete variable.");
db.delete_user_variable("myuser", "myroom", "var1")
.await
.expect("Could not delete variable.");
let count = db
.get_variable_count("myuser", "myroom")
.await
.expect("Could not get count");
let count = db
.get_variable_count("myuser", "myroom")
.await
.expect("Could not get count");
assert_eq!(count, 2);
assert_eq!(count, 2);
let var1 = db.get_user_variable("myuser", "myroom", "var1").await;
assert!(var1.is_err());
assert!(matches!(var1.err().unwrap(), DataError::KeyDoesNotExist(_)));
let var1 = db.get_user_variable("myuser", "myroom", "var1").await;
assert!(var1.is_err());
assert!(matches!(var1.err().unwrap(), DataError::KeyDoesNotExist(_)));
})
.await;
}
}

View File

@ -1,7 +1,10 @@
use std::net::AddrParseError;
use crate::commands::CommandError;
use crate::config::ConfigError;
use crate::db::errors::DataError;
use crate::{commands::CommandError, db::sqlite::migrator};
use thiserror::Error;
use tonic::metadata::errors::InvalidMetadataValue;
#[derive(Error, Debug)]
pub enum BotError {
@ -15,15 +18,18 @@ pub enum BotError {
#[error("could not retrieve device id")]
NoDeviceIdFound,
#[error("could not build client: {0}")]
ClientBuildError(#[from] matrix_sdk::ClientBuildError),
#[error("could not open matrix store: {0}")]
OpenStoreError(#[from] matrix_sdk::store::OpenStoreError),
#[error("command error: {0}")]
CommandError(#[from] CommandError),
#[error("database error: {0}")]
DataError(#[from] DataError),
#[error("sqlite database error: {0}")]
SqliteDataError(#[from] crate::db::sqlite::errors::DataError),
#[error("the message should not be processed because it failed validation")]
ShouldNotProcessError,
@ -33,15 +39,15 @@ pub enum BotError {
#[error("could not parse URL")]
UrlParseError(#[from] url::ParseError),
#[error("could not parse ID")]
IdParseError(#[from] matrix_sdk::ruma::IdParseError),
#[error("error in matrix state store: {0}")]
MatrixStateStoreError(#[from] matrix_sdk::StoreError),
#[error("uncategorized matrix SDK error: {0}")]
MatrixError(#[from] matrix_sdk::Error),
#[error("uncategorized matrix SDK base error: {0}")]
MatrixBaseError(#[from] matrix_sdk::BaseError),
#[error("future canceled")]
FutureCanceledError,
@ -53,7 +59,7 @@ pub enum BotError {
IoError(#[from] std::io::Error),
#[error("dice parsing error: {0}")]
DiceParsingError(#[from] crate::parser::DiceParsingError),
DiceParsingError(#[from] crate::parser::dice::DiceParsingError),
#[error("command parsing error: {0}")]
CommandParsingError(#[from] crate::commands::parser::CommandParsingError),
@ -62,7 +68,7 @@ pub enum BotError {
DiceRollingError(#[from] DiceRollingError),
#[error("variable parsing error: {0}")]
VariableParsingError(#[from] crate::variables::VariableParsingError),
VariableParsingError(#[from] crate::parser::variables::VariableParsingError),
#[error("legacy parsing error")]
NomParserError(nom::error::ErrorKind),
@ -73,17 +79,38 @@ pub enum BotError {
#[error("variables not yet supported")]
VariablesNotSupported,
#[error("database error")]
DatabaseError(#[from] sled::Error),
#[error("database migration error: {0}")]
SqliteError(#[from] migrator::MigrationError),
#[error("too many commands or message was too large")]
MessageTooLarge,
#[error("could not convert to proper integer type")]
TryFromIntError(#[from] std::num::TryFromIntError),
// #[error("identifier error: {0}")]
// IdentifierError(#[from] matrix_sdk::ruma::Error),
#[error("password creation error: {0}")]
PasswordCreationError(argon2::Error),
#[error("account does not exist, or password incorrect")]
AuthenticationError,
#[error("user account does not exist, try registering")]
AccountDoesNotExist,
#[error("user account already exists")]
AccountAlreadyExists,
#[error("room name or id does not exist")]
RoomDoesNotExist,
#[error("tonic transport error: {0}")]
TonicTransportError(#[from] tonic::transport::Error),
#[error("address parsing error: {0}")]
AddressParseError(#[from] AddrParseError),
#[error("invalid metadata value: {0}")]
TonicInvalidMetadata(#[from] InvalidMetadataValue),
}
#[derive(Error, Debug)]

View File

@ -6,6 +6,9 @@ pub fn parse_help_topic(input: &str) -> Option<HelpTopic> {
"dicepool" => Some(HelpTopic::DicePool),
"dice" => Some(HelpTopic::RollingDice),
"cthulhu" => Some(HelpTopic::Cthulhu),
"variables" => Some(HelpTopic::Variables),
"var" => Some(HelpTopic::Variables),
"variable" => Some(HelpTopic::Variables),
"" => Some(HelpTopic::General),
_ => None,
}
@ -16,6 +19,7 @@ pub enum HelpTopic {
DicePool,
Cthulhu,
RollingDice,
Variables,
General,
}
@ -101,6 +105,34 @@ Note: If !cthadv is given a variable, and the roll is successful, it will
update the variable with the new skill.
"};
const VARIABLES_HELP: &'static str = indoc! {"
Variables
Commands: !get, !set, !variables
Manage variables that can be substituted into roll commands.
Examples: !get myvar, !set myvar 10
!get <variable> = show variable of the given name
!set <variable> <num> = set a variable to a number
The !variables command will list all variables for the room. The
variables command cna be used in a secure room to avoid spamming the
actual room that the variable is set in.
Variable names can be used in all types of dice rolls:
!pool myvar + 3
!roll myvar
There are some limitations on variables: they cannot themselves be
dice expressions (i.e. can only be numbers), and they must be uniquely
parseable in an expression (i.e 'myvard6' does not work for the !roll
command).
"};
const GENERAL_HELP: &'static str = indoc! {"
General Help
@ -117,6 +149,7 @@ impl HelpTopic {
HelpTopic::DicePool => DICEPOOL_HELP,
HelpTopic::Cthulhu => CTHULHU_HELP,
HelpTopic::RollingDice => DICE_HELP,
HelpTopic::Variables => VARIABLES_HELP,
HelpTopic::General => GENERAL_HELP,
}
}

View File

@ -6,12 +6,12 @@ pub mod config;
pub mod context;
pub mod cthulhu;
pub mod db;
pub mod dice;
pub mod error;
mod help;
pub mod logic;
pub mod matrix;
pub mod models;
mod parser;
pub mod rpc;
pub mod state;
pub mod variables;
pub mod systems;

131
dicebot/src/logic.rs Normal file
View File

@ -0,0 +1,131 @@
use crate::error::{BotError, DiceRollingError};
use crate::parser::dice::{Amount, Element};
use crate::{context::Context, models::Account};
use crate::{
db::{sqlite::Database, Users, Variables},
models::TransientUser,
};
use argon2::{self, Config, Error as ArgonError};
use futures::stream::{self, StreamExt, TryStreamExt};
use rand::Rng;
use std::slice;
/// Calculate the amount of dice to roll by consulting the database
/// and replacing variables with corresponding the amount. Errors out
/// if it cannot find a variable defined, or if the database errors.
pub async fn calculate_single_die_amount(
amount: &Amount,
ctx: &Context<'_>,
) -> Result<i32, BotError> {
calculate_dice_amount(slice::from_ref(amount), ctx).await
}
/// Calculate the amount of dice to roll by consulting the database
/// and replacing variables with corresponding amounts. Errors out if
/// it cannot find a variable defined, or if the database errors.
pub async fn calculate_dice_amount(amounts: &[Amount], ctx: &Context<'_>) -> Result<i32, BotError> {
let stream = stream::iter(amounts);
let variables = &ctx
.db
.get_user_variables(&ctx.username, ctx.active_room_id().as_str())
.await?;
use DiceRollingError::VariableNotFound;
let dice_amount: i32 = stream
.then(|amount| async move {
match &amount.element {
Element::Number(num_dice) => Ok(num_dice * amount.operator.mult()),
Element::Variable(variable) => variables
.get(variable)
.ok_or_else(|| VariableNotFound(variable.clone()))
.map(|i| *i),
}
})
.try_fold(0, |total, num_dice| async move { Ok(total + num_dice) })
.await?;
Ok(dice_amount)
}
/// Hash a password using the argon2 algorithm with a 16 byte salt.
pub(crate) fn hash_password(raw_password: &str) -> Result<String, ArgonError> {
let salt = rand::thread_rng().gen::<[u8; 16]>();
let config = Config::default();
argon2::hash_encoded(raw_password.as_bytes(), &salt, &config)
}
pub(crate) async fn get_account(db: &Database, username: &str) -> Result<Account, BotError> {
Ok(db
.get_user(username)
.await?
.map(|user| Account::Registered(user))
.unwrap_or_else(|| {
Account::Transient(TransientUser {
username: username.to_owned(),
})
}))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::Users;
use crate::models::{AccountStatus, User};
use std::future::Future;
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
let db = Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
f(db).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_account_no_user_exists() {
with_db(|db| async move {
let account = get_account(&db, "@test:example.com")
.await
.expect("Account retrieval didn't work");
assert!(matches!(account, Account::Transient(_)));
let user = account.transient_user().unwrap();
assert_eq!(user.username, "@test:example.com");
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_or_create_user_when_user_exists() {
with_db(|db| async move {
let user = User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
};
let insert_result = db.upsert_user(&user).await;
assert!(insert_result.is_ok());
let account = get_account(&db, "myuser")
.await
.expect("Account retrieval did not work");
assert!(matches!(account, Account::Registered(_)));
let user_again = account.registered_user().unwrap();
assert_eq!(user, *user_again);
})
.await;
}
}

113
dicebot/src/matrix.rs Normal file
View File

@ -0,0 +1,113 @@
use std::path::PathBuf;
use futures::stream::{self, StreamExt, TryStreamExt};
use log::error;
use matrix_sdk::ruma::events::room::message::{InReplyTo, RoomMessageEventContent, Relation};
use matrix_sdk::ruma::events::AnyMessageLikeEventContent;
use matrix_sdk::ruma::{RoomId, OwnedEventId, OwnedUserId};
use matrix_sdk::Client;
use matrix_sdk::Error as MatrixError;
use matrix_sdk::room::Joined;
use url::Url;
use crate::{config::Config, error::BotError};
fn cache_dir() -> Result<PathBuf, BotError> {
let mut dir = dirs::cache_dir().ok_or(BotError::NoCacheDirectoryError)?;
dir.push("matrix-dicebot");
Ok(dir)
}
/// Extracts more detailed error messages out of a matrix SDK error.
fn extract_error_message(error: MatrixError) -> String {
use matrix_sdk::{Error::Http, HttpError};
if let Http(HttpError::Api(ruma_err)) = error {
ruma_err.to_string()
} else {
error.to_string()
}
}
/// Creates the matrix client.
pub async fn create_client(config: &Config) -> Result<Client, BotError> {
let cache_dir = cache_dir()?;
let homeserver_url = Url::parse(&config.matrix_homeserver())?;
let client = Client::builder()
.sled_store(cache_dir, None)?
.homeserver_url(homeserver_url).build()
.await?;
Ok(client)
}
/// Retrieve a list of users in a given room.
pub async fn get_users_in_room(
client: &Client,
room_id: &RoomId,
) -> Result<Vec<String>, MatrixError> {
if let Some(joined_room) = client.get_joined_room(room_id) {
let members = joined_room.joined_members().await?;
Ok(members
.into_iter()
.map(|member| member.user_id().to_string())
.collect())
} else {
Ok(vec![])
}
}
pub async fn get_rooms_for_user(
client: &Client,
user: &OwnedUserId,
) -> Result<Vec<Joined>, MatrixError> {
// Carries errors through, in case we cannot load joined user IDs
// from the room for some reason.
let user_is_in_room = |room: Joined| async move {
match room.joined_user_ids().await {
Ok(users) => match users.contains(user) {
true => Some(Ok(room)),
false => None,
},
Err(e) => Some(Err(e)),
}
};
let rooms_for_user: Vec<Joined> = stream::iter(client.joined_rooms())
.filter_map(user_is_in_room)
.try_collect()
.await?;
Ok(rooms_for_user)
}
/// Send a message. The message is a tuple of HTML and plain text
/// responses.
pub async fn send_message(
client: &Client,
room_id: &RoomId,
message: (&str, &str),
reply_to: Option<OwnedEventId>,
) {
let (html, plain) = message;
let room = match client.get_joined_room(room_id) {
Some(room) => room,
_ => return,
};
let mut content = RoomMessageEventContent::notice_html(plain.trim(), html);
content.relates_to = reply_to.map(|event_id| Relation::Reply {
in_reply_to: InReplyTo::new(event_id)
});
let content = AnyMessageLikeEventContent::RoomMessage(content);
let result = room.send(content, None).await;
if let Err(e) = result {
let html = extract_error_message(e);
error!("Error sending html: {}", html);
};
}

157
dicebot/src/models.rs Normal file
View File

@ -0,0 +1,157 @@
use serde::{Deserialize, Serialize};
/// RoomInfo has basic metadata about a room: its name, ID, etc.
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct RoomInfo {
pub room_id: String,
pub room_name: String,
}
#[derive(Eq, PartialEq, Clone, Copy, Debug, sqlx::Type)]
#[sqlx(rename_all = "snake_case")]
pub enum AccountStatus {
/// Account is not registered, which means a transient "account"
/// with limited information exists only for the duration of the
/// command request.
NotRegistered,
/// User account is fully registered, either via Matrix directly,
/// or a web UI sign-up.
Registered,
/// Account is awaiting activation with a registration
/// code. Account cannot do privileged actions yet.
AwaitingActivation,
}
impl Default for AccountStatus {
fn default() -> Self {
AccountStatus::NotRegistered
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Account {
/// A registered user account, stored in the database.
Registered(User),
/// A transient account. Not stored in the database. Represents a
/// user in a public channel that has not registered directly with
/// the bot yet.
Transient(TransientUser),
}
impl Account {
/// Whether or not this account is a registered user account.
pub fn is_registered(&self) -> bool {
matches!(self, Self::Registered(_))
}
/// Gets the account status. For registered users, this is their
/// actual account status (fully registered or awaiting
/// activation). For transient users, this is
/// AccountStatus::NotRegistered.
pub fn account_status(&self) -> AccountStatus {
match self {
Self::Registered(user) => user.account_status,
Self::Transient(_) => AccountStatus::NotRegistered,
}
}
/// Consume self into an Option<User> instance, which will be Some
/// if this account has a registered user, and None otherwise.
pub fn registered_user(&self) -> Option<&User> {
match self {
Self::Registered(ref user) => Some(user),
_ => None,
}
}
/// Consume self into an Option<TransientUser> instance, which
/// will be Some if this account has a non-registered user, and
/// None otherwise.
pub fn transient_user(self) -> Option<TransientUser> {
match self {
Self::Transient(user) => Some(user),
_ => None,
}
}
}
impl Default for Account {
fn default() -> Self {
Account::Transient(TransientUser {
username: "".to_string(),
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TransientUser {
pub username: String,
}
#[derive(Eq, PartialEq, Clone, Debug, Default, sqlx::FromRow)]
pub struct User {
pub username: String,
pub password: Option<String>,
pub active_room: Option<String>,
pub account_status: AccountStatus,
}
impl User {
/// Create a new unregistered skeleton marker account for a
/// username.
pub fn unregistered(username: &str) -> User {
User {
username: username.to_owned(),
..Default::default()
}
}
pub fn verify_password(&self, raw_password: &str) -> bool {
self.password
.as_ref()
.and_then(|p| argon2::verify_encoded(p, raw_password.as_bytes()).ok())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verify_password_passes_with_correct_password() {
let user = User {
password: Some(
crate::logic::hash_password("mypassword").expect("Password hashing error!"),
),
..Default::default()
};
assert_eq!(user.verify_password("mypassword"), true);
}
#[test]
fn verify_password_fails_with_wrong_password() {
let user = User {
password: Some(
crate::logic::hash_password("mypassword").expect("Password hashing error!"),
),
..Default::default()
};
assert_eq!(user.verify_password("wrong-password"), false);
}
#[test]
fn verify_password_fails_with_no_password() {
let user = User {
password: None,
..Default::default()
};
assert_eq!(user.verify_password("wrong-password"), false);
}
}

View File

@ -27,7 +27,7 @@ pub enum DiceParsingError {
}
impl From<std::num::ParseIntError> for DiceParsingError {
fn from(_error: std::num::ParseIntError) -> Self {
fn from(_: std::num::ParseIntError) -> Self {
DiceParsingError::ConversionError
}
}
@ -151,8 +151,9 @@ where
/// should not have an operator, but every one after that should.
/// Accepts expressions like "8", "10 + variablename", "variablename -
/// 3", etc. This function is currently common to systems that don't
/// deal with XdY rolls. Support for that will be added later.
pub fn parse_amounts(input: &str) -> ParseResult<Vec<Amount>> {
/// deal with XdY rolls. Support for that will be added later. Returns
/// parsed amounts and unconsumed input (e.g. roll modifiers).
pub fn parse_amounts(input: &str) -> ParseResult<(Vec<Amount>, &str)> {
let input = input.trim();
let remaining_amounts = many(amount_parser()).map(|amounts: Vec<ParseResult<Amount>>| amounts);
@ -169,31 +170,23 @@ pub fn parse_amounts(input: &str) -> ParseResult<Vec<Amount>> {
(amounts, results.1)
})?;
if rest.len() == 0 {
// Any ParseResult errors will short-circuit the collect.
results.into_iter().collect()
} else {
Err(DiceParsingError::UnconsumedInput)
}
// Any ParseResult errors will short-circuit the collect.
let results: Vec<Amount> = results.into_iter().collect::<ParseResult<_>>()?;
Ok((results, rest))
}
/// Parse an expression that expects a single number or variable. No
/// operators are allowed. This function is common to systems that
/// don't deal with XdY rolls. Currently. this function does not
/// support parsing negative numbers.
pub fn parse_single_amount(input: &str) -> ParseResult<Amount> {
/// support parsing negative numbers. Returns the parsed amount and
/// any unconsumed input (useful for dice roll modifiers).
pub fn parse_single_amount(input: &str) -> ParseResult<(Amount, &str)> {
// TODO add support for negative numbers, as technically they
// should be allowed.
let input = input.trim();
let mut parser = first_amount_parser().map(|amount: ParseResult<Amount>| amount);
let (result, rest) = parser.parse(input)?;
if rest.len() == 0 {
result
} else {
Err(DiceParsingError::UnconsumedInput)
}
Ok((result?, rest))
}
#[cfg(test)]
@ -206,10 +199,13 @@ mod parse_single_amount_tests {
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
Amount {
operator: Operator::Plus,
element: Element::Variable("abc".to_string())
}
(
Amount {
operator: Operator::Plus,
element: Element::Variable("abc".to_string())
},
""
)
)
}
@ -233,24 +229,15 @@ mod parse_single_amount_tests {
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
Amount {
operator: Operator::Plus,
element: Element::Number(1)
}
(
Amount {
operator: Operator::Plus,
element: Element::Number(1)
},
""
)
)
}
#[test]
fn parse_multiple_elements_test() {
let result = parse_single_amount("1+abc");
assert!(result.is_err());
let result = parse_single_amount("abc+1");
assert!(result.is_err());
let result = parse_single_amount("-1-abc");
assert!(result.is_err());
}
}
#[cfg(test)]
@ -263,20 +250,26 @@ mod parse_many_amounts_tests {
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
vec![Amount {
operator: Operator::Plus,
element: Element::Number(1)
}]
(
vec![Amount {
operator: Operator::Plus,
element: Element::Number(1)
}],
""
)
);
let result = parse_amounts("10");
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
vec![Amount {
operator: Operator::Plus,
element: Element::Number(10)
}]
(
vec![Amount {
operator: Operator::Plus,
element: Element::Number(10)
}],
""
)
);
}
@ -295,20 +288,26 @@ mod parse_many_amounts_tests {
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
vec![Amount {
operator: Operator::Plus,
element: Element::Variable("asdf".to_string())
}]
(
vec![Amount {
operator: Operator::Plus,
element: Element::Variable("asdf".to_string())
}],
""
)
);
let result = parse_amounts("nosis");
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
vec![Amount {
operator: Operator::Plus,
element: Element::Variable("nosis".to_string())
}]
(
vec![Amount {
operator: Operator::Plus,
element: Element::Variable("nosis".to_string())
}],
""
)
);
}

View File

@ -0,0 +1,2 @@
pub mod dice;
pub mod variables;

50
dicebot/src/rpc/mod.rs Normal file
View File

@ -0,0 +1,50 @@
use crate::error::BotError;
use crate::{config::Config, db::sqlite::Database};
use log::{info, warn};
use matrix_sdk::Client;
use service::DicebotRpcService;
use std::sync::Arc;
use tenebrous_rpc::protos::dicebot::dicebot_server::DicebotServer;
use tonic::{metadata::MetadataValue, transport::Server, Request, Status};
pub(crate) mod service;
pub async fn serve_grpc(
config: &Arc<Config>,
db: &Database,
client: &Client,
) -> Result<(), BotError> {
match config.rpc_addr().zip(config.rpc_key()) {
Some((addr, rpc_key)) => {
let expected_bearer = MetadataValue::from_str(&format!("Bearer {}", rpc_key))?;
let addr = addr.parse()?;
let rpc_service = DicebotRpcService {
db: db.clone(),
config: config.clone(),
client: client.clone(),
};
info!("Serving Dicebot gRPC service on {}", addr);
let interceptor = move |req: Request<()>| match req.metadata().get("authorization") {
Some(bearer) if bearer == expected_bearer => Ok(req),
_ => Err(Status::unauthenticated("No valid auth token")),
};
let server = DicebotServer::with_interceptor(rpc_service, interceptor);
Server::builder()
.add_service(server)
.serve(addr)
.await
.map_err(|e| e.into())
}
_ => noop().await,
}
}
pub async fn noop() -> Result<(), BotError> {
warn!("RPC address or shared secret not specified. Not enabling gRPC.");
Ok(())
}

117
dicebot/src/rpc/service.rs Normal file
View File

@ -0,0 +1,117 @@
use crate::db::{errors::DataError, Variables};
use crate::error::BotError;
use crate::matrix;
use crate::{config::Config, db::sqlite::Database};
use futures::stream;
use futures::{StreamExt, TryFutureExt, TryStreamExt};
use matrix_sdk::ruma::OwnedUserId;
use matrix_sdk::{room::Joined, Client};
use std::convert::TryFrom;
use std::sync::Arc;
use tenebrous_rpc::protos::dicebot::{
dicebot_server::Dicebot, rooms_list_reply::Room, GetAllVariablesReply, GetAllVariablesRequest,
RoomsListReply, SetVariableReply, SetVariableRequest, UserIdRequest,
};
use tenebrous_rpc::protos::dicebot::{GetVariableReply, GetVariableRequest};
use tonic::{Code, Request, Response, Status};
impl From<BotError> for Status {
fn from(error: BotError) -> Status {
Status::new(Code::Internal, error.to_string())
}
}
impl From<DataError> for Status {
fn from(error: DataError) -> Status {
Status::new(Code::Internal, error.to_string())
}
}
#[derive(Clone)]
pub(super) struct DicebotRpcService {
pub(super) config: Arc<Config>,
pub(super) db: Database,
pub(super) client: Client,
}
#[tonic::async_trait]
impl Dicebot for DicebotRpcService {
async fn set_variable(
&self,
request: Request<SetVariableRequest>,
) -> Result<Response<SetVariableReply>, Status> {
let SetVariableRequest {
user_id,
room_id,
variable_name,
value,
} = request.into_inner();
self.db
.set_user_variable(&user_id, &room_id, &variable_name, value)
.await?;
Ok(Response::new(SetVariableReply { success: true }))
}
async fn get_variable(
&self,
request: Request<GetVariableRequest>,
) -> Result<Response<GetVariableReply>, Status> {
let request = request.into_inner();
let value = self
.db
.get_user_variable(&request.user_id, &request.room_id, &request.variable_name)
.await?;
Ok(Response::new(GetVariableReply { value }))
}
async fn get_all_variables(
&self,
request: Request<GetAllVariablesRequest>,
) -> Result<Response<GetAllVariablesReply>, Status> {
let request = request.into_inner();
let variables = self
.db
.get_user_variables(&request.user_id, &request.room_id)
.await?;
Ok(Response::new(GetAllVariablesReply { variables }))
}
async fn rooms_for_user(
&self,
request: Request<UserIdRequest>,
) -> Result<Response<RoomsListReply>, Status> {
let UserIdRequest { user_id } = request.into_inner();
let user_id = OwnedUserId::try_from(user_id).map_err(BotError::from)?;
let rooms_for_user = matrix::get_rooms_for_user(&self.client, &user_id)
.err_into::<BotError>()
.await?;
let mut rooms: Vec<Room> = stream::iter(rooms_for_user)
.filter_map(|room: Joined| async move {
let room: Result<Room, _> = room.display_name().await.map(|room_name| Room {
room_id: room.room_id().to_string(),
display_name: room_name.to_string(),
});
Some(room)
})
.err_into::<BotError>()
.try_collect()
.await?;
let sort = |r1: &Room, r2: &Room| {
r1.display_name
.to_lowercase()
.cmp(&r2.display_name.to_lowercase())
};
rooms.sort_by(sort);
Ok(Response::new(RoomsListReply { rooms }))
}
}

View File

@ -0,0 +1,21 @@
use strum::{AsRefStr, Display, EnumIter, EnumString};
#[derive(EnumString, EnumIter, AsRefStr, Display)]
pub(crate) enum GameSystem {
ChroniclesOfDarkness,
Changeling,
MageTheAwakening,
WerewolfTheForsaken,
DeviantTheRenegades,
MummyTheCurse,
PrometheanTheCreated,
CallOfCthulhu,
DungeonsAndDragons5e,
DungeonsAndDragons4e,
DungeonsAndDragons35e,
DungeonsAndDragons2e,
DungeonsAndDragons1e,
None,
}
impl GameSystem {}

18
rpc/Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "tenebrous-rpc"
version = "0.1.0"
authors = ["projectmoon <projectmoon@agnos.is>"]
edition = "2018"
description = "gRPC protobuf models for Tenebrous."
homepage = "https://git.agnos.is/projectmoon/tenebrous-dicebot"
repository = "https://git.agnos.is/projectmoon/tenebrous-dicebot"
license = "AGPL-3.0-or-later"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
tonic-build = "0.4"
[dependencies]
tonic = "0.4"
prost = "0.7"

4
rpc/build.rs Normal file
View File

@ -0,0 +1,4 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("protos/dicebot.proto")?;
Ok(())
}

52
rpc/protos/dicebot.proto Normal file
View File

@ -0,0 +1,52 @@
syntax = "proto3";
package dicebot;
service Dicebot {
rpc GetVariable(GetVariableRequest) returns (GetVariableReply);
rpc GetAllVariables(GetAllVariablesRequest) returns (GetAllVariablesReply);
rpc SetVariable(SetVariableRequest) returns (SetVariableReply);
rpc RoomsForUser(UserIdRequest) returns (RoomsListReply);
}
message GetVariableRequest {
string user_id = 1;
string room_id = 2;
string variable_name = 3;
}
message GetVariableReply {
int32 value = 1;
}
message GetAllVariablesRequest {
string user_id = 1;
string room_id = 2;
}
message GetAllVariablesReply {
map<string, int32> variables = 1;
}
message SetVariableRequest {
string user_id = 1;
string room_id = 2;
string variable_name = 3;
int32 value = 4;
}
message SetVariableReply {
bool success = 1;
}
message UserIdRequest {
string user_id = 1;
}
message RoomsListReply {
message Room {
string room_id = 1;
string display_name = 2;
}
repeated Room rooms = 1;
}

5
rpc/src/lib.rs Normal file
View File

@ -0,0 +1,5 @@
pub mod protos {
pub mod dicebot {
tonic::include_proto!("dicebot");
}
}

View File

@ -1,189 +0,0 @@
/**
* In addition to the terms of the AGPL, this file is governed by the
* terms of the MIT license, from the original axfive-matrix-dicebot
* project.
*/
use nom::bytes::complete::take_while;
use nom::{
alt, bytes::complete::tag, character::complete::digit1, complete, many0, named,
sequence::tuple, tag, IResult,
};
use super::dice::*;
//******************************
//Legacy Code
//******************************
fn is_whitespace(input: char) -> bool {
input == ' ' || input == '\n' || input == '\t' || input == '\r'
}
/// Eat whitespace, returning it
pub fn eat_whitespace(input: &str) -> IResult<&str, &str> {
let (input, whitespace) = take_while(is_whitespace)(input)?;
Ok((input, whitespace))
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Sign {
Plus,
Minus,
}
// Parse a dice expression. Does not eat whitespace
fn parse_dice(input: &str) -> IResult<&str, Dice> {
let (input, (count, _, sides)) = tuple((digit1, tag("d"), digit1))(input)?;
Ok((
input,
Dice::new(count.parse().unwrap(), sides.parse().unwrap()),
))
}
// Parse a single digit expression. Does not eat whitespace
fn parse_bonus(input: &str) -> IResult<&str, u32> {
let (input, bonus) = digit1(input)?;
Ok((input, bonus.parse().unwrap()))
}
// Parse a sign expression. Eats whitespace.
fn parse_sign(input: &str) -> IResult<&str, Sign> {
let (input, _) = eat_whitespace(input)?;
named!(sign(&str) -> Sign, alt!(
complete!(tag!("+")) => { |_| Sign::Plus } |
complete!(tag!("-")) => { |_| Sign::Minus }
));
let (input, sign) = sign(input)?;
Ok((input, sign))
}
// Parse an element expression. Eats whitespace.
fn parse_element(input: &str) -> IResult<&str, Element> {
let (input, _) = eat_whitespace(input)?;
named!(element(&str) -> Element, alt!(
parse_dice => { |d| Element::Dice(d) } |
parse_bonus => { |b| Element::Bonus(b) }
));
let (input, element) = element(input)?;
Ok((input, element))
}
// Parse a signed element expression. Eats whitespace.
fn parse_signed_element(input: &str) -> IResult<&str, SignedElement> {
let (input, _) = eat_whitespace(input)?;
let (input, sign) = parse_sign(input)?;
let (input, _) = eat_whitespace(input)?;
let (input, element) = parse_element(input)?;
let element = match sign {
Sign::Plus => SignedElement::Positive(element),
Sign::Minus => SignedElement::Negative(element),
};
Ok((input, element))
}
// Parse a full element expression. Eats whitespace.
pub fn parse_element_expression(input: &str) -> IResult<&str, ElementExpression> {
named!(first_element(&str) -> SignedElement, alt!(
parse_signed_element => { |e| e } |
parse_element => { |e| SignedElement::Positive(e) }
));
let (input, first) = first_element(input)?;
let (input, rest) = if input.trim().is_empty() {
(input, vec![first])
} else {
named!(rest_elements(&str) -> Vec<SignedElement>, many0!(parse_signed_element));
let (input, mut rest) = rest_elements(input)?;
rest.insert(0, first);
(input, rest)
};
Ok((input, ElementExpression(rest)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dice_test() {
assert_eq!(parse_dice("2d4"), Ok(("", Dice::new(2, 4))));
assert_eq!(parse_dice("20d40"), Ok(("", Dice::new(20, 40))));
assert_eq!(parse_dice("8d7"), Ok(("", Dice::new(8, 7))));
}
#[test]
fn element_test() {
assert_eq!(
parse_element(" \t\n\r\n 8d7 \n"),
Ok((" \n", Element::Dice(Dice::new(8, 7))))
);
assert_eq!(
parse_element(" \t\n\r\n 8 \n"),
Ok((" \n", Element::Bonus(8)))
);
}
#[test]
fn signed_element_test() {
assert_eq!(
parse_signed_element("+ 7"),
Ok(("", SignedElement::Positive(Element::Bonus(7))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8 \n"),
Ok((" \n", SignedElement::Negative(Element::Bonus(8))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8d4 \n"),
Ok((
" \n",
SignedElement::Negative(Element::Dice(Dice::new(8, 4)))
))
);
assert_eq!(
parse_signed_element(" \t\n\r\n+ 8d4 \n"),
Ok((
" \n",
SignedElement::Positive(Element::Dice(Dice::new(8, 4)))
))
);
}
#[test]
fn element_expression_test() {
assert_eq!(
parse_element_expression("8d4"),
Ok((
"",
ElementExpression(vec![SignedElement::Positive(Element::Dice(Dice::new(
8, 4
)))])
))
);
assert_eq!(
parse_element_expression(" - 8d4 \n "),
Ok((
" \n ",
ElementExpression(vec![SignedElement::Negative(Element::Dice(Dice::new(
8, 4
)))])
))
);
assert_eq!(
parse_element_expression("\t3d4 + 7 - 5 - 6d12 + 1d1 + 53 1d5 "),
Ok((
" 1d5 ",
ElementExpression(vec![
SignedElement::Positive(Element::Dice(Dice::new(3, 4))),
SignedElement::Positive(Element::Bonus(7)),
SignedElement::Negative(Element::Bonus(5)),
SignedElement::Negative(Element::Dice(Dice::new(6, 12))),
SignedElement::Positive(Element::Dice(Dice::new(1, 1))),
SignedElement::Positive(Element::Bonus(53)),
])
))
);
}
}

View File

@ -1,37 +0,0 @@
use tenebrous_dicebot::db::sqlite::{Database as SqliteDatabase, Variables};
use tenebrous_dicebot::db::Database;
use tenebrous_dicebot::error::BotError;
#[tokio::main]
async fn main() -> Result<(), BotError> {
let sled_path = std::env::args()
.skip(1)
.next()
.expect("Need a path to a Sled database as an arument.");
let sqlite_path = std::env::args()
.skip(2)
.next()
.expect("Need a path to an sqlite database as an arument.");
let db = Database::new(&sled_path)?;
let all_variables = db.variables.get_all_variables()?;
let sql_db = SqliteDatabase::new(&sqlite_path).await?;
for var in all_variables {
if let ((username, room_id, variable_name), value) = var {
println!(
"Migrating {}::{}::{} = {} to sql",
username, room_id, variable_name, value
);
sql_db
.set_user_variable(&username, &room_id, &variable_name, value)
.await;
}
}
Ok(())
}

View File

@ -1,248 +0,0 @@
use crate::commands::{execute_command, ExecutionError, ExecutionResult, ResponseExtractor};
use crate::config::*;
use crate::context::{Context, RoomContext};
use crate::db::sqlite::Database;
use crate::db::sqlite::DbState;
use crate::error::BotError;
use crate::matrix;
use crate::state::DiceBotState;
use dirs;
use futures::stream::{self, StreamExt};
use log::{error, info};
use matrix_sdk::{self, identifiers::EventId, room::Joined, Client, ClientConfig, SyncSettings};
use std::clone::Clone;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
use url::Url;
pub mod event_handlers;
/// How many commands can be in one message. If the amount is higher
/// than this, we reject execution.
const MAX_COMMANDS_PER_MESSAGE: usize = 50;
/// The DiceBot struct represents an active dice bot. The bot is not
/// connected to Matrix until its run() function is called.
pub struct DiceBot {
/// A reference to the configuration read in on application start.
config: Arc<Config>,
/// The matrix client.
client: Client,
/// State of the dicebot
state: Arc<RwLock<DiceBotState>>,
/// Active database layer
db: Database,
}
fn cache_dir() -> Result<PathBuf, BotError> {
let mut dir = dirs::cache_dir().ok_or(BotError::NoCacheDirectoryError)?;
dir.push("matrix-dicebot");
Ok(dir)
}
/// Creates the matrix client.
fn create_client(config: &Config) -> Result<Client, BotError> {
let cache_dir = cache_dir()?;
//let store = JsonStore::open(&cache_dir)?;
let client_config = ClientConfig::new().store_path(cache_dir);
let homeserver_url = Url::parse(&config.matrix_homeserver())?;
Ok(Client::new_with_config(homeserver_url, client_config)?)
}
/// Handle responding to a single command being executed. Wil print
/// out the full result of that command.
async fn handle_single_result(
client: &Client,
cmd_result: &ExecutionResult,
respond_to: &str,
room: &Joined,
event_id: EventId,
) {
if cmd_result.is_err() {
error!(
"Command execution error: {}",
cmd_result.as_ref().err().unwrap()
);
}
let html = cmd_result.message_html(respond_to);
matrix::send_message(client, room.room_id(), &html, Some(event_id)).await;
}
/// Handle responding to multiple commands being executed. Will print
/// out how many commands succeeded and failed (if any failed).
async fn handle_multiple_results(
client: &Client,
results: &[(String, ExecutionResult)],
respond_to: &str,
room: &Joined,
) {
let respond_to = format!(
"<a href=\"https://matrix.to/#/{}\">{}</a>",
respond_to, respond_to
);
let errors: Vec<(&str, &ExecutionError)> = results
.into_iter()
.filter_map(|(cmd, result)| match result {
Err(e) => Some((cmd.as_ref(), e)),
_ => None,
})
.collect();
for result in errors.iter() {
error!("Command execution error: '{}' - {}", result.0, result.1);
}
let message = if errors.len() == 0 {
format!("{}: Executed {} commands", respond_to, results.len())
} else {
let failures: Vec<String> = errors
.iter()
.map(|&(cmd, err)| format!("<strong>{}:</strong> {}", cmd, err))
.collect();
format!(
"{}: Executed {} commands ({} failed)\n\nFailures:\n{}",
respond_to,
results.len(),
errors.len(),
failures.join("\n")
)
.replace("\n", "<br/>")
};
matrix::send_message(client, room.room_id(), &message, None).await;
}
impl DiceBot {
/// Create a new dicebot with the given configuration and state
/// actor. This function returns a Result because it is possible
/// for client creation to fail for some reason (e.g. invalid
/// homeserver URL).
pub fn new(
config: &Arc<Config>,
state: &Arc<RwLock<DiceBotState>>,
db: &Database,
) -> Result<Self, BotError> {
Ok(DiceBot {
client: create_client(&config)?,
config: config.clone(),
state: state.clone(),
db: db.clone(),
})
}
/// Logs in to matrix and potentially records a new device ID. If
/// no device ID is found in the database, a new one will be
/// generated by the matrix SDK, and we will store it.
async fn login(&self, client: &Client) -> Result<(), BotError> {
let username = self.config.matrix_username();
let password = self.config.matrix_password();
// Pull device ID from database, if it exists. Then write it
// to DB if the library generated one for us.
let device_id: Option<String> = self.db.get_device_id().await?;
let device_id: Option<&str> = device_id.as_deref();
client
.login(username, password, device_id, Some("matrix dice bot"))
.await?;
if device_id.is_none() {
let device_id = client.device_id().await.ok_or(BotError::NoDeviceIdFound)?;
self.db.set_device_id(device_id.as_str()).await?;
info!("Recorded new device ID: {}", device_id.as_str());
} else {
info!("Using existing device ID: {}", device_id.unwrap());
}
info!("Logged in as {}", username);
Ok(())
}
/// Logs the bot in to Matrix and listens for events until program
/// terminated, or a panic occurs. Originally adapted from the
/// matrix-rust-sdk command bot example.
pub async fn run(self) -> Result<(), BotError> {
let client = self.client.clone();
self.login(&client).await?;
client.set_event_handler(Box::new(self)).await;
info!("Listening for commands");
// TODO replace with sync_with_callback for cleaner shutdown
// process.
client.sync(SyncSettings::default()).await;
Ok(())
}
async fn execute_commands(
&self,
room: &Joined,
sender_username: &str,
msg_body: &str,
) -> Vec<(String, ExecutionResult)> {
let room_name: &str = &room.display_name().await.ok().unwrap_or_default();
let commands: Vec<&str> = msg_body
.lines()
.filter(|line| line.starts_with("!"))
.take(MAX_COMMANDS_PER_MESSAGE + 1)
.collect();
//Up to 50 commands allowed, otherwise we send back an error.
let results: Vec<(String, ExecutionResult)> = if commands.len() < MAX_COMMANDS_PER_MESSAGE {
stream::iter(commands)
.then(|command| async move {
let ctx = Context {
db: self.db.clone(),
matrix_client: &self.client,
room: RoomContext::new_with_name(&room, room_name),
username: &sender_username,
message_body: &command,
};
let cmd_result = execute_command(&ctx).await;
info!("[{}] {} executed: {}", room_name, sender_username, command);
(command.to_owned(), cmd_result)
})
.collect()
.await
} else {
vec![(
"".to_owned(),
Err(ExecutionError(BotError::MessageTooLarge)),
)]
};
results
}
pub async fn handle_results(
&self,
room: &Joined,
sender_username: &str,
event_id: EventId,
results: Vec<(String, ExecutionResult)>,
) {
if results.len() >= 1 {
if results.len() == 1 {
handle_single_result(
&self.client,
&results[0].1,
sender_username,
&room,
event_id,
)
.await;
} else if results.len() > 1 {
handle_multiple_results(&self.client, &results, sender_username, &room).await;
}
}
}
}

View File

@ -1,231 +0,0 @@
/**
* In addition to the terms of the AGPL, portions of this file
* are governed by the terms of the MIT license, from the Rust Matrix
* SDK example code.
*/
use super::DiceBot;
use crate::db::sqlite::Database;
use crate::db::sqlite::Rooms;
use crate::error::BotError;
use crate::logic::record_room_information;
use async_trait::async_trait;
use log::{debug, error, info, warn};
use matrix_sdk::{
self,
events::{
room::member::{MemberEventContent, MembershipChange},
room::message::{MessageEventContent, MessageType, TextMessageEventContent},
StrippedStateEvent, SyncMessageEvent, SyncStateEvent,
},
room::Room,
EventHandler,
};
use std::ops::Sub;
use std::time::{Duration, SystemTime};
use std::{clone::Clone, time::UNIX_EPOCH};
/// Check if a message is recent enough to actually process. If the
/// message is within "oldest_message_age" seconds, this function
/// returns true. If it's older than that, it returns false and logs a
/// debug message.
fn check_message_age(
event: &SyncMessageEvent<MessageEventContent>,
oldest_message_age: u64,
) -> bool {
let sending_time = event
.origin_server_ts
.to_system_time()
.unwrap_or(UNIX_EPOCH);
let oldest_timestamp = SystemTime::now().sub(Duration::from_secs(oldest_message_age));
if sending_time > oldest_timestamp {
true
} else {
let age = match oldest_timestamp.duration_since(sending_time) {
Ok(n) => format!("{} seconds too old", n.as_secs()),
Err(_) => "before the UNIX epoch".to_owned(),
};
debug!("Ignoring message because it is {}: {:?}", age, event);
false
}
}
/// Determine whether or not to process a received message. This check
/// is necessary in addition to the event processing check because we
/// may receive message events when entering a room for the first
/// time, and we don't want to respond to things before the bot was in
/// the channel, but we do want to respond to things that were sent if
/// the bot left and rejoined quickly.
async fn should_process_message<'a>(
bot: &DiceBot,
event: &SyncMessageEvent<MessageEventContent>,
) -> Result<(String, String), BotError> {
//Ignore messages that are older than configured duration.
if !check_message_age(event, bot.config.oldest_message_age()) {
let state_check = bot.state.read().unwrap();
if !((*state_check).logged_skipped_old_messages()) {
drop(state_check);
let mut state = bot.state.write().unwrap();
(*state).skipped_old_messages();
}
return Err(BotError::ShouldNotProcessError);
}
let (msg_body, sender_username) = if let SyncMessageEvent {
content:
MessageEventContent {
msgtype: MessageType::Text(TextMessageEventContent { body, .. }),
..
},
sender,
..
} = event
{
(
body.clone(),
format!("@{}:{}", sender.localpart(), sender.server_name()),
)
} else {
(String::new(), String::new())
};
Ok((msg_body, sender_username))
}
async fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool {
db.should_process(room_id, event_id)
.await
.unwrap_or_else(|e| {
error!(
"Database error when checking if we should process an event: {}",
e.to_string()
);
false
})
}
/// This event emitter listens for messages with dice rolling commands.
/// Originally adapted from the matrix-rust-sdk examples.
#[async_trait]
impl EventHandler for DiceBot {
async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) {
//let room = Common::new(self.client.clone(), room);
let (room_id, room_display_name) = match room.display_name().await {
Ok(display_name) => (room.room_id(), display_name),
_ => return,
};
let room_id_str = room_id.as_str();
let username = &event.state_key;
if !should_process_event(&self.db, room_id_str, event.event_id.as_str()).await {
return;
}
let event_affects_us = if let Some(our_user_id) = self.client.user_id().await {
event.state_key == our_user_id
} else {
false
};
// user_joing is true if a user is joining this room, and
// false if they have left for some reason. This user may be
// us, or another user in the room.
use MembershipChange::*;
let user_joining = match event.membership_change() {
Joined => true,
Banned | Left | Kicked | KickedAndBanned => false,
_ => return,
};
let result = if event_affects_us && !user_joining {
info!("Clearing all information for room ID {}", room_id);
self.db.clear_info(room_id_str).await.map_err(|e| e.into())
} else if event_affects_us && user_joining {
info!("Joined room {}; recording room information", room_id);
record_room_information(
&self.client,
&self.db,
&room_id,
&room_display_name,
&event.state_key,
)
.await
} else if !event_affects_us && user_joining {
info!("Adding user {} to room ID {}", username, room_id);
self.db
.add_user_to_room(username, room_id_str)
.await
.map_err(|e| e.into())
} else if !event_affects_us && !user_joining {
info!("Removing user {} from room ID {}", username, room_id);
self.db
.remove_user_from_room(username, room_id_str)
.await
.map_err(|e| e.into())
} else {
debug!("Ignoring a room member event: {:#?}", event);
Ok(())
};
if let Err(e) = result {
error!("Could not update room information: {}", e.to_string());
} else {
debug!("Successfully processed room member update.");
}
}
async fn on_stripped_state_member(
&self,
room: Room,
event: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>,
) {
let room = match room {
Room::Invited(invited_room) => invited_room,
_ => return,
};
if room.own_user_id().as_str() != event.state_key {
return;
}
info!(
"Autojoining room {}",
room.display_name().await.ok().unwrap_or_default()
);
if let Err(e) = self.client.join_room_by_id(&room.room_id()).await {
warn!("Could not join room: {}", e.to_string())
}
}
async fn on_room_message(&self, room: Room, event: &SyncMessageEvent<MessageEventContent>) {
let room = match room {
Room::Joined(joined_room) => joined_room,
_ => return,
};
let room_id = room.room_id().as_str();
if !should_process_event(&self.db, room_id, event.event_id.as_str()).await {
return;
}
let (msg_body, sender_username) =
if let Ok((msg_body, sender_username)) = should_process_message(self, &event).await {
(msg_body, sender_username)
} else {
return;
};
let results = self
.execute_commands(&room, &sender_username, &msg_body)
.await;
self.handle_results(&room, &sender_username, event.event_id.clone(), results)
.await;
}
}

View File

@ -1,158 +0,0 @@
use crate::context::Context;
use crate::error::BotError;
use async_trait::async_trait;
use thiserror::Error;
use BotError::{DataError, SqliteDataError};
pub mod basic_rolling;
pub mod cofd;
pub mod cthulhu;
pub mod management;
pub mod misc;
pub mod parser;
pub mod variables;
/// A custom error type specifically related to parsing command text.
/// Does not wrap an execution failure.
#[derive(Error, Debug)]
pub enum CommandError {
#[error("invalid command: {0}")]
InvalidCommand(String),
#[error("ignored command")]
IgnoredCommand,
}
/// A successfully executed command returns a message to be sent back
/// to the user in HTML (plain text used as a fallback by message
/// formatter).
#[derive(Debug)]
pub struct Execution {
html: String,
}
impl Execution {
pub fn success(html: String) -> ExecutionResult {
Ok(Execution { html })
}
/// Response message in HTML.
pub fn html(&self) -> String {
self.html.clone()
}
}
/// Wraps a command execution failure. Provides HTML formatting for
/// any error message from the BotError type, similar to how Execution
/// provides formatting for successfully executed commands.
#[derive(Error, Debug)]
#[error("{0}")]
pub struct ExecutionError(#[from] pub BotError);
impl From<crate::db::errors::DataError> for ExecutionError {
fn from(error: crate::db::errors::DataError) -> Self {
Self(DataError(error))
}
}
impl From<crate::db::sqlite::errors::DataError> for ExecutionError {
fn from(error: crate::db::sqlite::errors::DataError) -> Self {
Self(SqliteDataError(error))
}
}
impl ExecutionError {
/// Error message in bolded HTML.
pub fn html(&self) -> String {
format!("<p><strong>{}</strong></p>", self.0)
}
}
/// Wraps either a successful command execution response, or an error
/// that occurred.
pub type ExecutionResult = Result<Execution, ExecutionError>;
/// Extract response messages out of a type, whether it is success or
/// failure.
pub trait ResponseExtractor {
/// HTML representation of the message, directly mentioning the
/// username.
fn message_html(&self, username: &str) -> String;
}
impl ResponseExtractor for ExecutionResult {
/// Error message in bolded HTML.
fn message_html(&self, username: &str) -> String {
// TODO use user display name too (element seems to render this
// without display name)
let username = format!(
"<a href=\"https://matrix.to/#/{}\">{}</a>",
username, username
);
match self {
Ok(resp) => format!("<p>{}</p><p>{}</p>", username, resp.html).replace("\n", "<br/>"),
Err(e) => format!("<p>{}</p><p>{}</p>", username, e.html()).replace("\n", "<br/>"),
}
}
}
/// The trait that any command that can be executed must implement.
#[async_trait]
pub trait Command: Send + Sync {
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult;
fn name(&self) -> &'static str;
}
/// Attempt to execute a command, and return the content that should
/// go back to Matrix, if the command was executed (successfully or
/// not). If a command is determined to be ignored, this function will
/// return None, signifying that we should not send a response.
pub async fn execute_command(ctx: &Context<'_>) -> ExecutionResult {
let cmd = parser::parse_command(&ctx.message_body)?;
cmd.execute(ctx).await
}
#[cfg(test)]
mod tests {
use super::*;
use url::Url;
macro_rules! dummy_room {
() => {
crate::context::RoomContext {
id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"),
display_name: "displayname",
}
};
}
#[test]
fn command_result_extractor_creates_bubble() {
let result = Execution::success("test".to_string());
let message = result.message_html("@myuser:example.com");
assert!(message.contains(
"<a href=\"https://matrix.to/#/@myuser:example.com\">@myuser:example.com</a>"
));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn unrecognized_command() {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
let db = crate::db::sqlite::Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
username: "myusername",
message_body: "!notacommand",
};
let result = execute_command(&ctx).await;
assert!(result.is_err());
}
}

View File

@ -1,24 +0,0 @@
use super::{Command, Execution, ExecutionResult};
use crate::basic::dice::ElementExpression;
use crate::basic::roll::Roll;
use crate::context::Context;
use async_trait::async_trait;
pub struct RollCommand(pub ElementExpression);
#[async_trait]
impl Command for RollCommand {
fn name(&self) -> &'static str {
"roll regular dice"
}
async fn execute(&self, _ctx: &Context<'_>) -> ExecutionResult {
let roll = self.0.roll();
let html = format!(
"<strong>Dice:</strong> {}</p><p><strong>Result</strong>: {}",
self.0, roll
);
Execution::success(html)
}
}

View File

@ -1,31 +0,0 @@
use super::{Command, Execution, ExecutionResult};
use crate::context::Context;
use crate::logic::record_room_information;
use async_trait::async_trait;
use matrix_sdk::identifiers::UserId;
pub struct ResyncCommand;
#[async_trait]
impl Command for ResyncCommand {
fn name(&self) -> &'static str {
"resync room information"
}
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let our_username: Option<UserId> = ctx.matrix_client.user_id().await;
let our_username: &str = our_username.as_ref().map_or("", UserId::as_str);
record_room_information(
ctx.matrix_client,
&ctx.db,
ctx.room_id(),
&ctx.room.display_name,
our_username,
)
.await?;
let message = "Room information resynced.".to_string();
Execution::success(message)
}
}

View File

@ -1,36 +0,0 @@
use crate::db::sqlite::Database;
use matrix_sdk::identifiers::RoomId;
use matrix_sdk::room::Joined;
use matrix_sdk::Client;
/// A context carried through the system providing access to things
/// like the database.
#[derive(Clone)]
pub struct Context<'a> {
pub db: Database,
pub matrix_client: &'a Client,
pub room: RoomContext<'a>,
pub username: &'a str,
pub message_body: &'a str,
}
impl Context<'_> {
pub fn room_id(&self) -> &RoomId {
self.room.id
}
}
#[derive(Clone)]
pub struct RoomContext<'a> {
pub id: &'a RoomId,
pub display_name: &'a str,
}
impl RoomContext<'_> {
pub fn new_with_name<'a>(room: &'a Joined, display_name: &'a str) -> RoomContext<'a> {
RoomContext {
id: room.room_id(),
display_name,
}
}
}

View File

@ -1,98 +0,0 @@
use crate::db::errors::{DataError, MigrationError};
use crate::db::migrations::{get_migration_version, Migrations};
use crate::db::rooms::Rooms;
use crate::db::state::DbState;
use crate::db::variables::Variables;
use log::info;
use sled::{Config, Db};
use std::path::Path;
pub mod data_migrations;
pub mod errors;
pub mod migrations;
pub mod rooms;
pub mod schema;
pub mod sqlite;
pub mod state;
pub mod variables;
#[derive(Clone)]
pub struct Database {
db: Db,
pub variables: Variables,
pub(crate) migrations: Migrations,
pub(crate) rooms: Rooms,
pub(crate) state: DbState,
}
impl Database {
fn new_db(db: sled::Db) -> Result<Database, DataError> {
let migrations = db.open_tree("migrations")?;
let database = Database {
db: db.clone(),
variables: Variables::new(&db)?,
migrations: Migrations(migrations),
rooms: Rooms::new(&db)?,
state: DbState::new(&db)?,
};
//Start any event handlers.
database.rooms.start_handler();
info!("Opened database.");
Ok(database)
}
pub fn new<P: AsRef<Path>>(path: P) -> Result<Database, DataError> {
let db = sled::open(path)?;
Self::new_db(db)
}
pub fn new_temp() -> Result<Database, DataError> {
let config = Config::new().temporary(true);
let db = config.open()?;
Self::new_db(db)
}
pub fn migrate(&self, to_version: u32) -> Result<(), DataError> {
//get version from db
let db_version = get_migration_version(&self)?;
if db_version < to_version {
info!(
"Migrating database from version {} to version {}",
db_version, to_version
);
//if db version < to_version, proceed
//produce range of db_version+1 .. to_version (inclusive)
let versions_to_run: Vec<u32> = ((db_version + 1)..=to_version).collect();
let migrations = data_migrations::get_migrations(&versions_to_run)?;
//execute each closure.
for (version, migration) in versions_to_run.iter().zip(migrations) {
let (migration_func, name) = migration;
//This needs to be transactional on migrations
//keyspace. abort on migration func error.
info!("Applying migration {} :: {}", version, name);
match migration_func(&self) {
Ok(_) => Ok(()),
Err(e) => Err(e),
}?;
self.migrations.set_migration_version(*version)?;
}
info!("Done applying migrations.");
Ok(())
} else if db_version > to_version {
//if db version > to_version, cannot downgrade error
Err(MigrationError::CannotDowngrade.into())
} else {
//if db version == to_version, do nothing
info!("No database migrations needed.");
Ok(())
}
}
}

View File

@ -1,28 +0,0 @@
use crate::db::errors::{DataError, MigrationError};
use crate::db::variables::migrations::*;
use crate::db::Database;
use phf::phf_map;
pub(super) type DataMigration = (fn(&Database) -> Result<(), DataError>, &'static str);
static MIGRATIONS: phf::Map<u32, DataMigration> = phf_map! {
1u32 => (add_room_user_variable_count::migrate, "add_room_user_variable_count"),
2u32 => (delete_v0_schema, "delete_v0_schema"),
3u32 => (delete_variable_count, "delete_variable_count"),
4u32 => (change_delineator_delimiter::migrate, "change_delineator_delimiter"),
5u32 => (change_tree_structure::migrate, "change_tree_structure"),
};
pub fn get_migrations(versions: &[u32]) -> Result<Vec<DataMigration>, MigrationError> {
let mut migrations: Vec<DataMigration> = vec![];
for version in versions {
match MIGRATIONS.get(version) {
Some(func) => migrations.push(*func),
None => return Err(MigrationError::MigrationNotFound(*version)),
}
}
Ok(migrations)
}

View File

@ -1,81 +0,0 @@
use sled::transaction::{TransactionError, UnabortableTransactionError};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum MigrationError {
#[error("cannot downgrade to an older database version")]
CannotDowngrade,
#[error("migration for version {0} not defined")]
MigrationNotFound(u32),
#[error("migration failed: {0}")]
MigrationFailed(String),
}
//TODO better combining of key and value in certain errors (namely
//I32SchemaViolation).
#[derive(Error, Debug)]
pub enum DataError {
#[error("value does not exist for key: {0}")]
KeyDoesNotExist(String),
#[error("too many entries")]
TooManyEntries,
#[error("expected i32, but i32 schema was violated")]
I32SchemaViolation,
#[error("parse error")]
ParseError(#[from] std::num::ParseIntError),
#[error("unexpected or corruptd data bytes")]
InvalidValue,
#[error("expected string ref, but utf8 schema was violated: {0}")]
Utf8RefSchemaViolation(#[from] std::str::Utf8Error),
#[error("expected string, but utf8 schema was violated: {0}")]
Utf8SchemaViolation(#[from] std::string::FromUtf8Error),
#[error("internal database error: {0}")]
InternalError(#[from] sled::Error),
#[error("transaction error: {0}")]
TransactionError(#[from] sled::transaction::TransactionError),
#[error("unabortable transaction error: {0}")]
UnabortableTransactionError(#[from] UnabortableTransactionError),
#[error("data migration error: {0}")]
MigrationError(#[from] MigrationError),
#[error("deserialization error: {0}")]
DeserializationError(#[from] bincode::Error),
}
/// This From implementation is necessary to deal with the recursive
/// error type in the error enum. We defined a transaction error, but
/// the only place we use it is when converting from
/// sled::transaction::TransactionError<DataError>. This converter
/// extracts the inner data error from transaction aborted errors, and
/// forwards anything else onward as-is, but wrapped in DataError.
impl From<TransactionError<DataError>> for DataError {
fn from(error: TransactionError<DataError>) -> Self {
match error {
TransactionError::Abort(data_err) => data_err,
TransactionError::Storage(storage_err) => {
DataError::TransactionError(TransactionError::Storage(storage_err))
}
}
}
}
/// Automatically aborts transactions that hit a DataError by using
/// the try (question mark) operator when this trait implementation is
/// in scope.
impl From<DataError> for sled::transaction::ConflictableTransactionError<DataError> {
fn from(error: DataError) -> Self {
sled::transaction::ConflictableTransactionError::Abort(error)
}
}

View File

@ -1,54 +0,0 @@
use crate::db::errors::DataError;
use crate::db::schema::convert_u32;
use crate::db::Database;
use byteorder::LittleEndian;
use sled::Tree;
use zerocopy::byteorder::U32;
use zerocopy::AsBytes;
//This file is for controlling the migration info stored in the
//database, not actually running migrations.
#[derive(Clone)]
pub struct Migrations(pub(super) Tree);
const COLON: &'static [u8] = b":";
const METADATA_SPACE: &'static str = "metadata";
const MIGRATION_KEY: &'static str = "migration_version";
fn to_key(keyspace: &str, key_name: &str) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(keyspace.as_bytes());
key.extend_from_slice(COLON);
key.extend_from_slice(key_name.as_bytes());
key
}
fn metadata_key(key_name: &str) -> Vec<u8> {
to_key(METADATA_SPACE, key_name)
}
impl Migrations {
pub(super) fn set_migration_version(&self, version: u32) -> Result<(), DataError> {
//Rust cannot type infer this transaction
let result: Result<_, sled::transaction::TransactionError<DataError>> =
self.0.transaction(|tx| {
let key = metadata_key(MIGRATION_KEY);
let db_value: U32<LittleEndian> = U32::new(version);
tx.insert(key, db_value.as_bytes())?;
Ok(())
});
result?;
Ok(())
}
}
pub(super) fn get_migration_version(db: &Database) -> Result<u32, DataError> {
let key = metadata_key(MIGRATION_KEY);
match db.migrations.0.get(key)? {
Some(bytes) => convert_u32(&bytes),
None => Ok(0),
}
}

View File

@ -1,515 +0,0 @@
use crate::db::errors::DataError;
use crate::db::schema::convert_u64;
use crate::models::RoomInfo;
use byteorder::BigEndian;
use log::{debug, error, log_enabled};
use sled::transaction::TransactionalTree;
use sled::Transactional;
use sled::{CompareAndSwapError, Tree};
use std::collections::HashSet;
use std::str;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::task::JoinHandle;
use zerocopy::byteorder::U64;
use zerocopy::AsBytes;
#[derive(Clone)]
pub struct Rooms {
/// Room ID -> RoomInfo struct (single entries).
/// Key is just room ID as bytes.
pub(in crate::db) roomid_roominfo: Tree,
/// Room ID -> list of usernames in room.
pub(in crate::db) roomid_usernames: Tree,
/// Username -> list of room IDs user is in.
pub(in crate::db) username_roomids: Tree,
/// Room ID(str) 0xff event ID(str) -> timestamp. Records event
/// IDs that we have received, so we do not process twice.
pub(in crate::db) roomeventid_timestamp: Tree,
/// Room ID(str) 0xff timestamp(u64) -> event ID. Records event
/// IDs with timestamp as the primary key instead. Exists to allow
/// easy scanning of old roomeventid_timestamp records for
/// removal. Be careful with u64, it can have 0xff and 0xfe bytes.
/// A simple split on 0xff will not work with this key. Instead,
/// it is meant to be split on the first 0xff byte only, and
/// separated into room ID and timestamp.
pub(in crate::db) roomtimestamp_eventid: Tree,
}
/// An enum that can hold either a regular sled Tree, or a
/// Transactional tree.
#[derive(Clone, Copy)]
enum TxableTree<'a> {
Tree(&'a Tree),
Tx(&'a TransactionalTree),
}
impl<'a> Into<TxableTree<'a>> for &'a Tree {
fn into(self) -> TxableTree<'a> {
TxableTree::Tree(self)
}
}
impl<'a> Into<TxableTree<'a>> for &'a TransactionalTree {
fn into(self) -> TxableTree<'a> {
TxableTree::Tx(self)
}
}
/// A set of functions that can be used with a sled::Tree that stores
/// HashSets as its values. Atomicity is partially handled. If the
/// Tree is a transactional tree, operations will be atomic.
/// Otherwise, there is a potential non-atomic step.
mod hashset_tree {
use super::*;
fn insert_set<'a, T: Into<TxableTree<'a>>>(
tree: T,
key: &[u8],
set: HashSet<String>,
) -> Result<(), DataError> {
let serialized = bincode::serialize(&set)?;
match tree.into() {
TxableTree::Tree(tree) => tree.insert(key, serialized)?,
TxableTree::Tx(tx) => tx.insert(key, serialized)?,
};
Ok(())
}
pub(super) fn get_set<'a, T: Into<TxableTree<'a>>>(
tree: T,
key: &[u8],
) -> Result<HashSet<String>, DataError> {
let set: HashSet<String> = match tree.into() {
TxableTree::Tree(tree) => tree.get(key)?,
TxableTree::Tx(tx) => tx.get(key)?,
}
.map(|bytes| bincode::deserialize::<HashSet<String>>(&bytes))
.unwrap_or(Ok(HashSet::new()))?;
Ok(set)
}
pub(super) fn remove_from_set<'a, T: Into<TxableTree<'a>> + Copy>(
tree: T,
key: &[u8],
value_to_remove: &str,
) -> Result<(), DataError> {
let mut set = get_set(tree, key)?;
set.remove(value_to_remove);
insert_set(tree, key, set)?;
Ok(())
}
pub(super) fn add_to_set<'a, T: Into<TxableTree<'a>> + Copy>(
tree: T,
key: &[u8],
value_to_add: String,
) -> Result<(), DataError> {
let mut set = get_set(tree, key)?;
set.insert(value_to_add);
insert_set(tree, key, set)?;
Ok(())
}
}
/// Functions that specifically relate to the "timestamp index" tree,
/// which is stored on the Rooms instance as a tree called
/// roomtimestamp_eventid. Tightly coupled to the event watcher in the
/// Rooms impl, and only factored out for unit testing.
mod timestamp_index {
use super::*;
/// Insert an entry from the main roomeventid_timestamp Tree into
/// the timestamp index. Keys in this Tree are stored as room ID
/// 0xff timestamp, with the value being a hashset of event IDs
/// received at the time. The parameters come from an insert to
/// that Tree, where the key is room ID 0xff event ID, and the
/// value is the timestamp.
pub(super) fn insert(
roomtimestamp_eventid: &Tree,
key: &[u8],
timestamp_bytes: &[u8],
) -> Result<(), DataError> {
let parts: Vec<&[u8]> = key.split(|&b| b == 0xff).collect();
if let [room_id, event_id] = parts[..] {
let mut ts_key = room_id.to_vec();
ts_key.push(0xff);
ts_key.extend_from_slice(&timestamp_bytes);
log_index_record(room_id, event_id, &timestamp_bytes);
let event_id = str::from_utf8(event_id)?;
hashset_tree::add_to_set(roomtimestamp_eventid, &ts_key, event_id.to_owned())?;
Ok(())
} else {
Err(DataError::InvalidValue)
}
}
/// Log a debug message.
fn log_index_record(room_id: &[u8], event_id: &[u8], timestamp: &[u8]) {
if log_enabled!(log::Level::Debug) {
debug!(
"Recording event {} | {} received at {} in timestamp index.",
str::from_utf8(room_id).unwrap_or("[invalid room id]"),
str::from_utf8(event_id).unwrap_or("[invalid event id]"),
convert_u64(timestamp).unwrap_or(0)
);
}
}
}
impl Rooms {
pub(in crate::db) fn new(db: &sled::Db) -> Result<Rooms, sled::Error> {
Ok(Rooms {
roomid_roominfo: db.open_tree("roomid_roominfo")?,
roomid_usernames: db.open_tree("roomid_usernames")?,
username_roomids: db.open_tree("username_roomids")?,
roomeventid_timestamp: db.open_tree("roomeventid_timestamp")?,
roomtimestamp_eventid: db.open_tree("roomtimestamp_eventid")?,
})
}
/// Start an event subscriber that listens for inserts made by the
/// `should_process` function. This event handler duplicates the
/// entry by timestamp instead of event ID.
pub(in crate::db) fn start_handler(&self) -> JoinHandle<()> {
//Clone due to lifetime requirements.
let roomeventid_timestamp = self.roomeventid_timestamp.clone();
let roomtimestamp_eventid = self.roomtimestamp_eventid.clone();
tokio::spawn(async move {
let mut subscriber = roomeventid_timestamp.watch_prefix(b"");
// TODO make this handler receive kill messages somehow so
// we can unit test it and gracefully shut it down.
while let Some(event) = (&mut subscriber).await {
if let sled::Event::Insert { key, value } = event {
match timestamp_index::insert(&roomtimestamp_eventid, &key, &value) {
Err(e) => {
error!("Unable to update the timestamp index: {}", e);
}
_ => (),
}
}
}
})
}
/// Determine if an event in a room should be processed. The event
/// is atomically recorded and true returned if the database has
/// not seen tis event yet. If the event already exists in the
/// database, the function returns false. Events are recorded by
/// this function by inserting the (system-local) timestamp in
/// epoch seconds.
pub fn should_process(&self, room_id: &str, event_id: &str) -> Result<bool, DataError> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(event_id.as_bytes());
let timestamp: U64<BigEndian> = U64::new(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Clock has gone backwards")
.as_secs(),
);
match self.roomeventid_timestamp.compare_and_swap(
key,
None as Option<&[u8]>,
Some(timestamp.as_bytes()),
)? {
Ok(()) => Ok(true),
Err(CompareAndSwapError { .. }) => Ok(false),
}
}
pub fn insert_room_info(&self, info: &RoomInfo) -> Result<(), DataError> {
let key = info.room_id.as_bytes();
let serialized = bincode::serialize(&info)?;
self.roomid_roominfo.insert(key, serialized)?;
Ok(())
}
pub fn get_room_info(&self, room_id: &str) -> Result<Option<RoomInfo>, DataError> {
let key = room_id.as_bytes();
let room_info: Option<RoomInfo> = self
.roomid_roominfo
.get(key)?
.map(|bytes| bincode::deserialize(&bytes))
.transpose()?;
Ok(room_info)
}
pub fn get_rooms_for_user(&self, username: &str) -> Result<HashSet<String>, DataError> {
hashset_tree::get_set(&self.username_roomids, username.as_bytes())
}
pub fn get_users_in_room(&self, room_id: &str) -> Result<HashSet<String>, DataError> {
hashset_tree::get_set(&self.roomid_usernames, room_id.as_bytes())
}
pub fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError> {
debug!("Adding user {} to room {}", username, room_id);
(&self.username_roomids, &self.roomid_usernames).transaction(
|(tx_username_rooms, tx_room_usernames)| {
let username_key = &username.as_bytes();
hashset_tree::add_to_set(tx_username_rooms, username_key, room_id.to_owned())?;
let roomid_key = &room_id.as_bytes();
hashset_tree::add_to_set(tx_room_usernames, roomid_key, username.to_owned())?;
Ok(())
},
)?;
Ok(())
}
pub fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError> {
debug!("Removing user {} from room {}", username, room_id);
(&self.username_roomids, &self.roomid_usernames).transaction(
|(tx_username_rooms, tx_room_usernames)| {
let username_key = &username.as_bytes();
hashset_tree::remove_from_set(tx_username_rooms, username_key, room_id)?;
let roomid_key = &room_id.as_bytes();
hashset_tree::remove_from_set(tx_room_usernames, roomid_key, username)?;
Ok(())
},
)?;
Ok(())
}
pub fn clear_info(&self, room_id: &str) -> Result<(), DataError> {
debug!("Clearing all information for room {}", room_id);
(&self.username_roomids, &self.roomid_usernames).transaction(
|(tx_username_roomids, tx_roomid_usernames)| {
let roomid_key = room_id.as_bytes();
let users_in_room = hashset_tree::get_set(tx_roomid_usernames, roomid_key)?;
//Remove the room ID from every user's room ID list.
for username in users_in_room {
hashset_tree::remove_from_set(
tx_username_roomids,
username.as_bytes(),
room_id,
)?;
}
//Remove this room entry for the room ID -> username tree.
tx_roomid_usernames.remove(roomid_key)?;
//TODO: delete roominfo struct from room info tree.
Ok(())
},
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use sled::Config;
fn create_test_instance() -> Rooms {
let config = Config::new().temporary(true);
let db = config.open().unwrap();
Rooms::new(&db).unwrap()
}
#[test]
fn add_user_to_room() {
let rooms = create_test_instance();
rooms
.add_user_to_room("testuser", "myroom")
.expect("Could not add user to room");
let users_in_room = rooms
.get_users_in_room("myroom")
.expect("Could not retrieve users in room");
let rooms_for_user = rooms
.get_rooms_for_user("testuser")
.expect("Could not get rooms for user");
let expected_users_in_room: HashSet<String> =
vec![String::from("testuser")].into_iter().collect();
let expected_rooms_for_user: HashSet<String> =
vec![String::from("myroom")].into_iter().collect();
assert_eq!(expected_users_in_room, users_in_room);
assert_eq!(expected_rooms_for_user, rooms_for_user);
}
#[test]
fn remove_user_from_room() {
let rooms = create_test_instance();
rooms
.add_user_to_room("testuser", "myroom")
.expect("Could not add user to room");
rooms
.remove_user_from_room("testuser", "myroom")
.expect("Could not remove user from room");
let users_in_room = rooms
.get_users_in_room("myroom")
.expect("Could not retrieve users in room");
let rooms_for_user = rooms
.get_rooms_for_user("testuser")
.expect("Could not get rooms for user");
assert_eq!(HashSet::new(), users_in_room);
assert_eq!(HashSet::new(), rooms_for_user);
}
#[test]
fn insert_room_info_works() {
let rooms = create_test_instance();
let info = RoomInfo {
room_id: matrix_sdk::identifiers::room_id!("!fakeroom:example.com")
.as_str()
.to_owned(),
room_name: "fake room name".to_owned(),
};
rooms
.insert_room_info(&info)
.expect("Could insert room info");
let found_info = rooms
.get_room_info("!fakeroom:example.com")
.expect("Error loading room info");
assert!(found_info.is_some());
assert_eq!(info, found_info.unwrap());
}
#[test]
fn insert_room_info_updates_data() {
let rooms = create_test_instance();
let mut info = RoomInfo {
room_id: matrix_sdk::identifiers::room_id!("!fakeroom:example.com")
.as_str()
.to_owned(),
room_name: "fake room name".to_owned(),
};
rooms
.insert_room_info(&info)
.expect("Could insert room info");
//Update info to have a new room name before inserting again.
info.room_name = "new fake room name".to_owned();
rooms
.insert_room_info(&info)
.expect("Could insert room info");
let found_info = rooms
.get_room_info("!fakeroom:example.com")
.expect("Error loading room info");
assert!(found_info.is_some());
assert_eq!(info, found_info.unwrap());
}
#[test]
fn get_room_info_none_when_room_does_not_exist() {
let rooms = create_test_instance();
let found_info = rooms
.get_room_info("!fakeroom:example.com")
.expect("Error loading room info");
assert!(found_info.is_none());
}
#[test]
fn clear_info_modifies_removes_requested_room() {
let rooms = create_test_instance();
rooms
.add_user_to_room("testuser", "myroom1")
.expect("Could not add user to room1");
rooms
.add_user_to_room("testuser", "myroom2")
.expect("Could not add user to room2");
rooms
.clear_info("myroom1")
.expect("Could not clear room info");
let users_in_room1 = rooms
.get_users_in_room("myroom1")
.expect("Could not retrieve users in room1");
let users_in_room2 = rooms
.get_users_in_room("myroom2")
.expect("Could not retrieve users in room2");
let rooms_for_user = rooms
.get_rooms_for_user("testuser")
.expect("Could not get rooms for user");
let expected_users_in_room2: HashSet<String> =
vec![String::from("testuser")].into_iter().collect();
let expected_rooms_for_user: HashSet<String> =
vec![String::from("myroom2")].into_iter().collect();
assert_eq!(HashSet::new(), users_in_room1);
assert_eq!(expected_users_in_room2, users_in_room2);
assert_eq!(expected_rooms_for_user, rooms_for_user);
}
#[test]
fn insert_to_timestamp_index() {
let rooms = create_test_instance();
// Insertion into timestamp index based on data that would go
// into main room x eventID -> timestamp tree.
let mut key = b"myroom".to_vec();
key.push(0xff);
key.extend_from_slice(b"myeventid");
let timestamp: U64<BigEndian> = U64::new(1000);
let result = timestamp_index::insert(
&rooms.roomtimestamp_eventid,
key.as_bytes(),
timestamp.as_bytes(),
);
assert!(result.is_ok());
// Retrieval of data from the timestamp index tree.
let mut ts_key = b"myroom".to_vec();
ts_key.push(0xff);
ts_key.extend_from_slice(timestamp.as_bytes());
let expected_events: HashSet<String> =
vec![String::from("myeventid")].into_iter().collect();
let event_ids = hashset_tree::get_set(&rooms.roomtimestamp_eventid, &ts_key)
.expect("Could not get set out of Tree");
assert_eq!(expected_events, event_ids);
}
}

View File

@ -1,52 +0,0 @@
use crate::db::errors::DataError;
use byteorder::{BigEndian, LittleEndian};
use zerocopy::byteorder::{I32, U32, U64};
use zerocopy::LayoutVerified;
/// User variables are stored as little-endian 32-bit integers in the
/// database. This type alias makes the database code more pleasant to
/// read.
type LittleEndianI32Layout<'a> = LayoutVerified<&'a [u8], I32<LittleEndian>>;
type LittleEndianU32Layout<'a> = LayoutVerified<&'a [u8], U32<LittleEndian>>;
#[allow(dead_code)]
type LittleEndianU64Layout<'a> = LayoutVerified<&'a [u8], U64<LittleEndian>>;
type BigEndianU64Layout<'a> = LayoutVerified<&'a [u8], U64<BigEndian>>;
/// Convert bytes to an i32 with zero-copy deserialization. An error
/// is returned if the bytes do not represent an i32.
pub(super) fn convert_i32(raw_value: &[u8]) -> Result<i32, DataError> {
let layout = LittleEndianI32Layout::new_unaligned(raw_value.as_ref());
if let Some(layout) = layout {
let value: I32<LittleEndian> = *layout;
Ok(value.get())
} else {
Err(DataError::I32SchemaViolation)
}
}
pub(super) fn convert_u32(raw_value: &[u8]) -> Result<u32, DataError> {
let layout = LittleEndianU32Layout::new_unaligned(raw_value.as_ref());
if let Some(layout) = layout {
let value: U32<LittleEndian> = *layout;
Ok(value.get())
} else {
Err(DataError::I32SchemaViolation)
}
}
#[allow(dead_code)]
pub(super) fn convert_u64(raw_value: &[u8]) -> Result<u64, DataError> {
let layout = BigEndianU64Layout::new_unaligned(raw_value.as_ref());
if let Some(layout) = layout {
let value: U64<BigEndian> = *layout;
Ok(value.get())
} else {
Err(DataError::I32SchemaViolation)
}
}

View File

@ -1,72 +0,0 @@
use std::num::TryFromIntError;
use sled::transaction::{TransactionError, UnabortableTransactionError};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum DataError {
#[error("value does not exist for key: {0}")]
KeyDoesNotExist(String),
#[error("too many entries")]
TooManyEntries,
#[error("expected i32, but i32 schema was violated")]
I32SchemaViolation,
#[error("unexpected or corruptd data bytes")]
InvalidValue,
#[error("expected string ref, but utf8 schema was violated: {0}")]
Utf8RefSchemaViolation(#[from] std::str::Utf8Error),
#[error("expected string, but utf8 schema was violated: {0}")]
Utf8SchemaViolation(#[from] std::string::FromUtf8Error),
#[error("internal database error: {0}")]
InternalError(#[from] sled::Error),
#[error("transaction error: {0}")]
TransactionError(#[from] sled::transaction::TransactionError),
#[error("unabortable transaction error: {0}")]
UnabortableTransactionError(#[from] UnabortableTransactionError),
#[error("data migration error: {0}")]
MigrationError(#[from] super::migrator::MigrationError),
#[error("deserialization error: {0}")]
DeserializationError(#[from] bincode::Error),
#[error("sqlx error: {0}")]
SqlxError(#[from] sqlx::Error),
#[error("numeric conversion error")]
NumericConversionError(#[from] TryFromIntError),
}
/// This From implementation is necessary to deal with the recursive
/// error type in the error enum. We defined a transaction error, but
/// the only place we use it is when converting from
/// sled::transaction::TransactionError<DataError>. This converter
/// extracts the inner data error from transaction aborted errors, and
/// forwards anything else onward as-is, but wrapped in DataError.
impl From<TransactionError<DataError>> for DataError {
fn from(error: TransactionError<DataError>) -> Self {
match error {
TransactionError::Abort(data_err) => data_err,
TransactionError::Storage(storage_err) => {
DataError::TransactionError(TransactionError::Storage(storage_err))
}
}
}
}
/// Automatically aborts transactions that hit a DataError by using
/// the try (question mark) operator when this trait implementation is
/// in scope.
impl From<DataError> for sled::transaction::ConflictableTransactionError<DataError> {
fn from(error: DataError) -> Self {
sled::transaction::ConflictableTransactionError::Abort(error)
}
}

View File

@ -1,2 +0,0 @@
use refinery::include_migration_mods;
include_migration_mods!("src/db/sqlite/migrator/migrations");

View File

@ -1,116 +0,0 @@
use async_trait::async_trait;
use errors::DataError;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use sqlx::ConnectOptions;
use std::clone::Clone;
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use crate::models::RoomInfo;
pub mod errors;
pub mod migrator;
pub mod rooms;
pub mod state;
pub mod variables;
#[async_trait]
pub(crate) trait DbState {
async fn get_device_id(&self) -> Result<Option<String>, DataError>;
async fn set_device_id(&self, device_id: &str) -> Result<(), DataError>;
}
#[async_trait]
pub(crate) trait Rooms {
async fn should_process(&self, room_id: &str, event_id: &str) -> Result<bool, DataError>;
async fn insert_room_info(&self, info: &RoomInfo) -> Result<(), DataError>;
async fn get_room_info(&self, room_id: &str) -> Result<Option<RoomInfo>, DataError>;
async fn get_rooms_for_user(&self, user_id: &str) -> Result<HashSet<String>, DataError>;
async fn get_users_in_room(&self, room_id: &str) -> Result<HashSet<String>, DataError>;
async fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError>;
async fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError>;
async fn clear_info(&self, room_id: &str) -> Result<(), DataError>;
}
// TODO move this up to the top once we delete sled. Traits will be the
// main API, then we can have different impls for different DBs.
#[async_trait]
pub trait Variables {
async fn get_user_variables(
&self,
user: &str,
room_id: &str,
) -> Result<HashMap<String, i32>, DataError>;
async fn get_variable_count(&self, user: &str, room_id: &str) -> Result<i32, DataError>;
async fn get_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
) -> Result<i32, DataError>;
async fn set_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
value: i32,
) -> Result<(), DataError>;
async fn delete_user_variable(
&self,
user: &str,
room_id: &str,
variable_name: &str,
) -> Result<(), DataError>;
}
pub struct Database {
conn: SqlitePool,
}
impl Database {
fn new_db(conn: SqlitePool) -> Result<Database, DataError> {
let database = Database { conn: conn.clone() };
Ok(database)
}
pub async fn new(path: &str) -> Result<Database, DataError> {
//Create database if missing.
let conn = SqliteConnectOptions::from_str(path)?
.create_if_missing(true)
.connect()
.await?;
drop(conn);
//Migrate database.
migrator::migrate(&path).await?;
//Return actual conncetion pool.
let conn = SqlitePoolOptions::new()
.max_connections(5)
.connect(path)
.await?;
Self::new_db(conn)
}
}
impl Clone for Database {
fn clone(&self) -> Self {
Database {
conn: self.conn.clone(),
}
}
}

View File

@ -1,379 +0,0 @@
use super::errors::DataError;
use super::{Database, Rooms};
use crate::models::RoomInfo;
use async_trait::async_trait;
use sqlx::SqlitePool;
use std::collections::{HashMap, HashSet};
use std::time::{SystemTime, UNIX_EPOCH};
async fn record_event(conn: &SqlitePool, room_id: &str, event_id: &str) -> Result<(), DataError> {
use std::convert::TryFrom;
let now: i64 = i64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Clock has gone backwards")
.as_secs(),
)?;
sqlx::query(
r#"INSERT INTO room_events
(room_id, event_id, event_timestamp)
VALUES (?, ?, ?)"#,
)
.bind(room_id)
.bind(event_id)
.bind(now)
.execute(conn)
.await?;
Ok(())
}
#[async_trait]
impl Rooms for Database {
async fn should_process(&self, room_id: &str, event_id: &str) -> Result<bool, DataError> {
let row = sqlx::query!(
r#"SELECT event_id FROM room_events
WHERE room_id = ? AND event_id = ?"#,
room_id,
event_id
)
.fetch_optional(&self.conn)
.await?;
match row {
Some(_) => Ok(false),
None => {
record_event(&self.conn, room_id, event_id).await?;
Ok(true)
}
}
}
async fn insert_room_info(&self, info: &RoomInfo) -> Result<(), DataError> {
sqlx::query(
r#"INSERT INTO room_info (room_id, room_name) VALUES (?, ?)
ON CONFLICT(room_id) DO UPDATE SET room_name = ?"#,
)
.bind(&info.room_id)
.bind(&info.room_name)
.bind(&info.room_name)
.execute(&self.conn)
.await?;
Ok(())
}
async fn get_room_info(&self, room_id: &str) -> Result<Option<RoomInfo>, DataError> {
let info = sqlx::query!(
r#"SELECT room_id, room_name FROM room_info
WHERE room_id = ?"#,
room_id
)
.fetch_optional(&self.conn)
.await?;
Ok(info.map(|i| RoomInfo {
room_id: i.room_id,
room_name: i.room_name,
}))
}
async fn get_rooms_for_user(&self, user_id: &str) -> Result<HashSet<String>, DataError> {
let room_ids = sqlx::query!(
r#"SELECT room_id FROM room_users
WHERE username = ?"#,
user_id
)
.fetch_all(&self.conn)
.await?;
Ok(room_ids.into_iter().map(|row| row.room_id).collect())
}
async fn get_users_in_room(&self, room_id: &str) -> Result<HashSet<String>, DataError> {
let usernames = sqlx::query!(
r#"SELECT username FROM room_users
WHERE room_id = ?"#,
room_id
)
.fetch_all(&self.conn)
.await?;
Ok(usernames.into_iter().map(|row| row.username).collect())
}
async fn add_user_to_room(&self, username: &str, room_id: &str) -> Result<(), DataError> {
sqlx::query(
"INSERT INTO room_users (room_id, username) VALUES (?, ?)
ON CONFLICT DO NOTHING",
)
.bind(room_id)
.bind(username)
.execute(&self.conn)
.await?;
Ok(())
}
async fn remove_user_from_room(&self, username: &str, room_id: &str) -> Result<(), DataError> {
sqlx::query("DELETE FROM room_users where username = ? AND room_id = ?")
.bind(username)
.bind(room_id)
.execute(&self.conn)
.await?;
Ok(())
}
async fn clear_info(&self, room_id: &str) -> Result<(), DataError> {
// We do not clear event history here, because if we rejoin a
// room, we would re-process events we've already seen.
let mut tx = self.conn.begin().await?;
sqlx::query("DELETE FROM room_info where room_id = ?")
.bind(room_id)
.execute(&mut tx)
.await?;
sqlx::query("DELETE FROM room_users where room_id = ?")
.bind(room_id)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::super::Rooms;
use super::*;
async fn create_db() -> Database {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
Database::new(db_path.path().to_str().unwrap())
.await
.unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn should_process_test() {
let db = create_db().await;
let first_check = db
.should_process("myroom", "myeventid")
.await
.expect("should_process failed in first insert");
assert_eq!(first_check, true);
let second_check = db
.should_process("myroom", "myeventid")
.await
.expect("should_process failed in first insert");
assert_eq!(second_check, false);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn insert_and_get_room_info_test() {
let db = create_db().await;
let info = RoomInfo {
room_id: "myroomid".to_string(),
room_name: "myroomname".to_string(),
};
db.insert_room_info(&info)
.await
.expect("Could not insert room info.");
let retrieved_info = db
.get_room_info("myroomid")
.await
.expect("Could not retrieve room info.");
assert!(retrieved_info.is_some());
assert_eq!(info, retrieved_info.unwrap());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn insert_room_info_updates_existing() {
let db = create_db().await;
let info1 = RoomInfo {
room_id: "myroomid".to_string(),
room_name: "myroomname".to_string(),
};
db.insert_room_info(&info1)
.await
.expect("Could not insert room info1.");
let info2 = RoomInfo {
room_id: "myroomid".to_string(),
room_name: "myroomname2".to_string(),
};
db.insert_room_info(&info2)
.await
.expect("Could not update room info after first insert");
let retrieved_info = db
.get_room_info("myroomid")
.await
.expect("Could not get room info");
assert!(retrieved_info.is_some());
let retrieved_info = retrieved_info.unwrap();
assert_eq!(retrieved_info.room_id, "myroomid");
assert_eq!(retrieved_info.room_name, "myroomname2");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn add_user_to_room_test() {
let db = create_db().await;
db.add_user_to_room("myuser", "myroom")
.await
.expect("Could not add user to room.");
let users_in_room = db
.get_users_in_room("myroom")
.await
.expect("Could not get users in room.");
assert_eq!(users_in_room.len(), 1);
assert!(users_in_room.contains("myuser"));
let rooms_for_user = db
.get_rooms_for_user("myuser")
.await
.expect("Could not get rooms for user");
assert_eq!(rooms_for_user.len(), 1);
assert!(rooms_for_user.contains("myroom"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn add_user_to_room_does_not_have_constraint_violation() {
let db = create_db().await;
db.add_user_to_room("myuser", "myroom")
.await
.expect("Could not add user to room.");
let second_attempt = db.add_user_to_room("myuser", "myroom").await;
assert!(second_attempt.is_ok());
let users_in_room = db
.get_users_in_room("myroom")
.await
.expect("Could not get users in room.");
assert_eq!(users_in_room.len(), 1);
assert!(users_in_room.contains("myuser"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn remove_user_from_room_test() {
let db = create_db().await;
db.add_user_to_room("myuser", "myroom")
.await
.expect("Could not add user to room.");
let remove_attempt = db.remove_user_from_room("myuser", "myroom").await;
assert!(remove_attempt.is_ok());
let users_in_room = db
.get_users_in_room("myroom")
.await
.expect("Could not get users in room.");
assert_eq!(users_in_room.len(), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn clear_info_does_not_delete_other_rooms() {
let db = create_db().await;
let info1 = RoomInfo {
room_id: "myroomid".to_string(),
room_name: "myroomname".to_string(),
};
let info2 = RoomInfo {
room_id: "myroomid2".to_string(),
room_name: "myroomname2".to_string(),
};
db.insert_room_info(&info1)
.await
.expect("Could not insert room info1.");
db.insert_room_info(&info2)
.await
.expect("Could not insert room info2.");
db.add_user_to_room("myuser", &info1.room_id)
.await
.expect("Could not add user to room.");
db.clear_info(&info1.room_id)
.await
.expect("Could not clear room info1");
let room_info2 = db
.get_room_info(&info2.room_id)
.await
.expect("Could not get room info2.");
assert!(room_info2.is_some());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn clear_info_test() {
let db = create_db().await;
let info = RoomInfo {
room_id: "myroomid".to_string(),
room_name: "myroomname".to_string(),
};
db.insert_room_info(&info)
.await
.expect("Could not insert room info.");
db.add_user_to_room("myuser", &info.room_id)
.await
.expect("Could not add user to room.");
db.clear_info(&info.room_id)
.await
.expect("Could not clear room info");
let users_in_room = db
.get_users_in_room(&info.room_id)
.await
.expect("Could not get users in room.");
assert_eq!(users_in_room.len(), 0);
let room_info = db
.get_room_info(&info.room_id)
.await
.expect("Could not get room info.");
assert!(room_info.is_none());
}
}

View File

@ -1,88 +0,0 @@
use crate::db::errors::DataError;
use sled::Tree;
#[derive(Clone)]
pub struct DbState {
/// Tree of simple key-values for global state values that persist
/// between restarts (e.g. device ID).
pub(in crate::db) global_metadata: Tree,
}
const DEVICE_ID_KEY: &'static [u8] = b"device_id";
impl DbState {
pub(in crate::db) fn new(db: &sled::Db) -> Result<DbState, sled::Error> {
Ok(DbState {
global_metadata: db.open_tree("global_metadata")?,
})
}
pub fn get_device_id(&self) -> Result<Option<String>, DataError> {
self.global_metadata
.get(DEVICE_ID_KEY)?
.map(|v| String::from_utf8(v.to_vec()))
.transpose()
.map_err(|e| e.into())
}
pub fn set_device_id(&self, device_id: &str) -> Result<(), DataError> {
self.global_metadata
.insert(DEVICE_ID_KEY, device_id.as_bytes())?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use sled::Config;
fn create_test_instance() -> DbState {
let config = Config::new().temporary(true);
let db = config.open().unwrap();
DbState::new(&db).unwrap()
}
#[test]
fn set_device_id_works() {
let state = create_test_instance();
let result = state.set_device_id("test-device");
assert!(result.is_ok());
}
#[test]
fn set_device_id_can_overwrite() {
let state = create_test_instance();
state.set_device_id("test-device").expect("insert 1 failed");
let result = state.set_device_id("test-device2");
assert!(result.is_ok());
}
#[test]
fn get_device_id_returns_some_when_set() {
let state = create_test_instance();
state
.set_device_id("test-device")
.expect("could not store device id properly");
let device_id = state.get_device_id();
assert!(device_id.is_ok());
let device_id = device_id.unwrap();
assert!(device_id.is_some());
assert_eq!("test-device", device_id.unwrap());
}
#[test]
fn get_device_id_returns_none_when_unset() {
let state = create_test_instance();
let device_id = state.get_device_id();
assert!(device_id.is_ok());
let device_id = device_id.unwrap();
assert!(device_id.is_none());
}
}

View File

@ -1,410 +0,0 @@
use crate::db::errors::DataError;
use crate::db::schema::convert_i32;
use byteorder::LittleEndian;
use sled::transaction::{abort, TransactionalTree};
use sled::Transactional;
use sled::Tree;
use std::collections::HashMap;
use std::convert::From;
use std::str;
use zerocopy::byteorder::I32;
use zerocopy::AsBytes;
use super::errors;
pub(super) mod migrations;
#[derive(Clone)]
pub struct Variables {
//room id + username + variable = i32
pub(in crate::db) room_user_variables: Tree,
//room id + username = i32
pub(in crate::db) room_user_variable_count: Tree,
}
/// Request something by a username and room ID.
pub struct UserAndRoom<'a>(pub &'a str, pub &'a str);
fn to_vec(value: &UserAndRoom<'_>) -> Vec<u8> {
let mut bytes = vec![];
bytes.extend_from_slice(value.0.as_bytes());
bytes.push(0xfe);
bytes.extend_from_slice(value.1.as_bytes());
bytes
}
impl From<UserAndRoom<'_>> for Vec<u8> {
fn from(value: UserAndRoom) -> Vec<u8> {
to_vec(&value)
}
}
impl From<&UserAndRoom<'_>> for Vec<u8> {
fn from(value: &UserAndRoom) -> Vec<u8> {
to_vec(value)
}
}
/// Use a transaction to atomically alter the count of variables in
/// the database by the given amount. Count cannot go below 0.
fn alter_room_variable_count(
room_variable_count: &TransactionalTree,
user_and_room: &UserAndRoom<'_>,
amount: i32,
) -> Result<i32, DataError> {
let key: Vec<u8> = user_and_room.into();
let mut new_count = match room_variable_count.get(&key)? {
Some(bytes) => convert_i32(&bytes)? + amount,
None => amount,
};
if new_count < 0 {
new_count = 0;
}
let db_value: I32<LittleEndian> = I32::new(new_count);
room_variable_count.insert(key, db_value.as_bytes())?;
Ok(new_count)
}
/// Room ID, Username, Variable Name
pub type AllVariablesKey = (String, String, String);
impl Variables {
pub(in crate::db) fn new(db: &sled::Db) -> Result<Variables, sled::Error> {
Ok(Variables {
room_user_variables: db.open_tree("variables")?,
room_user_variable_count: db.open_tree("room_user_variable_count")?,
})
}
pub fn get_all_variables(&self) -> Result<HashMap<AllVariablesKey, i32>, DataError> {
use std::convert::TryFrom;
let variables: Result<Vec<(AllVariablesKey, i32)>, DataError> = self
.room_user_variables
.scan_prefix("")
.map(|entry| match entry {
Ok((key, raw_value)) => {
let keys: Vec<_> = key
.split(|&b| b == 0xfe || b == 0xff)
.map(|b| str::from_utf8(b))
.collect();
if let &[Ok(room_id), Ok(username), Ok(variable_name), ..] = keys.as_slice() {
Ok((
(
room_id.to_owned(),
username.to_owned(),
variable_name.to_owned(),
),
convert_i32(&raw_value)?,
))
} else {
Err(errors::DataError::InvalidValue)
}
}
Err(e) => Err(e.into()),
})
.collect();
// Convert tuples to hash map with collect(), inferred via
// return type.
variables.map(|entries| entries.into_iter().collect())
}
pub fn get_user_variables(
&self,
key: &UserAndRoom<'_>,
) -> Result<HashMap<String, i32>, DataError> {
let mut prefix: Vec<u8> = key.into();
prefix.push(0xff);
let prefix_len: usize = prefix.len();
let variables: Result<Vec<(String, i32)>, DataError> = self
.room_user_variables
.scan_prefix(prefix)
.map(|entry| match entry {
Ok((key, raw_value)) => {
//Strips room and username from key, leaving behind name.
let variable_name = str::from_utf8(&key[prefix_len..])?;
Ok((variable_name.to_owned(), convert_i32(&raw_value)?))
}
Err(e) => Err(e.into()),
})
.collect();
// Convert tuples to hash map with collect(), inferred via
// return type.
variables.map(|entries| entries.into_iter().collect())
}
pub fn get_variable_count(&self, user_and_room: &UserAndRoom<'_>) -> Result<i32, DataError> {
let key: Vec<u8> = user_and_room.into();
match self.room_user_variable_count.get(&key)? {
Some(raw_value) => convert_i32(&raw_value),
None => Ok(0),
}
}
pub fn get_user_variable(
&self,
user_and_room: &UserAndRoom<'_>,
variable_name: &str,
) -> Result<i32, DataError> {
let mut key: Vec<u8> = user_and_room.into();
key.push(0xff);
key.extend_from_slice(variable_name.as_bytes());
match self.room_user_variables.get(&key)? {
Some(raw_value) => convert_i32(&raw_value),
_ => Err(DataError::KeyDoesNotExist(variable_name.to_owned())),
}
}
pub fn set_user_variable(
&self,
user_and_room: &UserAndRoom<'_>,
variable_name: &str,
value: i32,
) -> Result<(), DataError> {
if self.get_variable_count(user_and_room)? >= 100 {
return Err(DataError::TooManyEntries);
}
(&self.room_user_variables, &self.room_user_variable_count).transaction(
|(tx_vars, tx_counts)| {
let mut key: Vec<u8> = user_and_room.into();
key.push(0xff);
key.extend_from_slice(variable_name.as_bytes());
let db_value: I32<LittleEndian> = I32::new(value);
let old_value = tx_vars.insert(key, db_value.as_bytes())?;
//Only increment variable count on new keys.
if let None = old_value {
if let Err(e) = alter_room_variable_count(&tx_counts, &user_and_room, 1) {
return abort(e);
}
}
Ok(())
},
)?;
Ok(())
}
pub fn delete_user_variable(
&self,
user_and_room: &UserAndRoom<'_>,
variable_name: &str,
) -> Result<(), DataError> {
(&self.room_user_variables, &self.room_user_variable_count).transaction(
|(tx_vars, tx_counts)| {
let mut key: Vec<u8> = user_and_room.into();
key.push(0xff);
key.extend_from_slice(variable_name.as_bytes());
if let Some(_) = tx_vars.remove(key)? {
if let Err(e) = alter_room_variable_count(&tx_counts, user_and_room, -1) {
return abort(e);
}
} else {
return abort(DataError::KeyDoesNotExist(variable_name.to_owned()));
}
Ok(())
},
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use sled::Config;
fn create_test_instance() -> Variables {
let config = Config::new().temporary(true);
let db = config.open().unwrap();
Variables::new(&db).unwrap()
}
//Room Variable count tests
#[test]
fn alter_room_variable_count_test() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
let alter_count = |amount: i32| {
variables
.room_user_variable_count
.transaction(|tx| match alter_room_variable_count(&tx, &key, amount) {
Err(e) => abort(e),
_ => Ok(()),
})
.expect("got transaction failure");
};
let get_count = |variables: &Variables| -> i32 {
variables
.get_variable_count(&key)
.expect("could not get variable count")
};
//addition
alter_count(5);
assert_eq!(5, get_count(&variables));
//subtraction
alter_count(-3);
assert_eq!(2, get_count(&variables));
}
#[test]
fn alter_room_variable_count_cannot_go_below_0_test() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
variables
.room_user_variable_count
.transaction(|tx| match alter_room_variable_count(&tx, &key, -1000) {
Err(e) => abort(e),
_ => Ok(()),
})
.expect("got transaction failure");
let count = variables
.get_variable_count(&key)
.expect("could not get variable count");
assert_eq!(0, count);
}
#[test]
fn empty_db_reports_0_room_variable_count_test() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
let count = variables
.get_variable_count(&key)
.expect("could not get variable count");
assert_eq!(0, count);
}
#[test]
fn set_user_variable_increments_count() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
variables
.set_user_variable(&key, "myvariable", 5)
.expect("could not insert variable");
let count = variables
.get_variable_count(&key)
.expect("could not get variable count");
assert_eq!(1, count);
}
#[test]
fn update_user_variable_does_not_increment_count() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
variables
.set_user_variable(&key, "myvariable", 5)
.expect("could not insert variable");
variables
.set_user_variable(&key, "myvariable", 10)
.expect("could not update variable");
let count = variables
.get_variable_count(&key)
.expect("could not get variable count");
assert_eq!(1, count);
}
// Set/get/delete variable tests
#[test]
fn set_and_get_variable_test() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
variables
.set_user_variable(&key, "myvariable", 5)
.expect("could not insert variable");
let value = variables
.get_user_variable(&key, "myvariable")
.expect("could not get value");
assert_eq!(5, value);
}
#[test]
fn cannot_set_more_than_100_variables_per_room() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
for c in 0..100 {
variables
.set_user_variable(&key, &format!("myvariable{}", c), 5)
.expect("could not insert variable");
}
let result = variables.set_user_variable(&key, "myvariable101", 5);
assert!(result.is_err());
assert!(matches!(result, Err(DataError::TooManyEntries)));
}
#[test]
fn delete_variable_test() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
variables
.set_user_variable(&key, "myvariable", 5)
.expect("could not insert variable");
variables
.delete_user_variable(&key, "myvariable")
.expect("could not delete value");
let result = variables.get_user_variable(&key, "myvariable");
assert!(result.is_err());
assert!(matches!(result, Err(DataError::KeyDoesNotExist(_))));
}
#[test]
fn get_missing_variable_returns_key_does_not_exist() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
let result = variables.get_user_variable(&key, "myvariable");
assert!(result.is_err());
assert!(matches!(result, Err(DataError::KeyDoesNotExist(_))));
}
#[test]
fn remove_missing_variable_returns_key_does_not_exist() {
let variables = create_test_instance();
let key = UserAndRoom("username", "room");
let result = variables.delete_user_variable(&key, "myvariable");
assert!(result.is_err());
assert!(matches!(result, Err(DataError::KeyDoesNotExist(_))));
}
}

View File

@ -1,354 +0,0 @@
use super::*;
use crate::db::errors::{DataError, MigrationError};
use crate::db::Database;
use byteorder::LittleEndian;
use memmem::{Searcher, TwoWaySearcher};
use sled::transaction::TransactionError;
use sled::{Batch, IVec};
use std::collections::HashMap;
use zerocopy::byteorder::{I32, U32};
use zerocopy::AsBytes;
pub(in crate::db) mod add_room_user_variable_count {
use super::*;
//Not to be confused with the super::RoomAndUser delineator.
#[derive(PartialEq, Eq, std::hash::Hash)]
struct RoomAndUser {
room_id: String,
username: String,
}
/// Create a version 0 user variable key.
fn v0_variable_key(info: &RoomAndUser, variable_name: &str) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(info.room_id.as_bytes());
key.extend_from_slice(info.username.as_bytes());
key.extend_from_slice(variable_name.as_bytes());
key
}
fn map_value_to_room_and_user(
entry: sled::Result<(IVec, IVec)>,
) -> Result<RoomAndUser, MigrationError> {
if let Ok((key, _)) = entry {
let keys: Vec<Result<&str, _>> = key
.split(|&b| b == 0xff)
.map(|b| str::from_utf8(b))
.collect();
if let &[_, Ok(room_id), Ok(username), Ok(_variable)] = keys.as_slice() {
Ok(RoomAndUser {
room_id: room_id.to_owned(),
username: username.to_owned(),
})
} else {
Err(MigrationError::MigrationFailed(
"a key violates utf8 schema".to_string(),
))
}
} else {
Err(MigrationError::MigrationFailed(
"encountered unexpected key".to_string(),
))
}
}
fn create_key(room_id: &str, username: &str) -> Vec<u8> {
let mut key = b"variables".to_vec();
key.push(0xff);
key.extend_from_slice(room_id.as_bytes());
key.push(0xff);
key.extend_from_slice(username.as_bytes());
key.push(0xff);
key.extend_from_slice(b"variable_count");
key
}
pub(in crate::db) fn migrate(db: &Database) -> Result<(), DataError> {
let tree = &db.variables.room_user_variables;
let prefix = b"variables";
//Extract a vec of tuples, consisting of room id + username.
let results: Vec<RoomAndUser> = tree
.scan_prefix(prefix)
.map(map_value_to_room_and_user)
.collect::<Result<Vec<_>, MigrationError>>()?;
let counts: HashMap<RoomAndUser, u32> =
results
.into_iter()
.fold(HashMap::new(), |mut count_map, room_and_user| {
let count = count_map.entry(room_and_user).or_insert(0);
*count += 1;
count_map
});
//Start a transaction on the variables tree.
let tx_result: Result<_, TransactionError<DataError>> =
db.variables.room_user_variables.transaction(|tx_vars| {
let batch = counts.iter().fold(Batch::default(), |mut batch, entry| {
let (info, count) = entry;
//Add variable count according to new schema.
let key = create_key(&info.room_id, &info.username);
let db_value: U32<LittleEndian> = U32::new(*count);
batch.insert(key, db_value.as_bytes());
//Delete the old variable_count variable if exists.
let old_key = v0_variable_key(&info, "variable_count");
batch.remove(old_key);
batch
});
tx_vars.apply_batch(&batch)?;
Ok(())
});
tx_result?; //For some reason, it cannot infer the type
Ok(())
}
}
pub(in crate::db) fn delete_v0_schema(db: &Database) -> Result<(), DataError> {
let mut vars = db.variables.room_user_variables.scan_prefix("");
let mut batch = Batch::default();
while let Some(Ok((key, _))) = vars.next() {
let key = key.to_vec();
if !key.contains(&0xff) {
batch.remove(key);
}
}
db.variables.room_user_variables.apply_batch(batch)?;
Ok(())
}
pub(in crate::db) fn delete_variable_count(db: &Database) -> Result<(), DataError> {
let prefix = b"variables";
let mut vars = db.variables.room_user_variables.scan_prefix(prefix);
let mut batch = Batch::default();
while let Some(Ok((key, _))) = vars.next() {
let search = TwoWaySearcher::new(b"variable_count");
let ends_with = {
match search.search_in(&key) {
Some(index) => key.len() - index == b"variable_count".len(),
None => false,
}
};
if ends_with {
batch.remove(key);
}
}
db.variables.room_user_variables.apply_batch(batch)?;
Ok(())
}
pub(in crate::db) mod change_delineator_delimiter {
use super::*;
/// An entry in the room user variables keyspace.
struct UserVariableEntry {
room_id: String,
username: String,
variable_name: String,
value: IVec,
}
/// Extract keys and values from the variables keyspace according
/// to the v1 schema.
fn extract_v1_entries(
entry: sled::Result<(IVec, IVec)>,
) -> Result<UserVariableEntry, MigrationError> {
if let Ok((key, value)) = entry {
let keys: Vec<Result<&str, _>> = key
.split(|&b| b == 0xff)
.map(|b| str::from_utf8(b))
.collect();
if let &[_, Ok(room_id), Ok(username), Ok(variable)] = keys.as_slice() {
Ok(UserVariableEntry {
room_id: room_id.to_owned(),
username: username.to_owned(),
variable_name: variable.to_owned(),
value: value,
})
} else {
Err(MigrationError::MigrationFailed(
"a key violates utf8 schema".to_string(),
))
}
} else {
Err(MigrationError::MigrationFailed(
"encountered unexpected key".to_string(),
))
}
}
/// Create an old key, where delineator is separated by 0xff.
fn create_old_key(prefix: &[u8], insert: &UserVariableEntry) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(&prefix); //prefix already has 0xff.
key.extend_from_slice(&insert.room_id.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.username.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.variable_name.as_bytes());
key
}
/// Create an old key, where delineator is separated by 0xfe.
fn create_new_key(prefix: &[u8], insert: &UserVariableEntry) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(&prefix); //prefix already has 0xff.
key.extend_from_slice(&insert.room_id.as_bytes());
key.push(0xfe);
key.extend_from_slice(&insert.username.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.variable_name.as_bytes());
key
}
pub fn migrate(db: &Database) -> Result<(), DataError> {
let tree = &db.variables.room_user_variables;
let prefix = b"variables";
let results: Vec<UserVariableEntry> = tree
.scan_prefix(&prefix)
.map(extract_v1_entries)
.collect::<Result<Vec<_>, MigrationError>>()?;
let mut batch = Batch::default();
for insert in results {
let old = create_old_key(prefix, &insert);
let new = create_new_key(prefix, &insert);
batch.remove(old);
batch.insert(new, insert.value);
}
tree.apply_batch(batch)?;
Ok(())
}
}
/// Move the user variable entries into two tree structures, with yet
/// another key format change. Now there is one tree for variable
/// counts, and one tree for actual user variables. Keys in the user
/// variable tree were changed to be username-first, then room ID.
/// They are still separated by 0xfe, while the variable name is
/// separated by 0xff. Variable count now stores just
/// USERNAME0xfeROOM_ID and a count in its own tree. This enables
/// public use of a strongly typed UserAndRoom struct for getting
/// variables.
pub(in crate::db) mod change_tree_structure {
use super::*;
/// An entry in the room user variables keyspace.
struct UserVariableEntry {
room_id: String,
username: String,
variable_name: String,
value: IVec,
}
/// Extract keys and values from the variables keyspace according
/// to the v1 schema.
fn extract_v1_entries(
entry: sled::Result<(IVec, IVec)>,
) -> Result<UserVariableEntry, MigrationError> {
if let Ok((key, value)) = entry {
let keys: Vec<Result<&str, _>> = key
.split(|&b| b == 0xff || b == 0xfe)
.map(|b| str::from_utf8(b))
.collect();
if let &[_, Ok(room_id), Ok(username), Ok(variable)] = keys.as_slice() {
Ok(UserVariableEntry {
room_id: room_id.to_owned(),
username: username.to_owned(),
variable_name: variable.to_owned(),
value: value,
})
} else {
Err(MigrationError::MigrationFailed(
"a key violates utf8 schema".to_string(),
))
}
} else {
Err(MigrationError::MigrationFailed(
"encountered unexpected key".to_string(),
))
}
}
/// Create an old key, of "variables" 0xff "room id" 0xfe "username" 0xff "variablename".
fn create_old_key(prefix: &[u8], insert: &UserVariableEntry) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(&prefix); //prefix already has 0xff.
key.extend_from_slice(&insert.room_id.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.username.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.variable_name.as_bytes());
key
}
/// Create a new key, of "username" 0xfe "room id" 0xff "variablename".
fn create_new_key(insert: &UserVariableEntry) -> Vec<u8> {
let mut key = vec![];
key.extend_from_slice(&insert.username.as_bytes());
key.push(0xfe);
key.extend_from_slice(&insert.room_id.as_bytes());
key.push(0xff);
key.extend_from_slice(&insert.variable_name.as_bytes());
key
}
pub fn migrate(db: &Database) -> Result<(), DataError> {
let variables_tree = &db.variables.room_user_variables;
let count_tree = &db.variables.room_user_variable_count;
let prefix = b"variables";
let results: Vec<UserVariableEntry> = variables_tree
.scan_prefix(&prefix)
.map(extract_v1_entries)
.collect::<Result<Vec<_>, MigrationError>>()?;
let mut counts: HashMap<(String, String), i32> = HashMap::new();
let mut batch = Batch::default();
for insert in results {
let count = counts
.entry((insert.username.clone(), insert.room_id.clone()))
.or_insert(0);
*count += 1;
let old = create_old_key(prefix, &insert);
let new = create_new_key(&insert);
batch.remove(old);
batch.insert(new, insert.value);
}
let mut count_batch = Batch::default();
counts.into_iter().for_each(|((username, room_id), count)| {
let mut key = username.as_bytes().to_vec();
key.push(0xfe);
key.extend_from_slice(room_id.as_bytes());
let db_value: I32<LittleEndian> = I32::new(count);
count_batch.insert(key, db_value.as_bytes());
});
variables_tree.apply_batch(batch)?;
count_tree.apply_batch(count_batch)?;
Ok(())
}
}

View File

@ -1,46 +0,0 @@
use crate::context::Context;
use crate::db::sqlite::Variables;
use crate::db::variables::UserAndRoom;
use crate::error::BotError;
use crate::error::DiceRollingError;
use crate::parser::Amount;
use crate::parser::Element as NewElement;
use futures::stream::{self, StreamExt, TryStreamExt};
use std::slice;
/// Calculate the amount of dice to roll by consulting the database
/// and replacing variables with corresponding the amount. Errors out
/// if it cannot find a variable defined, or if the database errors.
pub async fn calculate_single_die_amount(
amount: &Amount,
ctx: &Context<'_>,
) -> Result<i32, BotError> {
calculate_dice_amount(slice::from_ref(amount), ctx).await
}
/// Calculate the amount of dice to roll by consulting the database
/// and replacing variables with corresponding amounts. Errors out if
/// it cannot find a variable defined, or if the database errors.
pub async fn calculate_dice_amount(amounts: &[Amount], ctx: &Context<'_>) -> Result<i32, BotError> {
let stream = stream::iter(amounts);
let variables = &ctx
.db
.get_user_variables(&ctx.username, ctx.room_id().as_str())
.await?;
use DiceRollingError::VariableNotFound;
let dice_amount: i32 = stream
.then(|amount| async move {
match &amount.element {
NewElement::Number(num_dice) => Ok(num_dice * amount.operator.mult()),
NewElement::Variable(variable) => variables
.get(variable)
.ok_or_else(|| VariableNotFound(variable.clone()))
.map(|i| *i),
}
})
.try_fold(0, |total, num_dice| async move { Ok(total + num_dice) })
.await?;
Ok(dice_amount)
}

View File

@ -1,49 +0,0 @@
use crate::db::sqlite::errors::DataError;
use crate::db::sqlite::Rooms;
use crate::error::BotError;
use crate::matrix;
use crate::models::RoomInfo;
use futures::stream::{self, StreamExt, TryStreamExt};
use matrix_sdk::{self, identifiers::RoomId, Client};
/// Record the information about a room, including users in it.
pub async fn record_room_information(
client: &Client,
db: &crate::db::sqlite::Database,
room_id: &RoomId,
room_display_name: &str,
our_username: &str,
) -> Result<(), BotError> {
//Clear out any old room info first.
db.clear_info(room_id.as_str()).await?;
let room_id_str = room_id.as_str();
let usernames = matrix::get_users_in_room(&client, &room_id).await?;
let info = RoomInfo {
room_id: room_id_str.to_owned(),
room_name: room_display_name.to_owned(),
};
// TODO this and the username adding should be one whole
// transaction in the db.
db.insert_room_info(&info).await?;
let filtered_usernames = usernames
.into_iter()
.filter(|username| username != our_username);
// Async collect into vec of results, then use into_iter of result
// to go to from Result<Vec<()>> to just Result<()>. Easier than
// attempting to async-collect our way to a single Result<()>.
stream::iter(filtered_usernames)
.then(|username| async move {
db.add_user_to_room(&username, &room_id_str)
.await
.map_err(|e| e.into())
})
.collect::<Vec<Result<(), BotError>>>()
.await
.into_iter()
.collect()
}

View File

@ -1,67 +0,0 @@
use log::error;
use matrix_sdk::events::room::message::NoticeMessageEventContent;
use matrix_sdk::{
events::room::message::{InReplyTo, Relation},
events::room::message::{MessageEventContent, MessageType},
events::AnyMessageEventContent,
identifiers::EventId,
Error as MatrixError,
};
use matrix_sdk::{identifiers::RoomId, Client};
/// Extracts more detailed error messages out of a matrix SDK error.
fn extract_error_message(error: MatrixError) -> String {
use matrix_sdk::{Error::Http, HttpError};
if let Http(HttpError::Api(ruma_err)) = error {
ruma_err.to_string()
} else {
error.to_string()
}
}
/// Retrieve a list of users in a given room.
pub async fn get_users_in_room(
client: &Client,
room_id: &RoomId,
) -> Result<Vec<String>, MatrixError> {
if let Some(joined_room) = client.get_joined_room(room_id) {
let members = joined_room.joined_members().await?;
Ok(members
.into_iter()
.map(|member| member.user_id().to_string())
.collect())
} else {
Ok(vec![])
}
}
pub async fn send_message(
client: &Client,
room_id: &RoomId,
message: &str,
reply_to: Option<EventId>,
) {
let room = match client.get_joined_room(room_id) {
Some(room) => room,
_ => return,
};
let plain = html2text::from_read(message.as_bytes(), message.len());
let mut content = MessageEventContent::new(MessageType::Notice(
NoticeMessageEventContent::html(plain.trim(), message),
));
content.relates_to = reply_to.map(|event_id| Relation::Reply {
in_reply_to: InReplyTo::new(event_id),
});
let content = AnyMessageEventContent::RoomMessage(content);
let result = room.send(content, None).await;
if let Err(e) = result {
let message = extract_error_message(e);
error!("Error sending message: {}", message);
};
}

View File

@ -1,8 +0,0 @@
use serde::{Deserialize, Serialize};
/// RoomInfo has basic metadata about a room: its name, ID, etc.
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct RoomInfo {
pub room_id: String,
pub room_name: String,
}