Compare commits

..

59 Commits

Author SHA1 Message Date
projectmoon 86df3c5d1f Do not process commands coming from ourselves (help text)
continuous-integration/drone/push Build is passing Details
2024-09-26 09:18:42 +02:00
projectmoon 38a7e50c5c Don't forget to update xbps on final stage too
continuous-integration/drone/push Build is passing Details
2024-09-25 23:06:20 +02:00
projectmoon e309fd1fc6 Sync xbps and update it before everything else.
continuous-integration/drone/push Build is failing Details
2024-09-25 22:56:02 +02:00
projectmoon 9262fe2cac move xbps update after sync
continuous-integration/drone/push Build is failing Details
2024-09-25 22:41:54 +02:00
projectmoon 724a781e7c Attempt to correct error in docker image
continuous-integration/drone/push Build is failing Details
2024-09-25 22:30:30 +02:00
projectmoon ef074beb96 Drone: Update to Rust 1.80 builder
continuous-integration/drone/push Build is failing Details
continuous-integration/drone Build is failing Details
2024-09-25 21:56:03 +02:00
projectmoon 81a69f329a Update for Rust 1.80.x
continuous-integration/drone/push Build is failing Details
2024-09-25 20:57:19 +02:00
projectmoon c9e7efa61d update to sqlx 0.6
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
2023-04-13 21:12:04 +02:00
projectmoon f295f2b7b6 Update to Matrix SDK 0.6 (#98)
continuous-integration/drone/push Build is passing Details
Quite a few changes involved. Mostly variable renames and a few changes to `await`s.

Not ready yet because bot cannot login due to some arcane error of `expected value at line 1 column 1`.

Co-authored-by: projectmoon <projectmoon@noreply.git.agnos.is>
Reviewed-on: #98
2023-04-13 19:04:48 +00:00
projectmoon 090ce9be45 Add help topic for variables
continuous-integration/drone/push Build is passing Details
Fixes #60
2023-04-05 07:59:13 +02:00
projectmoon 2a6dff3e07 Update cargo deps
continuous-integration/drone/push Build is failing Details
2023-04-05 07:58:26 +02:00
projectmoon 952f35d53a Rust 1.68 (#99)
continuous-integration/drone/push Build is failing Details
Update to Rust 1.68

Co-authored-by: projectmoon <projectmoon@noreply.git.agnos.is>
Reviewed-on: #99
2023-04-05 05:57:16 +00:00
projectmoon 552daa4746 Add a game system column to room info (#95)
continuous-integration/drone/push Build is passing Details
Adds a new enum and table in preparation for storing game information about a specific room.

Reviewed-on: #95
2022-02-02 20:56:50 +00:00
projectmoon c514b85510 Change modifier order in Cthulhu
continuous-integration/drone/push Build is passing Details
2021-11-06 21:23:51 +00:00
projectmoon 6eb81f43d5 Change CofD modifiers to come after dice pool 2021-11-06 21:23:51 +00:00
projectmoon 44b1e0f649 Switch to working (but somewhat bigger) Void docker image
continuous-integration/drone/push Build is passing Details
2021-11-06 13:47:23 +00:00
projectmoon a8ccdc9cce Update rust test image version for CI.
continuous-integration/drone/push Build is passing Details
2021-11-05 19:37:59 +00:00
projectmoon 13ce7b3ee6 Readme update (aka force build)
continuous-integration/drone/push Build is failing Details
2021-11-05 17:53:53 +00:00
projectmoon 6f09a11586 Upgrade to matrix SDK 0.4. 2021-11-05 15:34:16 +00:00
projectmoon ee3ec18e06 Refactor keep-drop parsing into function, better error handling. (#93)
continuous-integration/drone/push Build is passing Details
This commit refactors the keep-drop parsing into two separate
functions: one for extracting keep-drop text, and one for actually
doing something with the extracted values. An intermediate enum is
introduced to contain extracted text, instead of relying on Ok/Err
values directly for figuring out what to do with the values.

This allows us to express "this behavior is correct, and all others
are not" instead of using a "fall back to secondary functionality"
approach.

Reviewed-on: #93
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-09-30 21:16:00 +00:00
projectmoon 126548d868 Do not panic on invalid dice/sides amount for keep/drop.
continuous-integration/drone/push Build is passing Details
Insted of unwrap(), map error to a nom parser error. Not the best-est
solution, but it is functional. The TooLarge value seems appropriate.
2021-09-26 14:15:12 +00:00
Matthew Sparks 7e7e9e534e Adding None enum to keep/drop, cleaning up matches
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
2021-09-24 23:03:20 -04:00
Matthew Sparks 2d9853fbf0 Updating README for new drop command
continuous-integration/drone/pr Build is passing Details
2021-09-17 23:15:55 -04:00
Matthew Sparks 3d6210b32d Adding enum for exclusive drop/keep
continuous-integration/drone/pr Build is passing Details
2021-09-17 23:11:13 -04:00
Matthew Sparks 8b5973475f Forgot to fix tests, fixing keep/drop Err case
continuous-integration/drone/pr Build is passing Details
2021-09-17 22:18:23 -04:00
Matthew Sparks 1992ef4e08 Updating roll doc
continuous-integration/drone/pr Build is failing Details
2021-09-17 22:08:51 -04:00
Matthew Sparks f904e3a948 Updating match blocks for keep/drop
continuous-integration/drone/pr Build is failing Details
2021-09-17 21:45:30 -04:00
Matthew Sparks 8317f40f61 Updating README for keep/drop
continuous-integration/drone/pr Build is passing Details
2021-09-16 23:25:26 -04:00
Matthew Sparks 069ee47364 Adding drop function 2021-09-16 22:55:11 -04:00
Matthew Sparks dc242182f4 Fix string comparison in keep/count check, and add test cases 2021-09-07 23:59:49 -04:00
Matthew Sparks 15163ac11d Adding calculations for keep, and adding validation on keep input 2021-09-07 22:10:14 -04:00
Matthew Sparks 1860eaf378 Adding parsing for keeping highest dice 2021-09-06 21:43:46 -04:00
Matthew Sparks 2654887d8c Initial commit to add keep to dice struct and preserve parser test cases 2021-09-06 21:43:46 -04:00
projectmoon 125f3d0cee Fix drone yml to produce docker images again.
continuous-integration/drone/push Build is passing Details
2021-09-06 23:58:05 +00:00
projectmoon a4c3d34a97 Version 0.13.1
continuous-integration/drone/push Build is passing Details
2021-09-06 22:21:24 +00:00
projectmoon 86fbb05e54 Run Drone CI on tags
continuous-integration/drone/push Build is passing Details
2021-09-06 22:18:06 +00:00
projectmoon 661a943672 Readme Updates (#91)
continuous-integration/drone/push Build was killed Details
Add contributing information.

Add support/community section.

Add matrix room badge
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-09-06 22:15:20 +00:00
projectmoon d65715dee6 Remove example room ID from tonic_client
continuous-integration/drone/push Build is passing Details
2021-09-05 20:38:45 +00:00
projectmoon 55a3bfb861 Update readme for crates.io installation. 2021-09-05 20:38:09 +00:00
projectmoon 0050810182 Fix dicebot readme link
continuous-integration/drone/push Build is passing Details
2021-09-05 20:22:42 +00:00
projectmoon 3ba546d4a4 Add metadata to rpc package.
continuous-integration/drone/push Build is passing Details
2021-09-05 20:14:56 +00:00
projectmoon ffded7b572 Add metadata to rpc package. 2021-09-05 20:14:13 +00:00
projectmoon cf93d14913 Version 0.13.0
continuous-integration/drone/push Build is passing Details
2021-09-05 19:08:27 +00:00
projectmoon cf6dd96b34 Update sqlx and refinery to newer versions (#88)
continuous-integration/drone/push Build is passing Details
For some reason, also required rewriting database tests to deal with
tempfile deleting files after scope drop. This never used to occur,
but now it does! So now the unit tests are in a closure where the temp
file is dropped at the end of the test. Really should just use sqlx
migrations, then we can get an in-memory database.

Co-authored-by: projectmoon <projectmoon@agnos.is>
Reviewed-on: #88
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-09-05 07:56:41 +00:00
projectmoon c8c6f4d6f0 Fix dependency specification for rpc crate in dicebot.
continuous-integration/drone/push Build is passing Details
2021-09-04 23:24:52 +00:00
projectmoon 2488429edb Version 0.12.0
continuous-integration/drone/push Build is passing Details
2021-09-04 22:23:36 +00:00
projectmoon f68d5ffcc1 Update to versioned matrix SDK.
continuous-integration/drone/push Build is passing Details
2021-09-04 21:37:49 +00:00
projectmoon 473e899275 Merge branch 'kg333-master'
continuous-integration/drone/push Build is passing Details
Merge PR #43 from github to fix docker build.
2021-09-03 09:33:02 +00:00
projectmoon 1f03837bfe Merge branch 'master' of https://github.com/kg333/matrix-dicebot into kg333-master 2021-09-03 09:32:48 +00:00
projectmoon 0059e3d133 Revert "Initial prototype of web UI and web API."
continuous-integration/drone/push Build is failing Details
This reverts commit cab856241d.
2021-09-03 09:29:52 +00:00
matthew 915b82d0aa Updating GPG key server; sks-keyservers.net is offline permanently 2021-08-28 00:12:12 +00:00
projectmoon cab856241d Initial prototype of web UI and web API.
continuous-integration/drone/push Build is failing Details
This commit shuffles the entire repository around into multiple crates, bringing with it an in-progress web UI and web AI. It was merged prematurely to allow for dependency upgrades of the matrix SDK.

The build should still only produce the dicebot image.
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-07-15 15:04:50 +00:00
projectmoon 764426382a Convert project to workspace with Tonic for gRPC. (#84)
continuous-integration/drone/push Build is passing Details
Convert project to workspace with Tonic for gRPC.

This commit adds an RPC service to the dicebot, allowing external
applications to control it. The project was converted to a cargo
workspace to house the protobuf definitions in a common crate
(tenebrous-rpc), so that clients and servers can make use of these
protobuf definitions.
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-06-02 21:09:58 +00:00
projectmoon b4321721c4 Minor documentation update.
continuous-integration/drone/push Build is passing Details
2021-05-30 22:53:56 +00:00
projectmoon 494d28486e Remove Box<dyn Command> conversion impls for map in macro.
continuous-integration/drone/push Build is passing Details
2021-05-30 22:49:28 +00:00
projectmoon b7393c1907 Use active room in relevant commands.
continuous-integration/drone/pr Build is passing Details
continuous-integration/drone/push Build is passing Details
2021-05-30 14:19:13 +00:00
projectmoon 3d2eb14cd3 Change room in context to origin_room, add active_room.
The context now knows about origin room (the room where the command
was executed), and the "active room," which is the room that the user
wants the command to apply to. If no active room is defined, then the
origin room acts as the active room. In a public room with the bot,
the active room is also the same as the origin room.
2021-05-30 14:18:56 +00:00
projectmoon 53339282e0 Actually set room when running SetRoomCommand (#79)
continuous-integration/drone/push Build is passing Details
Also sort rooms in get_rooms_for_user for consistency.
Co-Authored-By: projectmoon <projectmoon@noreply.git.agnos.is>
Co-Committed-By: projectmoon <projectmoon@noreply.git.agnos.is>
2021-05-29 20:26:20 +00:00
projectmoon 7050cf037a Remove return statements in Fuseable impl for room search.
continuous-integration/drone/push Build is passing Details
2021-05-29 14:49:24 +00:00
75 changed files with 4273 additions and 2427 deletions

View File

@ -3,18 +3,20 @@ name: build-and-test
steps: steps:
- name: test - name: test
image: rust:1.51 image: rust:1.80
commands: commands:
- apt-get update - apt-get update
- apt-get install -y cmake - apt-get install -y cmake
- rustup component add rustfmt
- cargo build --verbose --all - cargo build --verbose --all
- cargo test --verbose --all - cargo test --verbose --all
- name: docker - name: docker
image: plugins/docker image: plugins/docker
when: when:
branch: ref:
- master - refs/tags/v*
- refs/heads/master
settings: settings:
auto_tag: true auto_tag: true
username: username:

3218
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,47 +1,6 @@
[package] [workspace]
name = "tenebrous-dicebot"
version = "0.10.0"
authors = ["Taylor C. Richberger <taywee@gmx.com>", "projectmoon <projectmoon@agnos.is>"]
edition = "2018"
license = 'AGPL-3.0-or-later'
description = 'An async Matrix dice bot for role-playing games'
readme = 'README.md'
repository = 'https://git.agnos.is/projectmoon/matrix-dicebot'
keywords = ["games", "dice", "matrix", "bot"]
categories = ["games"]
[dependencies] members = [
log = "0.4" "dicebot",
tracing-subscriber = "0.2" "rpc"
toml = "0.5" ]
nom = "5"
rand = "0.8"
rust-argon2 = "0.8"
thiserror = "1.0"
itertools = "0.10"
async-trait = "0.1"
url = "2.1"
dirs = "3.0"
indoc = "1.0"
combine = "4.5"
futures = "0.3"
html2text = "0.2"
phf = { version = "0.8", features = ["macros"] }
matrix-sdk = { git = "https://github.com/matrix-org/matrix-rust-sdk", branch = "master" }
refinery = { version = "0.5", features = ["rusqlite"]}
barrel = { version = "0.6", features = ["sqlite3"] }
tempfile = "3"
substring = "1.4"
fuse-rust = "0.2"
[dependencies.sqlx]
version = "0.5"
features = [ "offline", "sqlite", "runtime-tokio-native-tls" ]
[dependencies.serde]
version = "1"
features = ['derive']
[dependencies.tokio]
version = "1"
features = [ "full" ]

View File

@ -1,16 +1,15 @@
# Builder image with development dependencies. # Builder image with development dependencies.
FROM bougyman/voidlinux:glibc as builder FROM ghcr.io/void-linux/void-linux:latest-mini-x86_64 as builder
RUN xbps-install -S
RUN xbps-install -yu xbps
RUN xbps-install -Syu RUN xbps-install -Syu
RUN xbps-install -Sy base-devel rustup cargo cmake wget gnupg RUN xbps-install -Sy base-devel rustup cmake wget gnupg
RUN xbps-install -Sy openssl-devel libstdc++-devel RUN xbps-install -Sy openssl-devel libstdc++-devel
RUN rustup-init -qy RUN rustup-init -qy
# Install tini for signal processing and zombie killing # Install tini for signal processing and zombie killing
ENV TINI_VERSION v0.19.0 ENV TINI_VERSION v0.19.0
ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /usr/local/bin/tini ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /usr/local/bin/tini
ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini.asc /tini.asc
RUN gpg --batch --keyserver hkp://p80.pool.sks-keyservers.net:80 --recv-keys 595E85A6B1B4779EA4DAAEC70B588DFF0527A9B7 \
&& gpg --batch --verify /tini.asc /usr/local/bin/tini
RUN chmod +x /usr/local/bin/tini RUN chmod +x /usr/local/bin/tini
# Build dicebot # Build dicebot
@ -20,7 +19,10 @@ ADD . ./
RUN . /root/.cargo/env && cargo build --release RUN . /root/.cargo/env && cargo build --release
# Final image # Final image
FROM bougyman/voidlinux:tiny FROM ghcr.io/void-linux/void-linux:latest-mini-x86_64
RUN xbps-install -S
RUN xbps-install -yu xbps
RUN xbps-install -Syu
RUN xbps-install -Sy ca-certificates libstdc++ RUN xbps-install -Sy ca-certificates libstdc++
COPY --from=builder \ COPY --from=builder \
/root/src/target/release/dicebot \ /root/src/target/release/dicebot \

View File

@ -1,6 +1,7 @@
# Tenebrous Dicebot # Tenebrous Dicebot
[![Build Status](https://drone.agnos.is/api/badges/projectmoon/tenebrous-dicebot/status.svg)](https://drone.agnos.is/projectmoon/tenebrous-dicebot) [![Build Status](https://drone.agnos.is/api/badges/projectmoon/tenebrous-dicebot/status.svg)](https://drone.agnos.is/projectmoon/tenebrous-dicebot)
[![Matrix Chat](https://img.shields.io/matrix/tenebrous:agnos.is?label=matrix&server_fqdn=matrix.org)][matrix-room]
_This repository is hosted on [Agnos.is Git][main-repo] and mirrored _This repository is hosted on [Agnos.is Git][main-repo] and mirrored
to [GitHub][github-repo]._ to [GitHub][github-repo]._
@ -24,6 +25,23 @@ System.
* Works in encrypted or unencrypted Matrix rooms. * Works in encrypted or unencrypted Matrix rooms.
* Storing variables created by the user. * Storing variables created by the user.
## Support and Community
The project has a Matrix room at [#tenebrous:agnos.is][matrix-room].
It is also possible to make a post in [GitHub
Discussions][github-discussions].
For reporting bugs, we prefer that you open an issue on
[git.agnos.is][agnosis-git-issues]. However, you may also open an
issue on [GitHub][github-issues].
### Development and Contributions
All development occurs on [git.agnos.is][main-repo]. If you wish to
contribute, please open a pull request there. In some cases, pull
requests from GitHub may be accepted. All contributions must be
licensed under [AGPL 3.0 or later][agpl] to be accepted.
## Building and Installation ## Building and Installation
### Docker Image ### Docker Image
@ -46,6 +64,17 @@ root of the repository.
After pulling or building the image, see [instructions on how to use After pulling or building the image, see [instructions on how to use
the Docker image](#running-the-bot). the Docker image](#running-the-bot).
### Install from crates.io
The project can be from [crates.io][crates-io]. To install it, execute
`cargo install tenebrous-dicebot`. This will make the following
executables available on your system:
* `dicebot`: Main dicebot executable.
* `dicebot-cmd`: Run dicebot commands from the command line.
* `dicebot_migrate`: Standalone database migrator (not required).
* `tonic_client`: Test client for the gRPC connection (not required).
### Build from Source ### Build from Source
Precompiled executables are not yet available. Clone this repository Precompiled executables are not yet available. Clone this repository
@ -89,8 +118,16 @@ expressions.
!r 3d12 - 5d2 + 3 - 7d3 + 20d20 !r 3d12 - 5d2 + 3 - 7d3 + 20d20
``` ```
This system does not yet have the capability to handle things like D&D #### Keep/Drop Dice
5e advantage or disadvantage. The bot supports either keeping the highest dice in a roll, or
dropping the highest dice in a roll. This allows the bot to handle
things like D&D 5e advantage or disadvantage.
```
!roll 2d20k1
!r 2d20dh1 + 5
!r 10d10k5 + 10d10dh5 - 2
```
### Storytelling System ### Storytelling System
@ -241,6 +278,7 @@ The most basic plans are:
* Perhaps some sort of character sheet integration. But for that, we * Perhaps some sort of character sheet integration. But for that, we
would need a sheet service. would need a sheet service.
* Use environment variables instead of config file in Docker image. * Use environment variables instead of config file in Docker image.
* Per-system game rules.
## Credits ## Credits
@ -254,3 +292,9 @@ support added for Chronicles of Darkness and Call of Cthulhu.
[main-repo]: https://git.agnos.is/projectmoon/tenebrous-dicebot [main-repo]: https://git.agnos.is/projectmoon/tenebrous-dicebot
[github-repo]: https://github.com/ProjectMoon/matrix-dicebot [github-repo]: https://github.com/ProjectMoon/matrix-dicebot
[roadmap]: https://git.agnos.is/projectmoon/tenebrous-dicebot/wiki/Roadmap [roadmap]: https://git.agnos.is/projectmoon/tenebrous-dicebot/wiki/Roadmap
[crates-io]: https://crates.io/crates/tenebrous-dicebot
[matrix-room]: https://matrix.to/#/#tenebrous:agnos.is
[agnosis-git-issues]: https://git.agnos.is/projectmoon/tenebrous-dicebot/issues
[github-discussions]: https://github.com/ProjectMoon/matrix-dicebot/discussions
[github-issues]: https://github.com/ProjectMoon/matrix-dicebot/issues
[agpl]: https://www.gnu.org/licenses/agpl-3.0.en.html

57
dicebot/Cargo.toml Normal file
View File

@ -0,0 +1,57 @@
[package]
name = "tenebrous-dicebot"
version = "0.13.2"
rust-version = "1.68"
authors = ["projectmoon <projectmoon@agnos.is>", "Taylor C. Richberger <taywee@gmx.com>"]
edition = "2018"
license = 'AGPL-3.0-or-later'
description = 'An async Matrix dice bot for role-playing games'
readme = '../README.md'
repository = 'https://git.agnos.is/projectmoon/matrix-dicebot'
keywords = ["games", "dice", "matrix", "bot"]
categories = ["games"]
[build-dependencies]
tonic-build = "0.4"
[dependencies]
# indexmap version locked fixes a dependency cycle.
# indexmap = "=1.6.2"
log = "0.4"
tracing-subscriber = "0.2"
toml = "0.5"
nom = "5"
rand = "0.8"
rust-argon2 = "0.8"
thiserror = "1.0"
itertools = "0.10"
async-trait = "0.1"
url = "2.1"
dirs = "3.0"
indoc = "1.0"
combine = "4.5"
futures = "0.3"
html2text = "0.2"
phf = { version = "0.8", features = ["macros"] }
matrix-sdk = { version = "0.6" }
refinery = { version = "0.8", features = ["rusqlite"]}
barrel = { version = "0.7", features = ["sqlite3"] }
strum = { version = "0.22", features = ["derive"] }
tempfile = "3"
substring = "1.4"
fuse-rust = "0.2"
tonic = "0.4"
prost = "0.7"
tenebrous-rpc = { path = "../rpc", version = "0.1.0" }
[dependencies.sqlx]
version = "0.6"
features = [ "offline", "sqlite", "runtime-tokio-native-tls" ]
[dependencies.serde]
version = "1"
features = ['derive']
[dependencies.tokio]
version = "1"
features = [ "full" ]

View File

@ -6,23 +6,52 @@
use std::fmt; use std::fmt;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
//Old stuff, for regular dice rolling. To be moved elsewhere. /// A basic dice roll, in XdY notation, like "1d4" or "3d6".
/// Optionally supports D&D advantage/disadvantge keep-or-drop
/// functionality.
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct Dice { pub struct Dice {
pub(crate) count: u32, pub(crate) count: u32,
pub(crate) sides: u32, pub(crate) sides: u32,
pub(crate) keep_drop: KeepOrDrop,
}
/// Enum indicating how to handle bonuses or penalties using extra
/// dice. If set to Keep, the roll will keep the highest X number of
/// dice in the roll, and add those together. If set to Drop, the
/// opposite is performed, and the lowest X number of dice are added
/// instead. If set to None, then all dice in the roll are added up as
/// normal.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum KeepOrDrop {
/// Keep only the X highest dice for adding up to the total.
Keep(u32),
/// Keep only the X lowest dice (i.e. drop the highest) for adding
/// up to the total.
Drop(u32),
/// Add up all dice in the roll for the total.
None,
} }
impl fmt::Display for Dice { impl fmt::Display for Dice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}d{}", self.count, self.sides) match self.keep_drop {
KeepOrDrop::Keep(keep) => write!(f, "{}d{}k{}", self.count, self.sides, keep),
KeepOrDrop::Drop(drop) => write!(f, "{}d{}dh{}", self.count, self.sides, drop),
KeepOrDrop::None => write!(f, "{}d{}", self.count, self.sides),
}
} }
} }
impl Dice { impl Dice {
pub fn new(count: u32, sides: u32) -> Dice { pub fn new(count: u32, sides: u32, keep_drop: KeepOrDrop) -> Dice {
Dice { count, sides } Dice {
count,
sides,
keep_drop,
}
} }
} }

360
dicebot/src/basic/parser.rs Normal file
View File

@ -0,0 +1,360 @@
/**
* In addition to the terms of the AGPL, this file is governed by the
* terms of the MIT license, from the original axfive-matrix-dicebot
* project.
*/
use nom::bytes::complete::take_while;
use nom::error::ErrorKind as NomErrorKind;
use nom::Err as NomErr;
use nom::{
alt, bytes::complete::tag, character::complete::digit1, complete, many0, named,
sequence::tuple, tag, IResult,
};
use super::dice::*;
//******************************
//Legacy Code
//******************************
fn is_whitespace(input: char) -> bool {
input == ' ' || input == '\n' || input == '\t' || input == '\r'
}
/// Eat whitespace, returning it
pub fn eat_whitespace(input: &str) -> IResult<&str, &str> {
let (input, whitespace) = take_while(is_whitespace)(input)?;
Ok((input, whitespace))
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Sign {
Plus,
Minus,
}
/// Intermediate parsed value for a keep-drop expression to indicate
/// which one it is.
enum ParsedKeepOrDrop<'a> {
Keep(&'a str),
Drop(&'a str),
NotPresent,
}
macro_rules! too_big {
($input: expr) => {
NomErr::Error(($input, NomErrorKind::TooLarge))
};
}
/// Parse a dice expression. Does not eat whitespace
fn parse_dice(input: &str) -> IResult<&str, Dice> {
let (input, (count, _, sides)) = tuple((digit1, tag("d"), digit1))(input)?;
let count: u32 = count.parse().map_err(|_| too_big!(count))?;
let sides = sides.parse().map_err(|_| too_big!(sides))?;
let (input, keep_drop) = parse_keep_or_drop(input, count)?;
Ok((input, Dice::new(count, sides, keep_drop)))
}
/// Extract keep/drop number as a string. Fails if the value is not a
/// string.
fn parse_keep_or_drop_text<'a>(
symbol: &'a str,
input: &'a str,
) -> IResult<&'a str, ParsedKeepOrDrop<'a>> {
let (parsed_kd, input) = match tuple::<&str, _, (_, _), _>((tag(symbol), digit1))(input) {
// if ok, one of the expressions is present
Ok((rest, (_, kd_expr))) => match symbol {
"k" => (ParsedKeepOrDrop::Keep(kd_expr), rest),
"dh" => (ParsedKeepOrDrop::Drop(kd_expr), rest),
_ => panic!("Unrecogized keep-drop symbol: {}", symbol),
},
// otherwise absent (attempt to keep all dice)
Err(_) => (ParsedKeepOrDrop::NotPresent, input),
};
Ok((input, parsed_kd))
}
/// Parse keep/drop expression, which consits of "k" or "dh" following
/// a dice expression. For example, "1d4h3" or "1d4dh2".
fn parse_keep_or_drop<'a>(input: &'a str, count: u32) -> IResult<&'a str, KeepOrDrop> {
let (input, keep) = parse_keep_or_drop_text("k", input)?;
let (input, drop) = parse_keep_or_drop_text("dh", input)?;
use ParsedKeepOrDrop::*;
let keep_drop: KeepOrDrop = match (keep, drop) {
//Potential valid Keep expression.
(Keep(keep), NotPresent) => match keep.parse().map_err(|_| too_big!(input))? {
_i if _i > count || _i == 0 => Ok(KeepOrDrop::None),
i => Ok(KeepOrDrop::Keep(i)),
},
//Potential valid Drop expression.
(NotPresent, Drop(drop)) => match drop.parse().map_err(|_| too_big!(input))? {
_i if _i >= count => Ok(KeepOrDrop::None),
i => Ok(KeepOrDrop::Drop(i)),
},
//No Keep or Drop specified; regular behavior.
(NotPresent, NotPresent) => Ok(KeepOrDrop::None),
//Anything else is an error.
_ => Err(NomErr::Error((input, NomErrorKind::Many1))),
}?;
Ok((input, keep_drop))
}
// Parse a single digit expression. Does not eat whitespace
fn parse_bonus(input: &str) -> IResult<&str, u32> {
let (input, bonus) = digit1(input)?;
Ok((input, bonus.parse().unwrap()))
}
// Parse a sign expression. Eats whitespace.
fn parse_sign(input: &str) -> IResult<&str, Sign> {
let (input, _) = eat_whitespace(input)?;
named!(sign(&str) -> Sign, alt!(
complete!(tag!("+")) => { |_| Sign::Plus } |
complete!(tag!("-")) => { |_| Sign::Minus }
));
let (input, sign) = sign(input)?;
Ok((input, sign))
}
// Parse an element expression. Eats whitespace.
fn parse_element(input: &str) -> IResult<&str, Element> {
let (input, _) = eat_whitespace(input)?;
named!(element(&str) -> Element, alt!(
parse_dice => { |d| Element::Dice(d) } |
parse_bonus => { |b| Element::Bonus(b) }
));
let (input, element) = element(input)?;
Ok((input, element))
}
// Parse a signed element expression. Eats whitespace.
fn parse_signed_element(input: &str) -> IResult<&str, SignedElement> {
let (input, _) = eat_whitespace(input)?;
let (input, sign) = parse_sign(input)?;
let (input, _) = eat_whitespace(input)?;
let (input, element) = parse_element(input)?;
let element = match sign {
Sign::Plus => SignedElement::Positive(element),
Sign::Minus => SignedElement::Negative(element),
};
Ok((input, element))
}
// Parse a full element expression. Eats whitespace.
pub fn parse_element_expression(input: &str) -> IResult<&str, ElementExpression> {
named!(first_element(&str) -> SignedElement, alt!(
parse_signed_element => { |e| e } |
parse_element => { |e| SignedElement::Positive(e) }
));
let (input, first) = first_element(input)?;
let (input, rest) = if input.trim().is_empty() {
(input, vec![first])
} else {
named!(rest_elements(&str) -> Vec<SignedElement>, many0!(parse_signed_element));
let (input, mut rest) = rest_elements(input)?;
rest.insert(0, first);
(input, rest)
};
Ok((input, ElementExpression(rest)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dice_test() {
assert_eq!(
parse_dice("2d4"),
Ok(("", Dice::new(2, 4, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("20d40"),
Ok(("", Dice::new(20, 40, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("8d7"),
Ok(("", Dice::new(8, 7, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("2d20k1"),
Ok(("", Dice::new(2, 20, KeepOrDrop::Keep(1))))
);
assert_eq!(
parse_dice("100d10k90"),
Ok(("", Dice::new(100, 10, KeepOrDrop::Keep(90))))
);
assert_eq!(
parse_dice("11d10k10"),
Ok(("", Dice::new(11, 10, KeepOrDrop::Keep(10))))
);
assert_eq!(
parse_dice("12d10k11"),
Ok(("", Dice::new(12, 10, KeepOrDrop::Keep(11))))
);
assert_eq!(
parse_dice("12d10k13"),
Ok(("", Dice::new(12, 10, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("12d10k0"),
Ok(("", Dice::new(12, 10, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("20d40dh5"),
Ok(("", Dice::new(20, 40, KeepOrDrop::Drop(5))))
);
assert_eq!(
parse_dice("8d7dh9"),
Ok(("", Dice::new(8, 7, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("8d7dh8"),
Ok(("", Dice::new(8, 7, KeepOrDrop::None)))
);
}
#[test]
fn cant_have_both_keep_and_drop_test() {
let res = parse_dice("1d4k3dh2");
assert!(res.is_err());
match res {
Err(NomErr::Error((_, kind))) => {
assert_eq!(kind, NomErrorKind::Many1);
}
_ => panic!("Got success, expected error"),
}
}
#[test]
fn big_number_of_dice_doesnt_crash_test() {
let res = parse_dice("64378631476346123874527551481376547657868536d4");
assert!(res.is_err());
match res {
Err(NomErr::Error((input, kind))) => {
assert_eq!(kind, NomErrorKind::TooLarge);
assert_eq!(input, "64378631476346123874527551481376547657868536");
}
_ => panic!("Got success, expected error"),
}
}
#[test]
fn big_number_of_sides_doesnt_crash_test() {
let res = parse_dice("1d423562312587425472658956278456298376234876");
assert!(res.is_err());
match res {
Err(NomErr::Error((input, kind))) => {
assert_eq!(kind, NomErrorKind::TooLarge);
assert_eq!(input, "423562312587425472658956278456298376234876");
}
_ => panic!("Got success, expected error"),
}
}
#[test]
fn element_test() {
assert_eq!(
parse_element(" \t\n\r\n 8d7 \n"),
Ok((" \n", Element::Dice(Dice::new(8, 7, KeepOrDrop::None))))
);
assert_eq!(
parse_element(" \t\n\r\n 3d20k2 \n"),
Ok((" \n", Element::Dice(Dice::new(3, 20, KeepOrDrop::Keep(2)))))
);
assert_eq!(
parse_element(" \t\n\r\n 8 \n"),
Ok((" \n", Element::Bonus(8)))
);
}
#[test]
fn signed_element_test() {
assert_eq!(
parse_signed_element("+ 7"),
Ok(("", SignedElement::Positive(Element::Bonus(7))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8 \n"),
Ok((" \n", SignedElement::Negative(Element::Bonus(8))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8d4 \n"),
Ok((
" \n",
SignedElement::Negative(Element::Dice(Dice::new(8, 4, KeepOrDrop::None)))
))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8d4k4 \n"),
Ok((
" \n",
SignedElement::Negative(Element::Dice(Dice::new(8, 4, KeepOrDrop::Keep(4))))
))
);
assert_eq!(
parse_signed_element(" \t\n\r\n+ 8d4 \n"),
Ok((
" \n",
SignedElement::Positive(Element::Dice(Dice::new(8, 4, KeepOrDrop::None)))
))
);
}
#[test]
fn element_expression_test() {
assert_eq!(
parse_element_expression("8d4"),
Ok((
"",
ElementExpression(vec![SignedElement::Positive(Element::Dice(Dice::new(
8,
4,
KeepOrDrop::None
)))])
))
);
assert_eq!(
parse_element_expression("\t2d20k1 + 5"),
Ok((
"",
ElementExpression(vec![
SignedElement::Positive(Element::Dice(Dice::new(2, 20, KeepOrDrop::Keep(1)))),
SignedElement::Positive(Element::Bonus(5)),
])
))
);
assert_eq!(
parse_element_expression(" - 8d4 \n "),
Ok((
" \n ",
ElementExpression(vec![SignedElement::Negative(Element::Dice(Dice::new(
8,
4,
KeepOrDrop::None
)))])
))
);
assert_eq!(
parse_element_expression("\t3d4k2 + 7 - 5 - 6d12dh3 + 1d1 + 53 1d5 "),
Ok((
" 1d5 ",
ElementExpression(vec![
SignedElement::Positive(Element::Dice(Dice::new(3, 4, KeepOrDrop::Keep(2)))),
SignedElement::Positive(Element::Bonus(7)),
SignedElement::Negative(Element::Bonus(5)),
SignedElement::Negative(Element::Dice(Dice::new(6, 12, KeepOrDrop::Drop(3)))),
SignedElement::Positive(Element::Dice(Dice::new(1, 1, KeepOrDrop::None))),
SignedElement::Positive(Element::Bonus(53)),
])
))
);
}
}

View File

@ -4,6 +4,7 @@
* project. * project.
*/ */
use crate::basic::dice; use crate::basic::dice;
use crate::basic::dice::KeepOrDrop;
use rand::prelude::*; use rand::prelude::*;
use std::fmt; use std::fmt;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
@ -19,15 +20,27 @@ pub trait Rolled {
} }
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct DiceRoll(pub Vec<u32>); /// array of rolls in order, how many dice to keep, and how many to drop
/// keep indicates how many of the highest dice to keep
/// drop indicates how many of the highest dice to drop
pub struct DiceRoll (pub Vec<u32>, usize, usize);
impl DiceRoll { impl DiceRoll {
pub fn rolls(&self) -> &[u32] { pub fn rolls(&self) -> &[u32] {
&self.0 &self.0
} }
pub fn keep(&self) -> usize {
self.1
}
pub fn drop(&self) -> usize {
self.2
}
// only count kept dice in total
pub fn total(&self) -> u32 { pub fn total(&self) -> u32 {
self.0.iter().sum() self.0[self.2..self.1].iter().sum()
} }
} }
@ -41,11 +54,21 @@ impl fmt::Display for DiceRoll {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.rolled_value())?; write!(f, "{}", self.rolled_value())?;
let rolls = self.rolls(); let rolls = self.rolls();
let mut iter = rolls.iter(); let keep = self.keep();
let drop = self.drop();
let mut iter = rolls.iter().enumerate();
if let Some(first) = iter.next() { if let Some(first) = iter.next() {
write!(f, " ({}", first)?; if drop != 0 {
write!(f, " ([{}]", first.1)?;
} else {
write!(f, " ({}", first.1)?;
}
for roll in iter { for roll in iter {
write!(f, " + {}", roll)?; if roll.0 >= keep || roll.0 < drop {
write!(f, " + [{}]", roll.1)?;
} else {
write!(f, " + {}", roll.1)?;
}
} }
write!(f, ")")?; write!(f, ")")?;
} }
@ -58,11 +81,17 @@ impl Roll for dice::Dice {
fn roll(&self) -> DiceRoll { fn roll(&self) -> DiceRoll {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let rolls: Vec<_> = (0..self.count) let mut rolls: Vec<_> = (0..self.count)
.map(|_| rng.gen_range(1..=self.sides)) .map(|_| rng.gen_range(1..=self.sides))
.collect(); .collect();
// sort rolls in descending order
rolls.sort_by(|a, b| b.cmp(a));
DiceRoll(rolls) match self.keep_drop {
KeepOrDrop::Keep(k) => DiceRoll(rolls,k as usize, 0),
KeepOrDrop::Drop(dh) => DiceRoll(rolls,self.count as usize, dh as usize),
KeepOrDrop::None => DiceRoll(rolls,self.count as usize, 0),
}
} }
} }
@ -198,18 +227,26 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn dice_roll_display_test() { fn dice_roll_display_test() {
assert_eq!(DiceRoll(vec![1, 3, 4]).to_string(), "8 (1 + 3 + 4)"); assert_eq!(DiceRoll(vec![1, 3, 4], 3, 0).to_string(), "8 (1 + 3 + 4)");
assert_eq!(DiceRoll(vec![]).to_string(), "0"); assert_eq!(DiceRoll(vec![], 0, 0).to_string(), "0");
assert_eq!( assert_eq!(
DiceRoll(vec![4, 7, 2, 10]).to_string(), DiceRoll(vec![4, 7, 2, 10], 4, 0).to_string(),
"23 (4 + 7 + 2 + 10)" "23 (4 + 7 + 2 + 10)"
); );
assert_eq!(
DiceRoll(vec![20, 13, 11, 10], 3, 0).to_string(),
"44 (20 + 13 + 11 + [10])"
);
assert_eq!(
DiceRoll(vec![20, 13, 11, 10], 4, 1).to_string(),
"34 ([20] + 13 + 11 + 10)"
);
} }
#[test] #[test]
fn element_roll_display_test() { fn element_roll_display_test() {
assert_eq!( assert_eq!(
ElementRoll::Dice(DiceRoll(vec![1, 3, 4])).to_string(), ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0)).to_string(),
"8 (1 + 3 + 4)" "8 (1 + 3 + 4)"
); );
assert_eq!(ElementRoll::Bonus(7).to_string(), "7"); assert_eq!(ElementRoll::Bonus(7).to_string(), "7");
@ -218,11 +255,11 @@ mod tests {
#[test] #[test]
fn signed_element_roll_display_test() { fn signed_element_roll_display_test() {
assert_eq!( assert_eq!(
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))).to_string(), SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))).to_string(),
"8 (1 + 3 + 4)" "8 (1 + 3 + 4)"
); );
assert_eq!( assert_eq!(
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))).to_string(), SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))).to_string(),
"-8 (1 + 3 + 4)" "-8 (1 + 3 + 4)"
); );
assert_eq!( assert_eq!(
@ -239,14 +276,14 @@ mod tests {
fn element_expression_roll_display_test() { fn element_expression_roll_display_test() {
assert_eq!( assert_eq!(
ElementExpressionRoll(vec![SignedElementRoll::Positive(ElementRoll::Dice( ElementExpressionRoll(vec![SignedElementRoll::Positive(ElementRoll::Dice(
DiceRoll(vec![1, 3, 4]) DiceRoll(vec![1, 3, 4], 3, 0)
)),]) )),])
.to_string(), .to_string(),
"8 (1 + 3 + 4)" "8 (1 + 3 + 4)"
); );
assert_eq!( assert_eq!(
ElementExpressionRoll(vec![SignedElementRoll::Negative(ElementRoll::Dice( ElementExpressionRoll(vec![SignedElementRoll::Negative(ElementRoll::Dice(
DiceRoll(vec![1, 3, 4]) DiceRoll(vec![1, 3, 4], 3, 0)
)),]) )),])
.to_string(), .to_string(),
"-8 (1 + 3 + 4)" "-8 (1 + 3 + 4)"
@ -263,8 +300,8 @@ mod tests {
); );
assert_eq!( assert_eq!(
ElementExpressionRoll(vec![ ElementExpressionRoll(vec![
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))), SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))),
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 2]))), SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 2], 2, 0))),
SignedElementRoll::Positive(ElementRoll::Bonus(4)), SignedElementRoll::Positive(ElementRoll::Bonus(4)),
SignedElementRoll::Negative(ElementRoll::Bonus(7)), SignedElementRoll::Negative(ElementRoll::Bonus(7)),
]) ])
@ -273,13 +310,33 @@ mod tests {
); );
assert_eq!( assert_eq!(
ElementExpressionRoll(vec![ ElementExpressionRoll(vec![
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))), SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 2]))), SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 2], 2, 0))),
SignedElementRoll::Negative(ElementRoll::Bonus(4)), SignedElementRoll::Negative(ElementRoll::Bonus(4)),
SignedElementRoll::Positive(ElementRoll::Bonus(7)), SignedElementRoll::Positive(ElementRoll::Bonus(7)),
]) ])
.to_string(), .to_string(),
"-2 (-8 (1 + 3 + 4) + 3 (1 + 2) - 4 + 7)" "-2 (-8 (1 + 3 + 4) + 3 (1 + 2) - 4 + 7)"
); );
assert_eq!(
ElementExpressionRoll(vec![
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![4, 3, 1], 3, 0))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![12, 2], 1, 0))),
SignedElementRoll::Negative(ElementRoll::Bonus(4)),
SignedElementRoll::Positive(ElementRoll::Bonus(7)),
])
.to_string(),
"7 (-8 (4 + 3 + 1) + 12 (12 + [2]) - 4 + 7)"
);
assert_eq!(
ElementExpressionRoll(vec![
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![4, 3, 1], 3, 1))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![12, 2], 2, 0))),
SignedElementRoll::Negative(ElementRoll::Bonus(4)),
SignedElementRoll::Positive(ElementRoll::Bonus(7)),
])
.to_string(),
"13 (-4 ([4] + 3 + 1) + 14 (12 + 2) - 4 + 7)"
);
} }
} }

