Compare commits

..

No commits in common. "master" and "set-room-command" have entirely different histories.

75 changed files with 2417 additions and 4263 deletions

View File

@ -3,20 +3,18 @@ name: build-and-test
steps: steps:
- name: test - name: test
image: rust:1.80 image: rust:1.51
commands: commands:
- apt-get update - apt-get update
- apt-get install -y cmake - apt-get install -y cmake
- rustup component add rustfmt
- cargo build --verbose --all - cargo build --verbose --all
- cargo test --verbose --all - cargo test --verbose --all
- name: docker - name: docker
image: plugins/docker image: plugins/docker
when: when:
ref: branch:
- refs/tags/v* - master
- refs/heads/master
settings: settings:
auto_tag: true auto_tag: true
username: username:

3198
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,47 @@
[workspace] [package]
name = "tenebrous-dicebot"
version = "0.10.0"
authors = ["Taylor C. Richberger <taywee@gmx.com>", "projectmoon <projectmoon@agnos.is>"]
edition = "2018"
license = 'AGPL-3.0-or-later'
description = 'An async Matrix dice bot for role-playing games'
readme = 'README.md'
repository = 'https://git.agnos.is/projectmoon/matrix-dicebot'
keywords = ["games", "dice", "matrix", "bot"]
categories = ["games"]
members = [ [dependencies]
"dicebot", log = "0.4"
"rpc" tracing-subscriber = "0.2"
] toml = "0.5"
nom = "5"
rand = "0.8"
rust-argon2 = "0.8"
thiserror = "1.0"
itertools = "0.10"
async-trait = "0.1"
url = "2.1"
dirs = "3.0"
indoc = "1.0"
combine = "4.5"
futures = "0.3"
html2text = "0.2"
phf = { version = "0.8", features = ["macros"] }
matrix-sdk = { git = "https://github.com/matrix-org/matrix-rust-sdk", branch = "master" }
refinery = { version = "0.5", features = ["rusqlite"]}
barrel = { version = "0.6", features = ["sqlite3"] }
tempfile = "3"
substring = "1.4"
fuse-rust = "0.2"
[dependencies.sqlx]
version = "0.5"
features = [ "offline", "sqlite", "runtime-tokio-native-tls" ]
[dependencies.serde]
version = "1"
features = ['derive']
[dependencies.tokio]
version = "1"
features = [ "full" ]

View File

@ -1,15 +1,16 @@
# Builder image with development dependencies. # Builder image with development dependencies.
FROM ghcr.io/void-linux/void-linux:latest-mini-x86_64 as builder FROM bougyman/voidlinux:glibc as builder
RUN xbps-install -S
RUN xbps-install -yu xbps
RUN xbps-install -Syu RUN xbps-install -Syu
RUN xbps-install -Sy base-devel rustup cmake wget gnupg RUN xbps-install -Sy base-devel rustup cargo cmake wget gnupg
RUN xbps-install -Sy openssl-devel libstdc++-devel RUN xbps-install -Sy openssl-devel libstdc++-devel
RUN rustup-init -qy RUN rustup-init -qy
# Install tini for signal processing and zombie killing # Install tini for signal processing and zombie killing
ENV TINI_VERSION v0.19.0 ENV TINI_VERSION v0.19.0
ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /usr/local/bin/tini ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /usr/local/bin/tini
ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini.asc /tini.asc
RUN gpg --batch --keyserver hkp://p80.pool.sks-keyservers.net:80 --recv-keys 595E85A6B1B4779EA4DAAEC70B588DFF0527A9B7 \
&& gpg --batch --verify /tini.asc /usr/local/bin/tini
RUN chmod +x /usr/local/bin/tini RUN chmod +x /usr/local/bin/tini
# Build dicebot # Build dicebot
@ -19,10 +20,7 @@ ADD . ./
RUN . /root/.cargo/env && cargo build --release RUN . /root/.cargo/env && cargo build --release
# Final image # Final image
FROM ghcr.io/void-linux/void-linux:latest-mini-x86_64 FROM bougyman/voidlinux:tiny
RUN xbps-install -S
RUN xbps-install -yu xbps
RUN xbps-install -Syu
RUN xbps-install -Sy ca-certificates libstdc++ RUN xbps-install -Sy ca-certificates libstdc++
COPY --from=builder \ COPY --from=builder \
/root/src/target/release/dicebot \ /root/src/target/release/dicebot \

View File

