From 1fab78ff96f0341e424d51ebda1121d6b6b75764 Mon Sep 17 00:00:00 2001 From: Matthew Gordon Date: Mon, 26 Feb 2024 13:12:31 -0400 Subject: [PATCH] Make a trait to allow dependence injection --- src/admin.rs | 23 ++++++++++------ src/app.rs | 7 ++++- src/authentication/jwt.rs | 2 +- src/authentication/mod.rs | 8 +++--- src/db/migrations.rs | 24 ++++++++--------- src/db/mod.rs | 55 ++++++++++++++++++++++++++++++--------- src/main.rs | 6 ++--- 7 files changed, 84 insertions(+), 41 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index f14c3c7..cb81775 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -22,7 +22,9 @@ use { serde::Deserialize, }; -pub fn routes() -> Router { +use super::app::AppState; + +pub fn routes() -> Router> { Router::new() .route("/", get(root)) .route("/create_first_admin_user", get(get_create_first_admin_user)) @@ -46,7 +48,10 @@ struct IndexTemplate<'a> { admin_user_name: &'a str, } -async fn check_jwt(db: &Database, cookie_jar: &CookieJar) -> Result { +async fn check_jwt( + db: &D, + cookie_jar: &CookieJar, +) -> Result { match authenticate_user_with_jwt( db, cookie_jar @@ -66,7 +71,10 @@ async fn check_jwt(db: &Database, cookie_jar: &CookieJar) -> Result) -> Result { +async fn root( + cookie_jar: CookieJar, + State(AppState { db, .. }): State>, +) -> Result { Ok(if !db.has_admin_users().await? { Redirect::temporary("admin/create_first_admin_user").into_response() } else { @@ -91,9 +99,9 @@ struct CreateFirstUserParameters { } #[tracing::instrument] -async fn post_create_first_admin_user( +async fn post_create_first_admin_user( cookie_jar: CookieJar, - State(db): State, + State(AppState:: { db, .. }): State>, Form(params): Form, ) -> Result<(CookieJar, FirstLoginTemplate), Error> { let user = db @@ -109,9 +117,8 @@ async fn post_create_first_admin_user( "Could not authenticate newly-created user.".to_string(), ))?; Ok(( - cookie_jar.add( - Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict), - ), + cookie_jar + .add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)), FirstLoginTemplate {}, )) } diff --git a/src/app.rs b/src/app.rs index b8fa6a0..724cc11 100644 --- a/src/app.rs +++ b/src/app.rs @@ -4,7 +4,12 @@ use { axum::{routing::get, Router}, }; -pub fn routes() -> Router { +#[derive(Clone)] +pub struct AppState { + pub db: D, +} + +pub fn routes() -> Router> { Router::new().route("/", get(root)) } diff --git a/src/authentication/jwt.rs b/src/authentication/jwt.rs index 82ced9b..a631dda 100644 --- a/src/authentication/jwt.rs +++ b/src/authentication/jwt.rs @@ -141,7 +141,7 @@ pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result { /// Given JWT string created by [create_jwt_for_user()], check if the /// JWT is valid and return an [AuthenticatedUser] if it is. #[tracing::instrument] -pub async fn authenticate_user_with_jwt(db: &Database, jwt: &str) -> Result { +pub async fn authenticate_user_with_jwt(db: &D, jwt: &str) -> Result { if let [header, payload, signature] = jwt.split('.').collect::>().as_slice() { let mut mac = mac(); mac.update(format!("{}.{}", header, payload).as_bytes()); diff --git a/src/authentication/mod.rs b/src/authentication/mod.rs index b124534..6f9155d 100644 --- a/src/authentication/mod.rs +++ b/src/authentication/mod.rs @@ -81,8 +81,8 @@ impl From for db::PasswordHash { } } -pub async fn authenticate_user_with_password( - db: &Database, +pub async fn authenticate_user_with_password( + db: &D, user: db::User, supplied_password: &str, ) -> Result, AuthenticationError> { @@ -94,8 +94,8 @@ pub async fn authenticate_user_with_password( }) } -pub async fn check_if_user_is_admin( - db: &Database, +pub async fn check_if_user_is_admin( + db: &D, user: &AuthenticatedUser, ) -> Result, db::Error> { if db.is_user_admin(user).await? { diff --git a/src/db/migrations.rs b/src/db/migrations.rs index 1b2a16b..1784614 100644 --- a/src/db/migrations.rs +++ b/src/db/migrations.rs @@ -12,7 +12,7 @@ //! needed to that the database schema version (as returned by //! [get_db_version()]) matches [CURRENT_VERSION]. -use super::{Database, Error}; +use super::{Error, PostgresDatabase}; /// Defines a database schema migration #[derive(Debug)] @@ -72,8 +72,8 @@ static MIGRATIONS: &[Migration] = &[ ALTER TABLE users DROP COLUMN IF EXISTS password_salt; ALTER TABLE users DROP COLUMN IF EXISTS password_hash; ALTER TABLE users ADD COLUMN password TEXT NOT NULL;"#, - down: "ALTER TABLE users DROP COLUMN password;" - } + down: "ALTER TABLE users DROP COLUMN password;", + }, ]; /// The current schema version. Normally this will be the @@ -92,7 +92,7 @@ static CURRENT_VERSION: i32 = 2; /// E.g. If [CURRENT_VERSION] is 10 but the database is on version 4, /// then running this function will apply [Migration] 5 to bring the /// database up to schema version 5. -async fn migrate_up(db: &Database) -> Result { +async fn migrate_up(db: &PostgresDatabase) -> Result { let current_version = get_db_version(db).await?; if let Some(migration) = MIGRATIONS.iter().find(|m| m.version > current_version) { let client = db.connection_pool.get().await?; @@ -114,7 +114,7 @@ async fn migrate_up(db: &Database) -> Result { /// Revert back to the previous schema version by running the /// [down](Migration::down) SQL of the current `Migration` -async fn migrate_down(db: &Database) -> Result { +async fn migrate_down(db: &PostgresDatabase) -> Result { let current_version = get_db_version(db).await?; let mut migration_iter = MIGRATIONS .iter() @@ -141,7 +141,7 @@ async fn migrate_down(db: &Database) -> Result { /// Apply whatever migrations are necessary to bring the database /// schema to the same version is [CURRENT_VERSION]. -pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> { +pub async fn migrate_to_current_version(db: &PostgresDatabase) -> Result<(), Error> { migrate_to_version(db, CURRENT_VERSION).await } @@ -149,7 +149,7 @@ pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> { /// schema to the same version as `target_version`. /// /// This may migrate up or down as required. -async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Error> { +async fn migrate_to_version(db: &PostgresDatabase, target_version: i32) -> Result<(), Error> { let mut version = get_db_version(db).await?; while version != target_version { if version < target_version { @@ -162,7 +162,7 @@ async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Er } /// Get the current schema version of the database. -pub async fn get_db_version(db: &Database) -> Result { +pub async fn get_db_version(db: &PostgresDatabase) -> Result { let client = db.connection_pool.get().await?; client .execute( @@ -195,12 +195,12 @@ pub async fn get_db_version(db: &Database) -> Result { #[cfg(test)] mod tests { - use super::super::Database; + use super::super::PostgresDatabase; use super::*; - async fn test_db() -> Database { + async fn test_db() -> PostgresDatabase { let url = std::env::var("LOCALITY_TEST_DATABASE_URL").unwrap(); - Database::new(&url) + PostgresDatabase::new(&url) .unwrap() .connection_pool .get() @@ -217,7 +217,7 @@ mod tests { ) .await .unwrap(); - Database::new(&url).unwrap() + PostgresDatabase::new(&url).unwrap() } #[test] diff --git a/src/db/mod.rs b/src/db/mod.rs index 0981534..6266a64 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -8,7 +8,7 @@ mod migrations; use { deadpool_postgres::{CreatePoolError, Pool, Runtime}, serde::{Deserialize, Serialize}, - std::ops::Deref, + std::{future::Future, ops::Deref}, tokio_postgres::NoTls, tracing::error, }; @@ -84,11 +84,40 @@ impl From for Error { pub type Result = std::result::Result; +pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static { + /// Run migrations as needed to ensure the database schema version + /// match the one used by the current version of the application. + fn migrate_to_current_version(&self) -> impl Future> + Send; + fn get_client(&self) -> impl Future> + Send; + fn has_admin_users(&self) -> impl Future> + Send; + fn create_user( + &self, + real_name: &str, + email: &str, + password: &PasswordHash, + ) -> impl Future> + Send; + fn get_password_for_user( + &self, + user: &User, + ) -> impl Future> + Send; + fn create_first_admin_user( + &self, + real_name: &str, + email: &str, + password: &PasswordHash, + ) -> impl Future> + Send; + fn is_user_admin(&self, user: &User) -> impl Future> + Send; + fn get_user_with_id( + &self, + user_id: UserId, + ) -> impl Future>> + Send; +} + /// Object that manages the database. /// /// All database access happens through this struct. #[derive(Clone, Debug)] -pub struct Database { +pub struct PostgresDatabase { connection_pool: Pool, } @@ -118,20 +147,22 @@ impl Deref for PasswordHash { } } -impl Database { +impl PostgresDatabase { /// Create a connection pool and return the [Database]. - pub fn new(connection_url: &str) -> InitialisationResult { + pub fn new(connection_url: &str) -> InitialisationResult { let mut config = deadpool_postgres::Config::new(); config.url = Some(connection_url.to_string()); let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?; - Ok(Database { + Ok(PostgresDatabase { connection_pool: pg_pool, }) } +} +impl Database for PostgresDatabase { /// Run migrations as needed to ensure the database schema version /// match the one used by the current version of the application. - pub async fn migrate_to_current_version(&self) -> Result<()> { + async fn migrate_to_current_version(&self) -> Result<()> { migrations::migrate_to_current_version(self).await } @@ -140,7 +171,7 @@ impl Database { } #[tracing::instrument] - pub async fn has_admin_users(&self) -> Result { + async fn has_admin_users(&self) -> Result { let client = self.get_client().await?; Ok(client .query_one("SELECT EXISTS(SELECT 1 FROM admin_users);", &[]) @@ -149,7 +180,7 @@ impl Database { } #[tracing::instrument] - pub async fn create_user( + async fn create_user( &self, real_name: &str, email: &str, @@ -174,7 +205,7 @@ impl Database { } #[tracing::instrument] - pub async fn get_password_for_user(&self, user: &User) -> Result { + async fn get_password_for_user(&self, user: &User) -> Result { let client = self.get_client().await?; let row = client .query_one("SELECT password FROM users WHERE id = $1;", &[&user.id.0]) @@ -183,7 +214,7 @@ impl Database { } #[tracing::instrument] - pub async fn create_first_admin_user( + async fn create_first_admin_user( &self, real_name: &str, email: &str, @@ -204,7 +235,7 @@ impl Database { } #[tracing::instrument] - pub async fn is_user_admin(&self, user: &User) -> Result { + async fn is_user_admin(&self, user: &User) -> Result { Ok(self .get_client() .await? @@ -214,7 +245,7 @@ impl Database { } #[tracing::instrument] - pub async fn get_user_with_id(&self, user_id: UserId) -> Result> { + async fn get_user_with_id(&self, user_id: UserId) -> Result> { Ok(self .get_client() .await? diff --git a/src/main.rs b/src/main.rs index 74b1675..53fa68f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ mod db; mod error; use config::get_config; -use db::Database; +use db::{Database, PostgresDatabase}; #[derive(Error, Debug)] pub enum Error { @@ -49,13 +49,13 @@ async fn locality_main() -> Result<(), Error> { .finish(); tracing::subscriber::set_global_default(subscriber)?; - let db_pool = Database::new(&config.database_url)?; + let db_pool = PostgresDatabase::new(&config.database_url)?; db_pool.migrate_to_current_version().await.unwrap(); let app = app::routes() .nest("/admin", admin::routes()) - .with_state(db_pool) + .with_state(app::AppState { db: db_pool }) .nest_service("/static", ServeDir::new(&config.static_file_path)) .layer(TraceLayer::new_for_http());