Do not automatically create accounts; use enum to show this instead. #75

Manually merged
projectmoon merged 2 commits from user-refactor into master 2021-05-26 16:03:10 +00:00
10 changed files with 149 additions and 81 deletions
Showing only changes of commit 495df13fe6 - Show all commits

View File

@ -4,7 +4,7 @@ use tenebrous_dicebot::commands::ResponseExtractor;
use tenebrous_dicebot::context::{Context, RoomContext};
use tenebrous_dicebot::db::sqlite::Database;
use tenebrous_dicebot::error::BotError;
use tenebrous_dicebot::models::User;
use tenebrous_dicebot::models::Account;
use url::Url;
#[tokio::main]
@ -27,7 +27,7 @@ async fn main() -> Result<(), BotError> {
let context = Context {
db: db,
user: User::default(),
account: Account::default(),
matrix_client: &matrix_sdk::Client::new(homeserver)
.expect("Could not create matrix client"),
room: RoomContext {

View File

@ -1,8 +1,8 @@
use crate::commands::{execute_command, ExecutionError, ExecutionResult, ResponseExtractor};
use crate::context::{Context, RoomContext};
use crate::db::sqlite::Database;
use crate::db::Users;
use crate::error::BotError;
use crate::logic;
use crate::matrix;
use futures::stream::{self, StreamExt};
use matrix_sdk::{self, identifiers::EventId, room::Joined, Client};
@ -78,7 +78,7 @@ async fn create_context<'a>(
matrix_client: client,
room: room_ctx,
username: &sender,
user: db.get_or_create_user(&sender).await?,
account: logic::get_account(db, &sender).await?,
message_body: &command,
})
}

View File

@ -483,7 +483,7 @@ mod tests {
.unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
@ -524,7 +524,7 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
@ -562,7 +562,7 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db.clone(),
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),

View File

@ -201,7 +201,7 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: secure_room!(),
@ -223,7 +223,7 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: secure_room!(),
@ -245,7 +245,7 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
@ -267,7 +267,7 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
@ -298,7 +298,7 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),

View File

@ -1,6 +1,6 @@
use crate::db::sqlite::Database;
use crate::error::BotError;
use crate::models::User;
use crate::models::{Account, User};
use matrix_sdk::identifiers::{RoomId, UserId};
use matrix_sdk::room::Joined;
use matrix_sdk::Client;
@ -15,7 +15,7 @@ pub struct Context<'a> {
pub room: RoomContext<'a>,
pub username: &'a str,
pub message_body: &'a str,
pub user: User,
pub account: Account,
}
impl Context<'_> {

View File

