Compare commits

..

2 Commits

Author SHA1 Message Date
projectmoon 4ada1697ee Fix broken get user query, because it's broken with query_as! macro
continuous-integration/drone/push Build is failing Details
2021-05-25 16:10:55 +00:00
projectmoon 5362488645 Tack state onto user accounts. Make password optional. Everybody is a user! 2021-05-25 15:06:50 +00:00
11 changed files with 56 additions and 238 deletions

View File

@ -4,7 +4,6 @@ use tenebrous_dicebot::commands::ResponseExtractor;
use tenebrous_dicebot::context::{Context, RoomContext}; use tenebrous_dicebot::context::{Context, RoomContext};
use tenebrous_dicebot::db::sqlite::Database; use tenebrous_dicebot::db::sqlite::Database;
use tenebrous_dicebot::error::BotError; use tenebrous_dicebot::error::BotError;
use tenebrous_dicebot::models::User;
use url::Url; use url::Url;
#[tokio::main] #[tokio::main]
@ -27,7 +26,6 @@ async fn main() -> Result<(), BotError> {
let context = Context { let context = Context {
db: db, db: db,
user: User::default(),
matrix_client: &matrix_sdk::Client::new(homeserver) matrix_client: &matrix_sdk::Client::new(homeserver)
.expect("Could not create matrix client"), .expect("Could not create matrix client"),
room: RoomContext { room: RoomContext {

View File

@ -1,7 +1,6 @@
use crate::commands::{execute_command, ExecutionError, ExecutionResult, ResponseExtractor}; use crate::commands::{execute_command, ExecutionError, ExecutionResult, ResponseExtractor};
use crate::context::{Context, RoomContext}; use crate::context::{Context, RoomContext};
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::Users;
use crate::error::BotError; use crate::error::BotError;
use crate::matrix; use crate::matrix;
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
@ -78,7 +77,6 @@ async fn create_context<'a>(
matrix_client: client, matrix_client: client,
room: room_ctx, room: room_ctx,
username: &sender, username: &sender,
user: db.get_or_create_user(&sender).await?,
message_body: &command, message_body: &command,
}) })
} }

View File

