Make a trait to allow dependence injection

This commit is contained in:
Matthew Gordon 2024-02-26 13:12:31 -04:00
parent af2fb99ad9
commit 1fab78ff96
7 changed files with 84 additions and 41 deletions

View File

@ -22,7 +22,9 @@ use {
serde::Deserialize, serde::Deserialize,
}; };
pub fn routes() -> Router<Database> { use super::app::AppState;
pub fn routes<D: Database>() -> Router<AppState<D>> {
Router::new() Router::new()
.route("/", get(root)) .route("/", get(root))
.route("/create_first_admin_user", get(get_create_first_admin_user)) .route("/create_first_admin_user", get(get_create_first_admin_user))
@ -46,7 +48,10 @@ struct IndexTemplate<'a> {
admin_user_name: &'a str, admin_user_name: &'a str,
} }
async fn check_jwt(db: &Database, cookie_jar: &CookieJar) -> Result<AuthenticatedAdminUser, Error> { async fn check_jwt<D: Database>(
db: &D,
cookie_jar: &CookieJar,
) -> Result<AuthenticatedAdminUser, Error> {
match authenticate_user_with_jwt( match authenticate_user_with_jwt(
db, db,
cookie_jar cookie_jar
@ -66,7 +71,10 @@ async fn check_jwt(db: &Database, cookie_jar: &CookieJar) -> Result<Authenticate
} }
#[tracing::instrument] #[tracing::instrument]
async fn root(cookie_jar: CookieJar, State(db): State<Database>) -> Result<Response, Error> { async fn root<D: Database>(
cookie_jar: CookieJar,
State(AppState { db, .. }): State<AppState<D>>,
) -> Result<Response, Error> {
Ok(if !db.has_admin_users().await? { Ok(if !db.has_admin_users().await? {
Redirect::temporary("admin/create_first_admin_user").into_response() Redirect::temporary("admin/create_first_admin_user").into_response()
} else { } else {
@ -91,9 +99,9 @@ struct CreateFirstUserParameters {
} }
#[tracing::instrument] #[tracing::instrument]
async fn post_create_first_admin_user( async fn post_create_first_admin_user<D: Database>(
cookie_jar: CookieJar, cookie_jar: CookieJar,
State(db): State<Database>, State(AppState::<D> { db, .. }): State<AppState<D>>,
Form(params): Form<CreateFirstUserParameters>, Form(params): Form<CreateFirstUserParameters>,
) -> Result<(CookieJar, FirstLoginTemplate), Error> { ) -> Result<(CookieJar, FirstLoginTemplate), Error> {
let user = db let user = db
@ -109,9 +117,8 @@ async fn post_create_first_admin_user(
"Could not authenticate newly-created user.".to_string(), "Could not authenticate newly-created user.".to_string(),
))?; ))?;
Ok(( Ok((
cookie_jar.add( cookie_jar
Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict), .add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),
),
FirstLoginTemplate {}, FirstLoginTemplate {},
)) ))
} }

View File

@ -4,7 +4,12 @@ use {
axum::{routing::get, Router}, axum::{routing::get, Router},
}; };
pub fn routes() -> Router<Database> { #[derive(Clone)]
pub struct AppState<D: Database> {
pub db: D,
}
pub fn routes<D:Database>() -> Router<AppState<D>> {
Router::new().route("/", get(root)) Router::new().route("/", get(root))
} }

View File

@ -141,7 +141,7 @@ pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
/// Given JWT string created by [create_jwt_for_user()], check if the /// Given JWT string created by [create_jwt_for_user()], check if the
/// JWT is valid and return an [AuthenticatedUser] if it is. /// JWT is valid and return an [AuthenticatedUser] if it is.
#[tracing::instrument] #[tracing::instrument]
pub async fn authenticate_user_with_jwt(db: &Database, jwt: &str) -> Result<ParsedJwt> { pub async fn authenticate_user_with_jwt<D:Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() { if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
let mut mac = mac(); let mut mac = mac();
mac.update(format!("{}.{}", header, payload).as_bytes()); mac.update(format!("{}.{}", header, payload).as_bytes());

View File

@ -81,8 +81,8 @@ impl From<Password> for db::PasswordHash {
} }
} }
pub async fn authenticate_user_with_password( pub async fn authenticate_user_with_password<D: Database>(
db: &Database, db: &D,
user: db::User, user: db::User,
supplied_password: &str, supplied_password: &str,
) -> Result<Option<AuthenticatedUser>, AuthenticationError> { ) -> Result<Option<AuthenticatedUser>, AuthenticationError> {
@ -94,8 +94,8 @@ pub async fn authenticate_user_with_password(
}) })
} }
pub async fn check_if_user_is_admin( pub async fn check_if_user_is_admin<D: Database>(
db: &Database, db: &D,
user: &AuthenticatedUser, user: &AuthenticatedUser,
) -> Result<Option<AuthenticatedAdminUser>, db::Error> { ) -> Result<Option<AuthenticatedAdminUser>, db::Error> {
if db.is_user_admin(user).await? { if db.is_user_admin(user).await? {

View File

@ -12,7 +12,7 @@
//! needed to that the database schema version (as returned by //! needed to that the database schema version (as returned by
//! [get_db_version()]) matches [CURRENT_VERSION]. //! [get_db_version()]) matches [CURRENT_VERSION].
use super::{Database, Error}; use super::{Error, PostgresDatabase};
/// Defines a database schema migration /// Defines a database schema migration
#[derive(Debug)] #[derive(Debug)]
@ -72,8 +72,8 @@ static MIGRATIONS: &[Migration] = &[
ALTER TABLE users DROP COLUMN IF EXISTS password_salt; ALTER TABLE users DROP COLUMN IF EXISTS password_salt;
ALTER TABLE users DROP COLUMN IF EXISTS password_hash; ALTER TABLE users DROP COLUMN IF EXISTS password_hash;
ALTER TABLE users ADD COLUMN password TEXT NOT NULL;"#, ALTER TABLE users ADD COLUMN password TEXT NOT NULL;"#,
down: "ALTER TABLE users DROP COLUMN password;" down: "ALTER TABLE users DROP COLUMN password;",
} },
]; ];
/// The current schema version. Normally this will be the /// The current schema version. Normally this will be the
@ -92,7 +92,7 @@ static CURRENT_VERSION: i32 = 2;
/// E.g. If [CURRENT_VERSION] is 10 but the database is on version 4, /// E.g. If [CURRENT_VERSION] is 10 but the database is on version 4,
/// then running this function will apply [Migration] 5 to bring the /// then running this function will apply [Migration] 5 to bring the
/// database up to schema version 5. /// database up to schema version 5.
async fn migrate_up(db: &Database) -> Result<i32, Error> { async fn migrate_up(db: &PostgresDatabase) -> Result<i32, Error> {
let current_version = get_db_version(db).await?; let current_version = get_db_version(db).await?;
if let Some(migration) = MIGRATIONS.iter().find(|m| m.version > current_version) { if let Some(migration) = MIGRATIONS.iter().find(|m| m.version > current_version) {
let client = db.connection_pool.get().await?; let client = db.connection_pool.get().await?;
@ -114,7 +114,7 @@ async fn migrate_up(db: &Database) -> Result<i32, Error> {
/// Revert back to the previous schema version by running the /// Revert back to the previous schema version by running the
/// [down](Migration::down) SQL of the current `Migration` /// [down](Migration::down) SQL of the current `Migration`
async fn migrate_down(db: &Database) -> Result<i32, Error> { async fn migrate_down(db: &PostgresDatabase) -> Result<i32, Error> {
let current_version = get_db_version(db).await?; let current_version = get_db_version(db).await?;
let mut migration_iter = MIGRATIONS let mut migration_iter = MIGRATIONS
.iter() .iter()
@ -141,7 +141,7 @@ async fn migrate_down(db: &Database) -> Result<i32, Error> {
/// Apply whatever migrations are necessary to bring the database /// Apply whatever migrations are necessary to bring the database
/// schema to the same version is [CURRENT_VERSION]. /// schema to the same version is [CURRENT_VERSION].
pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> { pub async fn migrate_to_current_version(db: &PostgresDatabase) -> Result<(), Error> {
migrate_to_version(db, CURRENT_VERSION).await migrate_to_version(db, CURRENT_VERSION).await
} }
@ -149,7 +149,7 @@ pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> {
/// schema to the same version as `target_version`. /// schema to the same version as `target_version`.
/// ///
/// This may migrate up or down as required. /// This may migrate up or down as required.
async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Error> { async fn migrate_to_version(db: &PostgresDatabase, target_version: i32) -> Result<(), Error> {
let mut version = get_db_version(db).await?; let mut version = get_db_version(db).await?;
while version != target_version { while version != target_version {
if version < target_version { if version < target_version {
@ -162,7 +162,7 @@ async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Er
} }
/// Get the current schema version of the database. /// Get the current schema version of the database.
pub async fn get_db_version(db: &Database) -> Result<i32, Error> { pub async fn get_db_version(db: &PostgresDatabase) -> Result<i32, Error> {
let client = db.connection_pool.get().await?; let client = db.connection_pool.get().await?;
client client
.execute( .execute(
@ -195,12 +195,12 @@ pub async fn get_db_version(db: &Database) -> Result<i32, Error> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::Database; use super::super::PostgresDatabase;
use super::*; use super::*;
async fn test_db() -> Database { async fn test_db() -> PostgresDatabase {
let url = std::env::var("LOCALITY_TEST_DATABASE_URL").unwrap(); let url = std::env::var("LOCALITY_TEST_DATABASE_URL").unwrap();
Database::new(&url) PostgresDatabase::new(&url)
.unwrap() .unwrap()
.connection_pool .connection_pool
.get() .get()
@ -217,7 +217,7 @@ mod tests {
) )
.await .await
.unwrap(); .unwrap();
Database::new(&url).unwrap() PostgresDatabase::new(&url).unwrap()
} }
#[test] #[test]

View File

@ -8,7 +8,7 @@ mod migrations;
use { use {
deadpool_postgres::{CreatePoolError, Pool, Runtime}, deadpool_postgres::{CreatePoolError, Pool, Runtime},
serde::{Deserialize, Serialize}, serde::{Deserialize, Serialize},
std::ops::Deref, std::{future::Future, ops::Deref},
tokio_postgres::NoTls, tokio_postgres::NoTls,
tracing::error, tracing::error,
}; };
@ -84,11 +84,40 @@ impl From<tokio_postgres::Error> for Error {
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static {
/// Run migrations as needed to ensure the database schema version
/// match the one used by the current version of the application.
fn migrate_to_current_version(&self) -> impl Future<Output = Result<()>> + Send;
fn get_client(&self) -> impl Future<Output = Result<deadpool_postgres::Client>> + Send;
fn has_admin_users(&self) -> impl Future<Output = Result<bool>> + Send;
fn create_user(
&self,
real_name: &str,
email: &str,
password: &PasswordHash,
) -> impl Future<Output = Result<User>> + Send;
fn get_password_for_user(
&self,
user: &User,
) -> impl Future<Output = Result<PasswordHash>> + Send;
fn create_first_admin_user(
&self,
real_name: &str,
email: &str,
password: &PasswordHash,
) -> impl Future<Output = Result<User>> + Send;
fn is_user_admin(&self, user: &User) -> impl Future<Output = Result<bool>> + Send;
fn get_user_with_id(
&self,
user_id: UserId,
) -> impl Future<Output = Result<Option<User>>> + Send;
}
/// Object that manages the database. /// Object that manages the database.
/// ///
/// All database access happens through this struct. /// All database access happens through this struct.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Database { pub struct PostgresDatabase {
connection_pool: Pool, connection_pool: Pool,
} }
@ -118,20 +147,22 @@ impl Deref for PasswordHash {
} }
} }
impl Database { impl PostgresDatabase {
/// Create a connection pool and return the [Database]. /// Create a connection pool and return the [Database].
pub fn new(connection_url: &str) -> InitialisationResult<Database> { pub fn new(connection_url: &str) -> InitialisationResult<PostgresDatabase> {
let mut config = deadpool_postgres::Config::new(); let mut config = deadpool_postgres::Config::new();
config.url = Some(connection_url.to_string()); config.url = Some(connection_url.to_string());
let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?; let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?;
Ok(Database { Ok(PostgresDatabase {
connection_pool: pg_pool, connection_pool: pg_pool,
}) })
} }
}
impl Database for PostgresDatabase {
/// Run migrations as needed to ensure the database schema version /// Run migrations as needed to ensure the database schema version
/// match the one used by the current version of the application. /// match the one used by the current version of the application.
pub async fn migrate_to_current_version(&self) -> Result<()> { async fn migrate_to_current_version(&self) -> Result<()> {
migrations::migrate_to_current_version(self).await migrations::migrate_to_current_version(self).await
} }
@ -140,7 +171,7 @@ impl Database {
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn has_admin_users(&self) -> Result<bool> { async fn has_admin_users(&self) -> Result<bool> {
let client = self.get_client().await?; let client = self.get_client().await?;
Ok(client Ok(client
.query_one("SELECT EXISTS(SELECT 1 FROM admin_users);", &[]) .query_one("SELECT EXISTS(SELECT 1 FROM admin_users);", &[])
@ -149,7 +180,7 @@ impl Database {
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn create_user( async fn create_user(
&self, &self,
real_name: &str, real_name: &str,
email: &str, email: &str,
@ -174,7 +205,7 @@ impl Database {
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> { async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
let client = self.get_client().await?; let client = self.get_client().await?;
let row = client let row = client
.query_one("SELECT password FROM users WHERE id = $1;", &[&user.id.0]) .query_one("SELECT password FROM users WHERE id = $1;", &[&user.id.0])
@ -183,7 +214,7 @@ impl Database {
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn create_first_admin_user( async fn create_first_admin_user(
&self, &self,
real_name: &str, real_name: &str,
email: &str, email: &str,
@ -204,7 +235,7 @@ impl Database {
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn is_user_admin(&self, user: &User) -> Result<bool> { async fn is_user_admin(&self, user: &User) -> Result<bool> {
Ok(self Ok(self
.get_client() .get_client()
.await? .await?
@ -214,7 +245,7 @@ impl Database {
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> { async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
Ok(self Ok(self
.get_client() .get_client()
.await? .await?

View File

@ -12,7 +12,7 @@ mod db;
mod error; mod error;
use config::get_config; use config::get_config;
use db::Database; use db::{Database, PostgresDatabase};
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
@ -49,13 +49,13 @@ async fn locality_main() -> Result<(), Error> {
.finish(); .finish();
tracing::subscriber::set_global_default(subscriber)?; tracing::subscriber::set_global_default(subscriber)?;
let db_pool = Database::new(&config.database_url)?; let db_pool = PostgresDatabase::new(&config.database_url)?;
db_pool.migrate_to_current_version().await.unwrap(); db_pool.migrate_to_current_version().await.unwrap();
let app = app::routes() let app = app::routes()
.nest("/admin", admin::routes()) .nest("/admin", admin::routes())
.with_state(db_pool) .with_state(app::AppState { db: db_pool })
.nest_service("/static", ServeDir::new(&config.static_file_path)) .nest_service("/static", ServeDir::new(&config.static_file_path))
.layer(TraceLayer::new_for_http()); .layer(TraceLayer::new_for_http());