@ -1,7 +1,6 @@
# Tenebrous Dicebot # Tenebrous Dicebot
[![Build Status](https://drone.agnos.is/api/badges/projectmoon/tenebrous-dicebot/status.svg)](https://drone.agnos.is/projectmoon/tenebrous-dicebot) [![Build Status](https://drone.agnos.is/api/badges/projectmoon/tenebrous-dicebot/status.svg)](https://drone.agnos.is/projectmoon/tenebrous-dicebot)
[![Matrix Chat](https://img.shields.io/matrix/tenebrous:agnos.is?label=matrix&server_fqdn=matrix.org)][matrix-room]
_This repository is hosted on [Agnos.is Git][main-repo] and mirrored _This repository is hosted on [Agnos.is Git][main-repo] and mirrored
to [GitHub][github-repo]._ to [GitHub][github-repo]._
@ -25,23 +24,6 @@ System.
* Works in encrypted or unencrypted Matrix rooms. * Works in encrypted or unencrypted Matrix rooms.
* Storing variables created by the user. * Storing variables created by the user.
## Support and Community
The project has a Matrix room at [#tenebrous:agnos.is][matrix-room].
It is also possible to make a post in [GitHub
Discussions][github-discussions].
For reporting bugs, we prefer that you open an issue on
[git.agnos.is][agnosis-git-issues]. However, you may also open an
issue on [GitHub][github-issues].
### Development and Contributions
All development occurs on [git.agnos.is][main-repo]. If you wish to
contribute, please open a pull request there. In some cases, pull
requests from GitHub may be accepted. All contributions must be
licensed under [AGPL 3.0 or later][agpl] to be accepted.
## Building and Installation ## Building and Installation
### Docker Image ### Docker Image
@ -64,17 +46,6 @@ root of the repository.
After pulling or building the image, see [instructions on how to use After pulling or building the image, see [instructions on how to use
the Docker image](#running-the-bot). the Docker image](#running-the-bot).
### Install from crates.io
The project can be from [crates.io][crates-io]. To install it, execute
`cargo install tenebrous-dicebot`. This will make the following
executables available on your system:
* `dicebot`: Main dicebot executable.
* `dicebot-cmd`: Run dicebot commands from the command line.
* `dicebot_migrate`: Standalone database migrator (not required).
* `tonic_client`: Test client for the gRPC connection (not required).
### Build from Source ### Build from Source
Precompiled executables are not yet available. Clone this repository Precompiled executables are not yet available. Clone this repository
@ -118,16 +89,8 @@ expressions.
!r 3d12 - 5d2 + 3 - 7d3 + 20d20 !r 3d12 - 5d2 + 3 - 7d3 + 20d20
``` ```
#### Keep/Drop Dice This system does not yet have the capability to handle things like D&D
The bot supports either keeping the highest dice in a roll, or 5e advantage or disadvantage.
dropping the highest dice in a roll. This allows the bot to handle
things like D&D 5e advantage or disadvantage.
```
!roll 2d20k1
!r 2d20dh1 + 5
!r 10d10k5 + 10d10dh5 - 2
```
### Storytelling System ### Storytelling System
@ -278,7 +241,6 @@ The most basic plans are:
* Perhaps some sort of character sheet integration. But for that, we * Perhaps some sort of character sheet integration. But for that, we
would need a sheet service. would need a sheet service.
* Use environment variables instead of config file in Docker image. * Use environment variables instead of config file in Docker image.
* Per-system game rules.
## Credits ## Credits
@ -292,9 +254,3 @@ support added for Chronicles of Darkness and Call of Cthulhu.
[main-repo]: https://git.agnos.is/projectmoon/tenebrous-dicebot [main-repo]: https://git.agnos.is/projectmoon/tenebrous-dicebot
[github-repo]: https://github.com/ProjectMoon/matrix-dicebot [github-repo]: https://github.com/ProjectMoon/matrix-dicebot
[roadmap]: https://git.agnos.is/projectmoon/tenebrous-dicebot/wiki/Roadmap [roadmap]: https://git.agnos.is/projectmoon/tenebrous-dicebot/wiki/Roadmap
[crates-io]: https://crates.io/crates/tenebrous-dicebot
[matrix-room]: https://matrix.to/#/#tenebrous:agnos.is
[agnosis-git-issues]: https://git.agnos.is/projectmoon/tenebrous-dicebot/issues
[github-discussions]: https://github.com/ProjectMoon/matrix-dicebot/discussions
[github-issues]: https://github.com/ProjectMoon/matrix-dicebot/issues
[agpl]: https://www.gnu.org/licenses/agpl-3.0.en.html

View File

@ -1,57 +0,0 @@
[package]
name = "tenebrous-dicebot"
version = "0.13.2"
rust-version = "1.68"
authors = ["projectmoon <projectmoon@agnos.is>", "Taylor C. Richberger <taywee@gmx.com>"]
edition = "2018"
license = 'AGPL-3.0-or-later'
description = 'An async Matrix dice bot for role-playing games'
readme = '../README.md'
repository = 'https://git.agnos.is/projectmoon/matrix-dicebot'
keywords = ["games", "dice", "matrix", "bot"]
categories = ["games"]
[build-dependencies]
tonic-build = "0.4"
[dependencies]
# indexmap version locked fixes a dependency cycle.
# indexmap = "=1.6.2"
log = "0.4"
tracing-subscriber = "0.2"
toml = "0.5"
nom = "5"
rand = "0.8"
rust-argon2 = "0.8"
thiserror = "1.0"
itertools = "0.10"
async-trait = "0.1"
url = "2.1"
dirs = "3.0"
indoc = "1.0"
combine = "4.5"
futures = "0.3"
html2text = "0.2"
phf = { version = "0.8", features = ["macros"] }
matrix-sdk = { version = "0.6" }
refinery = { version = "0.8", features = ["rusqlite"]}
barrel = { version = "0.7", features = ["sqlite3"] }
strum = { version = "0.22", features = ["derive"] }
tempfile = "3"
substring = "1.4"
fuse-rust = "0.2"
tonic = "0.4"
prost = "0.7"
tenebrous-rpc = { path = "../rpc", version = "0.1.0" }
[dependencies.sqlx]
version = "0.6"
features = [ "offline", "sqlite", "runtime-tokio-native-tls" ]
[dependencies.serde]
version = "1"
features = ['derive']
[dependencies.tokio]
version = "1"
features = [ "full" ]

View File

@ -1,360 +0,0 @@
/**
* In addition to the terms of the AGPL, this file is governed by the
* terms of the MIT license, from the original axfive-matrix-dicebot
* project.
*/
use nom::bytes::complete::take_while;
use nom::error::ErrorKind as NomErrorKind;
use nom::Err as NomErr;
use nom::{
alt, bytes::complete::tag, character::complete::digit1, complete, many0, named,
sequence::tuple, tag, IResult,
};
use super::dice::*;
//******************************
//Legacy Code
//******************************
fn is_whitespace(input: char) -> bool {
input == ' ' || input == '\n' || input == '\t' || input == '\r'
}
/// Eat whitespace, returning it
pub fn eat_whitespace(input: &str) -> IResult<&str, &str> {
let (input, whitespace) = take_while(is_whitespace)(input)?;
Ok((input, whitespace))
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Sign {
Plus,
Minus,
}
/// Intermediate parsed value for a keep-drop expression to indicate
/// which one it is.
enum ParsedKeepOrDrop<'a> {
Keep(&'a str),
Drop(&'a str),
NotPresent,
}
macro_rules! too_big {
($input: expr) => {
NomErr::Error(($input, NomErrorKind::TooLarge))
};
}
/// Parse a dice expression. Does not eat whitespace
fn parse_dice(input: &str) -> IResult<&str, Dice> {
let (input, (count, _, sides)) = tuple((digit1, tag("d"), digit1))(input)?;
let count: u32 = count.parse().map_err(|_| too_big!(count))?;
let sides = sides.parse().map_err(|_| too_big!(sides))?;
let (input, keep_drop) = parse_keep_or_drop(input, count)?;
Ok((input, Dice::new(count, sides, keep_drop)))
}
/// Extract keep/drop number as a string. Fails if the value is not a
/// string.
fn parse_keep_or_drop_text<'a>(
symbol: &'a str,
input: &'a str,
) -> IResult<&'a str, ParsedKeepOrDrop<'a>> {
let (parsed_kd, input) = match tuple::<&str, _, (_, _), _>((tag(symbol), digit1))(input) {
// if ok, one of the expressions is present
Ok((rest, (_, kd_expr))) => match symbol {
"k" => (ParsedKeepOrDrop::Keep(kd_expr), rest),
"dh" => (ParsedKeepOrDrop::Drop(kd_expr), rest),
_ => panic!("Unrecogized keep-drop symbol: {}", symbol),
},
// otherwise absent (attempt to keep all dice)
Err(_) => (ParsedKeepOrDrop::NotPresent, input),
};
Ok((input, parsed_kd))
}
/// Parse keep/drop expression, which consits of "k" or "dh" following
/// a dice expression. For example, "1d4h3" or "1d4dh2".
fn parse_keep_or_drop<'a>(input: &'a str, count: u32) -> IResult<&'a str, KeepOrDrop> {
let (input, keep) = parse_keep_or_drop_text("k", input)?;
let (input, drop) = parse_keep_or_drop_text("dh", input)?;
use ParsedKeepOrDrop::*;
let keep_drop: KeepOrDrop = match (keep, drop) {
//Potential valid Keep expression.
(Keep(keep), NotPresent) => match keep.parse().map_err(|_| too_big!(input))? {
_i if _i > count || _i == 0 => Ok(KeepOrDrop::None),
i => Ok(KeepOrDrop::Keep(i)),
},
//Potential valid Drop expression.
(NotPresent, Drop(drop)) => match drop.parse().map_err(|_| too_big!(input))? {
_i if _i >= count => Ok(KeepOrDrop::None),
i => Ok(KeepOrDrop::Drop(i)),
},
//No Keep or Drop specified; regular behavior.
(NotPresent, NotPresent) => Ok(KeepOrDrop::None),
//Anything else is an error.
_ => Err(NomErr::Error((input, NomErrorKind::Many1))),
}?;
Ok((input, keep_drop))
}
// Parse a single digit expression. Does not eat whitespace
fn parse_bonus(input: &str) -> IResult<&str, u32> {
let (input, bonus) = digit1(input)?;
Ok((input, bonus.parse().unwrap()))
}
// Parse a sign expression. Eats whitespace.
fn parse_sign(input: &str) -> IResult<&str, Sign> {
let (input, _) = eat_whitespace(input)?;
named!(sign(&str) -> Sign, alt!(
complete!(tag!("+")) => { |_| Sign::Plus } |
complete!(tag!("-")) => { |_| Sign::Minus }
));
let (input, sign) = sign(input)?;
Ok((input, sign))
}
// Parse an element expression. Eats whitespace.
fn parse_element(input: &str) -> IResult<&str, Element> {
let (input, _) = eat_whitespace(input)?;
named!(element(&str) -> Element, alt!(
parse_dice => { |d| Element::Dice(d) } |
parse_bonus => { |b| Element::Bonus(b) }
));
let (input, element) = element(input)?;
Ok((input, element))
}
// Parse a signed element expression. Eats whitespace.
fn parse_signed_element(input: &str) -> IResult<&str, SignedElement> {
let (input, _) = eat_whitespace(input)?;
let (input, sign) = parse_sign(input)?;
let (input, _) = eat_whitespace(input)?;
let (input, element) = parse_element(input)?;
let element = match sign {
Sign::Plus => SignedElement::Positive(element),
Sign::Minus => SignedElement::Negative(element),
};
Ok((input, element))
}
// Parse a full element expression. Eats whitespace.
pub fn parse_element_expression(input: &str) -> IResult<&str, ElementExpression> {
named!(first_element(&str) -> SignedElement, alt!(
parse_signed_element => { |e| e } |
parse_element => { |e| SignedElement::Positive(e) }
));
let (input, first) = first_element(input)?;
let (input, rest) = if input.trim().is_empty() {
(input, vec![first])
} else {
named!(rest_elements(&str) -> Vec<SignedElement>, many0!(parse_signed_element));
let (input, mut rest) = rest_elements(input)?;
rest.insert(0, first);
(input, rest)
};
Ok((input, ElementExpression(rest)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dice_test() {
assert_eq!(
parse_dice("2d4"),
Ok(("", Dice::new(2, 4, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("20d40"),
Ok(("", Dice::new(20, 40, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("8d7"),
Ok(("", Dice::new(8, 7, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("2d20k1"),
Ok(("", Dice::new(2, 20, KeepOrDrop::Keep(1))))
);
assert_eq!(
parse_dice("100d10k90"),
Ok(("", Dice::new(100, 10, KeepOrDrop::Keep(90))))
);
assert_eq!(
parse_dice("11d10k10"),
Ok(("", Dice::new(11, 10, KeepOrDrop::Keep(10))))
);
assert_eq!(
parse_dice("12d10k11"),
Ok(("", Dice::new(12, 10, KeepOrDrop::Keep(11))))
);
assert_eq!(
parse_dice("12d10k13"),
Ok(("", Dice::new(12, 10, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("12d10k0"),
Ok(("", Dice::new(12, 10, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("20d40dh5"),
Ok(("", Dice::new(20, 40, KeepOrDrop::Drop(5))))
);
assert_eq!(
parse_dice("8d7dh9"),
Ok(("", Dice::new(8, 7, KeepOrDrop::None)))
);
assert_eq!(
parse_dice("8d7dh8"),
Ok(("", Dice::new(8, 7, KeepOrDrop::None)))
);
}
#[test]
fn cant_have_both_keep_and_drop_test() {
let res = parse_dice("1d4k3dh2");
assert!(res.is_err());
match res {
Err(NomErr::Error((_, kind))) => {
assert_eq!(kind, NomErrorKind::Many1);
}
_ => panic!("Got success, expected error"),
}
}
#[test]
fn big_number_of_dice_doesnt_crash_test() {
let res = parse_dice("64378631476346123874527551481376547657868536d4");
assert!(res.is_err());
match res {
Err(NomErr::Error((input, kind))) => {
assert_eq!(kind, NomErrorKind::TooLarge);
assert_eq!(input, "64378631476346123874527551481376547657868536");
}
_ => panic!("Got success, expected error"),
}
}
#[test]
fn big_number_of_sides_doesnt_crash_test() {
let res = parse_dice("1d423562312587425472658956278456298376234876");
assert!(res.is_err());
match res {
Err(NomErr::Error((input, kind))) => {
assert_eq!(kind, NomErrorKind::TooLarge);
assert_eq!(input, "423562312587425472658956278456298376234876");
}
_ => panic!("Got success, expected error"),
}
}
#[test]
fn element_test() {
assert_eq!(
parse_element(" \t\n\r\n 8d7 \n"),
Ok((" \n", Element::Dice(Dice::new(8, 7, KeepOrDrop::None))))
);
assert_eq!(
parse_element(" \t\n\r\n 3d20k2 \n"),
Ok((" \n", Element::Dice(Dice::new(3, 20, KeepOrDrop::Keep(2)))))
);
assert_eq!(
parse_element(" \t\n\r\n 8 \n"),
Ok((" \n", Element::Bonus(8)))
);
}
#[test]
fn signed_element_test() {
assert_eq!(
parse_signed_element("+ 7"),
Ok(("", SignedElement::Positive(Element::Bonus(7))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8 \n"),
Ok((" \n", SignedElement::Negative(Element::Bonus(8))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8d4 \n"),
Ok((
" \n",
SignedElement::Negative(Element::Dice(Dice::new(8, 4, KeepOrDrop::None)))
))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8d4k4 \n"),
Ok((
" \n",
SignedElement::Negative(Element::Dice(Dice::new(8, 4, KeepOrDrop::Keep(4))))
))
);
assert_eq!(
parse_signed_element(" \t\n\r\n+ 8d4 \n"),
Ok((
" \n",
SignedElement::Positive(Element::Dice(Dice::new(8, 4, KeepOrDrop::None)))
))
);
}
#[test]
fn element_expression_test() {
assert_eq!(
parse_element_expression("8d4"),
Ok((
"",
ElementExpression(vec![SignedElement::Positive(Element::Dice(Dice::new(
8,
4,
KeepOrDrop::None
)))])
))
);
assert_eq!(
parse_element_expression("\t2d20k1 + 5"),
Ok((
"",
ElementExpression(vec![
SignedElement::Positive(Element::Dice(Dice::new(2, 20, KeepOrDrop::Keep(1)))),
SignedElement::Positive(Element::Bonus(5)),
])
))
);
assert_eq!(
parse_element_expression(" - 8d4 \n "),
Ok((
" \n ",
ElementExpression(vec![SignedElement::Negative(Element::Dice(Dice::new(
8,
4,
KeepOrDrop::None
)))])
))
);
assert_eq!(
parse_element_expression("\t3d4k2 + 7 - 5 - 6d12dh3 + 1d1 + 53 1d5 "),
Ok((
" 1d5 ",
ElementExpression(vec![
SignedElement::Positive(Element::Dice(Dice::new(3, 4, KeepOrDrop::Keep(2)))),
SignedElement::Positive(Element::Bonus(7)),
SignedElement::Negative(Element::Bonus(5)),
SignedElement::Negative(Element::Dice(Dice::new(6, 12, KeepOrDrop::Drop(3)))),
SignedElement::Positive(Element::Dice(Dice::new(1, 1, KeepOrDrop::None))),
SignedElement::Positive(Element::Bonus(53)),
])
))
);
}
}

View File

@ -1,33 +0,0 @@
use tenebrous_rpc::protos::dicebot::UserIdRequest;
use tenebrous_rpc::protos::dicebot::{dicebot_client::DicebotClient};
use tonic::{metadata::MetadataValue, transport::Channel, Request};
async fn create_client(
shared_secret: &str,
) -> Result<DicebotClient<Channel>, Box<dyn std::error::Error>> {
let channel = Channel::from_static("http://0.0.0.0:9090")
.connect()
.await?;
let bearer = MetadataValue::from_str(&format!("Bearer {}", shared_secret))?;
let client = DicebotClient::with_interceptor(channel, move |mut req: Request<()>| {
req.metadata_mut().insert("authorization", bearer.clone());
Ok(req)
});
Ok(client)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut client = create_client("example-key").await?;
let request = tonic::Request::new(UserIdRequest {
user_id: "@projectmoon:agnos.is".into(),
});
let response = client.rooms_for_user(request).await?.into_inner();
println!("Rooms: {:?}", response.rooms);
Ok(())
}

View File

@ -1,163 +0,0 @@
use super::DiceBot;
use crate::db::sqlite::Database;
use crate::db::Rooms;
use crate::error::BotError;
use log::{debug, error, info, warn};
use matrix_sdk::ruma::events::room::member::RoomMemberEventContent;
use matrix_sdk::ruma::events::{StrippedStateEvent, SyncMessageLikeEvent};
use matrix_sdk::{self, room::Room, ruma::events::room::message::RoomMessageEventContent};
use matrix_sdk::{Client, DisplayName};
use std::ops::Sub;
use std::time::UNIX_EPOCH;
use std::time::{Duration, SystemTime};
/// Check if a message is recent enough to actually process. If the
/// message is within "oldest_message_age" seconds, this function
/// returns true. If it's older than that, it returns false and logs a
/// debug message.
fn check_message_age(
event: &SyncMessageLikeEvent<RoomMessageEventContent>,
oldest_message_age: u64,
) -> bool {
let sending_time = event
.origin_server_ts()
.to_system_time()
.unwrap_or(UNIX_EPOCH);
let oldest_timestamp = SystemTime::now().sub(Duration::from_secs(oldest_message_age));
if sending_time > oldest_timestamp {
true
} else {
let age = match oldest_timestamp.duration_since(sending_time) {
Ok(n) => format!("{} seconds too old", n.as_secs()),
Err(_) => "before the UNIX epoch".to_owned(),
};
debug!("Ignoring message because it is {}: {:?}", age, event);
false
}
}
/// Determine whether or not to process a received message. This check
/// is necessary in addition to the event processing check because we
/// may receive message events when entering a room for the first
/// time, and we don't want to respond to things before the bot was in
/// the channel, but we do want to respond to things that were sent if
/// the bot left and rejoined quickly.
async fn should_process_message<'a>(
bot: &DiceBot,
event: &SyncMessageLikeEvent<RoomMessageEventContent>,
) -> Result<(String, String), BotError> {
//Ignore messages that are older than configured duration.
if !check_message_age(event, bot.config.oldest_message_age()) {
let state_check = bot.state.read().unwrap();
if !((*state_check).logged_skipped_old_messages()) {
drop(state_check);
let mut state = bot.state.write().unwrap();
(*state).skipped_old_messages();
}
return Err(BotError::ShouldNotProcessError);
}
let msg_body: String = event
.as_original()
.map(|e| e.content.body())
.map(str::to_string)
.unwrap_or_else(|| String::new());
let sender_username: String = format!(
"@{}:{}",
event.sender().localpart(),
event.sender().server_name()
);
// Do not process messages from the bot itself. Otherwise it might
// try to execute its own commands.
let bot_username = bot
.client
.user_id()
.map(|u| format!("@{}:{}", u.localpart(), u.server_name()))
.unwrap_or_default();
if sender_username == bot_username {
return Err(BotError::ShouldNotProcessError);
}
Ok((msg_body, sender_username))
}
async fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool {
db.should_process(room_id, event_id)
.await
.unwrap_or_else(|e| {
error!(
"Database error when checking if we should process an event: {}",
e.to_string()
);
false
})
}
pub(super) async fn on_stripped_state_member(
event: StrippedStateEvent<RoomMemberEventContent>,
client: Client,
room: Room,
) {
let room = match room {
Room::Invited(invited_room) => invited_room,
_ => return,
};
if room.own_user_id().as_str() != event.state_key {
return;
}
info!(
"Autojoining room {}",
room.display_name()
.await
.ok()
.unwrap_or_else(|| DisplayName::Named("[error]".to_string()))
);
if let Err(e) = client.join_room_by_id(&room.room_id()).await {
warn!("Could not join room: {}", e.to_string())
}
}
pub(super) async fn on_room_message(
event: SyncMessageLikeEvent<RoomMessageEventContent>,
room: Room,
bot: DiceBot,
) {
let room = match room {
Room::Joined(joined_room) => joined_room,
_ => return,
};
let room_id = room.room_id().as_str();
if !should_process_event(&bot.db, room_id, event.event_id().as_str()).await {
return;
}
let (msg_body, sender_username) =
if let Ok((msg_body, sender_username)) = should_process_message(&bot, &event).await {
(msg_body, sender_username)
} else {
return;
};
let results = bot
.execute_commands(&room, &sender_username, &msg_body)
.await;
bot.handle_results(
&room,
&sender_username,
event.event_id().to_owned(),
results,
)
.await;
}

View File

@ -1,22 +0,0 @@
use crate::systems::GameSystem;
use barrel::backend::Sqlite;
use barrel::{types, types::Type, Migration};
use itertools::Itertools;
use strum::IntoEnumIterator;
fn primary_id() -> Type {
types::text().unique(true).primary(true).nullable(false)
}
pub fn migration() -> String {
let mut m = Migration::new();
//Normally we would add a CHECK clause here, but types::custom requires a 'static string.
//Which means we can't automagically generate one from the enum.
m.create_table("room_info", move |t| {
t.add_column("room_id", primary_id());
t.add_column("game_system", types::text().nullable(false));
});
m.make::<Sqlite>()
}

View File

@ -1,361 +0,0 @@
use super::Database;
use crate::db::{errors::DataError, Users};
use crate::error::BotError;
use crate::models::User;
use async_trait::async_trait;
#[async_trait]
impl Users for Database {
async fn upsert_user(&self, user: &User) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query!(
r#"INSERT INTO accounts (user_id, password, account_status)
VALUES (?, ?, ?)
ON CONFLICT(user_id) DO
UPDATE SET password = ?, account_status = ?"#,
user.username,
user.password,
user.account_status,
user.password,
user.account_status
)
.execute(&mut tx)
.await?;
sqlx::query!(
r#"INSERT INTO user_state (user_id, active_room)
VALUES (?, ?)
ON CONFLICT(user_id) DO
UPDATE SET active_room = ?"#,
user.username,
user.active_room,
user.active_room
)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn delete_user(&self, username: &str) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query!(r#"DELETE FROM accounts WHERE user_id = ?"#, username)
.execute(&mut tx)
.await?;
sqlx::query!(r#"DELETE FROM user_state WHERE user_id = ?"#, username)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn get_user(&self, username: &str) -> Result<Option<User>, DataError> {
// Should be query_as! macro, but the left join breaks it with a
// non existing error message.
let user_row: Option<User> = sqlx::query_as(
r#"SELECT
a.user_id as "username",
a.password,
s.active_room,
COALESCE(a.account_status, 'not_registered') as "account_status"
FROM accounts a
LEFT JOIN user_state s on a.user_id = s.user_id
WHERE a.user_id = ?"#,
)
.bind(username)
.fetch_optional(&self.conn)
.await?;
Ok(user_row)
}
async fn authenticate_user(
&self,
username: &str,
raw_password: &str,
) -> Result<Option<User>, BotError> {
let user = self.get_user(username).await?;
Ok(user.filter(|u| u.verify_password(raw_password)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::sqlite::Database;
use crate::db::Users;
use crate::models::AccountStatus;
use std::future::Future;
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut)
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
let db = Database::new(db_path.path().to_str().unwrap())
.await
.unwrap();
f(db).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn create_and_get_full_user_test() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
assert_eq!(user.account_status, AccountStatus::Registered);
assert_eq!(user.active_room, Some("myroom".to_string()));
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_get_user_with_no_state_record() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::AwaitingActivation,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
sqlx::query("DELETE FROM user_state")
.execute(&db.conn)
.await
.expect("Could not delete from user_state table.");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
//These should be default values because the state record is missing.
assert_eq!(user.active_room, None);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_password() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, None);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_active_room() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
active_room: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.active_room, None);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_update_user() {
with_db(|db| async move {
let insert_result1 = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
..Default::default()
})
.await;
assert!(insert_result1.is_ok());
let insert_result2 = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("123".to_string()),
active_room: Some("room".to_string()),
account_status: AccountStatus::AwaitingActivation,
})
.await;
assert!(insert_result2.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
//From second upsert
assert_eq!(user.password, Some("123".to_string()));
assert_eq!(user.active_room, Some("room".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_delete_user() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
db.delete_user("myuser")
.await
.expect("User deletion query failed");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_none());
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn username_not_in_db_returns_none() {
with_db(|db| async move {
let user = db
.get_user("does not exist")
.await
.expect("Get user query failure");
assert!(user.is_none());
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn authenticate_user_is_some_with_valid_password() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some(
crate::logic::hash_password("abc").expect("password hash error!"),
),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.authenticate_user("myuser", "abc")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn authenticate_user_is_none_with_wrong_password() {
with_db(|db| async move {
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some(
crate::logic::hash_password("abc").expect("password hash error!"),
),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.authenticate_user("myuser", "wrong-password")
.await
.expect("User retrieval query failed");
assert!(user.is_none());
})
.await;
}
}

View File

@ -1,50 +0,0 @@
use crate::error::BotError;
use crate::{config::Config, db::sqlite::Database};
use log::{info, warn};
use matrix_sdk::Client;
use service::DicebotRpcService;
use std::sync::Arc;
use tenebrous_rpc::protos::dicebot::dicebot_server::DicebotServer;
use tonic::{metadata::MetadataValue, transport::Server, Request, Status};
pub(crate) mod service;
pub async fn serve_grpc(
config: &Arc<Config>,
db: &Database,
client: &Client,
) -> Result<(), BotError> {
match config.rpc_addr().zip(config.rpc_key()) {
Some((addr, rpc_key)) => {
let expected_bearer = MetadataValue::from_str(&format!("Bearer {}", rpc_key))?;
let addr = addr.parse()?;
let rpc_service = DicebotRpcService {
db: db.clone(),
config: config.clone(),
client: client.clone(),
};
info!("Serving Dicebot gRPC service on {}", addr);
let interceptor = move |req: Request<()>| match req.metadata().get("authorization") {
Some(bearer) if bearer == expected_bearer => Ok(req),
_ => Err(Status::unauthenticated("No valid auth token")),
};
let server = DicebotServer::with_interceptor(rpc_service, interceptor);
Server::builder()
.add_service(server)
.serve(addr)
.await
.map_err(|e| e.into())
}
_ => noop().await,
}
}
pub async fn noop() -> Result<(), BotError> {
warn!("RPC address or shared secret not specified. Not enabling gRPC.");
Ok(())
}

View File

@ -1,117 +0,0 @@
use crate::db::{errors::DataError, Variables};
use crate::error::BotError;
use crate::matrix;
use crate::{config::Config, db::sqlite::Database};
use futures::stream;
use futures::{StreamExt, TryFutureExt, TryStreamExt};
use matrix_sdk::ruma::OwnedUserId;
use matrix_sdk::{room::Joined, Client};
use std::convert::TryFrom;
use std::sync::Arc;
use tenebrous_rpc::protos::dicebot::{
dicebot_server::Dicebot, rooms_list_reply::Room, GetAllVariablesReply, GetAllVariablesRequest,
RoomsListReply, SetVariableReply, SetVariableRequest, UserIdRequest,
};
use tenebrous_rpc::protos::dicebot::{GetVariableReply, GetVariableRequest};
use tonic::{Code, Request, Response, Status};
impl From<BotError> for Status {
fn from(error: BotError) -> Status {
Status::new(Code::Internal, error.to_string())
}
}
impl From<DataError> for Status {
fn from(error: DataError) -> Status {
Status::new(Code::Internal, error.to_string())
}
}
#[derive(Clone)]
pub(super) struct DicebotRpcService {
pub(super) config: Arc<Config>,
pub(super) db: Database,
pub(super) client: Client,
}
#[tonic::async_trait]
impl Dicebot for DicebotRpcService {
async fn set_variable(
&self,
request: Request<SetVariableRequest>,
) -> Result<Response<SetVariableReply>, Status> {
let SetVariableRequest {
user_id,
room_id,
variable_name,
value,
} = request.into_inner();
self.db
.set_user_variable(&user_id, &room_id, &variable_name, value)
.await?;
Ok(Response::new(SetVariableReply { success: true }))
}
async fn get_variable(
&self,
request: Request<GetVariableRequest>,
) -> Result<Response<GetVariableReply>, Status> {
let request = request.into_inner();
let value = self
.db
.get_user_variable(&request.user_id, &request.room_id, &request.variable_name)
.await?;
Ok(Response::new(GetVariableReply { value }))
}
async fn get_all_variables(
&self,
request: Request<GetAllVariablesRequest>,
) -> Result<Response<GetAllVariablesReply>, Status> {
let request = request.into_inner();
let variables = self
.db
.get_user_variables(&request.user_id, &request.room_id)
.await?;
Ok(Response::new(GetAllVariablesReply { variables }))
}
async fn rooms_for_user(
&self,
request: Request<UserIdRequest>,
) -> Result<Response<RoomsListReply>, Status> {
let UserIdRequest { user_id } = request.into_inner();
let user_id = OwnedUserId::try_from(user_id).map_err(BotError::from)?;
let rooms_for_user = matrix::get_rooms_for_user(&self.client, &user_id)
.err_into::<BotError>()
.await?;
let mut rooms: Vec<Room> = stream::iter(rooms_for_user)
.filter_map(|room: Joined| async move {
let room: Result<Room, _> = room.display_name().await.map(|room_name| Room {
room_id: room.room_id().to_string(),
display_name: room_name.to_string(),
});
Some(room)
})
.err_into::<BotError>()
.try_collect()
.await?;
let sort = |r1: &Room, r2: &Room| {
r1.display_name
.to_lowercase()
.cmp(&r2.display_name.to_lowercase())
};
rooms.sort_by(sort);
Ok(Response::new(RoomsListReply { rooms }))
}
}

View File

@ -1,21 +0,0 @@
use strum::{AsRefStr, Display, EnumIter, EnumString};
#[derive(EnumString, EnumIter, AsRefStr, Display)]
pub(crate) enum GameSystem {
ChroniclesOfDarkness,
Changeling,
MageTheAwakening,
WerewolfTheForsaken,
DeviantTheRenegades,
MummyTheCurse,
PrometheanTheCreated,
CallOfCthulhu,
DungeonsAndDragons5e,
DungeonsAndDragons4e,
DungeonsAndDragons35e,
DungeonsAndDragons2e,
DungeonsAndDragons1e,
None,
}
impl GameSystem {}

View File

@ -1,18 +0,0 @@
[package]
name = "tenebrous-rpc"
version = "0.1.0"
authors = ["projectmoon <projectmoon@agnos.is>"]
edition = "2018"
description = "gRPC protobuf models for Tenebrous."
homepage = "https://git.agnos.is/projectmoon/tenebrous-dicebot"
repository = "https://git.agnos.is/projectmoon/tenebrous-dicebot"
license = "AGPL-3.0-or-later"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
tonic-build = "0.4"
[dependencies]
tonic = "0.4"
prost = "0.7"

View File

@ -1,4 +0,0 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("protos/dicebot.proto")?;
Ok(())
}

View File

@ -1,52 +0,0 @@
syntax = "proto3";
package dicebot;
service Dicebot {
rpc GetVariable(GetVariableRequest) returns (GetVariableReply);
rpc GetAllVariables(GetAllVariablesRequest) returns (GetAllVariablesReply);
rpc SetVariable(SetVariableRequest) returns (SetVariableReply);
rpc RoomsForUser(UserIdRequest) returns (RoomsListReply);
}
message GetVariableRequest {
string user_id = 1;
string room_id = 2;
string variable_name = 3;
}
message GetVariableReply {
int32 value = 1;
}
message GetAllVariablesRequest {
string user_id = 1;
string room_id = 2;
}
message GetAllVariablesReply {
map<string, int32> variables = 1;
}
message SetVariableRequest {
string user_id = 1;
string room_id = 2;
string variable_name = 3;
int32 value = 4;
}
message SetVariableReply {
bool success = 1;
}
message UserIdRequest {
string user_id = 1;
}
message RoomsListReply {
message Room {
string room_id = 1;
string display_name = 2;
}
repeated Room rooms = 1;
}

View File

@ -1,5 +0,0 @@
pub mod protos {
pub mod dicebot {
tonic::include_proto!("dicebot");
}
}

View File

@ -6,52 +6,23 @@
use std::fmt; use std::fmt;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
/// A basic dice roll, in XdY notation, like "1d4" or "3d6". //Old stuff, for regular dice rolling. To be moved elsewhere.
/// Optionally supports D&D advantage/disadvantge keep-or-drop
/// functionality.
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct Dice { pub struct Dice {
pub(crate) count: u32, pub(crate) count: u32,
pub(crate) sides: u32, pub(crate) sides: u32,
pub(crate) keep_drop: KeepOrDrop,
}
/// Enum indicating how to handle bonuses or penalties using extra
/// dice. If set to Keep, the roll will keep the highest X number of
/// dice in the roll, and add those together. If set to Drop, the
/// opposite is performed, and the lowest X number of dice are added
/// instead. If set to None, then all dice in the roll are added up as
/// normal.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum KeepOrDrop {
/// Keep only the X highest dice for adding up to the total.
Keep(u32),
/// Keep only the X lowest dice (i.e. drop the highest) for adding
/// up to the total.
Drop(u32),
/// Add up all dice in the roll for the total.
None,
} }
impl fmt::Display for Dice { impl fmt::Display for Dice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.keep_drop { write!(f, "{}d{}", self.count, self.sides)
KeepOrDrop::Keep(keep) => write!(f, "{}d{}k{}", self.count, self.sides, keep),
KeepOrDrop::Drop(drop) => write!(f, "{}d{}dh{}", self.count, self.sides, drop),
KeepOrDrop::None => write!(f, "{}d{}", self.count, self.sides),
}
} }
} }
impl Dice { impl Dice {
pub fn new(count: u32, sides: u32, keep_drop: KeepOrDrop) -> Dice { pub fn new(count: u32, sides: u32) -> Dice {
Dice { Dice { count, sides }
count,
sides,
keep_drop,
}
} }
} }

189
src/basic/parser.rs Normal file
View File

@ -0,0 +1,189 @@
/**
* In addition to the terms of the AGPL, this file is governed by the
* terms of the MIT license, from the original axfive-matrix-dicebot
* project.
*/
use nom::bytes::complete::take_while;
use nom::{
alt, bytes::complete::tag, character::complete::digit1, complete, many0, named,
sequence::tuple, tag, IResult,
};
use super::dice::*;
//******************************
//Legacy Code
//******************************
fn is_whitespace(input: char) -> bool {
input == ' ' || input == '\n' || input == '\t' || input == '\r'
}
/// Eat whitespace, returning it
pub fn eat_whitespace(input: &str) -> IResult<&str, &str> {
let (input, whitespace) = take_while(is_whitespace)(input)?;
Ok((input, whitespace))
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Sign {
Plus,
Minus,
}
// Parse a dice expression. Does not eat whitespace
fn parse_dice(input: &str) -> IResult<&str, Dice> {
let (input, (count, _, sides)) = tuple((digit1, tag("d"), digit1))(input)?;
Ok((
input,
Dice::new(count.parse().unwrap(), sides.parse().unwrap()),
))
}
// Parse a single digit expression. Does not eat whitespace
fn parse_bonus(input: &str) -> IResult<&str, u32> {
let (input, bonus) = digit1(input)?;
Ok((input, bonus.parse().unwrap()))
}
// Parse a sign expression. Eats whitespace.
fn parse_sign(input: &str) -> IResult<&str, Sign> {
let (input, _) = eat_whitespace(input)?;
named!(sign(&str) -> Sign, alt!(
complete!(tag!("+")) => { |_| Sign::Plus } |
complete!(tag!("-")) => { |_| Sign::Minus }
));
let (input, sign) = sign(input)?;
Ok((input, sign))
}
// Parse an element expression. Eats whitespace.
fn parse_element(input: &str) -> IResult<&str, Element> {
let (input, _) = eat_whitespace(input)?;
named!(element(&str) -> Element, alt!(
parse_dice => { |d| Element::Dice(d) } |
parse_bonus => { |b| Element::Bonus(b) }
));
let (input, element) = element(input)?;
Ok((input, element))
}
// Parse a signed element expression. Eats whitespace.
fn parse_signed_element(input: &str) -> IResult<&str, SignedElement> {
let (input, _) = eat_whitespace(input)?;
let (input, sign) = parse_sign(input)?;
let (input, _) = eat_whitespace(input)?;
let (input, element) = parse_element(input)?;
let element = match sign {
Sign::Plus => SignedElement::Positive(element),
Sign::Minus => SignedElement::Negative(element),
};
Ok((input, element))
}
// Parse a full element expression. Eats whitespace.
pub fn parse_element_expression(input: &str) -> IResult<&str, ElementExpression> {
named!(first_element(&str) -> SignedElement, alt!(
parse_signed_element => { |e| e } |
parse_element => { |e| SignedElement::Positive(e) }
));
let (input, first) = first_element(input)?;
let (input, rest) = if input.trim().is_empty() {
(input, vec![first])
} else {
named!(rest_elements(&str) -> Vec<SignedElement>, many0!(parse_signed_element));
let (input, mut rest) = rest_elements(input)?;
rest.insert(0, first);
(input, rest)
};
Ok((input, ElementExpression(rest)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dice_test() {
assert_eq!(parse_dice("2d4"), Ok(("", Dice::new(2, 4))));
assert_eq!(parse_dice("20d40"), Ok(("", Dice::new(20, 40))));
assert_eq!(parse_dice("8d7"), Ok(("", Dice::new(8, 7))));
}
#[test]
fn element_test() {
assert_eq!(
parse_element(" \t\n\r\n 8d7 \n"),
Ok((" \n", Element::Dice(Dice::new(8, 7))))
);
assert_eq!(
parse_element(" \t\n\r\n 8 \n"),
Ok((" \n", Element::Bonus(8)))
);
}
#[test]
fn signed_element_test() {
assert_eq!(
parse_signed_element("+ 7"),
Ok(("", SignedElement::Positive(Element::Bonus(7))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8 \n"),
Ok((" \n", SignedElement::Negative(Element::Bonus(8))))
);
assert_eq!(
parse_signed_element(" \t\n\r\n- 8d4 \n"),
Ok((
" \n",
SignedElement::Negative(Element::Dice(Dice::new(8, 4)))
))
);
assert_eq!(
parse_signed_element(" \t\n\r\n+ 8d4 \n"),
Ok((
" \n",
SignedElement::Positive(Element::Dice(Dice::new(8, 4)))
))
);
}
#[test]
fn element_expression_test() {
assert_eq!(
parse_element_expression("8d4"),
Ok((
"",
ElementExpression(vec![SignedElement::Positive(Element::Dice(Dice::new(
8, 4
)))])
))
);
assert_eq!(
parse_element_expression(" - 8d4 \n "),
Ok((
" \n ",
ElementExpression(vec![SignedElement::Negative(Element::Dice(Dice::new(
8, 4
)))])
))
);
assert_eq!(
parse_element_expression("\t3d4 + 7 - 5 - 6d12 + 1d1 + 53 1d5 "),
Ok((
" 1d5 ",
ElementExpression(vec![
SignedElement::Positive(Element::Dice(Dice::new(3, 4))),
SignedElement::Positive(Element::Bonus(7)),
SignedElement::Negative(Element::Bonus(5)),
SignedElement::Negative(Element::Dice(Dice::new(6, 12))),
SignedElement::Positive(Element::Dice(Dice::new(1, 1))),
SignedElement::Positive(Element::Bonus(53)),
])
))
);
}
}

View File

@ -4,7 +4,6 @@
* project. * project.
*/ */
use crate::basic::dice; use crate::basic::dice;
use crate::basic::dice::KeepOrDrop;
use rand::prelude::*; use rand::prelude::*;
use std::fmt; use std::fmt;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
@ -20,27 +19,15 @@ pub trait Rolled {
} }
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone)]
/// array of rolls in order, how many dice to keep, and how many to drop pub struct DiceRoll(pub Vec<u32>);
/// keep indicates how many of the highest dice to keep
/// drop indicates how many of the highest dice to drop
pub struct DiceRoll (pub Vec<u32>, usize, usize);
impl DiceRoll { impl DiceRoll {
pub fn rolls(&self) -> &[u32] { pub fn rolls(&self) -> &[u32] {
&self.0 &self.0
} }
pub fn keep(&self) -> usize {
self.1
}
pub fn drop(&self) -> usize {
self.2
}
// only count kept dice in total
pub fn total(&self) -> u32 { pub fn total(&self) -> u32 {
self.0[self.2..self.1].iter().sum() self.0.iter().sum()
} }
} }
@ -54,21 +41,11 @@ impl fmt::Display for DiceRoll {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.rolled_value())?; write!(f, "{}", self.rolled_value())?;
let rolls = self.rolls(); let rolls = self.rolls();
let keep = self.keep(); let mut iter = rolls.iter();
let drop = self.drop();
let mut iter = rolls.iter().enumerate();
if let Some(first) = iter.next() { if let Some(first) = iter.next() {
if drop != 0 { write!(f, " ({}", first)?;
write!(f, " ([{}]", first.1)?;
} else {
write!(f, " ({}", first.1)?;
}
for roll in iter { for roll in iter {
if roll.0 >= keep || roll.0 < drop { write!(f, " + {}", roll)?;
write!(f, " + [{}]", roll.1)?;
} else {
write!(f, " + {}", roll.1)?;
}
} }
write!(f, ")")?; write!(f, ")")?;
} }
@ -81,17 +58,11 @@ impl Roll for dice::Dice {
fn roll(&self) -> DiceRoll { fn roll(&self) -> DiceRoll {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let mut rolls: Vec<_> = (0..self.count) let rolls: Vec<_> = (0..self.count)
.map(|_| rng.gen_range(1..=self.sides)) .map(|_| rng.gen_range(1..=self.sides))
.collect(); .collect();
// sort rolls in descending order
rolls.sort_by(|a, b| b.cmp(a));
match self.keep_drop { DiceRoll(rolls)
KeepOrDrop::Keep(k) => DiceRoll(rolls,k as usize, 0),
KeepOrDrop::Drop(dh) => DiceRoll(rolls,self.count as usize, dh as usize),
KeepOrDrop::None => DiceRoll(rolls,self.count as usize, 0),
}
} }
} }
@ -227,26 +198,18 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn dice_roll_display_test() { fn dice_roll_display_test() {
assert_eq!(DiceRoll(vec![1, 3, 4], 3, 0).to_string(), "8 (1 + 3 + 4)"); assert_eq!(DiceRoll(vec![1, 3, 4]).to_string(), "8 (1 + 3 + 4)");
assert_eq!(DiceRoll(vec![], 0, 0).to_string(), "0"); assert_eq!(DiceRoll(vec![]).to_string(), "0");
assert_eq!( assert_eq!(
DiceRoll(vec![4, 7, 2, 10], 4, 0).to_string(), DiceRoll(vec![4, 7, 2, 10]).to_string(),
"23 (4 + 7 + 2 + 10)" "23 (4 + 7 + 2 + 10)"
); );
assert_eq!(
DiceRoll(vec![20, 13, 11, 10], 3, 0).to_string(),
"44 (20 + 13 + 11 + [10])"
);
assert_eq!(
DiceRoll(vec![20, 13, 11, 10], 4, 1).to_string(),
"34 ([20] + 13 + 11 + 10)"
);
} }
#[test] #[test]
fn element_roll_display_test() { fn element_roll_display_test() {
assert_eq!( assert_eq!(
ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0)).to_string(), ElementRoll::Dice(DiceRoll(vec![1, 3, 4])).to_string(),
"8 (1 + 3 + 4)" "8 (1 + 3 + 4)"
); );
assert_eq!(ElementRoll::Bonus(7).to_string(), "7"); assert_eq!(ElementRoll::Bonus(7).to_string(), "7");
@ -255,11 +218,11 @@ mod tests {
#[test] #[test]
fn signed_element_roll_display_test() { fn signed_element_roll_display_test() {
assert_eq!( assert_eq!(
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))).to_string(), SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))).to_string(),
"8 (1 + 3 + 4)" "8 (1 + 3 + 4)"
); );
assert_eq!( assert_eq!(
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))).to_string(), SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))).to_string(),
"-8 (1 + 3 + 4)" "-8 (1 + 3 + 4)"
); );
assert_eq!( assert_eq!(
@ -276,14 +239,14 @@ mod tests {
fn element_expression_roll_display_test() { fn element_expression_roll_display_test() {
assert_eq!( assert_eq!(
ElementExpressionRoll(vec![SignedElementRoll::Positive(ElementRoll::Dice( ElementExpressionRoll(vec![SignedElementRoll::Positive(ElementRoll::Dice(
DiceRoll(vec![1, 3, 4], 3, 0) DiceRoll(vec![1, 3, 4])
)),]) )),])
.to_string(), .to_string(),
"8 (1 + 3 + 4)" "8 (1 + 3 + 4)"
); );
assert_eq!( assert_eq!(
ElementExpressionRoll(vec![SignedElementRoll::Negative(ElementRoll::Dice( ElementExpressionRoll(vec![SignedElementRoll::Negative(ElementRoll::Dice(
DiceRoll(vec![1, 3, 4], 3, 0) DiceRoll(vec![1, 3, 4])
)),]) )),])
.to_string(), .to_string(),
"-8 (1 + 3 + 4)" "-8 (1 + 3 + 4)"
@ -300,8 +263,8 @@ mod tests {
); );
assert_eq!( assert_eq!(
ElementExpressionRoll(vec![ ElementExpressionRoll(vec![
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))), SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))),
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 2], 2, 0))), SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 2]))),
SignedElementRoll::Positive(ElementRoll::Bonus(4)), SignedElementRoll::Positive(ElementRoll::Bonus(4)),
SignedElementRoll::Negative(ElementRoll::Bonus(7)), SignedElementRoll::Negative(ElementRoll::Bonus(7)),
]) ])
@ -310,33 +273,13 @@ mod tests {
); );
assert_eq!( assert_eq!(
ElementExpressionRoll(vec![ ElementExpressionRoll(vec![
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4], 3, 0))), SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![1, 3, 4]))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 2], 2, 0))), SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![1, 2]))),
SignedElementRoll::Negative(ElementRoll::Bonus(4)), SignedElementRoll::Negative(ElementRoll::Bonus(4)),
SignedElementRoll::Positive(ElementRoll::Bonus(7)), SignedElementRoll::Positive(ElementRoll::Bonus(7)),
]) ])
.to_string(), .to_string(),
"-2 (-8 (1 + 3 + 4) + 3 (1 + 2) - 4 + 7)" "-2 (-8 (1 + 3 + 4) + 3 (1 + 2) - 4 + 7)"
); );
assert_eq!(
ElementExpressionRoll(vec![
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![4, 3, 1], 3, 0))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![12, 2], 1, 0))),
SignedElementRoll::Negative(ElementRoll::Bonus(4)),
SignedElementRoll::Positive(ElementRoll::Bonus(7)),
])
.to_string(),
"7 (-8 (4 + 3 + 1) + 12 (12 + [2]) - 4 + 7)"
);
assert_eq!(
ElementExpressionRoll(vec![
SignedElementRoll::Negative(ElementRoll::Dice(DiceRoll(vec![4, 3, 1], 3, 1))),
SignedElementRoll::Positive(ElementRoll::Dice(DiceRoll(vec![12, 2], 2, 0))),
SignedElementRoll::Negative(ElementRoll::Bonus(4)),
SignedElementRoll::Positive(ElementRoll::Bonus(7)),
])
.to_string(),
"13 (-4 ([4] + 3 + 1) + 14 (12 + 2) - 4 + 7)"
);
} }
} }