@ -504,7 +504,7 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
@ -541,7 +541,7 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),
@ -578,7 +578,7 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context {
user: crate::models::User::default(),
account: crate::models::Account::default(),
db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(),

View File

@ -16,8 +16,6 @@ pub(crate) trait DbState {
#[async_trait]
pub(crate) trait Users {
async fn get_or_create_user(&self, username: &str) -> Result<User, DataError>;
async fn upsert_user(&self, user: &User) -> Result<(), DataError>;
async fn get_user(&self, username: &str) -> Result<Option<User>, DataError>;

View File

@ -76,21 +76,6 @@ impl Users for Database {
Ok(user_row)
}
//TODO should this logic be moved further up into logic.rs maybe?
async fn get_or_create_user(&self, username: &str) -> Result<User, DataError> {
let maybe_user = self.get_user(username).await?;
match maybe_user {
Some(user) => Ok(user),
None => {
info!("Creating unregistered account for {}", username);
let user = User::unregistered(&username);
self.upsert_user(&user).await?;
Ok(user)
}
}
}
async fn authenticate_user(
&self,
username: &str,
@ -119,48 +104,6 @@ mod tests {
.unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_or_create_user_no_user_exists() {
let db = create_db().await;
let user = db
.get_or_create_user("@test:example.com")
.await
.expect("User creation didn't work.");
assert_eq!(user.username, "@test:example.com");
let user_again = db
.get_user("@test:example.com")
.await
.expect("User retrieval didn't work.")
.expect("No user returned from option.");
assert_eq!(user, user_again);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_or_create_user_when_user_exists() {
let db = create_db().await;
let user = User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
};
let insert_result = db.upsert_user(&user).await;
assert!(insert_result.is_ok());
let user_again = db
.get_or_create_user("myuser")
.await
.expect("User retrieval didn't work.");
assert_eq!(user, user_again);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn create_and_get_full_user_test() {
let db = create_db().await;

View File

@ -1,7 +1,10 @@
use crate::context::Context;
use crate::db::Variables;
use crate::error::{BotError, DiceRollingError};
use crate::parser::dice::{Amount, Element};
use crate::{context::Context, models::Account};
use crate::{
db::{sqlite::Database, Users, Variables},
models::TransientUser,
};
use argon2::{self, Config, Error as ArgonError};
use futures::stream::{self, StreamExt, TryStreamExt};
use rand::Rng;
@ -50,3 +53,71 @@ pub(crate) fn hash_password(raw_password: &str) -> Result<String, ArgonError> {
let config = Config::default();
argon2::hash_encoded(raw_password.as_bytes(), &salt, &config)
}
pub(crate) async fn get_account(db: &Database, username: &str) -> Result<Account, BotError> {
Ok(db
.get_user(username)
.await?
.map(|user| Account::Registered(user))
.unwrap_or_else(|| {
Account::Transient(TransientUser {
username: username.to_owned(),
})
}))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::Users;
use crate::models::{AccountStatus, User};
async fn create_db() -> Database {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
.await
.unwrap();
Database::new(db_path.path().to_str().unwrap())
.await
.unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_account_no_user_exists() {
let db = create_db().await;
let account = get_account(&db, "@test:example.com")
.await
.expect("Account retrieval didn't work");
assert!(matches!(account, Account::Transient(_)));
let user = account.transient_user().unwrap();
assert_eq!(user.username, "@test:example.com");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_or_create_user_when_user_exists() {
let db = create_db().await;
let user = User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
};
let insert_result = db.upsert_user(&user).await;
assert!(insert_result.is_ok());
let account = get_account(&db, "myuser")
.await
.expect("Account retrieval did not work");
assert!(matches!(account, Account::Registered(_)));
let user_again = account.registered_user().unwrap();
assert_eq!(user, user_again);
}
}

View File

@ -10,9 +10,9 @@ pub struct RoomInfo {
#[derive(Eq, PartialEq, Clone, Copy, Debug, sqlx::Type)]
#[sqlx(rename_all = "snake_case")]
pub enum AccountStatus {
/// User is not registered, which means the "account" only exists
/// for state management in the bot. No privileged actions
/// possible.
/// Account is not registered, which means a transient "account"
/// with limited information exists only for the duration of the
/// command request.
NotRegistered,
/// User account is fully registered, either via Matrix directly,
@ -30,6 +30,62 @@ impl Default for AccountStatus {
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Account {
/// A registered user account, stored in the database.
Registered(User),
/// A transient account. Not stored in the database. Represents a
/// user in a public channel that has not registered directly with
/// the bot yet.
Transient(TransientUser),
}
impl Account {
/// Gets the account status. For registered users, this is their
/// actual account status (fully registered or awaiting
/// activation). For transient users, this is
/// AccountStatus::NotRegistered.
pub fn account_status(&self) -> AccountStatus {
match self {
Self::Registered(user) => user.account_status,
Self::Transient(_) => AccountStatus::NotRegistered,
}
}
/// Consume self into an Option<User> instance, which will be Some
/// if this account has a registered user, and None otherwise.
pub fn registered_user(self) -> Option<User> {
match self {
Self::Registered(user) => Some(user),
_ => None,
}
}
/// Consume self into an Option<TransientUser> instance, which
/// will be Some if this account has a non-registered user, and
/// None otherwise.
pub fn transient_user(self) -> Option<TransientUser> {
match self {
Self::Transient(user) => Some(user),
_ => None,
}
}
}
impl Default for Account {
fn default() -> Self {
Account::Transient(TransientUser {
username: "".to_string(),
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TransientUser {
pub username: String,
}
#[derive(Eq, PartialEq, Clone, Debug, Default, sqlx::FromRow)]
pub struct User {
pub username: String,