diff --git a/Cargo.lock b/Cargo.lock index 07fb519..fe612d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -50,23 +50,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "ahash" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "739f4a8db6605981345c5654f3a85b056ce52f37a39d34da03f25bf2151ea16e" - -[[package]] -name = "ahash" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "796540673305a66d127804eef19ad696f1f204b8c1025aaca4958c17eab32877" -dependencies = [ - "getrandom 0.2.3", - "once_cell", - "version_check", -] - [[package]] name = "ahash" version = "0.7.4" @@ -234,12 +217,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "build_const" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4ae4235e6dac0694637c763029ecea1a2ec9e4e06ec2729bd21ba4d9c863eb7" - [[package]] name = "bumpalo" version = "3.7.0" @@ -408,13 +385,19 @@ dependencies = [ [[package]] name = "crc" -version = "1.8.1" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d663548de7f5cca343f1e0a48d14dcfb0e9eb4e079ec58883b7251539fa10aeb" +checksum = "10c2722795460108a7872e1cd933a85d6ec38abc4baecad51028f702da28889f" dependencies = [ - "build_const", + "crc-catalog", ] +[[package]] +name = "crc-catalog" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccaeedb56da03b09f598226e25e80088cb4cd25f316e6e4df7d695f0feeb1403" + [[package]] name = "crc32fast" version = "1.2.1" @@ -741,6 +724,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62007592ac46aa7c2b6416f7deb9a8a8f63a01e0f1d6e1787d5630170db2b63e" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.17" @@ -940,9 +934,6 @@ name = "hashbrown" version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" -dependencies = [ - "ahash 0.4.7", -] [[package]] name = "hashbrown" @@ -950,16 +941,16 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" dependencies = [ - "ahash 0.7.4", + "ahash", ] [[package]] name = "hashlink" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d99cf782f0dc4372d26846bec3de7804ceb5df083c2d4462c0b8d2330e894fa8" +checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf" dependencies = [ - "hashbrown 0.9.1", + "hashbrown 0.11.2", ] [[package]] @@ -1239,9 +1230,9 @@ checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21" [[package]] name = "libsqlite3-sys" -version = "0.20.1" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64d31059f22935e6c31830db5249ba2b7ecd54fd73a9909286f0a67aa55c2fbd" +checksum = "290b64917f8b0cb885d9de0f9959fe1f775d7fa12f1da2db9001c1c8ab60f89d" dependencies = [ "cc", "pkg-config", @@ -2150,9 +2141,9 @@ dependencies = [ [[package]] name = "refinery" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e29bd9c881127d714f4b5b9fdd9ea7651f3dd254922e959a10f6ada620e841da" +checksum = "9f3a3d4976479c5e9a50352cf9117896b581c939e81f4169295cb9b353d30bc8" dependencies = [ "refinery-core", "refinery-macros", @@ -2160,9 +2151,9 @@ dependencies = [ [[package]] name = "refinery-core" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53260bc01535ea10c553ce0fc410609ba2dc0a9f4c9b4503e0af842dd4a6f89d" +checksum = "1eb7989daaee90c44763f644bfa2c3ce46b0a241b9966c1c2be9e65b4918815b" dependencies = [ "async-trait", "cfg-if", @@ -2181,9 +2172,9 @@ dependencies = [ [[package]] name = "refinery-macros" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a79ff62c9b674b62c06a09cc8becf06cbafba9952afa1d8174e7e15f2c4ed43" +checksum = "1e2d142a0c173f7e096ae1297d677cb4bb056fc80f25f31d6b0d2a82da07beee" dependencies = [ "proc-macro2 1.0.29", "quote 1.0.9", @@ -2498,9 +2489,9 @@ dependencies = [ [[package]] name = "rusqlite" -version = "0.24.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38ee71cbab2c827ec0ac24e76f82eca723cee92c509a65f67dee393c25112" +checksum = "57adcf67c8faaf96f3248c2a7b419a0dbc52ebe36ba83dd57fe83827c1ea4eb3" dependencies = [ "bitflags", "fallible-iterator", @@ -2768,9 +2759,9 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.5.1" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2739d54a2ae9fdd0f545cb4e4b5574efb95e2ec71b7f921678e246fb20dcaaf" +checksum = "0e4b94ab0f8c21ee4899b93b06451ef5d965f1a355982ee73684338228498440" dependencies = [ "sqlx-core", "sqlx-macros", @@ -2778,11 +2769,11 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.5.1" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1cad9cae4ca8947eba1a90e8ec7d3c59e7a768e2f120dc9013b669c34a90711" +checksum = "ec28b91a01e1fe286d6ba66f68289a2286df023fc97444e1fd86c2fd6d5dc026" dependencies = [ - "ahash 0.6.3", + "ahash", "atoi", "bitflags", "byteorder", @@ -2794,6 +2785,7 @@ dependencies = [ "either", "futures-channel", "futures-core", + "futures-intrusive", "futures-util", "hashlink", "hex", @@ -2819,9 +2811,9 @@ dependencies = [ [[package]] name = "sqlx-macros" -version = "0.5.1" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01caee2b3935b4efe152f3262afbe51546ce3b1fc27ad61014e1b3cf5f55366e" +checksum = "4dc33c35d54774eed73d54568d47a6ac099aed8af5e1556a017c131be88217d5" dependencies = [ "dotenv", "either", @@ -2842,9 +2834,9 @@ dependencies = [ [[package]] name = "sqlx-rt" -version = "0.3.0" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ce2e16b6774c671cc183e1d202386fdf9cde1e8468c1894a7f2a63eb671c4f4" +checksum = "14302b678d9c76b28f2e60115211e25e0aabc938269991745a169753dc00e35c" dependencies = [ "native-tls", "once_cell", diff --git a/dicebot/Cargo.toml b/dicebot/Cargo.toml index ba446b7..5112afb 100644 --- a/dicebot/Cargo.toml +++ b/dicebot/Cargo.toml @@ -33,7 +33,7 @@ futures = "0.3" html2text = "0.2" phf = { version = "0.8", features = ["macros"] } matrix-sdk = { version = "0.3" } -refinery = { version = "0.5", features = ["rusqlite"]} +refinery = { version = "0.6", features = ["rusqlite"]} barrel = { version = "0.6", features = ["sqlite3"] } tempfile = "3" substring = "1.4" diff --git a/dicebot/src/db/sqlite/rooms.rs b/dicebot/src/db/sqlite/rooms.rs index 2b09a76..9d2636a 100644 --- a/dicebot/src/db/sqlite/rooms.rs +++ b/dicebot/src/db/sqlite/rooms.rs @@ -53,34 +53,41 @@ impl Rooms for Database { mod tests { use crate::db::sqlite::Database; use crate::db::Rooms; + use std::future::Future; - async fn create_db() -> Database { + async fn with_db(f: impl FnOnce(Database) -> Fut) + where + Fut: Future, + { let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) .await .unwrap(); - Database::new(db_path.path().to_str().unwrap()) + let db = Database::new(db_path.path().to_str().unwrap()) .await - .unwrap() + .unwrap(); + + f(db).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn should_process_test() { - let db = create_db().await; + with_db(|db| async move { + let first_check = db + .should_process("myroom", "myeventid") + .await + .expect("should_process failed in first insert"); - let first_check = db - .should_process("myroom", "myeventid") - .await - .expect("should_process failed in first insert"); + assert_eq!(first_check, true); - assert_eq!(first_check, true); + let second_check = db + .should_process("myroom", "myeventid") + .await + .expect("should_process failed in first insert"); - let second_check = db - .should_process("myroom", "myeventid") - .await - .expect("should_process failed in first insert"); - - assert_eq!(second_check, false); + assert_eq!(second_check, false); + }) + .await; } } diff --git a/dicebot/src/db/sqlite/state.rs b/dicebot/src/db/sqlite/state.rs index 31a922d..cc6bed3 100644 --- a/dicebot/src/db/sqlite/state.rs +++ b/dicebot/src/db/sqlite/state.rs @@ -37,54 +37,64 @@ impl DbState for Database { mod tests { use crate::db::sqlite::Database; use crate::db::DbState; + use std::future::Future; - async fn create_db() -> Database { + async fn with_db(f: impl FnOnce(Database) -> Fut) + where + Fut: Future, + { let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) .await .unwrap(); - Database::new(db_path.path().to_str().unwrap()) + let db = Database::new(db_path.path().to_str().unwrap()) .await - .unwrap() + .unwrap(); + + f(db).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn set_and_get_device_id() { - let db = create_db().await; + with_db(|db| async move { + db.set_device_id("device_id") + .await + .expect("Could not set device ID"); - db.set_device_id("device_id") - .await - .expect("Could not set device ID"); + let device_id = db.get_device_id().await.expect("Could not get device ID"); - let device_id = db.get_device_id().await.expect("Could not get device ID"); - - assert!(device_id.is_some()); - assert_eq!(device_id.unwrap(), "device_id"); + assert!(device_id.is_some()); + assert_eq!(device_id.unwrap(), "device_id"); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn no_device_id_set_returns_none() { - let db = create_db().await; - let device_id = db.get_device_id().await.expect("Could not get device ID"); - assert!(device_id.is_none()); + with_db(|db| async move { + let device_id = db.get_device_id().await.expect("Could not get device ID"); + assert!(device_id.is_none()); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn can_update_device_id() { - let db = create_db().await; + with_db(|db| async move { + db.set_device_id("device_id") + .await + .expect("Could not set device ID"); - db.set_device_id("device_id") - .await - .expect("Could not set device ID"); + db.set_device_id("device_id2") + .await + .expect("Could not set device ID"); - db.set_device_id("device_id2") - .await - .expect("Could not set device ID"); + let device_id = db.get_device_id().await.expect("Could not get device ID"); - let device_id = db.get_device_id().await.expect("Could not get device ID"); - - assert!(device_id.is_some()); - assert_eq!(device_id.unwrap(), "device_id2"); + assert!(device_id.is_some()); + assert_eq!(device_id.unwrap(), "device_id2"); + }) + .await; } } diff --git a/dicebot/src/db/sqlite/users.rs b/dicebot/src/db/sqlite/users.rs index 71c07f3..082784b 100644 --- a/dicebot/src/db/sqlite/users.rs +++ b/dicebot/src/db/sqlite/users.rs @@ -91,251 +91,271 @@ mod tests { use crate::db::sqlite::Database; use crate::db::Users; use crate::models::AccountStatus; + use std::future::Future; - async fn create_db() -> Database { + async fn with_db(f: impl FnOnce(Database) -> Fut) + where + Fut: Future, + { let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) .await .unwrap(); - Database::new(db_path.path().to_str().unwrap()) + let db = Database::new(db_path.path().to_str().unwrap()) .await - .unwrap() + .unwrap(); + + f(db).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn create_and_get_full_user_test() { - let db = create_db().await; + with_db(|db| async move { + let insert_result = db + .upsert_user(&User { + username: "myuser".to_string(), + password: Some("abc".to_string()), + account_status: AccountStatus::Registered, + active_room: Some("myroom".to_string()), + }) + .await; - let insert_result = db - .upsert_user(&User { - username: "myuser".to_string(), - password: Some("abc".to_string()), - account_status: AccountStatus::Registered, - active_room: Some("myroom".to_string()), - }) - .await; + assert!(insert_result.is_ok()); - assert!(insert_result.is_ok()); + let user = db + .get_user("myuser") + .await + .expect("User retrieval query failed"); - let user = db - .get_user("myuser") - .await - .expect("User retrieval query failed"); - - assert!(user.is_some()); - let user = user.unwrap(); - assert_eq!(user.username, "myuser"); - assert_eq!(user.password, Some("abc".to_string())); - assert_eq!(user.account_status, AccountStatus::Registered); - assert_eq!(user.active_room, Some("myroom".to_string())); + assert!(user.is_some()); + let user = user.unwrap(); + assert_eq!(user.username, "myuser"); + assert_eq!(user.password, Some("abc".to_string())); + assert_eq!(user.account_status, AccountStatus::Registered); + assert_eq!(user.active_room, Some("myroom".to_string())); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn can_get_user_with_no_state_record() { - let db = create_db().await; + with_db(|db| async move { + let insert_result = db + .upsert_user(&User { + username: "myuser".to_string(), + password: Some("abc".to_string()), + account_status: AccountStatus::AwaitingActivation, + active_room: Some("myroom".to_string()), + }) + .await; - let insert_result = db - .upsert_user(&User { - username: "myuser".to_string(), - password: Some("abc".to_string()), - account_status: AccountStatus::AwaitingActivation, - active_room: Some("myroom".to_string()), - }) - .await; + assert!(insert_result.is_ok()); - assert!(insert_result.is_ok()); + sqlx::query("DELETE FROM user_state") + .execute(&db.conn) + .await + .expect("Could not delete from user_state table."); - sqlx::query("DELETE FROM user_state") - .execute(&db.conn) - .await - .expect("Could not delete from user_state table."); + let user = db + .get_user("myuser") + .await + .expect("User retrieval query failed"); - let user = db - .get_user("myuser") - .await - .expect("User retrieval query failed"); + assert!(user.is_some()); + let user = user.unwrap(); + assert_eq!(user.username, "myuser"); + assert_eq!(user.password, Some("abc".to_string())); + assert_eq!(user.account_status, AccountStatus::AwaitingActivation); - assert!(user.is_some()); - let user = user.unwrap(); - assert_eq!(user.username, "myuser"); - assert_eq!(user.password, Some("abc".to_string())); - assert_eq!(user.account_status, AccountStatus::AwaitingActivation); - - //These should be default values because the state record is missing. - assert_eq!(user.active_room, None); + //These should be default values because the state record is missing. + assert_eq!(user.active_room, None); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn can_insert_without_password() { - let db = create_db().await; + with_db(|db| async move { + let insert_result = db + .upsert_user(&User { + username: "myuser".to_string(), + password: None, + ..Default::default() + }) + .await; - let insert_result = db - .upsert_user(&User { - username: "myuser".to_string(), - password: None, - ..Default::default() - }) - .await; + assert!(insert_result.is_ok()); - assert!(insert_result.is_ok()); + let user = db + .get_user("myuser") + .await + .expect("User retrieval query failed"); - let user = db - .get_user("myuser") - .await - .expect("User retrieval query failed"); - - assert!(user.is_some()); - let user = user.unwrap(); - assert_eq!(user.username, "myuser"); - assert_eq!(user.password, None); + assert!(user.is_some()); + let user = user.unwrap(); + assert_eq!(user.username, "myuser"); + assert_eq!(user.password, None); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn can_insert_without_active_room() { - let db = create_db().await; + with_db(|db| async move { + let insert_result = db + .upsert_user(&User { + username: "myuser".to_string(), + active_room: None, + ..Default::default() + }) + .await; - let insert_result = db - .upsert_user(&User { - username: "myuser".to_string(), - active_room: None, - ..Default::default() - }) - .await; + assert!(insert_result.is_ok()); - assert!(insert_result.is_ok()); + let user = db + .get_user("myuser") + .await + .expect("User retrieval query failed"); - let user = db - .get_user("myuser") - .await - .expect("User retrieval query failed"); - - assert!(user.is_some()); - let user = user.unwrap(); - assert_eq!(user.username, "myuser"); - assert_eq!(user.active_room, None); + assert!(user.is_some()); + let user = user.unwrap(); + assert_eq!(user.username, "myuser"); + assert_eq!(user.active_room, None); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn can_update_user() { - let db = create_db().await; + with_db(|db| async move { + let insert_result1 = db + .upsert_user(&User { + username: "myuser".to_string(), + password: Some("abc".to_string()), + ..Default::default() + }) + .await; - let insert_result1 = db - .upsert_user(&User { - username: "myuser".to_string(), - password: Some("abc".to_string()), - ..Default::default() - }) - .await; + assert!(insert_result1.is_ok()); - assert!(insert_result1.is_ok()); + let insert_result2 = db + .upsert_user(&User { + username: "myuser".to_string(), + password: Some("123".to_string()), + active_room: Some("room".to_string()), + account_status: AccountStatus::AwaitingActivation, + }) + .await; - let insert_result2 = db - .upsert_user(&User { - username: "myuser".to_string(), - password: Some("123".to_string()), - active_room: Some("room".to_string()), - account_status: AccountStatus::AwaitingActivation, - }) - .await; + assert!(insert_result2.is_ok()); - assert!(insert_result2.is_ok()); + let user = db + .get_user("myuser") + .await + .expect("User retrieval query failed"); - let user = db - .get_user("myuser") - .await - .expect("User retrieval query failed"); + assert!(user.is_some()); + let user = user.unwrap(); + assert_eq!(user.username, "myuser"); - assert!(user.is_some()); - let user = user.unwrap(); - assert_eq!(user.username, "myuser"); - - //From second upsert - assert_eq!(user.password, Some("123".to_string())); - assert_eq!(user.active_room, Some("room".to_string())); - assert_eq!(user.account_status, AccountStatus::AwaitingActivation); + //From second upsert + assert_eq!(user.password, Some("123".to_string())); + assert_eq!(user.active_room, Some("room".to_string())); + assert_eq!(user.account_status, AccountStatus::AwaitingActivation); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn can_delete_user() { - let db = create_db().await; + with_db(|db| async move { + let insert_result = db + .upsert_user(&User { + username: "myuser".to_string(), + password: Some("abc".to_string()), + ..Default::default() + }) + .await; - let insert_result = db - .upsert_user(&User { - username: "myuser".to_string(), - password: Some("abc".to_string()), - ..Default::default() - }) - .await; + assert!(insert_result.is_ok()); - assert!(insert_result.is_ok()); + db.delete_user("myuser") + .await + .expect("User deletion query failed"); - db.delete_user("myuser") - .await - .expect("User deletion query failed"); + let user = db + .get_user("myuser") + .await + .expect("User retrieval query failed"); - let user = db - .get_user("myuser") - .await - .expect("User retrieval query failed"); - - assert!(user.is_none()); + assert!(user.is_none()); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn username_not_in_db_returns_none() { - let db = create_db().await; - let user = db - .get_user("does not exist") - .await - .expect("Get user query failure"); + with_db(|db| async move { + let user = db + .get_user("does not exist") + .await + .expect("Get user query failure"); - assert!(user.is_none()); + assert!(user.is_none()); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn authenticate_user_is_some_with_valid_password() { - let db = create_db().await; + with_db(|db| async move { + let insert_result = db + .upsert_user(&User { + username: "myuser".to_string(), + password: Some( + crate::logic::hash_password("abc").expect("password hash error!"), + ), + ..Default::default() + }) + .await; - let insert_result = db - .upsert_user(&User { - username: "myuser".to_string(), - password: Some(crate::logic::hash_password("abc").expect("password hash error!")), - ..Default::default() - }) - .await; + assert!(insert_result.is_ok()); - assert!(insert_result.is_ok()); + let user = db + .authenticate_user("myuser", "abc") + .await + .expect("User retrieval query failed"); - let user = db - .authenticate_user("myuser", "abc") - .await - .expect("User retrieval query failed"); - - assert!(user.is_some()); - let user = user.unwrap(); - assert_eq!(user.username, "myuser"); + assert!(user.is_some()); + let user = user.unwrap(); + assert_eq!(user.username, "myuser"); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn authenticate_user_is_none_with_wrong_password() { - let db = create_db().await; + with_db(|db| async move { + let insert_result = db + .upsert_user(&User { + username: "myuser".to_string(), + password: Some( + crate::logic::hash_password("abc").expect("password hash error!"), + ), + ..Default::default() + }) + .await; - let insert_result = db - .upsert_user(&User { - username: "myuser".to_string(), - password: Some(crate::logic::hash_password("abc").expect("password hash error!")), - ..Default::default() - }) - .await; + assert!(insert_result.is_ok()); - assert!(insert_result.is_ok()); + let user = db + .authenticate_user("myuser", "wrong-password") + .await + .expect("User retrieval query failed"); - let user = db - .authenticate_user("myuser", "wrong-password") - .await - .expect("User retrieval query failed"); - - assert!(user.is_none()); + assert!(user.is_none()); + }) + .await; } } diff --git a/dicebot/src/db/sqlite/variables.rs b/dicebot/src/db/sqlite/variables.rs index 2898bba..43f8ecb 100644 --- a/dicebot/src/db/sqlite/variables.rs +++ b/dicebot/src/db/sqlite/variables.rs @@ -102,143 +102,156 @@ mod tests { use super::*; use crate::db::sqlite::Database; use crate::db::Variables; + use std::future::Future; - async fn create_db() -> Database { + async fn with_db(f: impl FnOnce(Database) -> Fut) + where + Fut: Future, + { let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) .await .unwrap(); - Database::new(db_path.path().to_str().unwrap()) + let db = Database::new(db_path.path().to_str().unwrap()) .await - .unwrap() + .unwrap(); + + f(db).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn set_and_get_variable_test() { - let db = create_db().await; + with_db(|db| async move { + db.set_user_variable("myuser", "myroom", "myvariable", 1) + .await + .expect("Could not set variable"); - db.set_user_variable("myuser", "myroom", "myvariable", 1) - .await - .expect("Could not set variable"); + let value = db + .get_user_variable("myuser", "myroom", "myvariable") + .await + .expect("Could not get variable"); - let value = db - .get_user_variable("myuser", "myroom", "myvariable") - .await - .expect("Could not get variable"); - - assert_eq!(value, 1); + assert_eq!(value, 1); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn get_missing_variable_test() { - let db = create_db().await; + with_db(|db| async move { + let value = db.get_user_variable("myuser", "myroom", "myvariable").await; - let value = db.get_user_variable("myuser", "myroom", "myvariable").await; - - assert!(value.is_err()); - assert!(matches!( - value.err().unwrap(), - DataError::KeyDoesNotExist(_) - )); + assert!(value.is_err()); + assert!(matches!( + value.err().unwrap(), + DataError::KeyDoesNotExist(_) + )); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn get_other_user_variable_test() { - let db = create_db().await; + with_db(|db| async move { + db.set_user_variable("myuser1", "myroom", "myvariable", 1) + .await + .expect("Could not set variable"); - db.set_user_variable("myuser1", "myroom", "myvariable", 1) - .await - .expect("Could not set variable"); + let value = db + .get_user_variable("myuser2", "myroom", "myvariable") + .await; - let value = db - .get_user_variable("myuser2", "myroom", "myvariable") - .await; - - assert!(value.is_err()); - assert!(matches!( - value.err().unwrap(), - DataError::KeyDoesNotExist(_) - )); + assert!(value.is_err()); + assert!(matches!( + value.err().unwrap(), + DataError::KeyDoesNotExist(_) + )); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn count_variables_test() { - let db = create_db().await; + with_db(|db| async move { + for variable_name in &["var1", "var2", "var3"] { + db.set_user_variable("myuser", "myroom", variable_name, 1) + .await + .expect("Could not set variable"); + } - for variable_name in &["var1", "var2", "var3"] { - db.set_user_variable("myuser", "myroom", variable_name, 1) + let count = db + .get_variable_count("myuser", "myroom") .await - .expect("Could not set variable"); - } + .expect("Could not get count."); - let count = db - .get_variable_count("myuser", "myroom") - .await - .expect("Could not get count."); - - assert_eq!(count, 3); + assert_eq!(count, 3); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn count_variables_respects_user_id() { - let db = create_db().await; + with_db(|db| async move { + for variable_name in &["var1", "var2", "var3"] { + db.set_user_variable("different-user", "myroom", variable_name, 1) + .await + .expect("Could not set variable"); + } - for variable_name in &["var1", "var2", "var3"] { - db.set_user_variable("different-user", "myroom", variable_name, 1) + let count = db + .get_variable_count("myuser", "myroom") .await - .expect("Could not set variable"); - } + .expect("Could not get count."); - let count = db - .get_variable_count("myuser", "myroom") - .await - .expect("Could not get count."); - - assert_eq!(count, 0); + assert_eq!(count, 0); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn count_variables_respects_room_id() { - let db = create_db().await; + with_db(|db| async move { + for variable_name in &["var1", "var2", "var3"] { + db.set_user_variable("myuser", "different-room", variable_name, 1) + .await + .expect("Could not set variable"); + } - for variable_name in &["var1", "var2", "var3"] { - db.set_user_variable("myuser", "different-room", variable_name, 1) + let count = db + .get_variable_count("myuser", "myroom") .await - .expect("Could not set variable"); - } + .expect("Could not get count."); - let count = db - .get_variable_count("myuser", "myroom") - .await - .expect("Could not get count."); - - assert_eq!(count, 0); + assert_eq!(count, 0); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn delete_variable_test() { - let db = create_db().await; + with_db(|db| async move { + for variable_name in &["var1", "var2", "var3"] { + db.set_user_variable("myuser", "myroom", variable_name, 1) + .await + .expect("Could not set variable"); + } - for variable_name in &["var1", "var2", "var3"] { - db.set_user_variable("myuser", "myroom", variable_name, 1) + db.delete_user_variable("myuser", "myroom", "var1") .await - .expect("Could not set variable"); - } + .expect("Could not delete variable."); - db.delete_user_variable("myuser", "myroom", "var1") - .await - .expect("Could not delete variable."); + let count = db + .get_variable_count("myuser", "myroom") + .await + .expect("Could not get count"); - let count = db - .get_variable_count("myuser", "myroom") - .await - .expect("Could not get count"); + assert_eq!(count, 2); - assert_eq!(count, 2); - - let var1 = db.get_user_variable("myuser", "myroom", "var1").await; - assert!(var1.is_err()); - assert!(matches!(var1.err().unwrap(), DataError::KeyDoesNotExist(_))); + let var1 = db.get_user_variable("myuser", "myroom", "var1").await; + assert!(var1.is_err()); + assert!(matches!(var1.err().unwrap(), DataError::KeyDoesNotExist(_))); + }) + .await; } } diff --git a/dicebot/src/logic.rs b/dicebot/src/logic.rs index 70fa796..3ec8e5b 100644 --- a/dicebot/src/logic.rs +++ b/dicebot/src/logic.rs @@ -71,53 +71,61 @@ mod tests { use super::*; use crate::db::Users; use crate::models::{AccountStatus, User}; + use std::future::Future; - async fn create_db() -> Database { + async fn with_db(f: impl FnOnce(Database) -> Fut) + where + Fut: Future, + { let db_path = tempfile::NamedTempFile::new_in(".").unwrap(); crate::db::sqlite::migrator::migrate(db_path.path().to_str().unwrap()) .await .unwrap(); - Database::new(db_path.path().to_str().unwrap()) + let db = Database::new(db_path.path().to_str().unwrap()) .await - .unwrap() + .unwrap(); + + f(db).await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn get_account_no_user_exists() { - let db = create_db().await; + with_db(|db| async move { + let account = get_account(&db, "@test:example.com") + .await + .expect("Account retrieval didn't work"); - let account = get_account(&db, "@test:example.com") - .await - .expect("Account retrieval didn't work"); + assert!(matches!(account, Account::Transient(_))); - assert!(matches!(account, Account::Transient(_))); - - let user = account.transient_user().unwrap(); - assert_eq!(user.username, "@test:example.com"); + let user = account.transient_user().unwrap(); + assert_eq!(user.username, "@test:example.com"); + }) + .await; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn get_or_create_user_when_user_exists() { - let db = create_db().await; + with_db(|db| async move { + let user = User { + username: "myuser".to_string(), + password: Some("abc".to_string()), + account_status: AccountStatus::Registered, + active_room: Some("myroom".to_string()), + }; - let user = User { - username: "myuser".to_string(), - password: Some("abc".to_string()), - account_status: AccountStatus::Registered, - active_room: Some("myroom".to_string()), - }; + let insert_result = db.upsert_user(&user).await; + assert!(insert_result.is_ok()); - let insert_result = db.upsert_user(&user).await; - assert!(insert_result.is_ok()); + let account = get_account(&db, "myuser") + .await + .expect("Account retrieval did not work"); - let account = get_account(&db, "myuser") - .await - .expect("Account retrieval did not work"); + assert!(matches!(account, Account::Registered(_))); - assert!(matches!(account, Account::Registered(_))); - - let user_again = account.registered_user().unwrap(); - assert_eq!(user, *user_again); + let user_again = account.registered_user().unwrap(); + assert_eq!(user, *user_again); + }) + .await; } }