View File

@ -1,5 +1,4 @@
use matrix_sdk::ruma::room_id; use matrix_sdk::identifiers::room_id;
use matrix_sdk::Client;
use tenebrous_dicebot::commands; use tenebrous_dicebot::commands;
use tenebrous_dicebot::commands::ResponseExtractor; use tenebrous_dicebot::commands::ResponseExtractor;
use tenebrous_dicebot::context::{Context, RoomContext}; use tenebrous_dicebot::context::{Context, RoomContext};
@ -27,15 +26,11 @@ async fn main() -> Result<(), BotError> {
.await?; .await?;
let context = Context { let context = Context {
db, db: db,
account: Account::default(), account: Account::default(),
matrix_client: Client::new(homeserver).await.expect("Could not create matrix client"), matrix_client: &matrix_sdk::Client::new(homeserver)
origin_room: RoomContext { .expect("Could not create matrix client"),
id: &room_id!("!fakeroomid:example.com"), room: RoomContext {
display_name: "fake room".to_owned(),
secure: false,
},
active_room: RoomContext {
id: &room_id!("!fakeroomid:example.com"), id: &room_id!("!fakeroomid:example.com"),
display_name: "fake room".to_owned(), display_name: "fake room".to_owned(),
secure: false, secure: false,

View File

@ -1,36 +1,21 @@
//Needed for nested Result handling from tokio. Probably can go away after 1.47.0. //Needed for nested Result handling from tokio. Probably can go away after 1.47.0.
#![type_length_limit = "7605144"] #![type_length_limit = "7605144"]
use futures::try_join;
use log::error; use log::error;
use matrix_sdk::Client;
use std::env; use std::env;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use tenebrous_dicebot::bot::DiceBot; use tenebrous_dicebot::bot::DiceBot;
use tenebrous_dicebot::config::*; use tenebrous_dicebot::config::*;
use tenebrous_dicebot::db::sqlite::Database; use tenebrous_dicebot::db::sqlite::Database;
use tenebrous_dicebot::error::BotError; use tenebrous_dicebot::error::BotError;
use tenebrous_dicebot::rpc;
use tenebrous_dicebot::state::DiceBotState; use tenebrous_dicebot::state::DiceBotState;
use tracing_subscriber::filter::EnvFilter; use tracing_subscriber::filter::EnvFilter;
/// Attempt to create config object and ddatabase connection pool from
/// the given config path. An error is returned if config creation or
/// database pool creation fails for some reason.
async fn init(config_path: &str) -> Result<(Arc<Config>, Database, Client), BotError> {
let cfg = read_config(config_path)?;
let cfg = Arc::new(cfg);
let sqlite_path = format!("{}/dicebot.sqlite", cfg.database_path());
let db = Database::new(&sqlite_path).await?;
let client = tenebrous_dicebot::matrix::create_client(&cfg).await?;
Ok((cfg, db, client))
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), BotError> { async fn main() {
let filter = if env::var("RUST_LOG").is_ok() { let filter = if env::var("RUST_LOG").is_ok() {
EnvFilter::from_default_env() EnvFilter::from_default_env()
} else { } else {
EnvFilter::new("tonic=info,tenebrous_dicebot=info,dicebot=info,refinery=info") EnvFilter::new("tenebrous_dicebot=info,dicebot=info,refinery=info")
}; };
tracing_subscriber::fmt().with_env_filter(filter).init(); tracing_subscriber::fmt().with_env_filter(filter).init();
@ -38,9 +23,7 @@ async fn main() -> Result<(), BotError> {
match run().await { match run().await {
Ok(_) => (), Ok(_) => (),
Err(e) => error!("Error: {}", e), Err(e) => error!("Error: {}", e),
} };
Ok(())
} }
async fn run() -> Result<(), BotError> { async fn run() -> Result<(), BotError> {
@ -49,22 +32,12 @@ async fn run() -> Result<(), BotError> {
.next() .next()
.expect("Need a config as an argument"); .expect("Need a config as an argument");
let (cfg, db, client) = init(&config_path).await?; let cfg = Arc::new(read_config(config_path)?);
let grpc = rpc::serve_grpc(&cfg, &db, &client); let sqlite_path = format!("{}/dicebot.sqlite", cfg.database_path());
let bot = run_bot(&cfg, &db, &client); let db = Database::new(&sqlite_path).await?;
match try_join!(bot, grpc) {
Ok(_) => (),
Err(e) => error!("Error: {:?}", e),
};
Ok(())
}
async fn run_bot(cfg: &Arc<Config>, db: &Database, client: &Client) -> Result<(), BotError> {
let state = Arc::new(RwLock::new(DiceBotState::new(&cfg))); let state = Arc::new(RwLock::new(DiceBotState::new(&cfg)));
match DiceBot::new(cfg, &state, db, client) { match DiceBot::new(&cfg, &state, &db) {
Ok(bot) => bot.run().await?, Ok(bot) => bot.run().await?,
Err(e) => println!("Error connecting: {:?}", e), Err(e) => println!("Error connecting: {:?}", e),
}; };

View File

@ -1,17 +1,12 @@
use crate::commands::{execute_command, ExecutionResult, ResponseExtractor};
use crate::context::{Context, RoomContext}; use crate::context::{Context, RoomContext};
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::error::BotError; use crate::error::BotError;
use crate::logic; use crate::logic;
use crate::matrix; use crate::matrix;
use crate::{
commands::{execute_command, ExecutionResult, ResponseExtractor},
models::Account,
};
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use matrix_sdk::ruma::{OwnedEventId, RoomId}; use matrix_sdk::{self, identifiers::EventId, room::Joined, Client};
use matrix_sdk::{self, room::Joined, Client};
use std::clone::Clone; use std::clone::Clone;
use std::convert::TryFrom;
/// Handle responding to a single command being executed. Wil print /// Handle responding to a single command being executed. Wil print
/// out the full result of that command. /// out the full result of that command.
@ -20,7 +15,7 @@ pub(super) async fn handle_single_result(
cmd_result: &ExecutionResult, cmd_result: &ExecutionResult,
respond_to: &str, respond_to: &str,
room: &Joined, room: &Joined,
event_id: OwnedEventId, event_id: EventId,
) { ) {
let html = cmd_result.message_html(respond_to); let html = cmd_result.message_html(respond_to);
let plain = cmd_result.message_plain(respond_to); let plain = cmd_result.message_plain(respond_to);
@ -100,57 +95,24 @@ pub(super) async fn handle_multiple_results(
matrix::send_message(client, room.room_id(), (&message, &plain), None).await; matrix::send_message(client, room.room_id(), (&message, &plain), None).await;
} }
/// Map an account's active room value to an actual matrix room, if /// Create a context for command execution. Can fai if the room
/// the account has an active room. This only retrieves the /// context creation fails.
/// user-specified active room, and doesn't perform any further async fn create_context<'a>(
/// filtering. db: &'a Database,
fn get_account_active_room(client: &Client, account: &Account) -> Result<Option<Joined>, BotError> { client: &'a Client,
let active_room = account room: &'a Joined,
.registered_user() sender: &'a str,
.and_then(|u| u.active_room.as_deref()) command: &'a str,
.map(|room_id| <&RoomId>::try_from(room_id)) ) -> Result<Context<'a>, BotError> {
.transpose()? let room_ctx = RoomContext::new(room, sender).await?;
.and_then(|active_room_id| client.get_joined_room(active_room_id)); Ok(Context {
Ok(active_room)
}
/// Execute a single command in the list of commands. Can fail if the
/// Account value cannot be created/fetched from the database, or if
/// room display names cannot be calculated. Otherwise, the success or
/// error of command execution itself is returned.
async fn execute_single_command(
command: &str,
db: &Database,
client: &Client,
origin_room: &Joined,
sender: &str,
) -> ExecutionResult {
let origin_ctx = RoomContext::new(origin_room, sender).await?;
let account = logic::get_account(db, sender).await?;
let active_room = get_account_active_room(client, &account)?;
// Active room is used in secure command-issuing rooms. In
// "public" rooms, where other users are, treat origin as the
// active room.
let active_room = active_room
.as_ref()
.filter(|_| origin_ctx.secure)
.unwrap_or(origin_room);
let active_ctx = RoomContext::new(active_room, sender).await?;
let ctx = Context {
account,
db: db.clone(), db: db.clone(),
matrix_client: client.clone(), matrix_client: client,
origin_room: origin_ctx, room: room_ctx,
username: &sender, username: &sender,
active_room: active_ctx, account: logic::get_account(db, &sender).await?,
message_body: &command, message_body: &command,
}; })
execute_command(&ctx).await
} }
/// Attempt to execute all commands sent to the bot in a message. This /// Attempt to execute all commands sent to the bot in a message. This
@ -165,8 +127,13 @@ pub(super) async fn execute(
) -> Vec<(String, ExecutionResult)> { ) -> Vec<(String, ExecutionResult)> {
stream::iter(commands) stream::iter(commands)
.then(|command| async move { .then(|command| async move {
let result = execute_single_command(command, db, client, room, sender).await; match create_context(db, client, room, sender, command).await {
(command.to_owned(), result) Err(e) => (command.to_owned(), Err(e)),
Ok(ctx) => {
let cmd_result = execute_command(&ctx).await;
(command.to_owned(), cmd_result)
}
}
}) })
.collect() .collect()
.await .await

158
src/bot/event_handlers.rs Normal file
View File

@ -0,0 +1,158 @@
use super::DiceBot;
use crate::db::sqlite::Database;
use crate::db::Rooms;
use crate::error::BotError;
use async_trait::async_trait;
use log::{debug, error, info, warn};
use matrix_sdk::{
self,
events::{
room::member::MemberEventContent,
room::message::{MessageEventContent, MessageType, TextMessageEventContent},
StrippedStateEvent, SyncMessageEvent,
},
room::Room,
EventHandler,
};
use std::ops::Sub;
use std::time::{Duration, SystemTime};
use std::{clone::Clone, time::UNIX_EPOCH};
/// Check if a message is recent enough to actually process. If the
/// message is within "oldest_message_age" seconds, this function
/// returns true. If it's older than that, it returns false and logs a
/// debug message.
fn check_message_age(
event: &SyncMessageEvent<MessageEventContent>,
oldest_message_age: u64,
) -> bool {
let sending_time = event
.origin_server_ts
.to_system_time()
.unwrap_or(UNIX_EPOCH);
let oldest_timestamp = SystemTime::now().sub(Duration::from_secs(oldest_message_age));
if sending_time > oldest_timestamp {
true
} else {
let age = match oldest_timestamp.duration_since(sending_time) {
Ok(n) => format!("{} seconds too old", n.as_secs()),
Err(_) => "before the UNIX epoch".to_owned(),
};
debug!("Ignoring message because it is {}: {:?}", age, event);
false
}
}
/// Determine whether or not to process a received message. This check
/// is necessary in addition to the event processing check because we
/// may receive message events when entering a room for the first
/// time, and we don't want to respond to things before the bot was in
/// the channel, but we do want to respond to things that were sent if
/// the bot left and rejoined quickly.
async fn should_process_message<'a>(
bot: &DiceBot,
event: &SyncMessageEvent<MessageEventContent>,
) -> Result<(String, String), BotError> {
//Ignore messages that are older than configured duration.
if !check_message_age(event, bot.config.oldest_message_age()) {
let state_check = bot.state.read().unwrap();
if !((*state_check).logged_skipped_old_messages()) {
drop(state_check);
let mut state = bot.state.write().unwrap();
(*state).skipped_old_messages();
}
return Err(BotError::ShouldNotProcessError);
}
let (msg_body, sender_username) = if let SyncMessageEvent {
content:
MessageEventContent {
msgtype: MessageType::Text(TextMessageEventContent { body, .. }),
..
},
sender,
..
} = event
{
(
body.clone(),
format!("@{}:{}", sender.localpart(), sender.server_name()),
)
} else {
(String::new(), String::new())
};
Ok((msg_body, sender_username))
}
async fn should_process_event(db: &Database, room_id: &str, event_id: &str) -> bool {
db.should_process(room_id, event_id)
.await
.unwrap_or_else(|e| {
error!(
"Database error when checking if we should process an event: {}",
e.to_string()
);
false
})
}
/// This event emitter listens for messages with dice rolling commands.
/// Originally adapted from the matrix-rust-sdk examples.
#[async_trait]
impl EventHandler for DiceBot {
async fn on_stripped_state_member(
&self,
room: Room,
event: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>,
) {
let room = match room {
Room::Invited(invited_room) => invited_room,
_ => return,
};
if room.own_user_id().as_str() != event.state_key {
return;
}
info!(
"Autojoining room {}",
room.display_name().await.ok().unwrap_or_default()
);
if let Err(e) = self.client.join_room_by_id(&room.room_id()).await {
warn!("Could not join room: {}", e.to_string())
}
}
async fn on_room_message(&self, room: Room, event: &SyncMessageEvent<MessageEventContent>) {
let room = match room {
Room::Joined(joined_room) => joined_room,
_ => return,
};
let room_id = room.room_id().as_str();
if !should_process_event(&self.db, room_id, event.event_id.as_str()).await {
return;
}
let (msg_body, sender_username) =
if let Ok((msg_body, sender_username)) = should_process_message(self, &event).await {
(msg_body, sender_username)
} else {
return;
};
let results = self
.execute_commands(&room, &sender_username, &msg_body)
.await;
self.handle_results(&room, &sender_username, event.event_id.clone(), results)
.await;
}
}

View File

@ -4,15 +4,13 @@ use crate::db::sqlite::Database;
use crate::db::DbState; use crate::db::DbState;
use crate::error::BotError; use crate::error::BotError;
use crate::state::DiceBotState; use crate::state::DiceBotState;
use dirs;
use log::info; use log::info;
use matrix_sdk::room::Room; use matrix_sdk::{self, identifiers::EventId, room::Joined, Client, ClientConfig, SyncSettings};
use matrix_sdk::ruma::events::room::message::RoomMessageEventContent;
use matrix_sdk::ruma::events::SyncMessageLikeEvent;
use matrix_sdk::ruma::OwnedEventId;
use matrix_sdk::{self, room::Joined, Client};
use matrix_sdk::config::SyncSettings;
use std::clone::Clone; use std::clone::Clone;
use std::path::PathBuf;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use url::Url;
mod command_execution; mod command_execution;
pub mod event_handlers; pub mod event_handlers;
@ -23,7 +21,6 @@ const MAX_COMMANDS_PER_MESSAGE: usize = 50;
/// The DiceBot struct represents an active dice bot. The bot is not /// The DiceBot struct represents an active dice bot. The bot is not
/// connected to Matrix until its run() function is called. /// connected to Matrix until its run() function is called.
#[derive(Clone)]
pub struct DiceBot { pub struct DiceBot {
/// A reference to the configuration read in on application start. /// A reference to the configuration read in on application start.
config: Arc<Config>, config: Arc<Config>,
@ -38,6 +35,22 @@ pub struct DiceBot {
db: Database, db: Database,
} }
fn cache_dir() -> Result<PathBuf, BotError> {
let mut dir = dirs::cache_dir().ok_or(BotError::NoCacheDirectoryError)?;
dir.push("matrix-dicebot");
Ok(dir)
}
/// Creates the matrix client.
fn create_client(config: &Config) -> Result<Client, BotError> {
let cache_dir = cache_dir()?;
//let store = JsonStore::open(&cache_dir)?;
let client_config = ClientConfig::new().store_path(cache_dir);
let homeserver_url = Url::parse(&config.matrix_homeserver())?;
Ok(Client::new_with_config(homeserver_url, client_config)?)
}
impl DiceBot { impl DiceBot {
/// Create a new dicebot with the given configuration and state /// Create a new dicebot with the given configuration and state
/// actor. This function returns a Result because it is possible /// actor. This function returns a Result because it is possible
@ -47,10 +60,9 @@ impl DiceBot {
config: &Arc<Config>, config: &Arc<Config>,
state: &Arc<RwLock<DiceBotState>>, state: &Arc<RwLock<DiceBotState>>,
db: &Database, db: &Database,
client: &Client,
) -> Result<Self, BotError> { ) -> Result<Self, BotError> {
Ok(DiceBot { Ok(DiceBot {
client: client.clone(), client: create_client(&config)?,
config: config.clone(), config: config.clone(),
state: state.clone(), state: state.clone(),
db: db.clone(), db: db.clone(),
@ -69,14 +81,12 @@ impl DiceBot {
let device_id: Option<String> = self.db.get_device_id().await?; let device_id: Option<String> = self.db.get_device_id().await?;
let device_id: Option<&str> = device_id.as_deref(); let device_id: Option<&str> = device_id.as_deref();
let no_device_ld_login = || client.login_username(username, password); client
let device_id_login = |id| client.login_username(username, password).device_id(id); .login(username, password, device_id, Some("matrix dice bot"))
let login = device_id.map_or_else(no_device_ld_login, device_id_login); .await?;
login.send().await?;
if device_id.is_none() { if device_id.is_none() {
let device_id = client.device_id().ok_or(BotError::NoDeviceIdFound)?; let device_id = client.device_id().await.ok_or(BotError::NoDeviceIdFound)?;
self.db.set_device_id(device_id.as_str()).await?; self.db.set_device_id(device_id.as_str()).await?;
info!("Recorded new device ID: {}", device_id.as_str()); info!("Recorded new device ID: {}", device_id.as_str());
} else { } else {
@ -87,35 +97,19 @@ impl DiceBot {
Ok(()) Ok(())
} }
async fn bind_events(&self) {
//on room message: need closure to pass bot ref in.
self.client
.add_event_handler({
let bot: DiceBot = self.clone();
move |event: SyncMessageLikeEvent<RoomMessageEventContent>, room: Room| {
let bot = bot.clone();
async move { event_handlers::on_room_message(event, room, bot).await }
}
});
//auto-join handler
self.client
.add_event_handler(event_handlers::on_stripped_state_member);
}
/// Logs the bot in to Matrix and listens for events until program /// Logs the bot in to Matrix and listens for events until program
/// terminated, or a panic occurs. Originally adapted from the /// terminated, or a panic occurs. Originally adapted from the
/// matrix-rust-sdk command bot example. /// matrix-rust-sdk command bot example.
pub async fn run(self) -> Result<(), BotError> { pub async fn run(self) -> Result<(), BotError> {
let client = self.client.clone(); let client = self.client.clone();
self.login(&client).await?; self.login(&client).await?;
self.bind_events().await;
client.set_event_handler(Box::new(self)).await;
info!("Listening for commands"); info!("Listening for commands");
// TODO replace with sync_with_callback for cleaner shutdown // TODO replace with sync_with_callback for cleaner shutdown
// process. // process.
client.sync(SyncSettings::default()).await?; client.sync(SyncSettings::default()).await;
Ok(()) Ok(())
} }
@ -145,7 +139,7 @@ impl DiceBot {
&self, &self,
room: &Joined, room: &Joined,
sender_username: &str, sender_username: &str,
event_id: OwnedEventId, event_id: EventId,
results: Vec<(String, ExecutionResult)>, results: Vec<(String, ExecutionResult)>,
) { ) {
if results.len() >= 1 { if results.len() >= 1 {

View File

@ -332,7 +332,7 @@ mod tests {
macro_rules! dummy_room { macro_rules! dummy_room {
() => { () => {
crate::context::RoomContext { crate::context::RoomContext {
id: &matrix_sdk::ruma::room_id!("!fakeroomid:example.com"), id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(), display_name: "displayname".to_owned(),
secure: false, secure: false,
} }
@ -485,9 +485,8 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: dummy_room!(), room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };
@ -527,9 +526,8 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: dummy_room!(), room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };
@ -566,21 +564,15 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db.clone(), db: db.clone(),
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: dummy_room!(), room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };
db.set_user_variable( db.set_user_variable(&ctx.username, &ctx.room.id.as_str(), "myvariable", 10)
&ctx.username, .await
&ctx.origin_room.id.as_str(), .expect("could not set myvariable to 10");
"myvariable",
10,
)
.await
.expect("could not set myvariable to 10");
let amounts = vec![Amount { let amounts = vec![Amount {
operator: Operator::Plus, operator: Operator::Plus,

View File

@ -45,13 +45,13 @@ pub fn parse_modifiers(input: &str) -> Result<DicePoolModifiers, DiceParsingErro
let (result, rest) = parser.parse(input)?; let (result, rest) = parser.parse(input)?;
if rest.len() == 0 { if rest.len() == 0 {
convert_to_modifiers(&result) convert_to_info(&result)
} else { } else {
Err(DiceParsingError::UnconsumedInput) Err(DiceParsingError::UnconsumedInput)
} }
} }
fn convert_to_modifiers(parsed: &Vec<ParsedInfo>) -> Result<DicePoolModifiers, DiceParsingError> { fn convert_to_info(parsed: &Vec<ParsedInfo>) -> Result<DicePoolModifiers, DiceParsingError> {
use ParsedInfo::*; use ParsedInfo::*;
if parsed.len() == 0 { if parsed.len() == 0 {
Ok(DicePoolModifiers::default()) Ok(DicePoolModifiers::default())
@ -79,8 +79,19 @@ fn convert_to_modifiers(parsed: &Vec<ParsedInfo>) -> Result<DicePoolModifiers, D
} }
pub fn parse_dice_pool(input: &str) -> Result<DicePool, BotError> { pub fn parse_dice_pool(input: &str) -> Result<DicePool, BotError> {
let (amounts, modifiers_str) = parse_amounts(input)?; //The "modifiers:" part is optional. Assume amounts if no modifier
//section found.
let split = input.split(":").collect::<Vec<_>>();
let (modifiers_str, amounts_str) = (match split[..] {
[amounts] => Ok(("", amounts)),
[modifiers, amounts] => Ok((modifiers, amounts)),
_ => Err(BotError::DiceParsingError(
DiceParsingError::UnconsumedInput,
)),
})?;
let modifiers = parse_modifiers(modifiers_str)?; let modifiers = parse_modifiers(modifiers_str)?;
let amounts = parse_amounts(&amounts_str)?;
Ok(DicePool::new(amounts, modifiers)) Ok(DicePool::new(amounts, modifiers))
} }
@ -164,7 +175,7 @@ mod tests {
#[test] #[test]
fn dice_pool_number_with_quality() { fn dice_pool_number_with_quality() {
let result = parse_dice_pool("8 n"); let result = parse_dice_pool("n:8");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
@ -175,7 +186,7 @@ mod tests {
#[test] #[test]
fn dice_pool_number_with_success_change() { fn dice_pool_number_with_success_change() {
let modifiers = DicePoolModifiers::custom_exceptional_on(3); let modifiers = DicePoolModifiers::custom_exceptional_on(3);
let result = parse_dice_pool("8 s3"); let result = parse_dice_pool("s3:8");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), DicePool::easy_with_modifiers(8, modifiers)); assert_eq!(result.unwrap(), DicePool::easy_with_modifiers(8, modifiers));
} }
@ -183,7 +194,7 @@ mod tests {
#[test] #[test]
fn dice_pool_with_quality_and_success_change() { fn dice_pool_with_quality_and_success_change() {
let modifiers = DicePoolModifiers::custom(DicePoolQuality::Rote, 3); let modifiers = DicePoolModifiers::custom(DicePoolQuality::Rote, 3);
let result = parse_dice_pool("8 rs3"); let result = parse_dice_pool("rs3:8");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), DicePool::easy_with_modifiers(8, modifiers)); assert_eq!(result.unwrap(), DicePool::easy_with_modifiers(8, modifiers));
} }
@ -213,20 +224,20 @@ mod tests {
let expected = DicePool::new(amounts, modifiers); let expected = DicePool::new(amounts, modifiers);
let result = parse_dice_pool("8+10-2+varname rs3"); let result = parse_dice_pool("rs3:8+10-2+varname");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
let result = parse_dice_pool("8+10- 2 + varname rs3"); let result = parse_dice_pool("rs3:8+10- 2 + varname");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
let result = parse_dice_pool("8+ 10 -2 + varname rs3"); let result = parse_dice_pool("rs3 : 8+ 10 -2 + varname");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
//This one has tabs in it. //This one has tabs in it.
let result = parse_dice_pool(" 8 + 10 -2 + varname r s3"); let result = parse_dice_pool(" r s3 : 8 + 10 -2 + varname");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
} }

View File

@ -10,6 +10,12 @@ use std::convert::TryFrom;
pub struct RollCommand(pub ElementExpression); pub struct RollCommand(pub ElementExpression);
impl From<RollCommand> for Box<dyn Command> {
fn from(cmd: RollCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for RollCommand { impl TryFrom<String> for RollCommand {
type Error = BotError; type Error = BotError;

View File

@ -15,6 +15,12 @@ impl PoolRollCommand {
} }
} }
impl From<PoolRollCommand> for Box<dyn Command> {
fn from(cmd: PoolRollCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for PoolRollCommand { impl TryFrom<String> for PoolRollCommand {
type Error = BotError; type Error = BotError;

View File

@ -11,6 +11,12 @@ use std::convert::TryFrom;
pub struct CthRoll(pub DiceRoll); pub struct CthRoll(pub DiceRoll);
impl From<CthRoll> for Box<dyn Command> {
fn from(cmd: CthRoll) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for CthRoll { impl TryFrom<String> for CthRoll {
type Error = BotError; type Error = BotError;
@ -45,6 +51,12 @@ impl Command for CthRoll {
pub struct CthAdvanceRoll(pub AdvancementRoll); pub struct CthAdvanceRoll(pub AdvancementRoll);
impl From<CthAdvanceRoll> for Box<dyn Command> {
fn from(cmd: CthAdvanceRoll) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for CthAdvanceRoll { impl TryFrom<String> for CthAdvanceRoll {
type Error = BotError; type Error = BotError;

View File

@ -9,6 +9,12 @@ use std::convert::{Into, TryFrom};
pub struct RegisterCommand; pub struct RegisterCommand;
impl From<RegisterCommand> for Box<dyn Command> {
fn from(cmd: RegisterCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for RegisterCommand { impl TryFrom<String> for RegisterCommand {
type Error = BotError; type Error = BotError;
@ -50,6 +56,12 @@ impl Command for RegisterCommand {
pub struct UnlinkCommand(pub String); pub struct UnlinkCommand(pub String);
impl From<UnlinkCommand> for Box<dyn Command> {
fn from(cmd: UnlinkCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for UnlinkCommand { impl TryFrom<String> for UnlinkCommand {
type Error = BotError; type Error = BotError;
@ -87,6 +99,12 @@ impl Command for UnlinkCommand {
pub struct LinkCommand(pub String); pub struct LinkCommand(pub String);
impl From<LinkCommand> for Box<dyn Command> {
fn from(cmd: LinkCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for LinkCommand { impl TryFrom<String> for LinkCommand {
type Error = BotError; type Error = BotError;
@ -126,6 +144,12 @@ impl Command for LinkCommand {
pub struct CheckCommand; pub struct CheckCommand;
impl From<CheckCommand> for Box<dyn Command> {
fn from(cmd: CheckCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for CheckCommand { impl TryFrom<String> for CheckCommand {
type Error = BotError; type Error = BotError;
@ -150,17 +174,14 @@ impl Command for CheckCommand {
match user { match user {
Some(user) => match user.password { Some(user) => match user.password {
Some(_) => Execution::success( Some(_) => Execution::success(
"Account exists, and is available to external applications with a password. \ "Account exists, and is available to external applications with a password. If you forgot your password, change it with !link.".to_string(),
If you forgot your password, change it with !link."
.to_string(),
), ),
None => Execution::success( None => Execution::success(
"Account exists, but is not available to external applications.".to_string(), "Account exists, but is not available to external applications.".to_string(),
), ),
}, },
None => Execution::success( None => Execution::success(
"No account registered. Only simple commands in public rooms are available." "No account registered. Only simple commands in public rooms are available.".to_string(),
.to_string(),
), ),
} }
} }
@ -168,6 +189,12 @@ impl Command for CheckCommand {
pub struct UnregisterCommand; pub struct UnregisterCommand;
impl From<UnregisterCommand> for Box<dyn Command> {
fn from(cmd: UnregisterCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for UnregisterCommand { impl TryFrom<String> for UnregisterCommand {
type Error = BotError; type Error = BotError;

View File

@ -7,6 +7,12 @@ use std::convert::TryFrom;
pub struct HelpCommand(pub Option<HelpTopic>); pub struct HelpCommand(pub Option<HelpTopic>);
impl From<HelpCommand> for Box<dyn Command> {
fn from(cmd: HelpCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for HelpCommand { impl TryFrom<String> for HelpCommand {
type Error = BotError; type Error = BotError;

View File

@ -112,8 +112,9 @@ fn execution_allowed(cmd: &(impl Command + ?Sized), ctx: &Context<'_>) -> Result
} }
/// Attempt to execute a command, and return the content that should /// Attempt to execute a command, and return the content that should
/// go back to Matrix, if the command was executed, whether or not the /// go back to Matrix, if the command was executed (successfully or
/// command was successful. /// not). If a command is determined to be ignored, this function will
/// return None, signifying that we should not send a response.
pub async fn execute_command(ctx: &Context<'_>) -> ExecutionResult { pub async fn execute_command(ctx: &Context<'_>) -> ExecutionResult {
let cmd = parser::parse_command(&ctx.message_body)?; let cmd = parser::parse_command(&ctx.message_body)?;
@ -145,13 +146,13 @@ fn log_command(cmd: &(impl Command + ?Sized), ctx: &Context, result: &ExecutionR
Ok(_) => { Ok(_) => {
info!( info!(
"[{}] {} <{}{}> - success", "[{}] {} <{}{}> - success",
ctx.origin_room.display_name, ctx.username, command, dots ctx.room.display_name, ctx.username, command, dots
); );
} }
Err(e) => { Err(e) => {
error!( error!(
"[{}] {} <{}{}> - {}", "[{}] {} <{}{}> - {}",
ctx.origin_room.display_name, ctx.username, command, dots, e ctx.room.display_name, ctx.username, command, dots, e
); );
} }
}; };
@ -162,12 +163,11 @@ mod tests {
use super::*; use super::*;
use management::RegisterCommand; use management::RegisterCommand;
use url::Url; use url::Url;
use matrix_sdk::ruma::room_id;
macro_rules! dummy_room { macro_rules! dummy_room {
() => { () => {
crate::context::RoomContext { crate::context::RoomContext {
id: &room_id!("!fakeroomid:example.com"), id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(), display_name: "displayname".to_owned(),
secure: false, secure: false,
} }
@ -177,7 +177,7 @@ mod tests {
macro_rules! secure_room { macro_rules! secure_room {
() => { () => {
crate::context::RoomContext { crate::context::RoomContext {
id: &room_id!("!fakeroomid:example.com"), id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(), display_name: "displayname".to_owned(),
secure: true, secure: true,
} }
@ -196,9 +196,8 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: secure_room!(), room: secure_room!(),
active_room: secure_room!(),
username: "myusername", username: "myusername",
message_body: "!notacommand", message_body: "!notacommand",
}; };
@ -219,9 +218,8 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: secure_room!(), room: secure_room!(),
active_room: secure_room!(),
username: "myusername", username: "myusername",
message_body: "!notacommand", message_body: "!notacommand",
}; };
@ -242,9 +240,8 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: dummy_room!(), room: dummy_room!(),
active_room: dummy_room!(),
username: "myusername", username: "myusername",
message_body: "!notacommand", message_body: "!notacommand",
}; };
@ -265,9 +262,8 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: dummy_room!(), room: dummy_room!(),
active_room: dummy_room!(),
username: "myusername", username: "myusername",
message_body: "!notacommand", message_body: "!notacommand",
}; };
@ -288,9 +284,8 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: dummy_room!(), room: dummy_room!(),
active_room: dummy_room!(),
username: "myusername", username: "myusername",
message_body: "!notacommand", message_body: "!notacommand",
}; };

View File

@ -65,7 +65,7 @@ fn split_command(input: &str) -> Result<(String, String), CommandParsingError> {
/// boilerplate. /// boilerplate.
macro_rules! convert_to { macro_rules! convert_to {
($type:ident, $input: expr) => { ($type:ident, $input: expr) => {
$type::try_from($input).map(|cmd| Box::new(cmd) as Box<dyn Command>) $type::try_from($input).map(Into::into)
}; };
} }
@ -81,7 +81,7 @@ pub fn parse_command(input: &str) -> Result<Box<dyn Command>, BotError> {
"del" => convert_to!(DeleteVariableCommand, cmd_input), "del" => convert_to!(DeleteVariableCommand, cmd_input),
"r" | "roll" => convert_to!(RollCommand, cmd_input), "r" | "roll" => convert_to!(RollCommand, cmd_input),
"rp" | "pool" => convert_to!(PoolRollCommand, cmd_input), "rp" | "pool" => convert_to!(PoolRollCommand, cmd_input),
"chance" => PoolRollCommand::chance_die().map(|cmd| Box::new(cmd) as Box<dyn Command>), "chance" => PoolRollCommand::chance_die().map(Into::into),
"cthroll" => convert_to!(CthRoll, cmd_input), "cthroll" => convert_to!(CthRoll, cmd_input),
"cthadv" | "ctharoll" => convert_to!(CthAdvanceRoll, cmd_input), "cthadv" | "ctharoll" => convert_to!(CthAdvanceRoll, cmd_input),
"help" => convert_to!(HelpCommand, cmd_input), "help" => convert_to!(HelpCommand, cmd_input),
@ -221,9 +221,9 @@ mod tests {
#[test] #[test]
fn pool_whitespace_test() { fn pool_whitespace_test() {
parse_command("!pool 8 ns3 ").expect("was error"); parse_command("!pool ns3:8 ").expect("was error");
parse_command(" !pool 8 ns3").expect("was error"); parse_command(" !pool ns3:8").expect("was error");
parse_command(" !pool 8 ns3 ").expect("was error"); parse_command(" !pool ns3:8 ").expect("was error");
} }
#[test] #[test]

View File

@ -1,12 +1,11 @@
use super::{Command, Execution, ExecutionResult}; use super::{Command, Execution, ExecutionResult};
use crate::context::Context; use crate::context::Context;
use crate::db::Users;
use crate::error::BotError; use crate::error::BotError;
use crate::matrix; use crate::matrix;
use async_trait::async_trait; use async_trait::async_trait;
use fuse_rust::{Fuse, FuseProperty, Fuseable}; use fuse_rust::{Fuse, FuseProperty, Fuseable};
use futures::stream::{self, StreamExt, TryStreamExt}; use futures::stream::{self, StreamExt, TryStreamExt};
use matrix_sdk::{ruma::OwnedUserId, Client}; use matrix_sdk::{identifiers::UserId, Client};
use std::convert::TryFrom; use std::convert::TryFrom;
/// Holds matrix room ID and display name as strings, for use with /// Holds matrix room ID and display name as strings, for use with
@ -21,17 +20,17 @@ struct RoomNameAndId {
/// searching room display names directly. /// searching room display names directly.
impl Fuseable for RoomNameAndId { impl Fuseable for RoomNameAndId {
fn properties(&self) -> Vec<FuseProperty> { fn properties(&self) -> Vec<FuseProperty> {
vec![FuseProperty { return vec![FuseProperty {
value: String::from("name"), value: String::from("name"),
weight: 1.0, weight: 1.0,
}] }];
} }
fn lookup(&self, key: &str) -> Option<&str> { fn lookup(&self, key: &str) -> Option<&str> {
match key { return match key {
"name" => Some(&self.name), "name" => Some(&self.name),
_ => None, _ => None,
} };
} }
} }
@ -62,29 +61,29 @@ async fn get_rooms_for_user(
client: &Client, client: &Client,
user_id: &str, user_id: &str,
) -> Result<Vec<RoomNameAndId>, BotError> { ) -> Result<Vec<RoomNameAndId>, BotError> {
let user_id = OwnedUserId::try_from(user_id)?; let user_id = UserId::try_from(user_id)?;
let rooms_for_user = matrix::get_rooms_for_user(client, &user_id).await?; let rooms_for_user = matrix::get_rooms_for_user(client, &user_id).await?;
let mut rooms_for_user: Vec<RoomNameAndId> = stream::iter(rooms_for_user) let rooms_for_user: Vec<RoomNameAndId> = stream::iter(rooms_for_user)
.filter_map(|room| async move { .filter_map(|room| async move {
Some(room.display_name().await.map(|room_name| RoomNameAndId { Some(room.display_name().await.map(|room_name| RoomNameAndId {
id: room.room_id().to_string(), id: room.room_id().to_string(),
name: room_name.to_string(), name: room_name,
})) }))
}) })
.try_collect() .try_collect()
.await?; .await?;
//Alphabetically descending, symbols first, ignore case.
let sort = |r1: &RoomNameAndId, r2: &RoomNameAndId| {
r1.name.to_lowercase().cmp(&r2.name.to_lowercase())
};
rooms_for_user.sort_by(sort);
Ok(rooms_for_user) Ok(rooms_for_user)
} }
pub struct ListRoomsCommand; pub struct ListRoomsCommand;
impl From<ListRoomsCommand> for Box<dyn Command> {
fn from(cmd: ListRoomsCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for ListRoomsCommand { impl TryFrom<String> for ListRoomsCommand {
type Error = BotError; type Error = BotError;
@ -104,7 +103,7 @@ impl Command for ListRoomsCommand {
} }
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let rooms_for_user: Vec<String> = get_rooms_for_user(&ctx.matrix_client, ctx.username) let rooms_for_user: Vec<String> = get_rooms_for_user(ctx.matrix_client, ctx.username)
.await .await
.map(|rooms| { .map(|rooms| {
rooms rooms
@ -120,6 +119,12 @@ impl Command for ListRoomsCommand {
pub struct SetRoomCommand(String); pub struct SetRoomCommand(String);
impl From<SetRoomCommand> for Box<dyn Command> {
fn from(cmd: SetRoomCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for SetRoomCommand { impl TryFrom<String> for SetRoomCommand {
type Error = BotError; type Error = BotError;
@ -139,22 +144,10 @@ impl Command for SetRoomCommand {
} }
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
if !ctx.account.is_registered() { let rooms_for_user = get_rooms_for_user(ctx.matrix_client, ctx.username).await?;
return Err(BotError::AccountDoesNotExist);
}
let rooms_for_user = get_rooms_for_user(&ctx.matrix_client, ctx.username).await?;
let room = search_for_room(&rooms_for_user, &self.0); let room = search_for_room(&rooms_for_user, &self.0);
if let Some(room) = room { if let Some(room) = room {
let mut new_user = ctx
.account
.registered_user()
.cloned()
.ok_or(BotError::AccountDoesNotExist)?;
new_user.active_room = Some(room.id.clone());
ctx.db.upsert_user(&new_user).await?;
Execution::success(format!(r#"Active room set to "{}""#, room.name)) Execution::success(format!(r#"Active room set to "{}""#, room.name))
} else { } else {
Err(BotError::RoomDoesNotExist) Err(BotError::RoomDoesNotExist)

View File

@ -8,6 +8,12 @@ use std::convert::TryFrom;
pub struct GetAllVariablesCommand; pub struct GetAllVariablesCommand;
impl From<GetAllVariablesCommand> for Box<dyn Command> {
fn from(cmd: GetAllVariablesCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for GetAllVariablesCommand { impl TryFrom<String> for GetAllVariablesCommand {
type Error = BotError; type Error = BotError;
@ -29,7 +35,7 @@ impl Command for GetAllVariablesCommand {
async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult { async fn execute(&self, ctx: &Context<'_>) -> ExecutionResult {
let variables = ctx let variables = ctx
.db .db
.get_user_variables(&ctx.username, ctx.active_room_id().as_str()) .get_user_variables(&ctx.username, ctx.room_id().as_str())
.await?; .await?;
let mut variable_list: Vec<String> = variables let mut variable_list: Vec<String> = variables
@ -51,6 +57,12 @@ impl Command for GetAllVariablesCommand {
pub struct GetVariableCommand(pub String); pub struct GetVariableCommand(pub String);
impl From<GetVariableCommand> for Box<dyn Command> {
fn from(cmd: GetVariableCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for GetVariableCommand { impl TryFrom<String> for GetVariableCommand {
type Error = BotError; type Error = BotError;
@ -73,7 +85,7 @@ impl Command for GetVariableCommand {
let name = &self.0; let name = &self.0;
let result = ctx let result = ctx
.db .db
.get_user_variable(&ctx.username, ctx.active_room_id().as_str(), name) .get_user_variable(&ctx.username, ctx.room_id().as_str(), name)
.await; .await;
let value = match result { let value = match result {
@ -89,6 +101,12 @@ impl Command for GetVariableCommand {
pub struct SetVariableCommand(pub String, pub i32); pub struct SetVariableCommand(pub String, pub i32);
impl From<SetVariableCommand> for Box<dyn Command> {
fn from(cmd: SetVariableCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for SetVariableCommand { impl TryFrom<String> for SetVariableCommand {
type Error = BotError; type Error = BotError;
@ -113,7 +131,7 @@ impl Command for SetVariableCommand {
let value = self.1; let value = self.1;
ctx.db ctx.db
.set_user_variable(&ctx.username, ctx.active_room_id().as_str(), name, value) .set_user_variable(&ctx.username, ctx.room_id().as_str(), name, value)
.await?; .await?;
let content = format!("{} = {}", name, value); let content = format!("{} = {}", name, value);
@ -124,6 +142,12 @@ impl Command for SetVariableCommand {
pub struct DeleteVariableCommand(pub String); pub struct DeleteVariableCommand(pub String);
impl From<DeleteVariableCommand> for Box<dyn Command> {
fn from(cmd: DeleteVariableCommand) -> Self {
Box::new(cmd)
}
}
impl TryFrom<String> for DeleteVariableCommand { impl TryFrom<String> for DeleteVariableCommand {
type Error = BotError; type Error = BotError;
@ -146,7 +170,7 @@ impl Command for DeleteVariableCommand {
let name = &self.0; let name = &self.0;
let result = ctx let result = ctx
.db .db
.delete_user_variable(&ctx.username, ctx.active_room_id().as_str(), name) .delete_user_variable(&ctx.username, ctx.room_id().as_str(), name)
.await; .await;
let value = match result { let value = match result {

View File

@ -4,6 +4,10 @@ use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
use thiserror::Error; use thiserror::Error;
/// Shortcut to defining db migration versions. Will probably
/// eventually be moved to a config file.
const MIGRATION_VERSION: u32 = 5;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ConfigError { pub enum ConfigError {
#[error("i/o error: {0}")] #[error("i/o error: {0}")]
@ -49,19 +53,10 @@ fn db_path_from_env() -> String {
} }
/// The "bot" section of the config file, for bot settings. /// The "bot" section of the config file, for bot settings.
#[derive(Serialize, Deserialize, Clone, Debug, Default)] #[derive(Serialize, Deserialize, Clone, Debug)]
struct BotConfig { struct BotConfig {
/// How far back from current time should we process a message? /// How far back from current time should we process a message?
oldest_message_age: Option<u64>, oldest_message_age: Option<u64>,
/// What address and port to run the RPC service on. If not
/// specified, RPC will not be enabled.
rpc_addr: Option<String>,
/// The shared secret key between the bot and any RPC clients that
/// want to connect to it. The RPC server will reject any clients
/// that don't present the shared key.
rpc_key: Option<String>,
} }
/// The "database" section of the config file. /// The "database" section of the config file.
@ -89,18 +84,6 @@ impl BotConfig {
self.oldest_message_age self.oldest_message_age
.unwrap_or(DEFAULT_OLDEST_MESSAGE_AGE) .unwrap_or(DEFAULT_OLDEST_MESSAGE_AGE)
} }
#[inline]
#[must_use]
fn rpc_addr(&self) -> Option<String> {
self.rpc_addr.clone()
}
#[inline]
#[must_use]
fn rpc_key(&self) -> Option<String> {
self.rpc_key.clone()
}
} }
/// Represents the toml config file for the dicebot. The sections of /// Represents the toml config file for the dicebot. The sections of
@ -145,6 +128,15 @@ impl Config {
.unwrap_or_else(|| db_path_from_env()) .unwrap_or_else(|| db_path_from_env())
} }
/// The current migration version we expect of the database. If
/// this number is higher than the one in the database, we will
/// execute migrations to update the data.
#[inline]
#[must_use]
pub fn migration_version(&self) -> u32 {
MIGRATION_VERSION
}
/// Figure out the allowed oldest message age, in seconds. This will /// Figure out the allowed oldest message age, in seconds. This will
/// be the defined oldest message age in the bot config, if the bot /// be the defined oldest message age in the bot config, if the bot
/// configuration and associated "oldest_message_age" setting are /// configuration and associated "oldest_message_age" setting are
@ -158,18 +150,6 @@ impl Config {
.map(|bc| bc.oldest_message_age()) .map(|bc| bc.oldest_message_age())
.unwrap_or(DEFAULT_OLDEST_MESSAGE_AGE) .unwrap_or(DEFAULT_OLDEST_MESSAGE_AGE)
} }
#[inline]
#[must_use]
pub fn rpc_addr(&self) -> Option<String> {
self.bot.as_ref().and_then(|bc| bc.rpc_addr())
}
#[inline]
#[must_use]
pub fn rpc_key(&self) -> Option<String> {
self.bot.as_ref().and_then(|bc| bc.rpc_key())
}
} }
#[cfg(test)] #[cfg(test)]
@ -189,7 +169,6 @@ mod tests {
}), }),
bot: Some(BotConfig { bot: Some(BotConfig {
oldest_message_age: None, oldest_message_age: None,
..Default::default()
}), }),
}; };

View File

@ -1,8 +1,8 @@
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::error::BotError; use crate::error::BotError;
use crate::models::Account; use crate::models::Account;
use matrix_sdk::identifiers::{RoomId, UserId};
use matrix_sdk::room::Joined; use matrix_sdk::room::Joined;
use matrix_sdk::ruma::{RoomId, UserId};
use matrix_sdk::Client; use matrix_sdk::Client;
use std::convert::TryFrom; use std::convert::TryFrom;
@ -11,25 +11,20 @@ use std::convert::TryFrom;
#[derive(Clone)] #[derive(Clone)]
pub struct Context<'a> { pub struct Context<'a> {
pub db: Database, pub db: Database,
pub matrix_client: Client, pub matrix_client: &'a Client,
pub origin_room: RoomContext<'a>, pub room: RoomContext<'a>,
pub active_room: RoomContext<'a>,
pub username: &'a str, pub username: &'a str,
pub message_body: &'a str, pub message_body: &'a str,
pub account: Account, pub account: Account,
} }
impl Context<'_> { impl Context<'_> {
pub fn active_room_id(&self) -> &RoomId {
self.active_room.id
}
pub fn room_id(&self) -> &RoomId { pub fn room_id(&self) -> &RoomId {
self.origin_room.id self.room.id
} }
pub fn is_secure(&self) -> bool { pub fn is_secure(&self) -> bool {
self.origin_room.secure self.room.secure
} }
} }
@ -43,22 +38,15 @@ pub struct RoomContext<'a> {
impl RoomContext<'_> { impl RoomContext<'_> {
pub async fn new_with_name<'a>( pub async fn new_with_name<'a>(
room: &'a Joined, room: &'a Joined,
display_name: String,
sending_user: &str, sending_user: &str,
) -> Result<RoomContext<'a>, BotError> { ) -> Result<RoomContext<'a>, BotError> {
// TODO is_direct is a hack; the bot should set eligible rooms // TODO is_direct is a hack; should set rooms to Direct
// to Direct Message upon joining, if other contact has // Message upon joining, if other contact has requested it.
// requested it. Waiting on SDK support. // Waiting on SDK support.
let display_name = let sending_user = UserId::try_from(sending_user)?;
room let user_in_room = room.get_member(&sending_user).await.ok().is_some();
.display_name() let is_direct = room.joined_members().await?.len() == 2;
.await
.ok()
.map(|d| d.to_string())
.unwrap_or_default();
let sending_user = <&UserId>::try_from(sending_user)?;
let user_in_room = room.get_member(sending_user).await.ok().is_some();
let is_direct = room.active_members().await?.len() == 2;
Ok(RoomContext { Ok(RoomContext {
id: room.room_id(), id: room.room_id(),
@ -69,8 +57,17 @@ impl RoomContext<'_> {
pub async fn new<'a>( pub async fn new<'a>(
room: &'a Joined, room: &'a Joined,
sending_user: &'a str, sending_user: &str,
) -> Result<RoomContext<'a>, BotError> { ) -> Result<RoomContext<'a>, BotError> {
Self::new_with_name(room, sending_user).await Self::new_with_name(
&room,
room.display_name()
.await
.ok()
.unwrap_or_default()
.to_string(),
sending_user,
)
.await
} }
} }

View File

@ -270,7 +270,7 @@ macro_rules! is_variable {
element: Element::Variable(_), element: Element::Variable(_),
.. ..
} }
) );
}; };
} }
@ -380,12 +380,7 @@ async fn update_skill(ctx: &Context<'_>, variable: &str, value: u32) -> Result<(
use std::convert::TryInto; use std::convert::TryInto;
let value: i32 = value.try_into()?; let value: i32 = value.try_into()?;
ctx.db ctx.db
.set_user_variable( .set_user_variable(&ctx.username, &ctx.room_id().as_str(), variable, value)
&ctx.username,
&ctx.active_room_id().as_str(),
variable,
value,
)
.await?; .await?;
Ok(()) Ok(())
} }
@ -427,12 +422,11 @@ mod tests {
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::parser::dice::{Amount, Element, Operator}; use crate::parser::dice::{Amount, Element, Operator};
use url::Url; use url::Url;
use matrix_sdk::ruma::room_id;
macro_rules! dummy_room { macro_rules! dummy_room {
() => { () => {
crate::context::RoomContext { crate::context::RoomContext {
id: &room_id!("!fakeroomid:example.com"), id: &matrix_sdk::identifiers::room_id!("!fakeroomid:example.com"),
display_name: "displayname".to_owned(), display_name: "displayname".to_owned(),
secure: false, secure: false,
} }
@ -512,9 +506,8 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: dummy_room!(), room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };
@ -550,9 +543,8 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: dummy_room!(), room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };
@ -588,9 +580,8 @@ mod tests {
let ctx = Context { let ctx = Context {
account: crate::models::Account::default(), account: crate::models::Account::default(),
db: db, db: db,
matrix_client: matrix_sdk::Client::new(homeserver).await.unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
origin_room: dummy_room!(), room: dummy_room!(),
active_room: dummy_room!(),
username: "username", username: "username",
message_body: "message", message_body: "message",
}; };

View File

@ -4,13 +4,16 @@ use crate::parser::dice::DiceParsingError;
//TOOD convert these to use parse_amounts from the common dice code. //TOOD convert these to use parse_amounts from the common dice code.
fn parse_modifier(input: &str) -> Result<DiceRollModifier, DiceParsingError> { fn parse_modifier(input: &str) -> Result<DiceRollModifier, DiceParsingError> {
match input.trim() { if input.ends_with("bb") {
"bb" => Ok(DiceRollModifier::TwoBonus), Ok(DiceRollModifier::TwoBonus)
"b" => Ok(DiceRollModifier::OneBonus), } else if input.ends_with("b") {
"pp" => Ok(DiceRollModifier::TwoPenalty), Ok(DiceRollModifier::OneBonus)
"p" => Ok(DiceRollModifier::OnePenalty), } else if input.ends_with("pp") {
"" => Ok(DiceRollModifier::Normal), Ok(DiceRollModifier::TwoPenalty)
_ => Err(DiceParsingError::InvalidModifiers), } else if input.ends_with("p") {
Ok(DiceRollModifier::OnePenalty)
} else {
Ok(DiceRollModifier::Normal)
} }
} }
@ -18,70 +21,32 @@ fn parse_modifier(input: &str) -> Result<DiceRollModifier, DiceParsingError> {
//Split based on :, send first part to parse_modifier. //Split based on :, send first part to parse_modifier.
//Send second part to parse_amounts //Send second part to parse_amounts
pub fn parse_regular_roll(input: &str) -> Result<DiceRoll, DiceParsingError> { pub fn parse_regular_roll(input: &str) -> Result<DiceRoll, DiceParsingError> {
let (amount, modifiers_str) = crate::parser::dice::parse_single_amount(input)?; let input: Vec<&str> = input.trim().split(":").collect();
let (modifiers_str, amounts_str) = match input[..] {
[amounts] => Ok(("", amounts)),
[modifiers, amounts] => Ok((modifiers, amounts)),
_ => Err(DiceParsingError::UnconsumedInput),
}?;
let modifier = parse_modifier(modifiers_str)?; let modifier = parse_modifier(modifiers_str)?;
let amount = crate::parser::dice::parse_single_amount(amounts_str)?;
Ok(DiceRoll { modifier, amount }) Ok(DiceRoll { modifier, amount })
} }
pub fn parse_advancement_roll(input: &str) -> Result<AdvancementRoll, DiceParsingError> { pub fn parse_advancement_roll(input: &str) -> Result<AdvancementRoll, DiceParsingError> {
let input = input.trim(); let input = input.trim();
let (amounts, unconsumed_input) = crate::parser::dice::parse_single_amount(input)?; let amounts = crate::parser::dice::parse_single_amount(input)?;
if unconsumed_input.len() == 0 { Ok(AdvancementRoll {
Ok(AdvancementRoll { existing_skill: amounts,
existing_skill: amounts, })
})
} else {
Err(DiceParsingError::InvalidAmount)
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::parser::dice::{Amount, DiceParsingError, Element, Operator}; use crate::parser::dice::{Amount, Element, Operator};
#[test]
fn parse_modifier_rejects_bad_value() {
let modifier = parse_modifier("qqq");
assert!(matches!(modifier, Err(DiceParsingError::InvalidModifiers)))
}
#[test]
fn parse_modifier_accepts_one_bonus() {
let modifier = parse_modifier("b");
assert!(matches!(modifier, Ok(DiceRollModifier::OneBonus)))
}
#[test]
fn parse_modifier_accepts_two_bonus() {
let modifier = parse_modifier("bb");
assert!(matches!(modifier, Ok(DiceRollModifier::TwoBonus)))
}
#[test]
fn parse_modifier_accepts_two_penalty() {
let modifier = parse_modifier("pp");
assert!(matches!(modifier, Ok(DiceRollModifier::TwoPenalty)))
}
#[test]
fn parse_modifier_accepts_one_penalty() {
let modifier = parse_modifier("p");
assert!(matches!(modifier, Ok(DiceRollModifier::OnePenalty)))
}
#[test]
fn parse_modifier_accepts_normal() {
let modifier = parse_modifier("");
assert!(matches!(modifier, Ok(DiceRollModifier::Normal)))
}
#[test]
fn parse_modifier_accepts_normal_unaffected_by_whitespace() {
let modifier = parse_modifier(" ");
assert!(matches!(modifier, Ok(DiceRollModifier::Normal)))
}
#[test] #[test]
fn regular_roll_accepts_single_number() { fn regular_roll_accepts_single_number() {
@ -107,7 +72,7 @@ mod tests {
#[test] #[test]
fn regular_roll_accepts_two_bonus() { fn regular_roll_accepts_two_bonus() {
let result = parse_regular_roll("60 bb"); let result = parse_regular_roll("bb:60");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
DiceRoll { DiceRoll {
@ -123,7 +88,7 @@ mod tests {
#[test] #[test]
fn regular_roll_accepts_one_bonus() { fn regular_roll_accepts_one_bonus() {
let result = parse_regular_roll("60 b"); let result = parse_regular_roll("b:60");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
DiceRoll { DiceRoll {
@ -139,7 +104,7 @@ mod tests {
#[test] #[test]
fn regular_roll_accepts_two_penalty() { fn regular_roll_accepts_two_penalty() {
let result = parse_regular_roll("60 pp"); let result = parse_regular_roll("pp:60");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
DiceRoll { DiceRoll {
@ -155,7 +120,7 @@ mod tests {
#[test] #[test]
fn regular_roll_accepts_one_penalty() { fn regular_roll_accepts_one_penalty() {
let result = parse_regular_roll("60 p"); let result = parse_regular_roll("p:60");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
DiceRoll { DiceRoll {
@ -175,21 +140,21 @@ mod tests {
assert!(parse_regular_roll(" 60").is_ok()); assert!(parse_regular_roll(" 60").is_ok());
assert!(parse_regular_roll(" 60 ").is_ok()); assert!(parse_regular_roll(" 60 ").is_ok());
assert!(parse_regular_roll("60bb ").is_ok()); assert!(parse_regular_roll("bb:60 ").is_ok());
assert!(parse_regular_roll(" 60 bb").is_ok()); assert!(parse_regular_roll(" bb:60").is_ok());
assert!(parse_regular_roll(" 60 bb ").is_ok()); assert!(parse_regular_roll(" bb:60 ").is_ok());
assert!(parse_regular_roll("60b ").is_ok()); assert!(parse_regular_roll("b:60 ").is_ok());
assert!(parse_regular_roll(" 60 b").is_ok()); assert!(parse_regular_roll(" b:60").is_ok());
assert!(parse_regular_roll(" 60 b ").is_ok()); assert!(parse_regular_roll(" b:60 ").is_ok());
assert!(parse_regular_roll("60pp ").is_ok()); assert!(parse_regular_roll("pp:60 ").is_ok());
assert!(parse_regular_roll(" 60 pp").is_ok()); assert!(parse_regular_roll(" pp:60").is_ok());
assert!(parse_regular_roll(" 60 pp ").is_ok()); assert!(parse_regular_roll(" pp:60 ").is_ok());
assert!(parse_regular_roll("60p ").is_ok()); assert!(parse_regular_roll("p:60 ").is_ok());
assert!(parse_regular_roll(" 60p ").is_ok()); assert!(parse_regular_roll(" p:60").is_ok());
assert!(parse_regular_roll(" 60 p ").is_ok()); assert!(parse_regular_roll(" p:60 ").is_ok());
} }
#[test] #[test]

View File

@ -0,0 +1,2 @@
use refinery::include_migration_mods;
include_migration_mods!("src/db/sqlite/migrator/migrations");

View File

@ -5,7 +5,7 @@ use sqlx::ConnectOptions;
use std::str::FromStr; use std::str::FromStr;
use thiserror::Error; use thiserror::Error;
//pub mod migrations; pub mod migrations;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum MigrationError { pub enum MigrationError {
@ -16,11 +16,6 @@ pub enum MigrationError {
RefineryError(#[from] refinery::Error), RefineryError(#[from] refinery::Error),
} }
mod embedded {
use refinery::embed_migrations;
embed_migrations!("src/db/sqlite/migrator/migrations");
}
/// Run database migrations against the sqlite database. /// Run database migrations against the sqlite database.
pub async fn migrate(db_path: &str) -> Result<(), MigrationError> { pub async fn migrate(db_path: &str) -> Result<(), MigrationError> {
//Create database if missing. //Create database if missing.
@ -33,6 +28,6 @@ pub async fn migrate(db_path: &str) -> Result<(), MigrationError> {
let mut conn = Config::new(ConfigDbType::Sqlite).set_db_path(db_path); let mut conn = Config::new(ConfigDbType::Sqlite).set_db_path(db_path);
info!("Running migrations"); info!("Running migrations");
embedded::migrations::runner().run(&mut conn)?; migrations::runner().run(&mut conn)?;
Ok(()) Ok(())
} }

View File

@ -53,41 +53,34 @@ impl Rooms for Database {
mod tests { mod tests {
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::Rooms; use crate::db::Rooms;
use std::future::Future;
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut) async fn create_db() -> Database {
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap();
let db = Database::new(db_path.path().to_str().unwrap()) Database::new(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap()
f(db).await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn should_process_test() { async fn should_process_test() {
with_db(|db| async move { let db = create_db().await;
let first_check = db
.should_process("myroom", "myeventid")
.await
.expect("should_process failed in first insert");
assert_eq!(first_check, true); let first_check = db
.should_process("myroom", "myeventid")
.await
.expect("should_process failed in first insert");
let second_check = db assert_eq!(first_check, true);
.should_process("myroom", "myeventid")
.await
.expect("should_process failed in first insert");
assert_eq!(second_check, false); let second_check = db
}) .should_process("myroom", "myeventid")
.await; .await
.expect("should_process failed in first insert");
assert_eq!(second_check, false);
} }
} }

View File

@ -37,64 +37,54 @@ impl DbState for Database {
mod tests { mod tests {
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::DbState; use crate::db::DbState;
use std::future::Future;
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut) async fn create_db() -> Database {
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap();
let db = Database::new(db_path.path().to_str().unwrap()) Database::new(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap()
f(db).await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn set_and_get_device_id() { async fn set_and_get_device_id() {
with_db(|db| async move { let db = create_db().await;
db.set_device_id("device_id")
.await
.expect("Could not set device ID");
let device_id = db.get_device_id().await.expect("Could not get device ID"); db.set_device_id("device_id")
.await
.expect("Could not set device ID");
assert!(device_id.is_some()); let device_id = db.get_device_id().await.expect("Could not get device ID");
assert_eq!(device_id.unwrap(), "device_id");
}) assert!(device_id.is_some());
.await; assert_eq!(device_id.unwrap(), "device_id");
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn no_device_id_set_returns_none() { async fn no_device_id_set_returns_none() {
with_db(|db| async move { let db = create_db().await;
let device_id = db.get_device_id().await.expect("Could not get device ID"); let device_id = db.get_device_id().await.expect("Could not get device ID");
assert!(device_id.is_none()); assert!(device_id.is_none());
})
.await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_update_device_id() { async fn can_update_device_id() {
with_db(|db| async move { let db = create_db().await;
db.set_device_id("device_id")
.await
.expect("Could not set device ID");
db.set_device_id("device_id2") db.set_device_id("device_id")
.await .await
.expect("Could not set device ID"); .expect("Could not set device ID");
let device_id = db.get_device_id().await.expect("Could not get device ID"); db.set_device_id("device_id2")
.await
.expect("Could not set device ID");
assert!(device_id.is_some()); let device_id = db.get_device_id().await.expect("Could not get device ID");
assert_eq!(device_id.unwrap(), "device_id2");
}) assert!(device_id.is_some());
.await; assert_eq!(device_id.unwrap(), "device_id2");
} }
} }

341
src/db/sqlite/users.rs Normal file
View File

@ -0,0 +1,341 @@
use super::Database;
use crate::db::{errors::DataError, Users};
use crate::error::BotError;
use crate::models::User;
use async_trait::async_trait;
#[async_trait]
impl Users for Database {
async fn upsert_user(&self, user: &User) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query!(
r#"INSERT INTO accounts (user_id, password, account_status)
VALUES (?, ?, ?)
ON CONFLICT(user_id) DO
UPDATE SET password = ?, account_status = ?"#,
user.username,
user.password,
user.account_status,
user.password,
user.account_status
)
.execute(&mut tx)
.await?;
sqlx::query!(
r#"INSERT INTO user_state (user_id, active_room)
VALUES (?, ?)
ON CONFLICT(user_id) DO
UPDATE SET active_room = ?"#,
user.username,
user.active_room,
user.active_room
)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn delete_user(&self, username: &str) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query!(r#"DELETE FROM accounts WHERE user_id = ?"#, username)
.execute(&mut tx)
.await?;
sqlx::query!(r#"DELETE FROM user_state WHERE user_id = ?"#, username)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn get_user(&self, username: &str) -> Result<Option<User>, DataError> {
// Should be query_as! macro, but the left join breaks it with a
// non existing error message.
let user_row: Option<User> = sqlx::query_as(
r#"SELECT
a.user_id as "username",
a.password,
s.active_room,
COALESCE(a.account_status, 'not_registered') as "account_status"
FROM accounts a
LEFT JOIN user_state s on a.user_id = s.user_id
WHERE a.user_id = ?"#,
)
.bind(username)
.fetch_optional(&self.conn)
.await?;
Ok(user_row)
}
async fn authenticate_user(
&self,
username: &str,
raw_password: &str,
) -> Result<Option<User>, BotError> {
let user = self.get_user(username).await?;
Ok(user.filter(|u| u.verify_password(raw_password)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::sqlite::Database;
use crate::db::Users;
use crate::models::AccountStatus;
async fn create_db() -> Database {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
Database::new(db_path.path().to_str().unwrap())
.await
.unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn create_and_get_full_user_test() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
assert_eq!(user.account_status, AccountStatus::Registered);
assert_eq!(user.active_room, Some("myroom".to_string()));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_get_user_with_no_state_record() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::AwaitingActivation,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
sqlx::query("DELETE FROM user_state")
.execute(&db.conn)
.await
.expect("Could not delete from user_state table.");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
//These should be default values because the state record is missing.
assert_eq!(user.active_room, None);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_password() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, None);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_active_room() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
active_room: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.active_room, None);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_update_user() {
let db = create_db().await;
let insert_result1 = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
..Default::default()
})
.await;
assert!(insert_result1.is_ok());
let insert_result2 = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("123".to_string()),
active_room: Some("room".to_string()),
account_status: AccountStatus::AwaitingActivation,
})
.await;
assert!(insert_result2.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
//From second upsert
assert_eq!(user.password, Some("123".to_string()));
assert_eq!(user.active_room, Some("room".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_delete_user() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
db.delete_user("myuser")
.await
.expect("User deletion query failed");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn username_not_in_db_returns_none() {
let db = create_db().await;
let user = db
.get_user("does not exist")
.await
.expect("Get user query failure");
assert!(user.is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn authenticate_user_is_some_with_valid_password() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some(crate::logic::hash_password("abc").expect("password hash error!")),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.authenticate_user("myuser", "abc")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn authenticate_user_is_none_with_wrong_password() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some(crate::logic::hash_password("abc").expect("password hash error!")),
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.authenticate_user("myuser", "wrong-password")
.await
.expect("User retrieval query failed");
assert!(user.is_none());
}
}

View File

@ -102,156 +102,143 @@ mod tests {
use super::*; use super::*;
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::Variables; use crate::db::Variables;
use std::future::Future;
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut) async fn create_db() -> Database {
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap();
let db = Database::new(db_path.path().to_str().unwrap()) Database::new(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap()
f(db).await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn set_and_get_variable_test() { async fn set_and_get_variable_test() {
with_db(|db| async move { let db = create_db().await;
db.set_user_variable("myuser", "myroom", "myvariable", 1)
.await
.expect("Could not set variable");
let value = db db.set_user_variable("myuser", "myroom", "myvariable", 1)
.get_user_variable("myuser", "myroom", "myvariable") .await
.await .expect("Could not set variable");
.expect("Could not get variable");
assert_eq!(value, 1); let value = db
}) .get_user_variable("myuser", "myroom", "myvariable")
.await; .await
.expect("Could not get variable");
assert_eq!(value, 1);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_missing_variable_test() { async fn get_missing_variable_test() {
with_db(|db| async move { let db = create_db().await;
let value = db.get_user_variable("myuser", "myroom", "myvariable").await;
assert!(value.is_err()); let value = db.get_user_variable("myuser", "myroom", "myvariable").await;
assert!(matches!(
value.err().unwrap(), assert!(value.is_err());
DataError::KeyDoesNotExist(_) assert!(matches!(
)); value.err().unwrap(),
}) DataError::KeyDoesNotExist(_)
.await; ));
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_other_user_variable_test() { async fn get_other_user_variable_test() {
with_db(|db| async move { let db = create_db().await;
db.set_user_variable("myuser1", "myroom", "myvariable", 1)
.await
.expect("Could not set variable");
let value = db db.set_user_variable("myuser1", "myroom", "myvariable", 1)
.get_user_variable("myuser2", "myroom", "myvariable") .await
.await; .expect("Could not set variable");
assert!(value.is_err()); let value = db
assert!(matches!( .get_user_variable("myuser2", "myroom", "myvariable")
value.err().unwrap(), .await;
DataError::KeyDoesNotExist(_)
)); assert!(value.is_err());
}) assert!(matches!(
.await; value.err().unwrap(),
DataError::KeyDoesNotExist(_)
));
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn count_variables_test() { async fn count_variables_test() {
with_db(|db| async move { let db = create_db().await;
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "myroom", variable_name, 1)
.await
.expect("Could not set variable");
}
let count = db for variable_name in &["var1", "var2", "var3"] {
.get_variable_count("myuser", "myroom") db.set_user_variable("myuser", "myroom", variable_name, 1)
.await .await
.expect("Could not get count."); .expect("Could not set variable");
}
assert_eq!(count, 3); let count = db
}) .get_variable_count("myuser", "myroom")
.await; .await
.expect("Could not get count.");
assert_eq!(count, 3);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn count_variables_respects_user_id() { async fn count_variables_respects_user_id() {
with_db(|db| async move { let db = create_db().await;
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("different-user", "myroom", variable_name, 1)
.await
.expect("Could not set variable");
}
let count = db for variable_name in &["var1", "var2", "var3"] {
.get_variable_count("myuser", "myroom") db.set_user_variable("different-user", "myroom", variable_name, 1)
.await .await
.expect("Could not get count."); .expect("Could not set variable");
}
assert_eq!(count, 0); let count = db
}) .get_variable_count("myuser", "myroom")
.await; .await
.expect("Could not get count.");
assert_eq!(count, 0);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn count_variables_respects_room_id() { async fn count_variables_respects_room_id() {
with_db(|db| async move { let db = create_db().await;
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "different-room", variable_name, 1)
.await
.expect("Could not set variable");
}
let count = db for variable_name in &["var1", "var2", "var3"] {
.get_variable_count("myuser", "myroom") db.set_user_variable("myuser", "different-room", variable_name, 1)
.await .await
.expect("Could not get count."); .expect("Could not set variable");
}
assert_eq!(count, 0); let count = db
}) .get_variable_count("myuser", "myroom")
.await; .await
.expect("Could not get count.");
assert_eq!(count, 0);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn delete_variable_test() { async fn delete_variable_test() {
with_db(|db| async move { let db = create_db().await;
for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "myroom", variable_name, 1)
.await
.expect("Could not set variable");
}
db.delete_user_variable("myuser", "myroom", "var1") for variable_name in &["var1", "var2", "var3"] {
db.set_user_variable("myuser", "myroom", variable_name, 1)
.await .await
.expect("Could not delete variable."); .expect("Could not set variable");
}
let count = db db.delete_user_variable("myuser", "myroom", "var1")
.get_variable_count("myuser", "myroom") .await
.await .expect("Could not delete variable.");
.expect("Could not get count");
assert_eq!(count, 2); let count = db
.get_variable_count("myuser", "myroom")
.await
.expect("Could not get count");
let var1 = db.get_user_variable("myuser", "myroom", "var1").await; assert_eq!(count, 2);
assert!(var1.is_err());
assert!(matches!(var1.err().unwrap(), DataError::KeyDoesNotExist(_))); let var1 = db.get_user_variable("myuser", "myroom", "var1").await;
}) assert!(var1.is_err());
.await; assert!(matches!(var1.err().unwrap(), DataError::KeyDoesNotExist(_)));
} }
} }

View File

@ -1,10 +1,7 @@
use std::net::AddrParseError;
use crate::commands::CommandError; use crate::commands::CommandError;
use crate::config::ConfigError; use crate::config::ConfigError;
use crate::db::errors::DataError; use crate::db::errors::DataError;
use thiserror::Error; use thiserror::Error;
use tonic::metadata::errors::InvalidMetadataValue;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum BotError { pub enum BotError {
@ -18,12 +15,6 @@ pub enum BotError {
#[error("could not retrieve device id")] #[error("could not retrieve device id")]
NoDeviceIdFound, NoDeviceIdFound,
#[error("could not build client: {0}")]
ClientBuildError(#[from] matrix_sdk::ClientBuildError),
#[error("could not open matrix store: {0}")]
OpenStoreError(#[from] matrix_sdk::store::OpenStoreError),
#[error("command error: {0}")] #[error("command error: {0}")]
CommandError(#[from] CommandError), CommandError(#[from] CommandError),
@ -39,15 +30,15 @@ pub enum BotError {
#[error("could not parse URL")] #[error("could not parse URL")]
UrlParseError(#[from] url::ParseError), UrlParseError(#[from] url::ParseError),
#[error("could not parse ID")]
IdParseError(#[from] matrix_sdk::ruma::IdParseError),
#[error("error in matrix state store: {0}")] #[error("error in matrix state store: {0}")]
MatrixStateStoreError(#[from] matrix_sdk::StoreError), MatrixStateStoreError(#[from] matrix_sdk::StoreError),
#[error("uncategorized matrix SDK error: {0}")] #[error("uncategorized matrix SDK error: {0}")]
MatrixError(#[from] matrix_sdk::Error), MatrixError(#[from] matrix_sdk::Error),
#[error("uncategorized matrix SDK base error: {0}")]
MatrixBaseError(#[from] matrix_sdk::BaseError),
#[error("future canceled")] #[error("future canceled")]
FutureCanceledError, FutureCanceledError,
@ -85,8 +76,8 @@ pub enum BotError {
#[error("could not convert to proper integer type")] #[error("could not convert to proper integer type")]
TryFromIntError(#[from] std::num::TryFromIntError), TryFromIntError(#[from] std::num::TryFromIntError),
// #[error("identifier error: {0}")] #[error("identifier error: {0}")]
// IdentifierError(#[from] matrix_sdk::ruma::Error), IdentifierError(#[from] matrix_sdk::identifiers::Error),
#[error("password creation error: {0}")] #[error("password creation error: {0}")]
PasswordCreationError(argon2::Error), PasswordCreationError(argon2::Error),
@ -102,15 +93,6 @@ pub enum BotError {
#[error("room name or id does not exist")] #[error("room name or id does not exist")]
RoomDoesNotExist, RoomDoesNotExist,
#[error("tonic transport error: {0}")]
TonicTransportError(#[from] tonic::transport::Error),
#[error("address parsing error: {0}")]
AddressParseError(#[from] AddrParseError),
#[error("invalid metadata value: {0}")]
TonicInvalidMetadata(#[from] InvalidMetadataValue),
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]

View File

@ -6,9 +6,6 @@ pub fn parse_help_topic(input: &str) -> Option<HelpTopic> {
"dicepool" => Some(HelpTopic::DicePool), "dicepool" => Some(HelpTopic::DicePool),
"dice" => Some(HelpTopic::RollingDice), "dice" => Some(HelpTopic::RollingDice),
"cthulhu" => Some(HelpTopic::Cthulhu), "cthulhu" => Some(HelpTopic::Cthulhu),
"variables" => Some(HelpTopic::Variables),
"var" => Some(HelpTopic::Variables),
"variable" => Some(HelpTopic::Variables),
"" => Some(HelpTopic::General), "" => Some(HelpTopic::General),
_ => None, _ => None,
} }
@ -19,7 +16,6 @@ pub enum HelpTopic {
DicePool, DicePool,
Cthulhu, Cthulhu,
RollingDice, RollingDice,
Variables,
General, General,
} }
@ -105,34 +101,6 @@ Note: If !cthadv is given a variable, and the roll is successful, it will
update the variable with the new skill. update the variable with the new skill.
"}; "};
const VARIABLES_HELP: &'static str = indoc! {"
Variables
Commands: !get, !set, !variables
Manage variables that can be substituted into roll commands.
Examples: !get myvar, !set myvar 10
!get <variable> = show variable of the given name
!set <variable> <num> = set a variable to a number
The !variables command will list all variables for the room. The
variables command cna be used in a secure room to avoid spamming the
actual room that the variable is set in.
Variable names can be used in all types of dice rolls:
!pool myvar + 3
!roll myvar
There are some limitations on variables: they cannot themselves be
dice expressions (i.e. can only be numbers), and they must be uniquely
parseable in an expression (i.e 'myvard6' does not work for the !roll
command).
"};
const GENERAL_HELP: &'static str = indoc! {" const GENERAL_HELP: &'static str = indoc! {"
General Help General Help
@ -149,7 +117,6 @@ impl HelpTopic {
HelpTopic::DicePool => DICEPOOL_HELP, HelpTopic::DicePool => DICEPOOL_HELP,
HelpTopic::Cthulhu => CTHULHU_HELP, HelpTopic::Cthulhu => CTHULHU_HELP,
HelpTopic::RollingDice => DICE_HELP, HelpTopic::RollingDice => DICE_HELP,
HelpTopic::Variables => VARIABLES_HELP,
HelpTopic::General => GENERAL_HELP, HelpTopic::General => GENERAL_HELP,
} }
} }

View File

@ -12,6 +12,4 @@ pub mod logic;
pub mod matrix; pub mod matrix;
pub mod models; pub mod models;
mod parser; mod parser;
pub mod rpc;
pub mod state; pub mod state;
pub mod systems;

View File

@ -27,7 +27,7 @@ pub async fn calculate_dice_amount(amounts: &[Amount], ctx: &Context<'_>) -> Res
let stream = stream::iter(amounts); let stream = stream::iter(amounts);
let variables = &ctx let variables = &ctx
.db .db
.get_user_variables(&ctx.username, ctx.active_room_id().as_str()) .get_user_variables(&ctx.username, ctx.room_id().as_str())
.await?; .await?;
use DiceRollingError::VariableNotFound; use DiceRollingError::VariableNotFound;
@ -71,61 +71,53 @@ mod tests {
use super::*; use super::*;
use crate::db::Users; use crate::db::Users;
use crate::models::{AccountStatus, User}; use crate::models::{AccountStatus, User};
use std::future::Future;
async fn with_db<Fut>(f: impl FnOnce(Database) -> Fut) async fn create_db() -> Database {
where
Fut: Future<Output = ()>,
{
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap();
let db = Database::new(db_path.path().to_str().unwrap()) Database::new(db_path.path().to_str().unwrap())
.await .await
.unwrap(); .unwrap()
f(db).await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_account_no_user_exists() { async fn get_account_no_user_exists() {
with_db(|db| async move { let db = create_db().await;
let account = get_account(&db, "@test:example.com")
.await
.expect("Account retrieval didn't work");
assert!(matches!(account, Account::Transient(_))); let account = get_account(&db, "@test:example.com")
.await
.expect("Account retrieval didn't work");
let user = account.transient_user().unwrap(); assert!(matches!(account, Account::Transient(_)));
assert_eq!(user.username, "@test:example.com");
}) let user = account.transient_user().unwrap();
.await; assert_eq!(user.username, "@test:example.com");
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_or_create_user_when_user_exists() { async fn get_or_create_user_when_user_exists() {
with_db(|db| async move { let db = create_db().await;
let user = User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
};
let insert_result = db.upsert_user(&user).await; let user = User {
assert!(insert_result.is_ok()); username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
};
let account = get_account(&db, "myuser") let insert_result = db.upsert_user(&user).await;
.await assert!(insert_result.is_ok());
.expect("Account retrieval did not work");
assert!(matches!(account, Account::Registered(_))); let account = get_account(&db, "myuser")
.await
.expect("Account retrieval did not work");
let user_again = account.registered_user().unwrap(); assert!(matches!(account, Account::Registered(_)));
assert_eq!(user, *user_again);
}) let user_again = account.registered_user().unwrap();
.await; assert_eq!(user, user_again);
} }
} }

View File

@ -1,22 +1,14 @@
use std::path::PathBuf;
use futures::stream::{self, StreamExt, TryStreamExt}; use futures::stream::{self, StreamExt, TryStreamExt};
use log::error; use log::error;
use matrix_sdk::ruma::events::room::message::{InReplyTo, RoomMessageEventContent, Relation}; use matrix_sdk::{events::room::message::NoticeMessageEventContent, room::Joined};
use matrix_sdk::ruma::events::AnyMessageLikeEventContent; use matrix_sdk::{
use matrix_sdk::ruma::{RoomId, OwnedEventId, OwnedUserId}; events::room::message::{InReplyTo, Relation},
use matrix_sdk::Client; events::room::message::{MessageEventContent, MessageType},
use matrix_sdk::Error as MatrixError; events::AnyMessageEventContent,
use matrix_sdk::room::Joined; identifiers::EventId,
use url::Url; Error as MatrixError,
};
use crate::{config::Config, error::BotError}; use matrix_sdk::{identifiers::RoomId, identifiers::UserId, Client};
fn cache_dir() -> Result<PathBuf, BotError> {
let mut dir = dirs::cache_dir().ok_or(BotError::NoCacheDirectoryError)?;
dir.push("matrix-dicebot");
Ok(dir)
}
/// Extracts more detailed error messages out of a matrix SDK error. /// Extracts more detailed error messages out of a matrix SDK error.
fn extract_error_message(error: MatrixError) -> String { fn extract_error_message(error: MatrixError) -> String {
@ -28,19 +20,6 @@ fn extract_error_message(error: MatrixError) -> String {
} }
} }
/// Creates the matrix client.
pub async fn create_client(config: &Config) -> Result<Client, BotError> {
let cache_dir = cache_dir()?;
let homeserver_url = Url::parse(&config.matrix_homeserver())?;
let client = Client::builder()
.sled_store(cache_dir, None)?
.homeserver_url(homeserver_url).build()
.await?;
Ok(client)
}
/// Retrieve a list of users in a given room. /// Retrieve a list of users in a given room.
pub async fn get_users_in_room( pub async fn get_users_in_room(
client: &Client, client: &Client,
@ -60,7 +39,7 @@ pub async fn get_users_in_room(
pub async fn get_rooms_for_user( pub async fn get_rooms_for_user(
client: &Client, client: &Client,
user: &OwnedUserId, user: &UserId,
) -> Result<Vec<Joined>, MatrixError> { ) -> Result<Vec<Joined>, MatrixError> {
// Carries errors through, in case we cannot load joined user IDs // Carries errors through, in case we cannot load joined user IDs
// from the room for some reason. // from the room for some reason.
@ -88,7 +67,7 @@ pub async fn send_message(
client: &Client, client: &Client,
room_id: &RoomId, room_id: &RoomId,
message: (&str, &str), message: (&str, &str),
reply_to: Option<OwnedEventId>, reply_to: Option<EventId>,
) { ) {
let (html, plain) = message; let (html, plain) = message;
let room = match client.get_joined_room(room_id) { let room = match client.get_joined_room(room_id) {
@ -96,13 +75,15 @@ pub async fn send_message(
_ => return, _ => return,
}; };
let mut content = RoomMessageEventContent::notice_html(plain.trim(), html); let mut content = MessageEventContent::new(MessageType::Notice(
NoticeMessageEventContent::html(plain.trim(), html),
));
content.relates_to = reply_to.map(|event_id| Relation::Reply { content.relates_to = reply_to.map(|event_id| Relation::Reply {
in_reply_to: InReplyTo::new(event_id) in_reply_to: InReplyTo::new(event_id),
}); });
let content = AnyMessageLikeEventContent::RoomMessage(content); let content = AnyMessageEventContent::RoomMessage(content);
let result = room.send(content, None).await; let result = room.send(content, None).await;

View File

@ -60,9 +60,9 @@ impl Account {
/// Consume self into an Option<User> instance, which will be Some /// Consume self into an Option<User> instance, which will be Some
/// if this account has a registered user, and None otherwise. /// if this account has a registered user, and None otherwise.
pub fn registered_user(&self) -> Option<&User> { pub fn registered_user(self) -> Option<User> {
match self { match self {
Self::Registered(ref user) => Some(user), Self::Registered(user) => Some(user),
_ => None, _ => None,
} }
} }

View File

@ -151,9 +151,8 @@ where
/// should not have an operator, but every one after that should. /// should not have an operator, but every one after that should.
/// Accepts expressions like "8", "10 + variablename", "variablename - /// Accepts expressions like "8", "10 + variablename", "variablename -
/// 3", etc. This function is currently common to systems that don't /// 3", etc. This function is currently common to systems that don't
/// deal with XdY rolls. Support for that will be added later. Returns /// deal with XdY rolls. Support for that will be added later.
/// parsed amounts and unconsumed input (e.g. roll modifiers). pub fn parse_amounts(input: &str) -> ParseResult<Vec<Amount>> {
pub fn parse_amounts(input: &str) -> ParseResult<(Vec<Amount>, &str)> {
let input = input.trim(); let input = input.trim();
let remaining_amounts = many(amount_parser()).map(|amounts: Vec<ParseResult<Amount>>| amounts); let remaining_amounts = many(amount_parser()).map(|amounts: Vec<ParseResult<Amount>>| amounts);
@ -170,23 +169,31 @@ pub fn parse_amounts(input: &str) -> ParseResult<(Vec<Amount>, &str)> {
(amounts, results.1) (amounts, results.1)
})?; })?;
// Any ParseResult errors will short-circuit the collect. if rest.len() == 0 {
let results: Vec<Amount> = results.into_iter().collect::<ParseResult<_>>()?; // Any ParseResult errors will short-circuit the collect.
Ok((results, rest)) results.into_iter().collect()
} else {
Err(DiceParsingError::UnconsumedInput)
}
} }
/// Parse an expression that expects a single number or variable. No /// Parse an expression that expects a single number or variable. No
/// operators are allowed. This function is common to systems that /// operators are allowed. This function is common to systems that
/// don't deal with XdY rolls. Currently. this function does not /// don't deal with XdY rolls. Currently. this function does not
/// support parsing negative numbers. Returns the parsed amount and /// support parsing negative numbers.
/// any unconsumed input (useful for dice roll modifiers). pub fn parse_single_amount(input: &str) -> ParseResult<Amount> {
pub fn parse_single_amount(input: &str) -> ParseResult<(Amount, &str)> {
// TODO add support for negative numbers, as technically they // TODO add support for negative numbers, as technically they
// should be allowed. // should be allowed.
let input = input.trim(); let input = input.trim();
let mut parser = first_amount_parser().map(|amount: ParseResult<Amount>| amount); let mut parser = first_amount_parser().map(|amount: ParseResult<Amount>| amount);
let (result, rest) = parser.parse(input)?; let (result, rest) = parser.parse(input)?;
Ok((result?, rest))
if rest.len() == 0 {
result
} else {
Err(DiceParsingError::UnconsumedInput)
}
} }
#[cfg(test)] #[cfg(test)]
@ -199,13 +206,10 @@ mod parse_single_amount_tests {
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
( Amount {
Amount { operator: Operator::Plus,
operator: Operator::Plus, element: Element::Variable("abc".to_string())
element: Element::Variable("abc".to_string()) }
},
""
)
) )
} }
@ -229,15 +233,24 @@ mod parse_single_amount_tests {
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
( Amount {
Amount { operator: Operator::Plus,
operator: Operator::Plus, element: Element::Number(1)
element: Element::Number(1) }
},
""
)
) )
} }
#[test]
fn parse_multiple_elements_test() {
let result = parse_single_amount("1+abc");
assert!(result.is_err());
let result = parse_single_amount("abc+1");
assert!(result.is_err());
let result = parse_single_amount("-1-abc");
assert!(result.is_err());
}
} }
#[cfg(test)] #[cfg(test)]
@ -250,26 +263,20 @@ mod parse_many_amounts_tests {
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
( vec![Amount {
vec![Amount { operator: Operator::Plus,
operator: Operator::Plus, element: Element::Number(1)
element: Element::Number(1) }]
}],
""
)
); );
let result = parse_amounts("10"); let result = parse_amounts("10");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
( vec![Amount {
vec![Amount { operator: Operator::Plus,
operator: Operator::Plus, element: Element::Number(10)
element: Element::Number(10) }]
}],
""
)
); );
} }
@ -288,26 +295,20 @@ mod parse_many_amounts_tests {
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
( vec![Amount {
vec![Amount { operator: Operator::Plus,
operator: Operator::Plus, element: Element::Variable("asdf".to_string())
element: Element::Variable("asdf".to_string()) }]
}],
""
)
); );
let result = parse_amounts("nosis"); let result = parse_amounts("nosis");
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!( assert_eq!(
result.unwrap(), result.unwrap(),
( vec![Amount {
vec![Amount { operator: Operator::Plus,
operator: Operator::Plus, element: Element::Variable("nosis".to_string())
element: Element::Variable("nosis".to_string()) }]
}],
""
)
); );
} }