Make a trait to allow dependence injection

This commit is contained in:
Matthew Gordon 2024-02-26 13:12:31 -04:00
parent af2fb99ad9
commit 1fab78ff96
7 changed files with 84 additions and 41 deletions

View File

@ -22,7 +22,9 @@ use {
serde::Deserialize,
};
pub fn routes() -> Router<Database> {
use super::app::AppState;
pub fn routes<D: Database>() -> Router<AppState<D>> {
Router::new()
.route("/", get(root))
.route("/create_first_admin_user", get(get_create_first_admin_user))
@ -46,7 +48,10 @@ struct IndexTemplate<'a> {
admin_user_name: &'a str,
}
async fn check_jwt(db: &Database, cookie_jar: &CookieJar) -> Result<AuthenticatedAdminUser, Error> {
async fn check_jwt<D: Database>(
db: &D,
cookie_jar: &CookieJar,
) -> Result<AuthenticatedAdminUser, Error> {
match authenticate_user_with_jwt(
db,
cookie_jar
@ -66,7 +71,10 @@ async fn check_jwt(db: &Database, cookie_jar: &CookieJar) -> Result<Authenticate
}
#[tracing::instrument]
async fn root(cookie_jar: CookieJar, State(db): State<Database>) -> Result<Response, Error> {
async fn root<D: Database>(
cookie_jar: CookieJar,
State(AppState { db, .. }): State<AppState<D>>,
) -> Result<Response, Error> {
Ok(if !db.has_admin_users().await? {
Redirect::temporary("admin/create_first_admin_user").into_response()
} else {
@ -91,9 +99,9 @@ struct CreateFirstUserParameters {
}
#[tracing::instrument]
async fn post_create_first_admin_user(
async fn post_create_first_admin_user<D: Database>(
cookie_jar: CookieJar,
State(db): State<Database>,
State(AppState::<D> { db, .. }): State<AppState<D>>,
Form(params): Form<CreateFirstUserParameters>,
) -> Result<(CookieJar, FirstLoginTemplate), Error> {
let user = db
@ -109,9 +117,8 @@ async fn post_create_first_admin_user(
"Could not authenticate newly-created user.".to_string(),
))?;
Ok((
cookie_jar.add(
Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict),
),
cookie_jar
.add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),
FirstLoginTemplate {},
))
}

View File

@ -4,7 +4,12 @@ use {
axum::{routing::get, Router},
};
pub fn routes() -> Router<Database> {
#[derive(Clone)]
pub struct AppState<D: Database> {
pub db: D,
}
pub fn routes<D:Database>() -> Router<AppState<D>> {
Router::new().route("/", get(root))
}

View File

@ -141,7 +141,7 @@ pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
/// Given JWT string created by [create_jwt_for_user()], check if the
/// JWT is valid and return an [AuthenticatedUser] if it is.
#[tracing::instrument]
pub async fn authenticate_user_with_jwt(db: &Database, jwt: &str) -> Result<ParsedJwt> {
pub async fn authenticate_user_with_jwt<D:Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
let mut mac = mac();
mac.update(format!("{}.{}", header, payload).as_bytes());

View File

@ -81,8 +81,8 @@ impl From<Password> for db::PasswordHash {
}
}
pub async fn authenticate_user_with_password(
db: &Database,
pub async fn authenticate_user_with_password<D: Database>(
db: &D,
user: db::User,
supplied_password: &str,
) -> Result<Option<AuthenticatedUser>, AuthenticationError> {
@ -94,8 +94,8 @@ pub async fn authenticate_user_with_password(
})
}
pub async fn check_if_user_is_admin(
db: &Database,
pub async fn check_if_user_is_admin<D: Database>(
db: &D,
user: &AuthenticatedUser,
) -> Result<Option<AuthenticatedAdminUser>, db::Error> {
if db.is_user_admin(user).await? {

View File

@ -12,7 +12,7 @@
//! needed to that the database schema version (as returned by
//! [get_db_version()]) matches [CURRENT_VERSION].
use super::{Database, Error};
use super::{Error, PostgresDatabase};
/// Defines a database schema migration
#[derive(Debug)]
@ -72,8 +72,8 @@ static MIGRATIONS: &[Migration] = &[
ALTER TABLE users DROP COLUMN IF EXISTS password_salt;
ALTER TABLE users DROP COLUMN IF EXISTS password_hash;
ALTER TABLE users ADD COLUMN password TEXT NOT NULL;"#,
down: "ALTER TABLE users DROP COLUMN password;"
}
down: "ALTER TABLE users DROP COLUMN password;",
},
];
/// The current schema version. Normally this will be the
@ -92,7 +92,7 @@ static CURRENT_VERSION: i32 = 2;
/// E.g. If [CURRENT_VERSION] is 10 but the database is on version 4,
/// then running this function will apply [Migration] 5 to bring the
/// database up to schema version 5.
async fn migrate_up(db: &Database) -> Result<i32, Error> {
async fn migrate_up(db: &PostgresDatabase) -> Result<i32, Error> {
let current_version = get_db_version(db).await?;
if let Some(migration) = MIGRATIONS.iter().find(|m| m.version > current_version) {
let client = db.connection_pool.get().await?;
@ -114,7 +114,7 @@ async fn migrate_up(db: &Database) -> Result<i32, Error> {
/// Revert back to the previous schema version by running the
/// [down](Migration::down) SQL of the current `Migration`
async fn migrate_down(db: &Database) -> Result<i32, Error> {
async fn migrate_down(db: &PostgresDatabase) -> Result<i32, Error> {
let current_version = get_db_version(db).await?;
let mut migration_iter = MIGRATIONS
.iter()
@ -141,7 +141,7 @@ async fn migrate_down(db: &Database) -> Result<i32, Error> {
/// Apply whatever migrations are necessary to bring the database
/// schema to the same version is [CURRENT_VERSION].
pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> {
pub async fn migrate_to_current_version(db: &PostgresDatabase) -> Result<(), Error> {
migrate_to_version(db, CURRENT_VERSION).await
}
@ -149,7 +149,7 @@ pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> {
/// schema to the same version as `target_version`.
///
/// This may migrate up or down as required.
async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Error> {
async fn migrate_to_version(db: &PostgresDatabase, target_version: i32) -> Result<(), Error> {
let mut version = get_db_version(db).await?;
while version != target_version {
if version < target_version {
@ -162,7 +162,7 @@ async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Er
}
/// Get the current schema version of the database.
pub async fn get_db_version(db: &Database) -> Result<i32, Error> {
pub async fn get_db_version(db: &PostgresDatabase) -> Result<i32, Error> {
let client = db.connection_pool.get().await?;
client
.execute(
@ -195,12 +195,12 @@ pub async fn get_db_version(db: &Database) -> Result<i32, Error> {
#[cfg(test)]
mod tests {
use super::super::Database;
use super::super::PostgresDatabase;
use super::*;
async fn test_db() -> Database {
async fn test_db() -> PostgresDatabase {
let url = std::env::var("LOCALITY_TEST_DATABASE_URL").unwrap();
Database::new(&url)
PostgresDatabase::new(&url)
.unwrap()
.connection_pool
.get()
@ -217,7 +217,7 @@ mod tests {
)
.await
.unwrap();
Database::new(&url).unwrap()
PostgresDatabase::new(&url).unwrap()
}
#[test]

View File

@ -8,7 +8,7 @@ mod migrations;
use {
deadpool_postgres::{CreatePoolError, Pool, Runtime},
serde::{Deserialize, Serialize},
std::ops::Deref,
std::{future::Future, ops::Deref},
tokio_postgres::NoTls,
tracing::error,
};
@ -84,11 +84,40 @@ impl From<tokio_postgres::Error> for Error {
pub type Result<T> = std::result::Result<T, Error>;
pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static {
/// Run migrations as needed to ensure the database schema version
/// match the one used by the current version of the application.
fn migrate_to_current_version(&self) -> impl Future<Output = Result<()>> + Send;
fn get_client(&self) -> impl Future<Output = Result<deadpool_postgres::Client>> + Send;
fn has_admin_users(&self) -> impl Future<Output = Result<bool>> + Send;
fn create_user(
&self,
real_name: &str,
email: &str,
password: &PasswordHash,
) -> impl Future<Output = Result<User>> + Send;
fn get_password_for_user(
&self,
user: &User,
) -> impl Future<Output = Result<PasswordHash>> + Send;
fn create_first_admin_user(
&self,
real_name: &str,
email: &str,
password: &PasswordHash,
) -> impl Future<Output = Result<User>> + Send;
fn is_user_admin(&self, user: &User) -> impl Future<Output = Result<bool>> + Send;
fn get_user_with_id(
&self,
user_id: UserId,
) -> impl Future<Output = Result<Option<User>>> + Send;
}
/// Object that manages the database.
///
/// All database access happens through this struct.
#[derive(Clone, Debug)]
pub struct Database {
pub struct PostgresDatabase {
connection_pool: Pool,
}
@ -118,20 +147,22 @@ impl Deref for PasswordHash {
}
}
impl Database {
impl PostgresDatabase {
/// Create a connection pool and return the [Database].
pub fn new(connection_url: &str) -> InitialisationResult<Database> {
pub fn new(connection_url: &str) -> InitialisationResult<PostgresDatabase> {
let mut config = deadpool_postgres::Config::new();
config.url = Some(connection_url.to_string());
let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?;
Ok(Database {
Ok(PostgresDatabase {
connection_pool: pg_pool,
})
}
}
impl Database for PostgresDatabase {
/// Run migrations as needed to ensure the database schema version
/// match the one used by the current version of the application.
pub async fn migrate_to_current_version(&self) -> Result<()> {
async fn migrate_to_current_version(&self) -> Result<()> {
migrations::migrate_to_current_version(self).await
}
@ -140,7 +171,7 @@ impl Database {
}
#[tracing::instrument]
pub async fn has_admin_users(&self) -> Result<bool> {
async fn has_admin_users(&self) -> Result<bool> {
let client = self.get_client().await?;
Ok(client
.query_one("SELECT EXISTS(SELECT 1 FROM admin_users);", &[])
@ -149,7 +180,7 @@ impl Database {
}
#[tracing::instrument]
pub async fn create_user(
async fn create_user(
&self,
real_name: &str,
email: &str,
@ -174,7 +205,7 @@ impl Database {
}
#[tracing::instrument]
pub async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
let client = self.get_client().await?;
let row = client
.query_one("SELECT password FROM users WHERE id = $1;", &[&user.id.0])
@ -183,7 +214,7 @@ impl Database {
}
#[tracing::instrument]
pub async fn create_first_admin_user(
async fn create_first_admin_user(
&self,
real_name: &str,
email: &str,
@ -204,7 +235,7 @@ impl Database {
}
#[tracing::instrument]
pub async fn is_user_admin(&self, user: &User) -> Result<bool> {
async fn is_user_admin(&self, user: &User) -> Result<bool> {
Ok(self
.get_client()
.await?
@ -214,7 +245,7 @@ impl Database {
}
#[tracing::instrument]
pub async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
Ok(self
.get_client()
.await?

View File

@ -12,7 +12,7 @@ mod db;
mod error;
use config::get_config;
use db::Database;
use db::{Database, PostgresDatabase};
#[derive(Error, Debug)]
pub enum Error {
@ -49,13 +49,13 @@ async fn locality_main() -> Result<(), Error> {
.finish();
tracing::subscriber::set_global_default(subscriber)?;
let db_pool = Database::new(&config.database_url)?;
let db_pool = PostgresDatabase::new(&config.database_url)?;
db_pool.migrate_to_current_version().await.unwrap();
let app = app::routes()
.nest("/admin", admin::routes())
.with_state(db_pool)
.with_state(app::AppState { db: db_pool })
.nest_service("/static", ServeDir::new(&config.static_file_path))
.layer(TraceLayer::new_for_http());