View File

@ -1,4 +1,5 @@
use matrix_sdk::identifiers::room_id; use matrix_sdk::ruma::room_id;
use matrix_sdk::Client;
use tenebrous_dicebot::commands; use tenebrous_dicebot::commands;
use tenebrous_dicebot::commands::ResponseExtractor; use tenebrous_dicebot::commands::ResponseExtractor;
use tenebrous_dicebot::context::{Context, RoomContext}; use tenebrous_dicebot::context::{Context, RoomContext};
@ -26,11 +27,15 @@ async fn main() -> Result<(), BotError> {
.await?; .await?;
let context = Context { let context = Context {
db: db, db,
account: Account::default(), account: Account::default(),
matrix_client: &matrix_sdk::Client::new(homeserver) matrix_client: Client::new(homeserver).await.expect("Could not create matrix client"),
.expect("Could not create matrix client"), origin_room: RoomContext {
room: RoomContext { id: &room_id!("!fakeroomid:example.com"),
display_name: "fake room".to_owned(),
secure: false,
},
active_room: RoomContext {
id: &room_id!("!fakeroomid:example.com"), id: &room_id!("!fakeroomid:example.com"),
display_name: "fake room".to_owned(), display_name: "fake room".to_owned(),
secure: false, secure: false,

View File

@ -1,21 +1,36 @@
//Needed for nested Result handling from tokio. Probably can go away after 1.47.0. //Needed for nested Result handling from tokio. Probably can go away after 1.47.0.
#![type_length_limit = "7605144"] #![type_length_limit = "7605144"]
use futures::try_join;
use log::error; use log::error;
use matrix_sdk::Client;
use std::env; use std::env;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use tenebrous_dicebot::bot::DiceBot; use tenebrous_dicebot::bot::DiceBot;
use tenebrous_dicebot::config::*; use tenebrous_dicebot::config::*;
use tenebrous_dicebot::db::sqlite::Database; use tenebrous_dicebot::db::sqlite::Database;
use tenebrous_dicebot::error::BotError; use tenebrous_dicebot::error::BotError;
use tenebrous_dicebot::rpc;
use tenebrous_dicebot::state::DiceBotState; use tenebrous_dicebot::state::DiceBotState;
use tracing_subscriber::filter::EnvFilter; use tracing_subscriber::filter::EnvFilter;
/// Attempt to create config object and ddatabase connection pool from
/// the given config path. An error is returned if config creation or
/// database pool creation fails for some reason.
async fn init(config_path: &str) -> Result<(Arc<Config>, Database, Client), BotError> {
let cfg = read_config(config_path)?;
let cfg = Arc::new(cfg);
let sqlite_path = format!("{}/dicebot.sqlite", cfg.database_path());
let db = Database::new(&sqlite_path).await?;
let client = tenebrous_dicebot::matrix::create_client(&cfg).await?;
Ok((cfg, db, client))
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> Result<(), BotError> {
let filter = if env::var("RUST_LOG").is_ok() { let filter = if env::var("RUST_LOG").is_ok() {
EnvFilter::from_default_env() EnvFilter::from_default_env()
} else { } else {
EnvFilter::new("tenebrous_dicebot=info,dicebot=info,refinery=info") EnvFilter::new("tonic=info,tenebrous_dicebot=info,dicebot=info,refinery=info")
}; };
tracing_subscriber::fmt().with_env_filter(filter).init(); tracing_subscriber::fmt().with_env_filter(filter).init();
@ -23,7 +38,9 @@ async fn main() {
match run().await { match run().await {
Ok(_) => (), Ok(_) => (),
Err(e) => error!("Error: {}", e), Err(e) => error!("Error: {}", e),
}; }
Ok(())
} }
async fn run() -> Result<(), BotError> { async fn run() -> Result<(), BotError> {
@ -32,12 +49,22 @@ async fn run() -> Result<(), BotError> {
.next() .next()
.expect("Need a config as an argument"); .expect("Need a config as an argument");
let cfg = Arc::new(read_config(config_path)?); let (cfg, db, client) = init(&config_path).await?;
let sqlite_path = format!("{}/dicebot.sqlite", cfg.database_path()); let grpc = rpc::serve_grpc(&cfg, &db, &client);
let db = Database::new(&sqlite_path).await?; let bot = run_bot(&cfg, &db, &client);
match try_join!(bot, grpc) {
Ok(_) => (),
Err(e) => error!("Error: {:?}", e),
};
Ok(())
}
async fn run_bot(cfg: &Arc<Config>, db: &Database, client: &Client) -> Result<(), BotError> {
let state = Arc::new(RwLock::new(DiceBotState::new(&cfg))); let state = Arc::new(RwLock::new(DiceBotState::new(&cfg)));
match DiceBot::new(&cfg, &state, &db) { match DiceBot::new(cfg, &state, db, client) {
Ok(bot) => bot.run().await?, Ok(bot) => bot.run().await?,
Err(e) => println!("Error connecting: {:?}", e), Err(e) => println!("Error connecting: {:?}", e),
}; };

View File

@ -0,0 +1,33 @@
use tenebrous_rpc::protos::dicebot::UserIdRequest;
use tenebrous_rpc::protos::dicebot::{dicebot_client::DicebotClient};
use tonic::{metadata::MetadataValue, transport::Channel, Request};
async fn create_client(
shared_secret: &str,
) -> Result<DicebotClient<Channel>, Box<dyn std::error::Error>> {
let channel = Channel::from_static("http://0.0.0.0:9090")
.connect()
.await?;
let bearer = MetadataValue::from_str(&format!("Bearer {}", shared_secret))?;
let client = DicebotClient::with_interceptor(channel, move |mut req: Request<()>| {
req.metadata_mut().insert("authorization", bearer.clone());
Ok(req)
});
Ok(client)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut client = create_client("example-key").await?;
let request = tonic::Request::new(UserIdRequest {
user_id: "@projectmoon:agnos.is".into(),
});
let response = client.rooms_for_user(request).await?.into_inner();
println!("Rooms: {:?}", response.rooms);
Ok(())
}

View File

@ -1,12 +1,17 @@
use crate::commands::{execute_command, ExecutionResult, ResponseExtractor};
use crate::context::{Context, RoomContext}; use crate::context::{Context, RoomContext};
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::error::BotError; use crate::error::BotError;
use crate::logic; use crate::logic;
use crate::matrix; use crate::matrix;
use crate::{
commands::{execute_command, ExecutionResult, ResponseExtractor},
models::Account,
};
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use matrix_sdk::{self, identifiers::EventId, room::Joined, Client}; use matrix_sdk::ruma::{OwnedEventId, RoomId};
use matrix_sdk::{self, room::Joined, Client};
use std::clone::Clone; use std::clone::Clone;
use std::convert::TryFrom;
/// Handle responding to a single command being executed. Wil print /// Handle responding to a single command being executed. Wil print
/// out the full result of that command. /// out the full result of that command.
@ -15,7 +20,7 @@ pub(super) async fn handle_single_result(
cmd_result: &ExecutionResult, cmd_result: &ExecutionResult,
respond_to: &str, respond_to: &str,
room: &Joined, room: &Joined,
event_id: EventId, event_id: OwnedEventId,
) { ) {
let html = cmd_result.message_html(respond_to); let html = cmd_result.message_html(respond_to);
let plain = cmd_result.message_plain(respond_to); let plain = cmd_result.message_plain(respond_to);
@ -95,24 +100,57 @@ pub(super) async fn handle_multiple_results(
matrix::send_message(client, room.room_id(), (&message, &plain), None).await; matrix::send_message(client, room.room_id(), (&message, &plain), None).await;
} }
/// Create a context for command execution. Can fai if the room /// Map an account's active room value to an actual matrix room, if
/// context creation fails. /// the account has an active room. This only retrieves the
async fn create_context<'a>( /// user-specified active room, and doesn't perform any further
db: &'a Database, /// filtering.
client: &'a Client, fn get_account_active_room(client: &Client, account: &Account) -> Result<Option<Joined>, BotError> {
room: &'a Joined, let active_room = account
sender: &'a str, .registered_user()
command: &'a str, .and_then(|u| u.active_room.as_deref())
) -> Result<Context<'a>, BotError> { .map(|room_id| <&RoomId>::try_from(room_id))
let room_ctx = RoomContext::new(room, sender).await?; .transpose()?
Ok(Context { .and_then(|active_room_id| client.get_joined_room(active_room_id));
Ok(active_room)
}
/// Execute a single command in the list of commands. Can fail if the
/// Account value cannot be created/fetched from the database, or if
/// room display names cannot be calculated. Otherwise, the success or
/// error of command execution itself is returned.
async fn execute_single_command(
command: &str,
db: &Database,
client: &Client,
origin_room: &Joined,
sender: &str,
) -> ExecutionResult {
let origin_ctx = RoomContext::new(origin_room, sender).await?;
let account = logic::get_account(db, sender).await?;
let active_room = get_account_active_room(client, &account)?;
// Active room is used in secure command-issuing rooms. In
// "public" rooms, where other users are, treat origin as the
// active room.
let active_room = active_room
.as_ref()
.filter(|_| origin_ctx.secure)
.unwrap_or(origin_room);
let active_ctx = RoomContext::new(active_room, sender).await?;
let ctx = Context {
account,
db: db.clone(), db: db.clone(),
matrix_client: client, matrix_client: client.clone(),
room: room_ctx, origin_room: origin_ctx,
username: &sender, username: &sender,
account: logic::get_account(db, &sender).await?, active_room: active_ctx,
message_body: &command, message_body: &command,
}) };
execute_command(&ctx).await
} }
/// Attempt to execute all commands sent to the bot in a message. This /// Attempt to execute all commands sent to the bot in a message. This
@ -127,13 +165,8 @@ pub(super) async fn execute(
) -> Vec<(String, ExecutionResult)> { ) -> Vec<(String, ExecutionResult)> {
stream::iter(commands) stream::iter(commands)
.then(|command| async move { .then(|command| async move {
match create_context(db, client, room, sender, command).await { let result = execute_single_command(command, db, client, room, sender).await;
Err(e) => (command.to_owned(), Err(e)), (command.to_owned(), result)
Ok(ctx) => {
let cmd_result = execute_command(&ctx).await;
(command.to_owned(), cmd_result)
}
}
}) })
.collect() .collect()
.await .await

View File

@ -0,0 +1,163 @@
use super::DiceBot;
use crate::db::sqlite::Database;
use crate::db::Rooms;
use crate::error::BotError;
use log::{debug, error, info, warn};
use matrix_sdk::ruma::events::room::member::RoomMemberEventContent;
use matrix_sdk::ruma::events::{StrippedStateEvent, SyncMessageLikeEvent};
use matrix_sdk::{self, room::Room, ruma::events::room::message::RoomMessageEventContent};
use matrix_sdk::{Client, DisplayName};
use std::ops::Sub;
use std::time::UNIX_EPOCH;
use std::time::{Duration, SystemTime};
/// Check if a message is recent enough to actually process. If the
/// message is within "oldest_message_age" seconds, this function
/// returns true. If it's older than that, it returns false and logs a
/// debug message.
fn check_message_age(
event: &SyncMessageLikeEvent<RoomMessageEventContent>,
oldest_message_age: u64,
) -> bool {
let sending_time = event
.origin_server_ts()
.to_system_time()
.unwrap_or(UNIX_EPOCH);
let oldest_timestamp = SystemTime::now().sub(Duration::from_secs(oldest_message_age));
if sending_time > oldest_timestamp {
true
} else {
let age = match oldest_timestamp.duration_since(sending_time) {
Ok(n) => format!("{} seconds too old", n.as_secs()),
Err(_) => "before the UNIX epoch".to_owned(),
};
debug!("Ignoring message because it is {}: {:?}", age, event);
false
}
}
/// Determine whether or not to process a received message. This check
/// is necessary in addition to the event processing check because we
/// may receive message events when entering a room for the first
/// time, and we don't want to respond to things before the bot was in
/// the channel, but we do want to respond to things that were sent if
/// the bot left and rejoined quickly.
async fn should_process_message<'a>(
bot: &DiceBot,
event: &SyncMessageLikeEvent<RoomMessageEventContent>,
) -> Result<(String, String), BotError> {
//Ignore messages that are older than configured duration.
if !check_message_age(event, bot.config.oldest_message_age()) {
let state_check = bot.state.read().unwrap();
if !((*state_check).logged_skipped_old_messages()) {
drop(state_check);
let mut state = bot.state.write().unwrap();
(*state).skipped_old_messages();
}
return Err(BotError::ShouldNotProcessError);
}
let msg_body: String = event
.as_original()
.map(|e| e.content.body())
.map(str::to_string)
.unwrap_or_else(|| String::new());
let sender_username: String = format!(
"@{}:{}",
event.sender().localpart(),
event.sender().server_name()
);
// Do not process messages from the bot itself. Otherwise it might
// try to execute its own commands.
let bot_username = bot
.client
.user_id()
.map(|u| format!("@{}:{}", u.localpart(), u.server_name()))
.unwrap_or_default();
if sender_username == bot_username {
return Err(BotError::ShouldNotProcessError);
}
Ok((msg_body, sender_username))
}
async fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool {
db.should_process(room_id, event_id)
.await
.unwrap_or_else(|e| {
error!(
"Database error when checking if we should process an event: {}",
e.to_string()
);
false
})
}
pub(super) async fn on_stripped_state_member(
event: StrippedStateEvent<RoomMemberEventContent>,
client: Client,
room: Room,
) {
let room = match room {
Room::Invited(invited_room) => invited_room,
_ => return,
};
if room.own_user_id().as_str() != event.state_key {
return;
}
info!(
"Autojoining room {}",
room.display_name()
.await
.ok()
.unwrap_or_else(|| DisplayName::Named("[error]".to_string()))
);
if let Err(e) = client.join_room_by_id(&room.room_id()).await {
warn!("Could not join room: {}", e.to_string())
}
}
pub(super) async fn on_room_message(
event: SyncMessageLikeEvent<RoomMessageEventContent>,
room: Room,
bot: DiceBot,
) {
let room = match room {
Room::Joined(joined_room) => joined_room,
_ => return,
};
let room_id = room.room_id().as_str();
if !should_process_event(&bot.db, room_id, event.event_id().as_str()).await {
return;
}
let (msg_body, sender_username) =
if let Ok((msg_body, sender_username)) = should_process_message(&bot, &event).await {
(msg_body, sender_username)
} else {
return;
};
let results = bot
.execute_commands(&room, &sender_username, &msg_body)
.await;
bot.handle_results(
&room,
&sender_username,
event.event_id().to_owned(),
results,
)
.await;
}

View File

@ -4,13 +4,15 @@ use crate::db::sqlite::Database;
use crate::db::DbState; use crate::db::DbState;
use crate::error::BotError; use crate::error::BotError;
use crate::state::DiceBotState; use crate::state::DiceBotState;
use dirs;
use log::info; use log::info;
use matrix_sdk::{self, identifiers::EventId, room::Joined, Client, ClientConfig, SyncSettings}; use matrix_sdk::room::Room;
use matrix_sdk::ruma::events::room::message::RoomMessageEventContent;
use matrix_sdk::ruma::events::SyncMessageLikeEvent;
use matrix_sdk::ruma::OwnedEventId;
use matrix_sdk::{self, room::Joined, Client};
use matrix_sdk::config::SyncSettings;
use std::clone::Clone; use std::clone::Clone;
use std::path::PathBuf;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use url::Url;
mod command_execution; mod command_execution;
pub mod event_handlers; pub mod event_handlers;
@ -21,6 +23,7 @@ const MAX_COMMANDS_PER_MESSAGE: usize = 50;
/// The DiceBot struct represents an active dice bot. The bot is not /// The DiceBot struct represents an active dice bot. The bot is not
/// connected to Matrix until its run() function is called. /// connected to Matrix until its run() function is called.
#[derive(Clone)]
pub struct DiceBot { pub struct DiceBot {
/// A reference to the configuration read in on application start. /// A reference to the configuration read in on application start.
config: Arc<Config>, config: Arc<Config>,
@ -35,22 +38,6 @@ pub struct DiceBot {
db: Database, db: Database,
} }
fn cache_dir() -> Result<PathBuf, BotError> {
let mut dir = dirs::cache_dir().ok_or(BotError::NoCacheDirectoryError)?;
dir.push("matrix-dicebot");
Ok(dir)
}
/// Creates the matrix client.
fn create_client(config: &Config) -> Result<Client, BotError> {
let cache_dir = cache_dir()?;
//let store = JsonStore::open(&cache_dir)?;
let client_config = ClientConfig::new().store_path(cache_dir);
let homeserver_url = Url::parse(&config.matrix_homeserver())?;
Ok(Client::new_with_config(homeserver_url, client_config)?)
}
impl DiceBot { impl DiceBot {
/// Create a new dicebot with the given configuration and state /// Create a new dicebot with the given configuration and state
/// actor. This function returns a Result because it is possible /// actor. This function returns a Result because it is possible
@ -60,9 +47,10 @@ impl DiceBot {
config: &Arc<Config>, config: &Arc<Config>,
state: &Arc<RwLock<DiceBotState>>, state: &Arc<RwLock<DiceBotState>>,
db: &Database, db: &Database,
client: &Client,
) -> Result<Self, BotError> { ) -> Result<Self, BotError> {
Ok(DiceBot { Ok(DiceBot {
client: create_client(&config)?, client: client.clone(),
config: config.clone(), config: config.clone(),
state: state.clone(), state: state.clone(),
db: db.clone(), db: db.clone(),
@ -81,12 +69,14 @@ impl DiceBot {
let device_id: Option<String> = self.db.get_device_id().await?; let device_id: Option<String> = self.db.get_device_id().await?;
let device_id: Option<&str> = device_id.as_deref(); let device_id: Option<&str> = device_id.as_deref();
client let no_device_ld_login = || client.login_username(username, password);
.login(username, password, device_id, Some("matrix dice bot")) let device_id_login = |id| client.login_username(username, password).device_id(id);
.await?; let login = device_id.map_or_else(no_device_ld_login, device_id_login);
login.send().await?;
if device_id.is_none() { if device_id.is_none() {
let device_id = client.device_id().await.ok_or(BotError::NoDeviceIdFound)?; let device_id = client.device_id().ok_or(BotError::NoDeviceIdFound)?;
self.db.set_device_id(device_id.as_str()).await?; self.db.set_device_id(device_id.as_str()).await?;
info!("Recorded new device ID: {}", device_id.as_str()); info!("Recorded new device ID: {}", device_id.as_str());
} else { } else {
@ -97,19 +87,35 @@ impl DiceBot {
Ok(()) Ok(())
} }
async fn bind_events(&self) {
//on room message: need closure to pass bot ref in.
self.client
.add_event_handler({
let bot: DiceBot = self.clone();
move |event: SyncMessageLikeEvent<RoomMessageEventContent>, room: Room| {
let bot = bot.clone();
async move { event_handlers::on_room_message(event, room, bot).await }
}
});
//auto-join handler
self.client
.add_event_handler(event_handlers::on_stripped_state_member);
}
/// Logs the bot in to Matrix and listens for events until program /// Logs the bot in to Matrix and listens for events until program
/// terminated, or a panic occurs. Originally adapted from the /// terminated, or a panic occurs. Originally adapted from the
/// matrix-rust-sdk command bot example. /// matrix-rust-sdk command bot example.
pub async fn run(self) -> Result<(), BotError> { pub async fn run(self) -> Result<(), BotError> {
let client = self.client.clone(); let client = self.client.clone();
self.login(&client).await?; self.login(&client).await?;
self.bind_events().await;
client.set_event_handler(Box::new(self)).await;
info!("Listening for commands"); info!("Listening for commands");
// TODO replace with sync_with_callback for cleaner shutdown // TODO replace with sync_with_callback for cleaner shutdown
// process. // process.
client.sync(SyncSettings::default()).await; client.sync(SyncSettings::default()).await?;
Ok(()) Ok(())
} }
@ -139,7 +145,7 @@ impl DiceBot {
&self, &self,
room: &Joined, room: &Joined,
sender_username: &str, sender_username: &str,
event_id: EventId, event_id: OwnedEventId,
results: Vec<(String, ExecutionResult)>, results: Vec<(String, ExecutionResult)>,
) { ) {
if results.len() >= 1 { if results.len() >= 1 {

View File

@ -332,7 +332,7 @@ mod tests {
macro_rules! dummy_room { macro_rules! dummy_room {
() => { () => {
crate::context::RoomContext { crate::context::RoomContext {
id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"), id: &matrix_sdk::ruma::room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(), display_name: "displayname".to_owned(),
secure: false, secure: false,
} }
@ -485,8 +485,9 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: dummy_room!(), origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };
@ -526,8 +527,9 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: dummy_room!(), origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };
@ -564,15 +566,21 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db.clone(), db: db.clone(),
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: dummy_room!(), origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };
db.set_user_variable(&ctx.username, &ctx.room.id.as_str(), "myvariable", 10) db.set_user_variable(
.await &ctx.username,
.expect("could not set myvariable to 10"); &ctx.origin_room.id.as_str(),
"myvariable",
10,
)
.await
.expect("could not set myvariable to 10");
let amounts = vec![Amount { let amounts = vec![Amount {
operator: Operator::Plus, operator: Operator::Plus,

View File

@ -45,13 +45,13 @@ pub fn parse_modifiers(input: &str) -> Result<DicePoolModifiers, DiceParsingErro
let (result, rest) = parser.parse(input)?; let (result, rest) = parser.parse(input)?;
if rest.len() == 0 { if rest.len() == 0 {
convert_to_info(&result) convert_to_modifiers(&result)
} else { } else {
Err(DiceParsingError::UnconsumedInput) Err(DiceParsingError::UnconsumedInput)
} }
} }
fn convert_to_info(parsed: &Vec<ParsedInfo>) -> Result<DicePoolModifiers, DiceParsingError> { fn convert_to_modifiers(parsed: &Vec<ParsedInfo>) -> Result<DicePoolModifiers, DiceParsingError> {
use ParsedInfo::*; use ParsedInfo::*;
if parsed.len() == 0 { if parsed.len() == 0 {
Ok(DicePoolModifiers::default()) Ok(DicePoolModifiers::default())
@ -79,19 +79,8 @@ fn convert_to_info(parsed: &Vec<ParsedInfo>) -> Result<DicePoolModifiers, DicePa
} }
pub fn parse_dice_pool(input: &str) -> Result<DicePool, BotError> { pub fn parse_dice_pool(input: &str) -> Result<DicePool, BotError> {
//The "modifiers:" part is optional. Assume amounts if no modifier let (amounts, modifiers_str) = parse_amounts(input)?;
//section found.
let split = input.split(":").collect::<Vec<_>>();
let (modifiers_str, amounts_str) = (match split[..] {
[amounts] => Ok(("", amounts)),
[modifiers, amounts] => Ok((modifiers, amounts)),
_ => Err(BotError::DiceParsingError(
DiceParsingError::UnconsumedInput,
)),
})?;
let modifiers = parse_modifiers(modifiers_str)?; let modifiers = parse_modifiers(modifiers_str)?;
let amounts = parse_amounts(&amounts_str)?;
Ok(DicePool::new(amounts, modifiers)) Ok(DicePool::new(amounts, modifiers))
} }
@ -175,7 +164,7 @@ mod tests {
#[test] #[test]
fn dice_pool_number_with_quality() { fn dice_pool_number_with_quality() {
let result = parse_dice_pool("n:8"); let result = parse_dice_pool("8 n");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
@ -186,7 +175,7 @@ mod tests {
#[test] #[test]
fn dice_pool_number_with_success_change() { fn dice_pool_number_with_success_change() {
let modifiers = DicePoolModifiers::custom_exceptional_on(3); let modifiers = DicePoolModifiers::custom_exceptional_on(3);
let result = parse_dice_pool("s3:8"); let result = parse_dice_pool("8 s3");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), DicePool::easy_with_modifiers(8, modifiers)); assert_eq!(result.unwrap(), DicePool::easy_with_modifiers(8, modifiers));
} }
@ -194,7 +183,7 @@ mod tests {
#[test] #[test]
fn dice_pool_with_quality_and_success_change() { fn dice_pool_with_quality_and_success_change() {
let modifiers = DicePoolModifiers::custom(DicePoolQuality::Rote, 3); let modifiers = DicePoolModifiers::custom(DicePoolQuality::Rote, 3);
let result = parse_dice_pool("rs3:8"); let result = parse_dice_pool("8 rs3");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), DicePool::easy_with_modifiers(8, modifiers)); assert_eq!(result.unwrap(), DicePool::easy_with_modifiers(8, modifiers));
} }
@ -224,20 +213,20 @@ mod tests {
let expected = DicePool::new(amounts, modifiers); let expected = DicePool::new(amounts, modifiers);
let result = parse_dice_pool("rs3:8+10-2+varname"); let result = parse_dice_pool("8+10-2+varname rs3");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
let result = parse_dice_pool("rs3:8+10- 2 + varname"); let result = parse_dice_pool("8+10- 2 + varname rs3");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
let result = parse_dice_pool("rs3 : 8+ 10 -2 + varname"); let result = parse_dice_pool("8+ 10 -2 + varname rs3");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
//This one has tabs in it. //This one has tabs in it.
let result = parse_dice_pool(" r s3 : 8 + 10 -2 + varname"); let result = parse_dice_pool(" 8 + 10 -2 + varname r s3");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
} }

View File

@ -10,12 +10,6 @@ use std::convert::TryFrom;
pub struct RollCommand(pub ElementExpression); pub struct RollCommand(pub ElementExpression);
impl From<RollCommand> for Box<dyn Command> {
fn from(cmd: RollCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for RollCommand { impl TryFrom<String> for RollCommand {
type Error = BotError; type Error = BotError;

View File

@ -15,12 +15,6 @@ impl PoolRollCommand {
} }
} }
impl From<PoolRollCommand> for Box<dyn Command> {
fn from(cmd: PoolRollCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for PoolRollCommand { impl TryFrom<String> for PoolRollCommand {
type Error = BotError; type Error = BotError;

View File

@ -11,12 +11,6 @@ use std::convert::TryFrom;
pub struct CthRoll(pub DiceRoll); pub struct CthRoll(pub DiceRoll);
impl From<CthRoll> for Box<dyn Command> {
fn from(cmd: CthRoll) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for CthRoll { impl TryFrom<String> for CthRoll {
type Error = BotError; type Error = BotError;
@ -51,12 +45,6 @@ impl Command for CthRoll {
pub struct CthAdvanceRoll(pub AdvancementRoll); pub struct CthAdvanceRoll(pub AdvancementRoll);
impl From<CthAdvanceRoll> for Box<dyn Command> {
fn from(cmd: CthAdvanceRoll) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for CthAdvanceRoll { impl TryFrom<String> for CthAdvanceRoll {
type Error = BotError; type Error = BotError;

View File

@ -9,12 +9,6 @@ use std::convert::{Into, TryFrom};
pub struct RegisterCommand; pub struct RegisterCommand;
impl From<RegisterCommand> for Box<dyn Command> {
fn from(cmd: RegisterCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for RegisterCommand { impl TryFrom<String> for RegisterCommand {
type Error = BotError; type Error = BotError;
@ -56,12 +50,6 @@ impl Command for RegisterCommand {
pub struct UnlinkCommand(pub String); pub struct UnlinkCommand(pub String);
impl From<UnlinkCommand> for Box<dyn Command> {
fn from(cmd: UnlinkCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for UnlinkCommand { impl TryFrom<String> for UnlinkCommand {
type Error = BotError; type Error = BotError;
@ -99,12 +87,6 @@ impl Command for UnlinkCommand {
pub struct LinkCommand(pub String); pub struct LinkCommand(pub String);
impl From<LinkCommand> for Box<dyn Command> {
fn from(cmd: LinkCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for LinkCommand { impl TryFrom<String> for LinkCommand {
type Error = BotError; type Error = BotError;
@ -144,12 +126,6 @@ impl Command for LinkCommand {
pub struct CheckCommand; pub struct CheckCommand;
impl From<CheckCommand> for Box<dyn Command> {
fn from(cmd: CheckCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for CheckCommand { impl TryFrom<String> for CheckCommand {
type Error = BotError; type Error = BotError;
@ -174,14 +150,17 @@ impl Command for CheckCommand {
match user { match user {
Some(user) => match user.password { Some(user) => match user.password {
Some(_) => Execution::success( Some(_) => Execution::success(
"Account exists, and is available to external applications with a password. If you forgot your password, change it with !link.".to_string(), "Account exists, and is available to external applications with a password. \
If you forgot your password, change it with !link."
.to_string(),
), ),
None => Execution::success( None => Execution::success(
"Account exists, but is not available to external applications.".to_string(), "Account exists, but is not available to external applications.".to_string(),
), ),
}, },
None => Execution::success( None => Execution::success(
"No account registered. Only simple commands in public rooms are available.".to_string(), "No account registered. Only simple commands in public rooms are available."
.to_string(),
), ),
} }
} }
@ -189,12 +168,6 @@ impl Command for CheckCommand {
pub struct UnregisterCommand; pub struct UnregisterCommand;
impl From<UnregisterCommand> for Box<dyn Command> {
fn from(cmd: UnregisterCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for UnregisterCommand { impl TryFrom<String> for UnregisterCommand {
type Error = BotError; type Error = BotError;

View File

@ -7,12 +7,6 @@ use std::convert::TryFrom;
pub struct HelpCommand(pub Option<HelpTopic>); pub struct HelpCommand(pub Option<HelpTopic>);
impl From<HelpCommand> for Box<dyn Command> {
fn from(cmd: HelpCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for HelpCommand { impl TryFrom<String> for HelpCommand {
type Error = BotError; type Error = BotError;

View File

@ -112,9 +112,8 @@ fn execution_allowed(cmd: &(impl Command + ?Sized), ctx: &Context<'_>) -> Result
} }
/// Attempt to execute a command, and return the content that should /// Attempt to execute a command, and return the content that should
/// go back to Matrix, if the command was executed (successfully or /// go back to Matrix, if the command was executed, whether or not the
/// not). If a command is determined to be ignored, this function will /// command was successful.
/// return None, signifying that we should not send a response.
pub async fn execute_command(ctx: &Context<'_>) -> ExecutionResult { pub async fn execute_command(ctx: &Context<'_>) -> ExecutionResult {
let cmd = parser::parse_command(&ctx.message_body)?; let cmd = parser::parse_command(&ctx.message_body)?;
@ -146,13 +145,13 @@ fn log_command(cmd: &(impl Command + ?Sized), ctx: &Context, result: &ExecutionR
Ok(_) => { Ok(_) => {
info!( info!(
"[{}] {} <{}{}> - success", "[{}] {} <{}{}> - success",
ctx.room.display_name, ctx.username, command, dots ctx.origin_room.display_name, ctx.username, command, dots
); );
} }
Err(e) => { Err(e) => {
error!( error!(
"[{}] {} <{}{}> - {}", "[{}] {} <{}{}> - {}",
ctx.room.display_name, ctx.username, command, dots, e ctx.origin_room.display_name, ctx.username, command, dots, e
); );
} }
}; };
@ -163,11 +162,12 @@ mod tests {
use super::*; use super::*;
use management::RegisterCommand; use management::RegisterCommand;
use url::Url; use url::Url;
use matrix_sdk::ruma::room_id;
macro_rules! dummy_room { macro_rules! dummy_room {
() => { () => {
crate::context::RoomContext { crate::context::RoomContext {
id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"), id: &room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(), display_name: "displayname".to_owned(),
secure: false, secure: false,
} }
@ -177,7 +177,7 @@ mod tests {
macro_rules! secure_room { macro_rules! secure_room {
() => { () => {
crate::context::RoomContext { crate::context::RoomContext {
id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"), id: &room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(), display_name: "displayname".to_owned(),
secure: true, secure: true,
} }
@ -196,8 +196,9 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: secure_room!(), origin_room: secure_room!(),
active_room: secure_room!(),
username: "myusername", username: "myusername",
message_body: "!notacommand", message_body: "!notacommand",
}; };
@ -218,8 +219,9 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: secure_room!(), origin_room: secure_room!(),
active_room: secure_room!(),
username: "myusername", username: "myusername",
message_body: "!notacommand", message_body: "!notacommand",
}; };
@ -240,8 +242,9 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: dummy_room!(), origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "myusername", username: "myusername",
message_body: "!notacommand", message_body: "!notacommand",
}; };
@ -262,8 +265,9 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: dummy_room!(), origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "myusername", username: "myusername",
message_body: "!notacommand", message_body: "!notacommand",
}; };
@ -284,8 +288,9 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: dummy_room!(), origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "myusername", username: "myusername",
message_body: "!notacommand", message_body: "!notacommand",
}; };

View File

@ -65,7 +65,7 @@ fn split_command(input: &str) -> Result<(String, String), CommandParsingError> {
/// boilerplate. /// boilerplate.
macro_rules! convert_to { macro_rules! convert_to {
($type:ident, $input: expr) => { ($type:ident, $input: expr) => {
$type::try_from($input).map(Into::into) $type::try_from($input).map(|cmd| Box::new(cmd) as Box<dyn Command>)
}; };
} }
@ -81,7 +81,7 @@ pub fn parse_command(input: &str) -> Result<Box<dyn Command>, BotError> {
"del" => convert_to!(DeleteVariableCommand, cmd_input), "del" => convert_to!(DeleteVariableCommand, cmd_input),
"r" | "roll" => convert_to!(RollCommand, cmd_input), "r" | "roll" => convert_to!(RollCommand, cmd_input),
"rp" | "pool" => convert_to!(PoolRollCommand, cmd_input), "rp" | "pool" => convert_to!(PoolRollCommand, cmd_input),
"chance" => PoolRollCommand::chance_die().map(Into::into), "chance" => PoolRollCommand::chance_die().map(|cmd| Box::new(cmd) as Box<dyn Command>),
"cthroll" => convert_to!(CthRoll, cmd_input), "cthroll" => convert_to!(CthRoll, cmd_input),
"cthadv" | "ctharoll" => convert_to!(CthAdvanceRoll, cmd_input), "cthadv" | "ctharoll" => convert_to!(CthAdvanceRoll, cmd_input),
"help" => convert_to!(HelpCommand, cmd_input), "help" => convert_to!(HelpCommand, cmd_input),
@ -221,9 +221,9 @@ mod tests {
#[test] #[test]
fn pool_whitespace_test() { fn pool_whitespace_test() {
parse_command("!pool ns3:8 ").expect("was error"); parse_command("!pool 8 ns3 ").expect("was error");
parse_command(" !pool ns3:8").expect("was error"); parse_command(" !pool 8 ns3").expect("was error");
parse_command(" !pool ns3:8 ").expect("was error"); parse_command(" !pool 8 ns3 ").expect("was error");
} }
#[test] #[test]

View File

@ -1,11 +1,12 @@
use super::{Command, Execution, ExecutionResult}; use super::{Command, Execution, ExecutionResult};
use crate::context::Context; use crate::context::Context;
use crate::db::Users;
use crate::error::BotError; use crate::error::BotError;
use crate::matrix; use crate::matrix;
use async_trait::async_trait; use async_trait::async_trait;
use fuse_rust::{Fuse, FuseProperty, Fuseable}; use fuse_rust::{Fuse, FuseProperty, Fuseable};
use futures::stream::{self, StreamExt, TryStreamExt}; use futures::stream::{self, StreamExt, TryStreamExt};
use matrix_sdk::{identifiers::UserId, Client}; use matrix_sdk::{ruma::OwnedUserId, Client};
use std::convert::TryFrom; use std::convert::TryFrom;
/// Holds matrix room ID and display name as strings, for use with /// Holds matrix room ID and display name as strings, for use with
@ -20,17 +21,17 @@ struct RoomNameAndId {
/// searching room display names directly. /// searching room display names directly.
impl Fuseable for RoomNameAndId { impl Fuseable for RoomNameAndId {
fn properties(&self) -> Vec<FuseProperty> { fn properties(&self) -> Vec<FuseProperty> {
return vec![FuseProperty { vec![FuseProperty {
value: String::from("name"), value: String::from("name"),
weight: 1.0, weight: 1.0,
}]; }]
} }
fn lookup(&self, key: &str) -> Option<&str> { fn lookup(&self, key: &str) -> Option<&str> {
return match key { match key {
"name" => Some(&self.name), "name" => Some(&self.name),
_ => None, _ => None,
}; }
} }
} }
@ -61,29 +62,29 @@ async fn get_rooms_for_user(
client: &Client, client: &Client,
user_id: &str, user_id: &str,
) -> Result<Vec<RoomNameAndId>, BotError> { ) -> Result<Vec<RoomNameAndId>, BotError> {
let user_id = UserId::try_from(user_id)?; let user_id = OwnedUserId::try_from(user_id)?;
let rooms_for_user = matrix::get_rooms_for_user(client, &user_id).await?; let rooms_for_user = matrix::get_rooms_for_user(client, &user_id).await?;
let rooms_for_user: Vec<RoomNameAndId> = stream::iter(rooms_for_user) let mut rooms_for_user: Vec<RoomNameAndId> = stream::iter(rooms_for_user)
.filter_map(|room| async move { .filter_map(|room| async move {
Some(room.display_name().await.map(|room_name| RoomNameAndId { Some(room.display_name().await.map(|room_name| RoomNameAndId {
id: room.room_id().to_string(), id: room.room_id().to_string(),
name: room_name, name: room_name.to_string(),
})) }))
}) })
.try_collect() .try_collect()
.await?; .await?;
//Alphabetically descending, symbols first, ignore case.
let sort = |r1: &RoomNameAndId, r2: &RoomNameAndId| {
r1.name.to_lowercase().cmp(&r2.name.to_lowercase())
};
rooms_for_user.sort_by(sort);
Ok(rooms_for_user) Ok(rooms_for_user)
} }
pub struct ListRoomsCommand; pub struct ListRoomsCommand;
impl From<ListRoomsCommand> for Box<dyn Command> {
fn from(cmd: ListRoomsCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for ListRoomsCommand { impl TryFrom<String> for ListRoomsCommand {
type Error = BotError; type Error = BotError;
@ -103,7 +104,7 @@ impl Command for ListRoomsCommand {
} }
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let rooms_for_user: Vec<String> = get_rooms_for_user(ctx.matrix_client, ctx.username) let rooms_for_user: Vec<String> = get_rooms_for_user(&ctx.matrix_client, ctx.username)
.await .await
.map(|rooms| { .map(|rooms| {
rooms rooms
@ -119,12 +120,6 @@ impl Command for ListRoomsCommand {
pub struct SetRoomCommand(String); pub struct SetRoomCommand(String);
impl From<SetRoomCommand> for Box<dyn Command> {
fn from(cmd: SetRoomCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for SetRoomCommand { impl TryFrom<String> for SetRoomCommand {
type Error = BotError; type Error = BotError;
@ -144,10 +139,22 @@ impl Command for SetRoomCommand {
} }
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let rooms_for_user = get_rooms_for_user(ctx.matrix_client, ctx.username).await?; if !ctx.account.is_registered() {
return Err(BotError::AccountDoesNotExist);
}
let rooms_for_user = get_rooms_for_user(&ctx.matrix_client, ctx.username).await?;
let room = search_for_room(&rooms_for_user, &self.0); let room = search_for_room(&rooms_for_user, &self.0);
if let Some(room) = room { if let Some(room) = room {
let mut new_user = ctx
.account
.registered_user()
.cloned()
.ok_or(BotError::AccountDoesNotExist)?;
new_user.active_room = Some(room.id.clone());
ctx.db.upsert_user(&new_user).await?;
Execution::success(format!(r#"Active room set to "{}""#, room.name)) Execution::success(format!(r#"Active room set to "{}""#, room.name))
} else { } else {
Err(BotError::RoomDoesNotExist) Err(BotError::RoomDoesNotExist)

View File

@ -8,12 +8,6 @@ use std::convert::TryFrom;
pub struct GetAllVariablesCommand; pub struct GetAllVariablesCommand;
impl From<GetAllVariablesCommand> for Box<dyn Command> {
fn from(cmd: GetAllVariablesCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for GetAllVariablesCommand { impl TryFrom<String> for GetAllVariablesCommand {
type Error = BotError; type Error = BotError;
@ -35,7 +29,7 @@ impl Command for GetAllVariablesCommand {
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let variables = ctx let variables = ctx
.db .db
.get_user_variables(&ctx.username, ctx.room_id().as_str()) .get_user_variables(&ctx.username, ctx.active_room_id().as_str())
.await?; .await?;
let mut variable_list: Vec<String> = variables let mut variable_list: Vec<String> = variables
@ -57,12 +51,6 @@ impl Command for GetAllVariablesCommand {
pub struct GetVariableCommand(pub String); pub struct GetVariableCommand(pub String);
impl From<GetVariableCommand> for Box<dyn Command> {
fn from(cmd: GetVariableCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for GetVariableCommand { impl TryFrom<String> for GetVariableCommand {
type Error = BotError; type Error = BotError;
@ -85,7 +73,7 @@ impl Command for GetVariableCommand {
let name = &self.0; let name = &self.0;
let result = ctx let result = ctx
.db .db
.get_user_variable(&ctx.username, ctx.room_id().as_str(), name) .get_user_variable(&ctx.username, ctx.active_room_id().as_str(), name)
.await; .await;
let value = match result { let value = match result {
@ -101,12 +89,6 @@ impl Command for GetVariableCommand {
pub struct SetVariableCommand(pub String, pub i32); pub struct SetVariableCommand(pub String, pub i32);
impl From<SetVariableCommand> for Box<dyn Command> {
fn from(cmd: SetVariableCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for SetVariableCommand { impl TryFrom<String> for SetVariableCommand {
type Error = BotError; type Error = BotError;
@ -131,7 +113,7 @@ impl Command for SetVariableCommand {
let value = self.1; let value = self.1;
ctx.db ctx.db
.set_user_variable(&ctx.username, ctx.room_id().as_str(), name, value) .set_user_variable(&ctx.username, ctx.active_room_id().as_str(), name, value)
.await?; .await?;
let content = format!("{} = {}", name, value); let content = format!("{} = {}", name, value);
@ -142,12 +124,6 @@ impl Command for SetVariableCommand {
pub struct DeleteVariableCommand(pub String); pub struct DeleteVariableCommand(pub String);
impl From<DeleteVariableCommand> for Box<dyn Command> {
fn from(cmd: DeleteVariableCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for DeleteVariableCommand { impl TryFrom<String> for DeleteVariableCommand {
type Error = BotError; type Error = BotError;
@ -170,7 +146,7 @@ impl Command for DeleteVariableCommand {
let name = &self.0; let name = &self.0;
let result = ctx let result = ctx
.db .db
.delete_user_variable(&ctx.username, ctx.room_id().as_str(), name) .delete_user_variable(&ctx.username, ctx.active_room_id().as_str(), name)
.await; .await;
let value = match result { let value = match result {

View File

@ -4,10 +4,6 @@ use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
use thiserror::Error; use thiserror::Error;
/// Shortcut to defining db migration versions. Will probably
/// eventually be moved to a config file.
const MIGRATION_VERSION: u32 = 5;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ConfigError { pub enum ConfigError {
#[error("i/o error: {0}")] #[error("i/o error: {0}")]
@ -53,10 +49,19 @@ fn db_path_from_env() -> String {
} }
/// The "bot" section of the config file, for bot settings. /// The "bot" section of the config file, for bot settings.
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug, Default)]
struct BotConfig { struct BotConfig {
/// How far back from current time should we process a message? /// How far back from current time should we process a message?
oldest_message_age: Option<u64>, oldest_message_age: Option<u64>,
/// What address and port to run the RPC service on. If not
/// specified, RPC will not be enabled.
rpc_addr: Option<String>,
/// The shared secret key between the bot and any RPC clients that
/// want to connect to it. The RPC server will reject any clients
/// that don't present the shared key.
rpc_key: Option<String>,
} }
/// The "database" section of the config file. /// The "database" section of the config file.
@ -84,6 +89,18 @@ impl BotConfig {
self.oldest_message_age self.oldest_message_age
.unwrap_or(DEFAULT_OLDEST_MESSAGE_AGE) .unwrap_or(DEFAULT_OLDEST_MESSAGE_AGE)
} }
#[inline]
#[must_use]
fn rpc_addr(&self) -> Option<String> {
self.rpc_addr.clone()
}
#[inline]
#[must_use]
fn rpc_key(&self) -> Option<String> {
self.rpc_key.clone()
}
} }
/// Represents the toml config file for the dicebot. The sections of /// Represents the toml config file for the dicebot. The sections of
@ -128,15 +145,6 @@ impl Config {
.unwrap_or_else(|| db_path_from_env()) .unwrap_or_else(|| db_path_from_env())
} }
/// The current migration version we expect of the database. If
/// this number is higher than the one in the database, we will
/// execute migrations to update the data.
#[inline]
#[must_use]
pub fn migration_version(&self) -> u32 {
MIGRATION_VERSION
}
/// Figure out the allowed oldest message age, in seconds. This will /// Figure out the allowed oldest message age, in seconds. This will
/// be the defined oldest message age in the bot config, if the bot /// be the defined oldest message age in the bot config, if the bot
/// configuration and associated "oldest_message_age" setting are /// configuration and associated "oldest_message_age" setting are
@ -150,6 +158,18 @@ impl Config {
.map(|bc| bc.oldest_message_age()) .map(|bc| bc.oldest_message_age())
.unwrap_or(DEFAULT_OLDEST_MESSAGE_AGE) .unwrap_or(DEFAULT_OLDEST_MESSAGE_AGE)
} }
#[inline]
#[must_use]
pub fn rpc_addr(&self) -> Option<String> {
self.bot.as_ref().and_then(|bc| bc.rpc_addr())
}
#[inline]
#[must_use]
pub fn rpc_key(&self) -> Option<String> {
self.bot.as_ref().and_then(|bc| bc.rpc_key())
}
} }
#[cfg(test)] #[cfg(test)]
@ -169,6 +189,7 @@ mod tests {
}), }),
bot: Some(BotConfig { bot: Some(BotConfig {
oldest_message_age: None, oldest_message_age: None,
..Default::default()
}), }),
}; };

View File

@ -1,8 +1,8 @@
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::error::BotError; use crate::error::BotError;
use crate::models::Account; use crate::models::Account;
use matrix_sdk::identifiers::{RoomId, UserId};
use matrix_sdk::room::Joined; use matrix_sdk::room::Joined;
use matrix_sdk::ruma::{RoomId, UserId};
use matrix_sdk::Client; use matrix_sdk::Client;
use std::convert::TryFrom; use std::convert::TryFrom;
@ -11,20 +11,25 @@ use std::convert::TryFrom;
#[derive(Clone)] #[derive(Clone)]
pub struct Context<'a> { pub struct Context<'a> {
pub db: Database, pub db: Database,
pub matrix_client: &'a Client, pub matrix_client: Client,
pub room: RoomContext<'a>, pub origin_room: RoomContext<'a>,
pub active_room: RoomContext<'a>,
pub username: &'a str, pub username: &'a str,
pub message_body: &'a str, pub message_body: &'a str,
pub account: Account, pub account: Account,
} }
impl Context<'_> { impl Context<'_> {
pub fn active_room_id(&self) -> &RoomId {
self.active_room.id
}
pub fn room_id(&self) -> &RoomId { pub fn room_id(&self) -> &RoomId {
self.room.id self.origin_room.id
} }
pub fn is_secure(&self) -> bool { pub fn is_secure(&self) -> bool {
self.room.secure self.origin_room.secure
} }
} }
@ -38,15 +43,22 @@ pub struct RoomContext<'a> {
impl RoomContext<'_> { impl RoomContext<'_> {
pub async fn new_with_name<'a>( pub async fn new_with_name<'a>(
room: &'a Joined, room: &'a Joined,
display_name: String,
sending_user: &str, sending_user: &str,
) -> Result<RoomContext<'a>, BotError> { ) -> Result<RoomContext<'a>, BotError> {
// TODO is_direct is a hack; should set rooms to Direct // TODO is_direct is a hack; the bot should set eligible rooms
// Message upon joining, if other contact has requested it. // to Direct Message upon joining, if other contact has
// Waiting on SDK support. // requested it. Waiting on SDK support.
let sending_user = UserId::try_from(sending_user)?; let display_name =
let user_in_room = room.get_member(&sending_user).await.ok().is_some(); room
let is_direct = room.joined_members().await?.len() == 2; .display_name()
.await
.ok()
.map(|d| d.to_string())
.unwrap_or_default();
let sending_user = <&UserId>::try_from(sending_user)?;
let user_in_room = room.get_member(sending_user).await.ok().is_some();
let is_direct = room.active_members().await?.len() == 2;
Ok(RoomContext { Ok(RoomContext {
id: room.room_id(), id: room.room_id(),
@ -57,17 +69,8 @@ impl RoomContext<'_> {
pub async fn new<'a>( pub async fn new<'a>(
room: &'a Joined, room: &'a Joined,
sending_user: &str, sending_user: &'a str,
) -> Result<RoomContext<'a>, BotError> { ) -> Result<RoomContext<'a>, BotError> {
Self::new_with_name( Self::new_with_name(room, sending_user).await
&room,
room.display_name()
.await
.ok()
.unwrap_or_default()
.to_string(),
sending_user,
)
.await
} }
} }

View File

@ -270,7 +270,7 @@ macro_rules! is_variable {
element: Element::Variable(_), element: Element::Variable(_),
.. ..
} }
); )
}; };
} }
@ -380,7 +380,12 @@ async fn update_skill(ctx: &Context<'_>, variable: &str, value: u32) -> Result<(
use std::convert::TryInto; use std::convert::TryInto;
let value: i32 = value.try_into()?; let value: i32 = value.try_into()?;
ctx.db ctx.db
.set_user_variable(&ctx.username, &ctx.room_id().as_str(), variable, value) .set_user_variable(
&ctx.username,
&ctx.active_room_id().as_str(),
variable,
value,
)
.await?; .await?;
Ok(()) Ok(())
} }
@ -422,11 +427,12 @@ mod tests {
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::parser::dice::{Amount, Element, Operator}; use crate::parser::dice::{Amount, Element, Operator};
use url::Url; use url::Url;
use matrix_sdk::ruma::room_id;
macro_rules! dummy_room { macro_rules! dummy_room {
() => { () => {
crate::context::RoomContext { crate::context::RoomContext {
id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"), id: &room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(), display_name: "displayname".to_owned(),
secure: false, secure: false,
} }
@ -506,8 +512,9 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: dummy_room!(), origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };
@ -543,8 +550,9 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: dummy_room!(), origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };
@ -580,8 +588,9 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(),
room: dummy_room!(), origin_room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };

View File

@ -4,16 +4,13 @@ use crate::parser::dice::DiceParsingError;
//TOOD convert these to use parse_amounts from the common dice code. //TOOD convert these to use parse_amounts from the common dice code.
fn parse_modifier(input: &str) -> Result<DiceRollModifier, DiceParsingError> { fn parse_modifier(input: &str) -> Result<DiceRollModifier, DiceParsingError> {
if input.ends_with("bb") { match input.trim() {
Ok(DiceRollModifier::TwoBonus) "bb" => Ok(DiceRollModifier::TwoBonus),
} else if input.ends_with("b") { "b" => Ok(DiceRollModifier::OneBonus),
Ok(DiceRollModifier::OneBonus) "pp" => Ok(DiceRollModifier::TwoPenalty),
} else if input.ends_with("pp") { "p" => Ok(DiceRollModifier::OnePenalty),
Ok(DiceRollModifier::TwoPenalty) "" => Ok(DiceRollModifier::Normal),
} else if input.ends_with("p") { _ => Err(DiceParsingError::InvalidModifiers),
Ok(DiceRollModifier::OnePenalty)
} else {
Ok(DiceRollModifier::Normal)
} }
} }
@ -21,32 +18,70 @@ fn parse_modifier(input: &str) -> Result<DiceRollModifier, DiceParsingError> {
//Split based on :, send first part to parse_modifier. //Split based on :, send first part to parse_modifier.
//Send second part to parse_amounts //Send second part to parse_amounts
pub fn parse_regular_roll(input: &str) -> Result<DiceRoll, DiceParsingError> { pub fn parse_regular_roll(input: &str) -> Result<DiceRoll, DiceParsingError> {
let input: Vec<&str> = input.trim().split(":").collect(); let (amount, modifiers_str) = crate::parser::dice::parse_single_amount(input)?;
let (modifiers_str, amounts_str) = match input[..] {
[amounts] => Ok(("", amounts)),
[modifiers, amounts] => Ok((modifiers, amounts)),
_ => Err(DiceParsingError::UnconsumedInput),
}?;
let modifier = parse_modifier(modifiers_str)?; let modifier = parse_modifier(modifiers_str)?;
let amount = crate::parser::dice::parse_single_amount(amounts_str)?;
Ok(DiceRoll { modifier, amount }) Ok(DiceRoll { modifier, amount })
} }
pub fn parse_advancement_roll(input: &str) -> Result<AdvancementRoll, DiceParsingError> { pub fn parse_advancement_roll(input: &str) -> Result<AdvancementRoll, DiceParsingError> {
let input = input.trim(); let input = input.trim();
let amounts = crate::parser::dice::parse_single_amount(input)?; let (amounts, unconsumed_input) = crate::parser::dice::parse_single_amount(input)?;
Ok(AdvancementRoll { if unconsumed_input.len() == 0 {
existing_skill: amounts, Ok(AdvancementRoll {
}) existing_skill: amounts,
})
} else {
Err(DiceParsingError::InvalidAmount)
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::parser::dice::{Amount, Element, Operator}; use crate::parser::dice::{Amount, DiceParsingError, Element, Operator};
#[test]
fn parse_modifier_rejects_bad_value() {
let modifier = parse_modifier("qqq");
assert!(matches!(modifier, Err(DiceParsingError::InvalidModifiers)))
}
#[test]
fn parse_modifier_accepts_one_bonus() {
let modifier = parse_modifier("b");
assert!(matches!(modifier, Ok(DiceRollModifier::OneBonus)))
}
#[test]
fn parse_modifier_accepts_two_bonus() {
let modifier = parse_modifier("bb");
assert!(matches!(modifier, Ok(DiceRollModifier::TwoBonus)))
}
#[test]
fn parse_modifier_accepts_two_penalty() {
let modifier = parse_modifier("pp");
assert!(matches!(modifier, Ok(DiceRollModifier::TwoPenalty)))
}
#[test]
fn parse_modifier_accepts_one_penalty() {
let modifier = parse_modifier("p");
assert!(matches!(modifier, Ok(DiceRollModifier::OnePenalty)))
}
#[test]
fn parse_modifier_accepts_normal() {
let modifier = parse_modifier("");
assert!(matches!(modifier, Ok(DiceRollModifier::Normal)))
}
#[test]
fn parse_modifier_accepts_normal_unaffected_by_whitespace() {
let modifier = parse_modifier(" ");
assert!(matches!(modifier, Ok(DiceRollModifier::Normal)))
}
#[test] #[test]
fn regular_roll_accepts_single_number() { fn regular_roll_accepts_single_number() {
@ -72,7 +107,7 @@ mod tests {
#[test] #[test]
fn regular_roll_accepts_two_bonus() { fn regular_roll_accepts_two_bonus() {
let result = parse_regular_roll("bb:60"); let result = parse_regular_roll("60 bb");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
DiceRoll { DiceRoll {
@ -88,7 +123,7 @@ mod tests {
#[test] #[test]
fn regular_roll_accepts_one_bonus() { fn regular_roll_accepts_one_bonus() {
let result = parse_regular_roll("b:60"); let result = parse_regular_roll("60 b");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
DiceRoll { DiceRoll {
@ -104,7 +139,7 @@ mod tests {
#[test] #[test]
fn regular_roll_accepts_two_penalty() { fn regular_roll_accepts_two_penalty() {
let result = parse_regular_roll("pp:60"); let result = parse_regular_roll("60 pp");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
DiceRoll { DiceRoll {
@ -120,7 +155,7 @@ mod tests {
#[test] #[test]
fn regular_roll_accepts_one_penalty() { fn regular_roll_accepts_one_penalty() {
let result = parse_regular_roll("p:60"); let result = parse_regular_roll("60 p");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
DiceRoll { DiceRoll {
@ -140,21 +175,21 @@ mod tests {
assert!(parse_regular_roll(" 60").is_ok()); assert!(parse_regular_roll(" 60").is_ok());
assert!(parse_regular_roll(" 60 ").is_ok()); assert!(parse_regular_roll(" 60 ").is_ok());
assert!(parse_regular_roll("bb:60 ").is_ok()); assert!(parse_regular_roll("60bb ").is_ok());
assert!(parse_regular_roll(" bb:60").is_ok()); assert!(parse_regular_roll(" 60 bb").is_ok());
assert!(parse_regular_roll(" bb:60 ").is_ok()); assert!(parse_regular_roll(" 60 bb ").is_ok());
assert!(parse_regular_roll("b:60 ").is_ok()); assert!(parse_regular_roll("60b ").is_ok());
assert!(parse_regular_roll(" b:60").is_ok()); assert!(parse_regular_roll(" 60 b").is_ok());
assert!(parse_regular_roll(" b:60 ").is_ok()); assert!(parse_regular_roll(" 60 b ").is_ok());
assert!(parse_regular_roll("pp:60 ").is_ok()); assert!(parse_regular_roll("60pp ").is_ok());
assert!(parse_regular_roll(" pp:60").is_ok()); assert!(parse_regular_roll(" 60 pp").is_ok());
assert!(parse_regular_roll(" pp:60 ").is_ok()); assert!(parse_regular_roll(" 60 pp ").is_ok());
assert!(parse_regular_roll("p:60 ").is_ok()); assert!(parse_regular_roll("60p ").is_ok());
assert!(parse_regular_roll(" p:60").is_ok()); assert!(parse_regular_roll(" 60p ").is_ok());
assert!(parse_regular_roll(" p:60 ").is_ok()); assert!(parse_regular_roll(" 60 p ").is_ok());
} }
#[test] #[test]

View File

@ -0,0 +1,22 @@
use crate::systems::GameSystem;
use barrel::backend::Sqlite;
use barrel::{types, types::Type, Migration};
use itertools::Itertools;
use strum::IntoEnumIterator;
fn primary_id() -> Type {
types::text().unique(true).primary(true).nullable(false)
}
pub fn migration() -> String {
let mut m = Migration::new();
//Normally we would add a CHECK clause here, but types::custom requires a 'static string.
//Which means we can't automagically generate one from the enum.
m.create_table("room_info", move |t| {
t.add_column("room_id", primary_id());
t.add_column("game_system", types::text().nullable(false));
});
m.make::<Sqlite>()
}

View File

@ -0,0 +1 @@

View File

@ -5,7 +5,7 @@ use sqlx::ConnectOptions;
use std::str::FromStr; use std::str::FromStr;
use thiserror::Error; use thiserror::Error;
pub mod migrations; //pub mod migrations;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum MigrationError { pub enum MigrationError {
@ -16,6 +16,11 @@ pub enum MigrationError {
RefineryError(#[from] refinery::Error), RefineryError(#[from] refinery::Error),
} }
mod embedded {
use refinery::embed_migrations;
embed_migrations!("src/db/sqlite/migrator/migrations");
}
/// Run database migrations against the sqlite database. /// Run database migrations against the sqlite database.
pub async fn migrate(db_path: &str) -> Result<(), MigrationError> { pub async fn migrate(db_path: &str) -> Result<(), MigrationError> {
//Create database if missing. //Create database if missing.
@ -28,6 +33,6 @@ pub async fn migrate(db_path: &str) -> Result<(), MigrationError> {
let mut conn = Config::new(ConfigDbType::Sqlite).set_db_path(db_path); let mut conn = Config::new(ConfigDbType::Sqlite).set_db_path(db_path);
info!("Running migrations"); info!("Running migrations");
migrations::runner().run(&mut conn)?; embedded::migrations::runner().run(&mut conn)?;
Ok(()) Ok(())
} }

View File

@ -53,34 +53,41 @@ impl Rooms for Database {
mod tests { mod tests {
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::Rooms; use crate::db::Rooms;
use std::future::Future;
async fn create_db() -> Database { async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap();
Database::new(db_path.path().to_str().unwrap()) let db = Database::new(db_path.path().to_str().unwrap())
.await .await
.unwrap() .unwrap();
f(db).await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn should_process_test() { async fn should_process_test() {
let db = create_db().await; with_db(|db| async move {
let first_check = db
.should_process("myroom", "myeventid")
.await
.expect("should_process failed in first insert");
let first_check = db assert_eq!(first_check, true);
.should_process("myroom", "myeventid")
.await
.expect("should_process failed in first insert");
assert_eq!(first_check, true); let second_check = db
.should_process("myroom", "myeventid")
.await
.expect("should_process failed in first insert");
let second_check = db assert_eq!(second_check, false);
.should_process("myroom", "myeventid") })
.await .await;
.expect("should_process failed in first insert");
assert_eq!(second_check, false);
} }
} }

View File

@ -37,54 +37,64 @@ impl DbState for Database {
mod tests { mod tests {
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::DbState; use crate::db::DbState;
use std::future::Future;
async fn create_db() -> Database { async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap();
Database::new(db_path.path().to_str().unwrap()) let db = Database::new(db_path.path().to_str().unwrap())
.await .await
.unwrap() .unwrap();
f(db).await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn set_and_get_device_id() { async fn set_and_get_device_id() {
let db = create_db().await; with_db(|db| async move {
db.set_device_id("device_id")
.await
.expect("Could not set device ID");
db.set_device_id("device_id") let device_id = db.get_device_id().await.expect("Could not get device ID");
.await
.expect("Could not set device ID");
let device_id = db.get_device_id().await.expect("Could not get device ID"); assert!(device_id.is_some());
assert_eq!(device_id.unwrap(), "device_id");
assert!(device_id.is_some()); })
assert_eq!(device_id.unwrap(), "device_id"); .await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn no_device_id_set_returns_none() { async fn no_device_id_set_returns_none() {
let db = create_db().await; with_db(|db| async move {
let device_id = db.get_device_id().await.expect("Could not get device ID"); let device_id = db.get_device_id().await.expect("Could not get device ID");
assert!(device_id.is_none()); assert!(device_id.is_none());
})
.await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_update_device_id() { async fn can_update_device_id() {
let db = create_db().await; with_db(|db| async move {
db.set_device_id("device_id")
.await
.expect("Could not set device ID");
db.set_device_id("device_id") db.set_device_id("device_id2")
.await .await
.expect("Could not set device ID"); .expect("Could not set device ID");
db.set_device_id("device_id2") let device_id = db.get_device_id().await.expect("Could not get device ID");
.await
.expect("Could not set device ID");
let device_id = db.get_device_id().await.expect("Could not get device ID"); assert!(device_id.is_some());
assert_eq!(device_id.unwrap(), "device_id2");
assert!(device_id.is_some()); })
assert_eq!(device_id.unwrap(), "device_id2"); .await;
} }
} }

View File

@ -0,0 +1,361 @@
use super::Database;
use crate::db::{errors::DataError, Users};
use crate::error::BotError;
use crate::models::User;
use async_trait::async_trait;
#[async_trait]
impl Users for Database {
async fn upsert_user(&self, user: &User) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query!(
r#"INSERT INTO accounts (user_id, password, account_status)
VALUES (?, ?, ?)
ON CONFLICT(user_id) DO
UPDATE SET password = ?, account_status = ?"#,
user.username,
user.password,
user.account_status,
user.password,
user.account_status
)
.execute(&mut tx)
.await?;
sqlx::query!(
r#"INSERT INTO user_state (user_id, active_room)
VALUES (?, ?)
ON CONFLICT(user_id) DO
UPDATE SET active_room = ?"#,
user.username,
user.active_room,
user.active_room
)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn delete_user(&self, username: &str) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query!(r#"DELETE FROM accounts WHERE user_id = ?"#, username)
.execute(&mut tx)
.await?;
sqlx::query!(r#"DELETE FROM user_state WHERE user_id = ?"#, username)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn get_user(&self, username: &str) -> Result<Option<User>, DataError> {
// Should be query_as! macro, but the left join breaks it with a
// non existing error message.
let user_row: Option<User> = sqlx::query_as(
r#"SELECT
a.user_id as "username",
a.password,
s.active_room,
COALESCE(a.account_status, 'not_registered') as "account_status"
FROM accounts a
LEFT JOIN user_state s on a.user_id = s.user_id
WHERE a.user_id = ?"#,
)
.bind(username)
.fetch_optional(&self.conn)
.await?;
Ok(user_row)
}
async fn authenticate_user(
&self,
username: &str,
raw_password: &str,
) -> Result<Option<User>, BotError> {
let user = self.get_user(username).await?;
Ok(user.filter(|u| u.verify_password(raw_password)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::sqlite::Database;
use crate::db::Users;
use crate::models::AccountStatus;
use std::future::Future;
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
let db = Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
f(db).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn create_and_get_full_user_test() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
assert_eq!(user.account_status, AccountStatus::Registered);
assert_eq!(user.active_room, Some("myroom".to_string()));
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_get_user_with_no_state_record() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::AwaitingActivation,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
sqlx::query("DELETE FROM user_state")
.execute(&db.conn)
.await
.expect("Could not delete from user_state table.");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
//These should be default values because the state record is missing.
assert_eq!(user.active_room, None);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_password() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, None);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_active_room() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
active_room: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.active_room, None);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_update_user() {
with_db(|db| async move {
let insert_result1 = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
..Default::default()
})
.await;
assert!(insert_result1.is_ok());
let insert_result2 = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("123".to_string()),
active_room: Some("room".to_string()),
account_status: AccountStatus::AwaitingActivation,
})
.await;
assert!(insert_result2.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
//From second upsert
assert_eq!(user.password, Some("123".to_string()));
assert_eq!(user.active_room, Some("room".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_delete_user() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
db.delete_user("myuser")
.await
.expect("User deletion query failed");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_none());
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn username_not_in_db_returns_none() {
with_db(|db| async move {
let user = db
.get_user("does not exist")
.await
.expect("Get user query failure");
assert!(user.is_none());
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn authenticate_user_is_some_with_valid_password() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some(
crate::logic::hash_password("abc").expect("password hash error!"),
),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.authenticate_user("myuser", "abc")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn authenticate_user_is_none_with_wrong_password() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some(
crate::logic::hash_password("abc").expect("password hash error!"),
),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.authenticate_user("myuser", "wrong-password")
.await
.expect("User retrieval query failed");
assert!(user.is_none());
})
.await;
}
}

View File

@ -102,143 +102,156 @@ mod tests {
use super::*; use super::*;
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::Variables; use crate::db::Variables;
use std::future::Future;
async fn create_db() -> Database { async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap();
Database::new(db_path.path().to_str().unwrap()) let db = Database::new(db_path.path().to_str().unwrap())
.await .await
.unwrap() .unwrap();
f(db).await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn set_and_get_variable_test() { async fn set_and_get_variable_test() {
let db = create_db().await; with_db(|db| async move {
db.set_user_variable("myuser", "myroom", "myvariable", 1)
.await
.expect("Could not set variable");
db.set_user_variable("myuser", "myroom", "myvariable", 1) let value = db
.await .get_user_variable("myuser", "myroom", "myvariable")
.expect("Could not set variable"); .await
.expect("Could not get variable");
let value = db assert_eq!(value, 1);
.get_user_variable("myuser", "myroom", "myvariable") })
.await .await;
.expect("Could not get variable");
assert_eq!(value, 1);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_missing_variable_test() { async fn get_missing_variable_test() {
let db = create_db().await; with_db(|db| async move {
let value = db.get_user_variable("myuser", "myroom", "myvariable").await;
let value = db.get_user_variable("myuser", "myroom", "myvariable").await; assert!(value.is_err());
assert!(matches!(
assert!(value.is_err()); value.err().unwrap(),
assert!(matches!( DataError::KeyDoesNotExist(_)
value.err().unwrap(), ));
DataError::KeyDoesNotExist(_) })
)); .await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_other_user_variable_test() { async fn get_other_user_variable_test() {
let db = create_db().await; with_db(|db| async move {
db.set_user_variable("myuser1", "myroom", "myvariable", 1)
.await
.expect("Could not set variable");
db.set_user_variable("myuser1", "myroom", "myvariable", 1) let value = db
.await .get_user_variable("myuser2", "myroom", "myvariable")
.expect("Could not set variable"); .await;
let value = db assert!(value.is_err());
.get_user_variable("myuser2", "myroom", "myvariable") assert!(matches!(
.await; value.err().unwrap(),
DataError::KeyDoesNotExist(_)
assert!(value.is_err()); ));
assert!(matches!( })
value.err().unwrap(), .await;
DataError::KeyDoesNotExist(_)
));
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn count_variables_test() { async fn count_variables_test() {
let db = create_db().await; with_db(|db| async move {
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "myroom", variable_name, 1)
.await
.expect("Could not set variable");
}
for variable_name in &["var1", "var2", "var3"] { let count = db
db.set_user_variable("myuser", "myroom", variable_name, 1) .get_variable_count("myuser", "myroom")
.await .await
.expect("Could not set variable"); .expect("Could not get count.");
}
let count = db assert_eq!(count, 3);
.get_variable_count("myuser", "myroom") })
.await .await;
.expect("Could not get count.");
assert_eq!(count, 3);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn count_variables_respects_user_id() { async fn count_variables_respects_user_id() {
let db = create_db().await; with_db(|db| async move {
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("different-user", "myroom", variable_name, 1)
.await
.expect("Could not set variable");
}
for variable_name in &["var1", "var2", "var3"] { let count = db
db.set_user_variable("different-user", "myroom", variable_name, 1) .get_variable_count("myuser", "myroom")
.await .await
.expect("Could not set variable"); .expect("Could not get count.");
}
let count = db assert_eq!(count, 0);
.get_variable_count("myuser", "myroom") })
.await .await;
.expect("Could not get count.");
assert_eq!(count, 0);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn count_variables_respects_room_id() { async fn count_variables_respects_room_id() {
let db = create_db().await; with_db(|db| async move {
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "different-room", variable_name, 1)
.await
.expect("Could not set variable");
}
for variable_name in &["var1", "var2", "var3"] { let count = db
db.set_user_variable("myuser", "different-room", variable_name, 1) .get_variable_count("myuser", "myroom")
.await .await
.expect("Could not set variable"); .expect("Could not get count.");
}
let count = db assert_eq!(count, 0);
.get_variable_count("myuser", "myroom") })
.await .await;
.expect("Could not get count.");
assert_eq!(count, 0);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn delete_variable_test() { async fn delete_variable_test() {
let db = create_db().await; with_db(|db| async move {
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "myroom", variable_name, 1)
.await
.expect("Could not set variable");
}
for variable_name in &["var1", "var2", "var3"] { db.delete_user_variable("myuser", "myroom", "var1")
db.set_user_variable("myuser", "myroom", variable_name, 1)
.await .await
.expect("Could not set variable"); .expect("Could not delete variable.");
}
db.delete_user_variable("myuser", "myroom", "var1") let count = db
.await .get_variable_count("myuser", "myroom")
.expect("Could not delete variable."); .await
.expect("Could not get count");
let count = db assert_eq!(count, 2);
.get_variable_count("myuser", "myroom")
.await
.expect("Could not get count");
assert_eq!(count, 2); let var1 = db.get_user_variable("myuser", "myroom", "var1").await;
assert!(var1.is_err());
let var1 = db.get_user_variable("myuser", "myroom", "var1").await; assert!(matches!(var1.err().unwrap(), DataError::KeyDoesNotExist(_)));
assert!(var1.is_err()); })
assert!(matches!(var1.err().unwrap(), DataError::KeyDoesNotExist(_))); .await;
} }
} }

View File

@ -1,7 +1,10 @@
use std::net::AddrParseError;
use crate::commands::CommandError; use crate::commands::CommandError;
use crate::config::ConfigError; use crate::config::ConfigError;
use crate::db::errors::DataError; use crate::db::errors::DataError;
use thiserror::Error; use thiserror::Error;
use tonic::metadata::errors::InvalidMetadataValue;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum BotError { pub enum BotError {
@ -15,6 +18,12 @@ pub enum BotError {
#[error("could not retrieve device id")] #[error("could not retrieve device id")]
NoDeviceIdFound, NoDeviceIdFound,
#[error("could not build client: {0}")]
ClientBuildError(#[from] matrix_sdk::ClientBuildError),
#[error("could not open matrix store: {0}")]
OpenStoreError(#[from] matrix_sdk::store::OpenStoreError),
#[error("command error: {0}")] #[error("command error: {0}")]
CommandError(#[from] CommandError), CommandError(#[from] CommandError),
@ -30,15 +39,15 @@ pub enum BotError {
#[error("could not parse URL")] #[error("could not parse URL")]
UrlParseError(#[from] url::ParseError), UrlParseError(#[from] url::ParseError),
#[error("could not parse ID")]
IdParseError(#[from] matrix_sdk::ruma::IdParseError),
#[error("error in matrix state store: {0}")] #[error("error in matrix state store: {0}")]
MatrixStateStoreError(#[from] matrix_sdk::StoreError), MatrixStateStoreError(#[from] matrix_sdk::StoreError),
#[error("uncategorized matrix SDK error: {0}")] #[error("uncategorized matrix SDK error: {0}")]
MatrixError(#[from] matrix_sdk::Error), MatrixError(#[from] matrix_sdk::Error),
#[error("uncategorized matrix SDK base error: {0}")]
MatrixBaseError(#[from] matrix_sdk::BaseError),
#[error("future canceled")] #[error("future canceled")]
FutureCanceledError, FutureCanceledError,
@ -76,8 +85,8 @@ pub enum BotError {
#[error("could not convert to proper integer type")] #[error("could not convert to proper integer type")]
TryFromIntError(#[from] std::num::TryFromIntError), TryFromIntError(#[from] std::num::TryFromIntError),
#[error("identifier error: {0}")] // #[error("identifier error: {0}")]
IdentifierError(#[from] matrix_sdk::identifiers::Error), // IdentifierError(#[from] matrix_sdk::ruma::Error),
#[error("password creation error: {0}")] #[error("password creation error: {0}")]
PasswordCreationError(argon2::Error), PasswordCreationError(argon2::Error),
@ -93,6 +102,15 @@ pub enum BotError {
#[error("room name or id does not exist")] #[error("room name or id does not exist")]
RoomDoesNotExist, RoomDoesNotExist,
#[error("tonic transport error: {0}")]
TonicTransportError(#[from] tonic::transport::Error),
#[error("address parsing error: {0}")]
AddressParseError(#[from] AddrParseError),
#[error("invalid metadata value: {0}")]
TonicInvalidMetadata(#[from] InvalidMetadataValue),
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]

View File

@ -6,6 +6,9 @@ pub fn parse_help_topic(input: &str) -> Option<HelpTopic> {
"dicepool" => Some(HelpTopic::DicePool), "dicepool" => Some(HelpTopic::DicePool),
"dice" => Some(HelpTopic::RollingDice), "dice" => Some(HelpTopic::RollingDice),
"cthulhu" => Some(HelpTopic::Cthulhu), "cthulhu" => Some(HelpTopic::Cthulhu),
"variables" => Some(HelpTopic::Variables),
"var" => Some(HelpTopic::Variables),
"variable" => Some(HelpTopic::Variables),
"" => Some(HelpTopic::General), "" => Some(HelpTopic::General),
_ => None, _ => None,
} }
@ -16,6 +19,7 @@ pub enum HelpTopic {
DicePool, DicePool,
Cthulhu, Cthulhu,
RollingDice, RollingDice,
Variables,
General, General,
} }
@ -101,6 +105,34 @@ Note: If !cthadv is given a variable, and the roll is successful, it will
update the variable with the new skill. update the variable with the new skill.
"}; "};
const VARIABLES_HELP: &'static str = indoc! {"
Variables
Commands: !get, !set, !variables
Manage variables that can be substituted into roll commands.
Examples: !get myvar, !set myvar 10
!get <variable> = show variable of the given name
!set <variable> <num> = set a variable to a number
The !variables command will list all variables for the room. The
variables command cna be used in a secure room to avoid spamming the
actual room that the variable is set in.
Variable names can be used in all types of dice rolls:
!pool myvar + 3
!roll myvar
There are some limitations on variables: they cannot themselves be
dice expressions (i.e. can only be numbers), and they must be uniquely
parseable in an expression (i.e 'myvard6' does not work for the !roll
command).
"};
const GENERAL_HELP: &'static str = indoc! {" const GENERAL_HELP: &'static str = indoc! {"
General Help General Help
@ -117,6 +149,7 @@ impl HelpTopic {
HelpTopic::DicePool => DICEPOOL_HELP, HelpTopic::DicePool => DICEPOOL_HELP,
HelpTopic::Cthulhu => CTHULHU_HELP, HelpTopic::Cthulhu => CTHULHU_HELP,
HelpTopic::RollingDice => DICE_HELP, HelpTopic::RollingDice => DICE_HELP,
HelpTopic::Variables => VARIABLES_HELP,
HelpTopic::General => GENERAL_HELP, HelpTopic::General => GENERAL_HELP,
} }
} }

View File

@ -12,4 +12,6 @@ pub mod logic;
pub mod matrix; pub mod matrix;
pub mod models; pub mod models;
mod parser; mod parser;
pub mod rpc;
pub mod state; pub mod state;
pub mod systems;

View File

@ -27,7 +27,7 @@ pub async fn calculate_dice_amount(amounts: &[Amount], ctx: &Context<'_>) -> Res
let stream = stream::iter(amounts); let stream = stream::iter(amounts);
let variables = &ctx let variables = &ctx
.db .db
.get_user_variables(&ctx.username, ctx.room_id().as_str()) .get_user_variables(&ctx.username, ctx.active_room_id().as_str())
.await?; .await?;
use DiceRollingError::VariableNotFound; use DiceRollingError::VariableNotFound;
@ -71,53 +71,61 @@ mod tests {
use super::*; use super::*;
use crate::db::Users; use crate::db::Users;
use crate::models::{AccountStatus, User}; use crate::models::{AccountStatus, User};
use std::future::Future;
async fn create_db() -> Database { async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap();
Database::new(db_path.path().to_str().unwrap()) let db = Database::new(db_path.path().to_str().unwrap())
.await .await
.unwrap() .unwrap();
f(db).await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_account_no_user_exists() { async fn get_account_no_user_exists() {
let db = create_db().await; with_db(|db| async move {
let account = get_account(&db, "@test:example.com")
.await
.expect("Account retrieval didn't work");
let account = get_account(&db, "@test:example.com") assert!(matches!(account, Account::Transient(_)));
.await
.expect("Account retrieval didn't work");
assert!(matches!(account, Account::Transient(_))); let user = account.transient_user().unwrap();
assert_eq!(user.username, "@test:example.com");
let user = account.transient_user().unwrap(); })
assert_eq!(user.username, "@test:example.com"); .await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_or_create_user_when_user_exists() { async fn get_or_create_user_when_user_exists() {
let db = create_db().await; with_db(|db| async move {
let user = User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
};
let user = User { let insert_result = db.upsert_user(&user).await;
username: "myuser".to_string(), assert!(insert_result.is_ok());
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
};
let insert_result = db.upsert_user(&user).await; let account = get_account(&db, "myuser")
assert!(insert_result.is_ok()); .await
.expect("Account retrieval did not work");
let account = get_account(&db, "myuser") assert!(matches!(account, Account::Registered(_)));
.await
.expect("Account retrieval did not work");
assert!(matches!(account, Account::Registered(_))); let user_again = account.registered_user().unwrap();
assert_eq!(user, *user_again);
let user_again = account.registered_user().unwrap(); })
assert_eq!(user, user_again); .await;
} }
} }

View File

@ -1,14 +1,22 @@
use std::path::PathBuf;
use futures::stream::{self, StreamExt, TryStreamExt}; use futures::stream::{self, StreamExt, TryStreamExt};
use log::error; use log::error;
use matrix_sdk::{events::room::message::NoticeMessageEventContent, room::Joined}; use matrix_sdk::ruma::events::room::message::{InReplyTo, RoomMessageEventContent, Relation};
use matrix_sdk::{ use matrix_sdk::ruma::events::AnyMessageLikeEventContent;
events::room::message::{InReplyTo, Relation}, use matrix_sdk::ruma::{RoomId, OwnedEventId, OwnedUserId};
events::room::message::{MessageEventContent, MessageType}, use matrix_sdk::Client;
events::AnyMessageEventContent, use matrix_sdk::Error as MatrixError;
identifiers::EventId, use matrix_sdk::room::Joined;
Error as MatrixError, use url::Url;
};
use matrix_sdk::{identifiers::RoomId, identifiers::UserId, Client}; use crate::{config::Config, error::BotError};
fn cache_dir() -> Result<PathBuf, BotError> {
let mut dir = dirs::cache_dir().ok_or(BotError::NoCacheDirectoryError)?;
dir.push("matrix-dicebot");
Ok(dir)
}
/// Extracts more detailed error messages out of a matrix SDK error. /// Extracts more detailed error messages out of a matrix SDK error.
fn extract_error_message(error: MatrixError) -> String { fn extract_error_message(error: MatrixError) -> String {
@ -20,6 +28,19 @@ fn extract_error_message(error: MatrixError) -> String {
} }
} }
/// Creates the matrix client.
pub async fn create_client(config: &Config) -> Result<Client, BotError> {
let cache_dir = cache_dir()?;
let homeserver_url = Url::parse(&config.matrix_homeserver())?;
let client = Client::builder()
.sled_store(cache_dir, None)?
.homeserver_url(homeserver_url).build()
.await?;
Ok(client)
}
/// Retrieve a list of users in a given room. /// Retrieve a list of users in a given room.
pub async fn get_users_in_room( pub async fn get_users_in_room(
client: &Client, client: &Client,
@ -39,7 +60,7 @@ pub async fn get_users_in_room(
pub async fn get_rooms_for_user( pub async fn get_rooms_for_user(
client: &Client, client: &Client,
user: &UserId, user: &OwnedUserId,
) -> Result<Vec<Joined>, MatrixError> { ) -> Result<Vec<Joined>, MatrixError> {
// Carries errors through, in case we cannot load joined user IDs // Carries errors through, in case we cannot load joined user IDs
// from the room for some reason. // from the room for some reason.
@ -67,7 +88,7 @@ pub async fn send_message(
client: &Client, client: &Client,
room_id: &RoomId, room_id: &RoomId,
message: (&str, &str), message: (&str, &str),
reply_to: Option<EventId>, reply_to: Option<OwnedEventId>,
) { ) {
let (html, plain) = message; let (html, plain) = message;
let room = match client.get_joined_room(room_id) { let room = match client.get_joined_room(room_id) {
@ -75,15 +96,13 @@ pub async fn send_message(
_ => return, _ => return,
}; };
let mut content = MessageEventContent::new(MessageType::Notice( let mut content = RoomMessageEventContent::notice_html(plain.trim(), html);
NoticeMessageEventContent::html(plain.trim(), html),
));
content.relates_to = reply_to.map(|event_id| Relation::Reply { content.relates_to = reply_to.map(|event_id| Relation::Reply {
in_reply_to: InReplyTo::new(event_id), in_reply_to: InReplyTo::new(event_id)
}); });
let content = AnyMessageEventContent::RoomMessage(content); let content = AnyMessageLikeEventContent::RoomMessage(content);
let result = room.send(content, None).await; let result = room.send(content, None).await;

View File

@ -60,9 +60,9 @@ impl Account {
/// Consume self into an Option<User> instance, which will be Some /// Consume self into an Option<User> instance, which will be Some
/// if this account has a registered user, and None otherwise. /// if this account has a registered user, and None otherwise.
pub fn registered_user(self) -> Option<User> { pub fn registered_user(&self) -> Option<&User> {
match self { match self {
Self::Registered(user) => Some(user), Self::Registered(ref user) => Some(user),
_ => None, _ => None,
} }
} }

View File

@ -151,8 +151,9 @@ where
/// should not have an operator, but every one after that should. /// should not have an operator, but every one after that should.
/// Accepts expressions like "8", "10 + variablename", "variablename - /// Accepts expressions like "8", "10 + variablename", "variablename -
/// 3", etc. This function is currently common to systems that don't /// 3", etc. This function is currently common to systems that don't
/// deal with XdY rolls. Support for that will be added later. /// deal with XdY rolls. Support for that will be added later. Returns
pub fn parse_amounts(input: &str) -> ParseResult<Vec<Amount>> { /// parsed amounts and unconsumed input (e.g. roll modifiers).
pub fn parse_amounts(input: &str) -> ParseResult<(Vec<Amount>, &str)> {
let input = input.trim(); let input = input.trim();
let remaining_amounts = many(amount_parser()).map(|amounts: Vec<ParseResult<Amount>>| amounts); let remaining_amounts = many(amount_parser()).map(|amounts: Vec<ParseResult<Amount>>| amounts);
@ -169,31 +170,23 @@ pub fn parse_amounts(input: &str) -> ParseResult<Vec<Amount>> {
(amounts, results.1) (amounts, results.1)
})?; })?;
if rest.len() == 0 { // Any ParseResult errors will short-circuit the collect.
// Any ParseResult errors will short-circuit the collect. let results: Vec<Amount> = results.into_iter().collect::<ParseResult<_>>()?;
results.into_iter().collect() Ok((results, rest))
} else {
Err(DiceParsingError::UnconsumedInput)
}
} }
/// Parse an expression that expects a single number or variable. No /// Parse an expression that expects a single number or variable. No
/// operators are allowed. This function is common to systems that /// operators are allowed. This function is common to systems that
/// don't deal with XdY rolls. Currently. this function does not /// don't deal with XdY rolls. Currently. this function does not
/// support parsing negative numbers. /// support parsing negative numbers. Returns the parsed amount and
pub fn parse_single_amount(input: &str) -> ParseResult<Amount> { /// any unconsumed input (useful for dice roll modifiers).
pub fn parse_single_amount(input: &str) -> ParseResult<(Amount, &str)> {
// TODO add support for negative numbers, as technically they // TODO add support for negative numbers, as technically they
// should be allowed. // should be allowed.
let input = input.trim(); let input = input.trim();
let mut parser = first_amount_parser().map(|amount: ParseResult<Amount>| amount); let mut parser = first_amount_parser().map(|amount: ParseResult<Amount>| amount);
let (result, rest) = parser.parse(input)?; let (result, rest) = parser.parse(input)?;
Ok((result?, rest))
if rest.len() == 0 {
result
} else {
Err(DiceParsingError::UnconsumedInput)
}
} }
#[cfg(test)] #[cfg(test)]
@ -206,10 +199,13 @@ mod parse_single_amount_tests {
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
Amount { (
operator: Operator::Plus, Amount {
element: Element::Variable("abc".to_string()) operator: Operator::Plus,
} element: Element::Variable("abc".to_string())
},
""
)
) )
} }
@ -233,24 +229,15 @@ mod parse_single_amount_tests {
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
Amount { (
operator: Operator::Plus, Amount {
element: Element::Number(1) operator: Operator::Plus,
} element: Element::Number(1)
},
""
)
) )
} }
#[test]
fn parse_multiple_elements_test() {
let result = parse_single_amount("1+abc");
assert!(result.is_err());
let result = parse_single_amount("abc+1");
assert!(result.is_err());
let result = parse_single_amount("-1-abc");
assert!(result.is_err());
}
} }
#[cfg(test)] #[cfg(test)]
@ -263,20 +250,26 @@ mod parse_many_amounts_tests {
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
vec![Amount { (
operator: Operator::Plus, vec![Amount {
element: Element::Number(1) operator: Operator::Plus,
}] element: Element::Number(1)
}],
""
)
); );
let result = parse_amounts("10"); let result = parse_amounts("10");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
vec![Amount { (
operator: Operator::Plus, vec![Amount {
element: Element::Number(10) operator: Operator::Plus,
}] element: Element::Number(10)
}],
""
)
); );
} }
@ -295,20 +288,26 @@ mod parse_many_amounts_tests {
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
vec![Amount { (
operator: Operator::Plus, vec![Amount {
element: Element::Variable("asdf".to_string()) operator: Operator::Plus,
}] element: Element::Variable("asdf".to_string())
}],
""
)
); );
let result = parse_amounts("nosis"); let result = parse_amounts("nosis");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
vec![Amount { (
operator: Operator::Plus, vec![Amount {
element: Element::Variable("nosis".to_string()) operator: Operator::Plus,
}] element: Element::Variable("nosis".to_string())
}],
""
)
); );
} }

50
dicebot/src/rpc/mod.rs Normal file
View File

@ -0,0 +1,50 @@
use crate::error::BotError;
use crate::{config::Config, db::sqlite::Database};
use log::{info, warn};
use matrix_sdk::Client;
use service::DicebotRpcService;
use std::sync::Arc;
use tenebrous_rpc::protos::dicebot::dicebot_server::DicebotServer;
use tonic::{metadata::MetadataValue, transport::Server, Request, Status};
pub(crate) mod service;
pub async fn serve_grpc(
config: &Arc<Config>,
db: &Database,
client: &Client,
) -> Result<(), BotError> {
match config.rpc_addr().zip(config.rpc_key()) {
Some((addr, rpc_key)) => {
let expected_bearer = MetadataValue::from_str(&format!("Bearer {}", rpc_key))?;
let addr = addr.parse()?;
let rpc_service = DicebotRpcService {
db: db.clone(),
config: config.clone(),
client: client.clone(),
};
info!("Serving Dicebot gRPC service on {}", addr);
let interceptor = move |req: Request<()>| match req.metadata().get("authorization") {
Some(bearer) if bearer == expected_bearer => Ok(req),
_ => Err(Status::unauthenticated("No valid auth token")),
};
let server = DicebotServer::with_interceptor(rpc_service, interceptor);
Server::builder()
.add_service(server)
.serve(addr)
.await
.map_err(|e| e.into())
}
_ => noop().await,
}
}
pub async fn noop() -> Result<(), BotError> {
warn!("RPC address or shared secret not specified. Not enabling gRPC.");
Ok(())
}

117
dicebot/src/rpc/service.rs Normal file
View File

@ -0,0 +1,117 @@
use crate::db::{errors::DataError, Variables};
use crate::error::BotError;
use crate::matrix;
use crate::{config::Config, db::sqlite::Database};
use futures::stream;
use futures::{StreamExt, TryFutureExt, TryStreamExt};
use matrix_sdk::ruma::OwnedUserId;
use matrix_sdk::{room::Joined, Client};
use std::convert::TryFrom;
use std::sync::Arc;
use tenebrous_rpc::protos::dicebot::{
dicebot_server::Dicebot, rooms_list_reply::Room, GetAllVariablesReply, GetAllVariablesRequest,
RoomsListReply, SetVariableReply, SetVariableRequest, UserIdRequest,
};
use tenebrous_rpc::protos::dicebot::{GetVariableReply, GetVariableRequest};
use tonic::{Code, Request, Response, Status};
impl From<BotError> for Status {
fn from(error: BotError) -> Status {
Status::new(Code::Internal, error.to_string())
}
}
impl From<DataError> for Status {
fn from(error: DataError) -> Status {
Status::new(Code::Internal, error.to_string())
}
}
#[derive(Clone)]
pub(super) struct DicebotRpcService {
pub(super) config: Arc<Config>,
pub(super) db: Database,
pub(super) client: Client,
}
#[tonic::async_trait]
impl Dicebot for DicebotRpcService {
async fn set_variable(
&self,
request: Request<SetVariableRequest>,
) -> Result<Response<SetVariableReply>, Status> {
let SetVariableRequest {
user_id,
room_id,
variable_name,
value,
} = request.into_inner();
self.db
.set_user_variable(&user_id, &room_id, &variable_name, value)
.await?;
Ok(Response::new(SetVariableReply { success: true }))
}
async fn get_variable(
&self,
request: Request<GetVariableRequest>,
) -> Result<Response<GetVariableReply>, Status> {
let request = request.into_inner();
let value = self
.db
.get_user_variable(&request.user_id, &request.room_id, &request.variable_name)
.await?;
Ok(Response::new(GetVariableReply { value }))
}
async fn get_all_variables(
&self,
request: Request<GetAllVariablesRequest>,
) -> Result<Response<GetAllVariablesReply>, Status> {
let request = request.into_inner();
let variables = self
.db
.get_user_variables(&request.user_id, &request.room_id)
.await?;
Ok(Response::new(GetAllVariablesReply { variables }))
}
async fn rooms_for_user(
&self,
request: Request<UserIdRequest>,
) -> Result<Response<RoomsListReply>, Status> {
let UserIdRequest { user_id } = request.into_inner();
let user_id = OwnedUserId::try_from(user_id).map_err(BotError::from)?;
let rooms_for_user = matrix::get_rooms_for_user(&self.client, &user_id)
.err_into::<BotError>()
.await?;
let mut rooms: Vec<Room> = stream::iter(rooms_for_user)
.filter_map(|room: Joined| async move {
let room: Result<Room, _> = room.display_name().await.map(|room_name| Room {
room_id: room.room_id().to_string(),
display_name: room_name.to_string(),
});
Some(room)
})
.err_into::<BotError>()
.try_collect()
.await?;
let sort = |r1: &Room, r2: &Room| {
r1.display_name
.to_lowercase()
.cmp(&r2.display_name.to_lowercase())
};
rooms.sort_by(sort);
Ok(Response::new(RoomsListReply { rooms }))
}
}

View File

@ -0,0 +1,21 @@
use strum::{AsRefStr, Display, EnumIter, EnumString};
#[derive(EnumString, EnumIter, AsRefStr, Display)]
pub(crate) enum GameSystem {
ChroniclesOfDarkness,
Changeling,
MageTheAwakening,
WerewolfTheForsaken,
DeviantTheRenegades,
MummyTheCurse,
PrometheanTheCreated,
CallOfCthulhu,
DungeonsAndDragons5e,
DungeonsAndDragons4e,
DungeonsAndDragons35e,
DungeonsAndDragons2e,
DungeonsAndDragons1e,
None,
}
impl GameSystem {}

18
rpc/Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "tenebrous-rpc"
version = "0.1.0"
authors = ["projectmoon <projectmoon@agnos.is>"]
edition = "2018"
description = "gRPC protobuf models for Tenebrous."
homepage = "https://git.agnos.is/projectmoon/tenebrous-dicebot"
repository = "https://git.agnos.is/projectmoon/tenebrous-dicebot"
license = "AGPL-3.0-or-later"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
tonic-build = "0.4"
[dependencies]
tonic = "0.4"
prost = "0.7"

4
rpc/build.rs Normal file
View File

@ -0,0 +1,4 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("protos/dicebot.proto")?;
Ok(())
}

52
rpc/protos/dicebot.proto Normal file
View File

@ -0,0 +1,52 @@
syntax = "proto3";
package dicebot;
service Dicebot {
rpc GetVariable(GetVariableRequest) returns (GetVariableReply);
rpc GetAllVariables(GetAllVariablesRequest) returns (GetAllVariablesReply);
rpc SetVariable(SetVariableRequest) returns (SetVariableReply);
rpc RoomsForUser(UserIdRequest) returns (RoomsListReply);
}
message GetVariableRequest {
string user_id = 1;
string room_id = 2;
string variable_name = 3;
}
message GetVariableReply {
int32 value = 1;
}
message GetAllVariablesRequest {
string user_id = 1;
string room_id = 2;
}
message GetAllVariablesReply {
map<string, int32> variables = 1;
}
message SetVariableRequest {
string user_id = 1;
string room_id = 2;
string variable_name = 3;
int32 value = 4;
}
message SetVariableReply {
bool success = 1;
}
message UserIdRequest {
string user_id = 1;
}
message RoomsListReply {
message Room {
string room_id = 1;
string display_name = 2;
}
repeated Room rooms = 1;
}

5
rpc/src/lib.rs Normal file
View File

@ -0,0 +1,5 @@
pub mod protos {
pub mod dicebot {
tonic::include_proto!("dicebot");
}
}

View File

@ -1,189 +0,0 @@
/**
* In addition to the terms of the AGPL, this file is governed by the
* terms of the MIT license, from the original axfive-matrix-dicebot
* project.
*/
use nom::bytes::complete::take_while;
use nom::{
alt, bytes::complete::tag, character::complete::digit1, complete, many0, named,
sequence::tuple, tag, IResult,
};
use super::dice::*;
//******************************
//Legacy Code
//******************************
fn is_whitespace(input: char) -> bool {
input == ' ' || input == '\n' || input == '\t' || input == '\r'
}
/// Eat whitespace, returning it
pub fn eat_whitespace(input: &str) -> IResult<&str, &str> {
let (input, whitespace) = take_while(is_whitespace)(input)?;
Ok((input, whitespace))
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Sign {
Plus,
Minus,
}
// Parse a dice expression. Does not eat whitespace
fn parse_dice(input: &str) -> IResult<&str, Dice> {
let (input, (count, _, sides)) = tuple((digit1, tag("d"), digit1))(input)?;
Ok((
input,
Dice::new(count.parse().unwrap(), sides.parse().unwrap()),
))
}
// Parse a single digit expression. Does not eat whitespace
fn parse_bonus(input: &str) -> IResult<&str, u32> {
let (input, bonus) = digit1(input)?;
Ok((input, bonus.parse().unwrap()))
}
// Parse a sign expression. Eats whitespace.
fn parse_sign(input: &str) -> IResult<&str, Sign> {
let (input, _) = eat_whitespace(input)?;
named!(sign(&str) -> Sign, alt!(
complete!(tag!("+")) => { |_| Sign::Plus } |
complete!(tag!("-")) => { |_| Sign::Minus }
));
let (input, sign) = sign(input)?;
Ok((input, sign))
}
// Parse an element expression. Eats whitespace.
fn parse_element(input: &str) -> IResult<&str, Element> {
let (input, _) = eat_whitespace(input)?;
named!(element(&str) -> Element, alt!(
parse_dice => { |d| Element::Dice(d) } |
parse_bonus => { |b| Element::Bonus(b) }
));
let (input, element) = element(input)?;
Ok((input, element))
}
// Parse a signed element expression. Eats whitespace.
fn parse_signed_element(input: &str) -> IResult<&str, SignedElement> {
let (input, _) = eat_whitespace(input)?;
let (input, sign) = parse_sign(input)?;
let (input, _) = eat_whitespace(input)?;
let (input, element) = parse_element(input)?;
let element = match sign {
Sign::Plus => SignedElement::Positive(element),
Sign::Minus => SignedElement::Negative(element),
};
Ok((input, element))
}
// Parse a full element expression. Eats whitespace.
pub fn parse_element_expression(input: &str) -> IResult<&str, ElementExpression> {
named!(first_element(&str) -> SignedElement, alt!(
parse_signed_element => { |e| e } |
parse_element => { |e| SignedElement::Positive(e) }
));
let (input, first) = first_element(input)?;
let (input, rest) = if input.trim().is_empty() {
(input, vec![first])
} else {
named!(rest_elements(&str) -> Vec<SignedElement>, many0!(parse_signed_element));
let (input, mut rest) = rest_elements(input)?;
rest.insert(0, first);
(input, rest)
};
Ok((input, ElementExpression(rest)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dice_test() {
assert_eq!(parse_dice("2d4"), Ok(("", Dice::new(2, 4))));
assert_eq!(parse_dice("20d40"), Ok(("", Dice::new(20, 40))));
assert_eq!(parse_dice("8d7"), Ok(("", Dice::new(8, 7))));
}
#[test]
fn element_test() {
assert_eq!(
parse_element(" \t\n\r\n 8d7 \n"),
Ok((" \n", Element::Dice(Dice::new(8, 7))))
);
assert_eq!(
parse_element(" \t\n\r\n 8 \n"),
Ok((" \n", Element::Bonus(8)))
);
}
#[test]
fn signed_element_test() {
assert_eq!(
parse_signed_element("+ 7"),
Ok(("", SignedElement::Positive(Element::Bonus(7))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8 \n"),
Ok((" \n", SignedElement::Negative(Element::Bonus(8))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8d4 \n"),
Ok((
" \n",
SignedElement::Negative(Element::Dice(Dice::new(8, 4)))
))
);
assert_eq!(
parse_signed_element(" \t\n\r\n+ 8d4 \n"),
Ok((
" \n",
SignedElement::Positive(Element::Dice(Dice::new(8, 4)))
))
);
}
#[test]
fn element_expression_test() {
assert_eq!(
parse_element_expression("8d4"),
Ok((
"",
ElementExpression(vec![SignedElement::Positive(Element::Dice(Dice::new(
8, 4
)))])
))
);
assert_eq!(
parse_element_expression(" - 8d4 \n "),
Ok((
" \n ",
ElementExpression(vec![SignedElement::Negative(Element::Dice(Dice::new(
8, 4
)))])
))
);
assert_eq!(
parse_element_expression("\t3d4 + 7 - 5 - 6d12 + 1d1 + 53 1d5 "),
Ok((
" 1d5 ",
ElementExpression(vec![
SignedElement::Positive(Element::Dice(Dice::new(3, 4))),
SignedElement::Positive(Element::Bonus(7)),
SignedElement::Negative(Element::Bonus(5)),
SignedElement::Negative(Element::Dice(Dice::new(6, 12))),
SignedElement::Positive(Element::Dice(Dice::new(1, 1))),
SignedElement::Positive(Element::Bonus(53)),
])
))
);
}
}

View File

@ -1,158 +0,0 @@
use super::DiceBot;
use crate::db::sqlite::Database;
use crate::db::Rooms;
use crate::error::BotError;
use async_trait::async_trait;
use log::{debug, error, info, warn};
use matrix_sdk::{
self,
events::{
room::member::MemberEventContent,
room::message::{MessageEventContent, MessageType, TextMessageEventContent},
StrippedStateEvent, SyncMessageEvent,
},
room::Room,
EventHandler,
};
use std::ops::Sub;
use std::time::{Duration, SystemTime};
use std::{clone::Clone, time::UNIX_EPOCH};
/// Check if a message is recent enough to actually process. If the
/// message is within "oldest_message_age" seconds, this function
/// returns true. If it's older than that, it returns false and logs a
/// debug message.
fn check_message_age(
event: &SyncMessageEvent<MessageEventContent>,
oldest_message_age: u64,
) -> bool {
let sending_time = event
.origin_server_ts
.to_system_time()
.unwrap_or(UNIX_EPOCH);
let oldest_timestamp = SystemTime::now().sub(Duration::from_secs(oldest_message_age));
if sending_time > oldest_timestamp {
true
} else {
let age = match oldest_timestamp.duration_since(sending_time) {
Ok(n) => format!("{} seconds too old", n.as_secs()),
Err(_) => "before the UNIX epoch".to_owned(),
};
debug!("Ignoring message because it is {}: {:?}", age, event);
false
}
}
/// Determine whether or not to process a received message. This check
/// is necessary in addition to the event processing check because we
/// may receive message events when entering a room for the first
/// time, and we don't want to respond to things before the bot was in
/// the channel, but we do want to respond to things that were sent if
/// the bot left and rejoined quickly.
async fn should_process_message<'a>(
bot: &DiceBot,
event: &SyncMessageEvent<MessageEventContent>,
) -> Result<(String, String), BotError> {
//Ignore messages that are older than configured duration.
if !check_message_age(event, bot.config.oldest_message_age()) {
let state_check = bot.state.read().unwrap();
if !((*state_check).logged_skipped_old_messages()) {
drop(state_check);
let mut state = bot.state.write().unwrap();
(*state).skipped_old_messages();
}
return Err(BotError::ShouldNotProcessError);
}
let (msg_body, sender_username) = if let SyncMessageEvent {
content:
MessageEventContent {
msgtype: MessageType::Text(TextMessageEventContent { body, .. }),
..
},
sender,
..
} = event
{
(
body.clone(),
format!("@{}:{}", sender.localpart(), sender.server_name()),
)
} else {
(String::new(), String::new())
};
Ok((msg_body, sender_username))
}
async fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool {
db.should_process(room_id, event_id)
.await
.unwrap_or_else(|e| {
error!(
"Database error when checking if we should process an event: {}",
e.to_string()
);
false
})
}
/// This event emitter listens for messages with dice rolling commands.
/// Originally adapted from the matrix-rust-sdk examples.
#[async_trait]
impl EventHandler for DiceBot {
async fn on_stripped_state_member(
&self,
room: Room,
event: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>,
) {
let room = match room {
Room::Invited(invited_room) => invited_room,
_ => return,
};
if room.own_user_id().as_str() != event.state_key {
return;
}
info!(
"Autojoining room {}",
room.display_name().await.ok().unwrap_or_default()
);
if let Err(e) = self.client.join_room_by_id(&room.room_id()).await {
warn!("Could not join room: {}", e.to_string())
}
}
async fn on_room_message(&self, room: Room, event: &SyncMessageEvent<MessageEventContent>) {
let room = match room {
Room::Joined(joined_room) => joined_room,
_ => return,
};
let room_id = room.room_id().as_str();
if !should_process_event(&self.db, room_id, event.event_id.as_str()).await {
return;
}
let (msg_body, sender_username) =
if let Ok((msg_body, sender_username)) = should_process_message(self, &event).await {
(msg_body, sender_username)
} else {
return;
};
let results = self
.execute_commands(&room, &sender_username, &msg_body)
.await;
self.handle_results(&room, &sender_username, event.event_id.clone(), results)
.await;
}
}

View File

@ -1,2 +0,0 @@
use refinery::include_migration_mods;
include_migration_mods!("src/db/sqlite/migrator/migrations");

View File

@ -1,341 +0,0 @@
use super::Database;
use crate::db::{errors::DataError, Users};
use crate::error::BotError;
use crate::models::User;
use async_trait::async_trait;
#[async_trait]
impl Users for Database {
async fn upsert_user(&self, user: &User) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query!(
r#"INSERT INTO accounts (user_id, password, account_status)
VALUES (?, ?, ?)
ON CONFLICT(user_id) DO
UPDATE SET password = ?, account_status = ?"#,
user.username,
user.password,
user.account_status,
user.password,
user.account_status
)
.execute(&mut tx)
.await?;
sqlx::query!(
r#"INSERT INTO user_state (user_id, active_room)
VALUES (?, ?)
ON CONFLICT(user_id) DO
UPDATE SET active_room = ?"#,
user.username,
user.active_room,
user.active_room
)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn delete_user(&self, username: &str) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query!(r#"DELETE FROM accounts WHERE user_id = ?"#, username)
.execute(&mut tx)
.await?;
sqlx::query!(r#"DELETE FROM user_state WHERE user_id = ?"#, username)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn get_user(&self, username: &str) -> Result<Option<User>, DataError> {
// Should be query_as! macro, but the left join breaks it with a
// non existing error message.
let user_row: Option<User> = sqlx::query_as(
r#"SELECT
a.user_id as "username",
a.password,
s.active_room,
COALESCE(a.account_status, 'not_registered') as "account_status"
FROM accounts a
LEFT JOIN user_state s on a.user_id = s.user_id
WHERE a.user_id = ?"#,
)
.bind(username)
.fetch_optional(&self.conn)
.await?;
Ok(user_row)
}
async fn authenticate_user(
&self,
username: &str,
raw_password: &str,
) -> Result<Option<User>, BotError> {
let user = self.get_user(username).await?;
Ok(user.filter(|u| u.verify_password(raw_password)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::sqlite::Database;
use crate::db::Users;
use crate::models::AccountStatus;
async fn create_db() -> Database {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
Database::new(db_path.path().to_str().unwrap())
.await
.unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn create_and_get_full_user_test() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
assert_eq!(user.account_status, AccountStatus::Registered);
assert_eq!(user.active_room, Some("myroom".to_string()));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_get_user_with_no_state_record() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::AwaitingActivation,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
sqlx::query("DELETE FROM user_state")
.execute(&db.conn)
.await
.expect("Could not delete from user_state table.");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
//These should be default values because the state record is missing.
assert_eq!(user.active_room, None);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_password() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, None);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_active_room() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
active_room: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.active_room, None);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_update_user() {
let db = create_db().await;
let insert_result1 = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
..Default::default()
})
.await;
assert!(insert_result1.is_ok());
let insert_result2 = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("123".to_string()),
active_room: Some("room".to_string()),
account_status: AccountStatus::AwaitingActivation,
})
.await;
assert!(insert_result2.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
//From second upsert
assert_eq!(user.password, Some("123".to_string()));
assert_eq!(user.active_room, Some("room".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_delete_user() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
db.delete_user("myuser")
.await
.expect("User deletion query failed");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn username_not_in_db_returns_none() {
let db = create_db().await;
let user = db
.get_user("does not exist")
.await
.expect("Get user query failure");
assert!(user.is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn authenticate_user_is_some_with_valid_password() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some(crate::logic::hash_password("abc").expect("password hash error!")),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.authenticate_user("myuser", "abc")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn authenticate_user_is_none_with_wrong_password() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some(crate::logic::hash_password("abc").expect("password hash error!")),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.authenticate_user("myuser", "wrong-password")
.await
.expect("User retrieval query failed");
assert!(user.is_none());
}
}