@ -483,7 +483,6 @@ mod tests {
.unwrap(); .unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(), room: dummy_room!(),
@ -524,7 +523,6 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap(); let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(), room: dummy_room!(),
@ -562,7 +560,6 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap(); let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db.clone(), db: db.clone(),
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(), room: dummy_room!(),

View File

@ -201,7 +201,6 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap(); let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: secure_room!(), room: secure_room!(),
@ -223,7 +222,6 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap(); let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: secure_room!(), room: secure_room!(),
@ -245,7 +243,6 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap(); let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(), room: dummy_room!(),
@ -267,7 +264,6 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap(); let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(), room: dummy_room!(),
@ -298,7 +294,6 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap(); let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(), room: dummy_room!(),

View File

@ -1,6 +1,5 @@
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::error::BotError; use crate::error::BotError;
use crate::models::User;
use matrix_sdk::identifiers::{RoomId, UserId}; use matrix_sdk::identifiers::{RoomId, UserId};
use matrix_sdk::room::Joined; use matrix_sdk::room::Joined;
use matrix_sdk::Client; use matrix_sdk::Client;
@ -15,7 +14,6 @@ pub struct Context<'a> {
pub room: RoomContext<'a>, pub room: RoomContext<'a>,
pub username: &'a str, pub username: &'a str,
pub message_body: &'a str, pub message_body: &'a str,
pub user: User,
} }
impl Context<'_> { impl Context<'_> {

View File

@ -504,7 +504,6 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap(); let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(), room: dummy_room!(),
@ -541,7 +540,6 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap(); let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(), room: dummy_room!(),
@ -578,7 +576,6 @@ mod tests {
let homeserver = Url::parse("http://example.com").unwrap(); let homeserver = Url::parse("http://example.com").unwrap();
let ctx = Context { let ctx = Context {
user: crate::models::User::default(),
db: db, db: db,
matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(), matrix_client: &matrix_sdk::Client::new(homeserver).unwrap(),
room: dummy_room!(), room: dummy_room!(),

View File

@ -16,8 +16,6 @@ pub(crate) trait DbState {
#[async_trait] #[async_trait]
pub(crate) trait Users { pub(crate) trait Users {
async fn get_or_create_user(&self, username: &str) -> Result<User, DataError>;
async fn upsert_user(&self, user: &User) -> Result<(), DataError>; async fn upsert_user(&self, user: &User) -> Result<(), DataError>;
async fn get_user(&self, username: &str) -> Result<Option<User>, DataError>; async fn get_user(&self, username: &str) -> Result<Option<User>, DataError>;

View File

@ -6,12 +6,15 @@ fn primary_uuid() -> Type {
} }
pub fn migration() -> String { pub fn migration() -> String {
let status_enum =
r#"CHECK(account_status IN ('not_registered', 'registered', 'awaiting_activation'))"#;
let mut m = Migration::new(); let mut m = Migration::new();
// Keep track of contextual user state. // Keep track of contextual user state.
m.create_table("user_state", move |t| { m.create_table("user_state", move |t| {
t.add_column("user_id", primary_uuid()); t.add_column("user_id", primary_uuid());
t.add_column("active_room", types::text().nullable(true)); t.add_column("active_room", types::text().nullable(true));
t.add_column("account_status", types::custom(status_enum).nullable(false));
}); });
m.make::<Sqlite>() m.make::<Sqlite>()

View File

@ -1,15 +1,19 @@
use barrel::backend::Sqlite;
use barrel::{types, types::Type, Migration};
fn primary_uuid() -> Type {
types::text().unique(true).primary(true).nullable(false)
}
pub fn migration() -> String { pub fn migration() -> String {
// sqlite does really support alter column, and barrel does not // sqlite does really support alter column, and barrel does not
// implement the required workaround, so we do it ourselves! // implement the required workaround, so we do it ourselves!
r#" r#"
CREATE TABLE IF NOT EXISTS "accounts2" ( CREATE TABLE IF NOT EXISTS "accounts2" (
"user_id" TEXT PRIMARY KEY NOT NULL UNIQUE, "user_id" TEXT PRIMARY KEY NOT NULL UNIQUE,
"password" TEXT NULL, "password" TEXT NULL
"account_status" TEXT NOT NULL CHECK(
account_status IN ('not_registered', 'registered', 'awaiting_activation'
))
); );
INSERT INTO accounts2 select *, 'registered' FROM accounts; INSERT INTO accounts2 select * from accounts;
DROP TABLE accounts; DROP TABLE accounts;
ALTER TABLE accounts2 RENAME TO accounts; ALTER TABLE accounts2 RENAME TO accounts;
"# "#

View File

@ -1,72 +1,63 @@
use super::Database; use super::Database;
use crate::db::{errors::DataError, Users}; use crate::db::{errors::DataError, Users};
use crate::error::BotError; use crate::error::BotError;
use crate::models::User; use crate::models::{AccountStatus, User};
use async_trait::async_trait; use async_trait::async_trait;
use log::info; use std::convert::From;
#[derive(Eq, PartialEq, Debug, Default, sqlx::FromRow)]
struct UserRow {
pub username: String,
pub password: Option<String>,
pub active_room: Option<String>,
pub account_status: Option<AccountStatus>,
}
impl From<UserRow> for User {
fn from(row: UserRow) -> Self {
User {
username: row.username,
password: row.password,
active_room: row.active_room,
account_status: row.account_status.unwrap_or_default(),
}
}
}
#[async_trait] #[async_trait]
impl Users for Database { impl Users for Database {
async fn upsert_user(&self, user: &User) -> Result<(), DataError> { async fn upsert_user(&self, user: &User) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query( sqlx::query(
r#"INSERT INTO accounts (user_id, password, account_status) r#"INSERT INTO accounts (user_id, password) VALUES (?, ?)
VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET password = ?"#,
ON CONFLICT(user_id) DO
UPDATE SET password = ?, account_status = ?"#,
) )
.bind(&user.username) .bind(&user.username)
.bind(&user.password) .bind(&user.password)
.bind(&user.account_status)
.bind(&user.password) .bind(&user.password)
.bind(&user.account_status) .execute(&self.conn)
.execute(&mut tx)
.await?; .await?;
sqlx::query(
r#"INSERT INTO user_state (user_id, active_room)
VALUES (?, ?)
ON CONFLICT(user_id) DO
UPDATE SET active_room = ?"#,
)
.bind(&user.username)
.bind(&user.active_room)
.bind(&user.active_room)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(()) Ok(())
} }
async fn delete_user(&self, username: &str) -> Result<(), DataError> { async fn delete_user(&self, username: &str) -> Result<(), DataError> {
let mut tx = self.conn.begin().await?;
sqlx::query(r#"DELETE FROM accounts WHERE user_id = ?"#) sqlx::query(r#"DELETE FROM accounts WHERE user_id = ?"#)
.bind(&username) .bind(&username)
.execute(&mut tx) .execute(&self.conn)
.await?; .await?;
sqlx::query(r#"DELETE FROM user_state WHERE user_id = ?"#)
.bind(&username)
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(()) Ok(())
} }
async fn get_user(&self, username: &str) -> Result<Option<User>, DataError> { async fn get_user(&self, username: &str) -> Result<Option<User>, DataError> {
// Should be query_as! macro, but the left join breaks it with a // Should be query_as! macro, but the left join breaks it with a
// non existing error message. // non existing error message.
let user_row: Option<User> = sqlx::query_as( let user_row: Option<UserRow> = sqlx::query_as(
r#"SELECT r#"SELECT
a.user_id as "username", a.user_id as "username",
a.password, a.password,
s.active_room, s.active_room,
COALESCE(a.account_status, 'not_registered') as "account_status" s.account_status
FROM accounts a FROM accounts a
LEFT JOIN user_state s on a.user_id = s.user_id LEFT JOIN user_state s on a.user_id = s.user_id
WHERE a.user_id = ?"#, WHERE a.user_id = ?"#,
@ -75,22 +66,7 @@ impl Users for Database {
.fetch_optional(&self.conn) .fetch_optional(&self.conn)
.await?; .await?;
Ok(user_row) Ok(user_row.map(|r| r.into()))
}
//TODO should this logic be moved further up into logic.rs maybe?
async fn get_or_create_user(&self, username: &str) -> Result<User, DataError> {
let maybe_user = self.get_user(username).await?;
match maybe_user {
Some(user) => Ok(user),
None => {
info!("Creating unregistered account for {}", username);
let user = User::unregistered(&username);
self.upsert_user(&user).await?;
Ok(user)
}
}
} }
async fn authenticate_user( async fn authenticate_user(
@ -99,6 +75,10 @@ impl Users for Database {
raw_password: &str, raw_password: &str,
) -> Result<Option<User>, BotError> { ) -> Result<Option<User>, BotError> {
let user = self.get_user(username).await?; let user = self.get_user(username).await?;
println!(
"user pw is {:?}",
user.as_ref().map(|u| u.password.as_ref())
);
Ok(user.filter(|u| u.verify_password(raw_password))) Ok(user.filter(|u| u.verify_password(raw_password)))
} }
} }
@ -108,8 +88,8 @@ mod tests {
use super::*; use super::*;
use crate::db::sqlite::Database; use crate::db::sqlite::Database;
use crate::db::Users; use crate::db::Users;
use crate::models::AccountStatus;
//TODO test selecting user when state doesn't exist.
async fn create_db() -> Database { async fn create_db() -> Database {
let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); let db_path = tempfile::NamedTempFile::new_in(".").unwrap();
crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap())
@ -122,57 +102,13 @@ mod tests {
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_or_create_user_no_user_exists() { async fn create_and_get_user_test() {
let db = create_db().await;
let user = db
.get_or_create_user("@test:example.com")
.await
.expect("User creation didn't work.");
assert_eq!(user.username, "@test:example.com");
let user_again = db
.get_user("@test:example.com")
.await
.expect("User retrieval didn't work.")
.expect("No user returned from option.");
assert_eq!(user, user_again);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn get_or_create_user_when_user_exists() {
let db = create_db().await;
let user = User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
};
let insert_result = db.upsert_user(&user).await;
assert!(insert_result.is_ok());
let user_again = db
.get_or_create_user("myuser")
.await
.expect("User retrieval didn't work.");
assert_eq!(user, user_again);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn create_and_get_full_user_test() {
let db = create_db().await; let db = create_db().await;
let insert_result = db let insert_result = db
.upsert_user(&User { .upsert_user(&User {
username: "myuser".to_string(), username: "myuser".to_string(),
password: Some("abc".to_string()), password: "abc".to_string(),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
}) })
.await; .await;
@ -186,94 +122,7 @@ mod tests {
assert!(user.is_some()); assert!(user.is_some());
let user = user.unwrap(); let user = user.unwrap();
assert_eq!(user.username, "myuser"); assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string())); assert_eq!(user.password, "abc");
assert_eq!(user.account_status, AccountStatus::Registered);
assert_eq!(user.active_room, Some("myroom".to_string()));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_get_user_with_no_state_record() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: Some("abc".to_string()),
account_status: AccountStatus::Registered,
active_room: Some("myroom".to_string()),
})
.await;
assert!(insert_result.is_ok());
sqlx::query("DELETE FROM user_state")
.execute(&db.conn)
.await
.expect("Could not delete from user_state table.");
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, Some("abc".to_string()));
//These should be default values because the state record is missing.
assert_eq!(user.account_status, AccountStatus::NotRegistered);
assert_eq!(user.active_room, None);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_password() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
password: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.password, None);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn can_insert_without_active_room() {
let db = create_db().await;
let insert_result = db
.upsert_user(&User {
username: "myuser".to_string(),
active_room: None,
..Default::default()
})
.await;
assert!(insert_result.is_ok());
let user = db
.get_user("myuser")
.await
.expect("User retrieval query failed");
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.username, "myuser");
assert_eq!(user.active_room, None);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
@ -283,8 +132,7 @@ mod tests {
let insert_result1 = db let insert_result1 = db
.upsert_user(&User { .upsert_user(&User {
username: "myuser".to_string(), username: "myuser".to_string(),
password: Some("abc".to_string()), password: "abc".to_string(),
..Default::default()
}) })
.await; .await;
@ -293,9 +141,7 @@ mod tests {
let insert_result2 = db let insert_result2 = db
.upsert_user(&User { .upsert_user(&User {
username: "myuser".to_string(), username: "myuser".to_string(),
password: Some("123".to_string()), password: "123".to_string(),
active_room: Some("room".to_string()),
account_status: AccountStatus::AwaitingActivation,
}) })
.await; .await;
@ -309,11 +155,7 @@ mod tests {
assert!(user.is_some()); assert!(user.is_some());
let user = user.unwrap(); let user = user.unwrap();
assert_eq!(user.username, "myuser"); assert_eq!(user.username, "myuser");
assert_eq!(user.password, "123"); //From second upsert
//From second upsert
assert_eq!(user.password, Some("123".to_string()));
assert_eq!(user.active_room, Some("room".to_string()));
assert_eq!(user.account_status, AccountStatus::AwaitingActivation);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
@ -323,8 +165,7 @@ mod tests {
let insert_result = db let insert_result = db
.upsert_user(&User { .upsert_user(&User {
username: "myuser".to_string(), username: "myuser".to_string(),
password: Some("abc".to_string()), password: "abc".to_string(),
..Default::default()
}) })
.await; .await;
@ -360,8 +201,7 @@ mod tests {
let insert_result = db let insert_result = db
.upsert_user(&User { .upsert_user(&User {
username: "myuser".to_string(), username: "myuser".to_string(),
password: Some(crate::logic::hash_password("abc").expect("password hash error!")), password: crate::logic::hash_password("abc").expect("password hash error!"),
..Default::default()
}) })
.await; .await;
@ -384,8 +224,7 @@ mod tests {
let insert_result = db let insert_result = db
.upsert_user(&User { .upsert_user(&User {
username: "myuser".to_string(), username: "myuser".to_string(),
password: Some(crate::logic::hash_password("abc").expect("password hash error!")), password: crate::logic::hash_password("abc").expect("password hash error!"),
..Default::default()
}) })
.await; .await;

View File

@ -7,7 +7,7 @@ pub struct RoomInfo {
pub room_name: String, pub room_name: String,
} }
#[derive(Eq, PartialEq, Clone, Copy, Debug, sqlx::Type)] #[derive(Eq, PartialEq, Debug, sqlx::Type)]
#[sqlx(rename_all = "snake_case")] #[sqlx(rename_all = "snake_case")]
pub enum AccountStatus { pub enum AccountStatus {
/// User is not registered, which means the "account" only exists /// User is not registered, which means the "account" only exists
@ -30,7 +30,7 @@ impl Default for AccountStatus {
} }
} }
#[derive(Eq, PartialEq, Clone, Debug, Default, sqlx::FromRow)] #[derive(Eq, PartialEq, Debug, Default)]
pub struct User { pub struct User {
pub username: String, pub username: String,
pub password: Option<String>, pub password: Option<String>,
@ -39,15 +39,6 @@ pub struct User {
} }
impl User { impl User {
/// Create a new unregistered skeleton marker account for a
/// username.
pub fn unregistered(username: &str) -> User {
User {
username: username.to_owned(),
..Default::default()
}
}
pub fn verify_password(&self, raw_password: &str) -> bool { pub fn verify_password(&self, raw_password: &str) -> bool {
self.password self.password
.as_ref() .as_ref()