Make a trait to allow dependence injection
This commit is contained in:
parent
af2fb99ad9
commit
1fab78ff96
23
src/admin.rs
23
src/admin.rs
|
|
@ -22,7 +22,9 @@ use {
|
|||
serde::Deserialize,
|
||||
};
|
||||
|
||||
pub fn routes() -> Router<Database> {
|
||||
use super::app::AppState;
|
||||
|
||||
pub fn routes<D: Database>() -> Router<AppState<D>> {
|
||||
Router::new()
|
||||
.route("/", get(root))
|
||||
.route("/create_first_admin_user", get(get_create_first_admin_user))
|
||||
|
|
@ -46,7 +48,10 @@ struct IndexTemplate<'a> {
|
|||
admin_user_name: &'a str,
|
||||
}
|
||||
|
||||
async fn check_jwt(db: &Database, cookie_jar: &CookieJar) -> Result<AuthenticatedAdminUser, Error> {
|
||||
async fn check_jwt<D: Database>(
|
||||
db: &D,
|
||||
cookie_jar: &CookieJar,
|
||||
) -> Result<AuthenticatedAdminUser, Error> {
|
||||
match authenticate_user_with_jwt(
|
||||
db,
|
||||
cookie_jar
|
||||
|
|
@ -66,7 +71,10 @@ async fn check_jwt(db: &Database, cookie_jar: &CookieJar) -> Result<Authenticate
|
|||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn root(cookie_jar: CookieJar, State(db): State<Database>) -> Result<Response, Error> {
|
||||
async fn root<D: Database>(
|
||||
cookie_jar: CookieJar,
|
||||
State(AppState { db, .. }): State<AppState<D>>,
|
||||
) -> Result<Response, Error> {
|
||||
Ok(if !db.has_admin_users().await? {
|
||||
Redirect::temporary("admin/create_first_admin_user").into_response()
|
||||
} else {
|
||||
|
|
@ -91,9 +99,9 @@ struct CreateFirstUserParameters {
|
|||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn post_create_first_admin_user(
|
||||
async fn post_create_first_admin_user<D: Database>(
|
||||
cookie_jar: CookieJar,
|
||||
State(db): State<Database>,
|
||||
State(AppState::<D> { db, .. }): State<AppState<D>>,
|
||||
Form(params): Form<CreateFirstUserParameters>,
|
||||
) -> Result<(CookieJar, FirstLoginTemplate), Error> {
|
||||
let user = db
|
||||
|
|
@ -109,9 +117,8 @@ async fn post_create_first_admin_user(
|
|||
"Could not authenticate newly-created user.".to_string(),
|
||||
))?;
|
||||
Ok((
|
||||
cookie_jar.add(
|
||||
Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict),
|
||||
),
|
||||
cookie_jar
|
||||
.add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),
|
||||
FirstLoginTemplate {},
|
||||
))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,12 @@ use {
|
|||
axum::{routing::get, Router},
|
||||
};
|
||||
|
||||
pub fn routes() -> Router<Database> {
|
||||
#[derive(Clone)]
|
||||
pub struct AppState<D: Database> {
|
||||
pub db: D,
|
||||
}
|
||||
|
||||
pub fn routes<D:Database>() -> Router<AppState<D>> {
|
||||
Router::new().route("/", get(root))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
|
|||
/// Given JWT string created by [create_jwt_for_user()], check if the
|
||||
/// JWT is valid and return an [AuthenticatedUser] if it is.
|
||||
#[tracing::instrument]
|
||||
pub async fn authenticate_user_with_jwt(db: &Database, jwt: &str) -> Result<ParsedJwt> {
|
||||
pub async fn authenticate_user_with_jwt<D:Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
|
||||
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
|
||||
let mut mac = mac();
|
||||
mac.update(format!("{}.{}", header, payload).as_bytes());
|
||||
|
|
|
|||
|
|
@ -81,8 +81,8 @@ impl From<Password> for db::PasswordHash {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn authenticate_user_with_password(
|
||||
db: &Database,
|
||||
pub async fn authenticate_user_with_password<D: Database>(
|
||||
db: &D,
|
||||
user: db::User,
|
||||
supplied_password: &str,
|
||||
) -> Result<Option<AuthenticatedUser>, AuthenticationError> {
|
||||
|
|
@ -94,8 +94,8 @@ pub async fn authenticate_user_with_password(
|
|||
})
|
||||
}
|
||||
|
||||
pub async fn check_if_user_is_admin(
|
||||
db: &Database,
|
||||
pub async fn check_if_user_is_admin<D: Database>(
|
||||
db: &D,
|
||||
user: &AuthenticatedUser,
|
||||
) -> Result<Option<AuthenticatedAdminUser>, db::Error> {
|
||||
if db.is_user_admin(user).await? {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
//! needed to that the database schema version (as returned by
|
||||
//! [get_db_version()]) matches [CURRENT_VERSION].
|
||||
|
||||
use super::{Database, Error};
|
||||
use super::{Error, PostgresDatabase};
|
||||
|
||||
/// Defines a database schema migration
|
||||
#[derive(Debug)]
|
||||
|
|
@ -72,8 +72,8 @@ static MIGRATIONS: &[Migration] = &[
|
|||
ALTER TABLE users DROP COLUMN IF EXISTS password_salt;
|
||||
ALTER TABLE users DROP COLUMN IF EXISTS password_hash;
|
||||
ALTER TABLE users ADD COLUMN password TEXT NOT NULL;"#,
|
||||
down: "ALTER TABLE users DROP COLUMN password;"
|
||||
}
|
||||
down: "ALTER TABLE users DROP COLUMN password;",
|
||||
},
|
||||
];
|
||||
|
||||
/// The current schema version. Normally this will be the
|
||||
|
|
@ -92,7 +92,7 @@ static CURRENT_VERSION: i32 = 2;
|
|||
/// E.g. If [CURRENT_VERSION] is 10 but the database is on version 4,
|
||||
/// then running this function will apply [Migration] 5 to bring the
|
||||
/// database up to schema version 5.
|
||||
async fn migrate_up(db: &Database) -> Result<i32, Error> {
|
||||
async fn migrate_up(db: &PostgresDatabase) -> Result<i32, Error> {
|
||||
let current_version = get_db_version(db).await?;
|
||||
if let Some(migration) = MIGRATIONS.iter().find(|m| m.version > current_version) {
|
||||
let client = db.connection_pool.get().await?;
|
||||
|
|
@ -114,7 +114,7 @@ async fn migrate_up(db: &Database) -> Result<i32, Error> {
|
|||
|
||||
/// Revert back to the previous schema version by running the
|
||||
/// [down](Migration::down) SQL of the current `Migration`
|
||||
async fn migrate_down(db: &Database) -> Result<i32, Error> {
|
||||
async fn migrate_down(db: &PostgresDatabase) -> Result<i32, Error> {
|
||||
let current_version = get_db_version(db).await?;
|
||||
let mut migration_iter = MIGRATIONS
|
||||
.iter()
|
||||
|
|
@ -141,7 +141,7 @@ async fn migrate_down(db: &Database) -> Result<i32, Error> {
|
|||
|
||||
/// Apply whatever migrations are necessary to bring the database
|
||||
/// schema to the same version is [CURRENT_VERSION].
|
||||
pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> {
|
||||
pub async fn migrate_to_current_version(db: &PostgresDatabase) -> Result<(), Error> {
|
||||
migrate_to_version(db, CURRENT_VERSION).await
|
||||
}
|
||||
|
||||
|
|
@ -149,7 +149,7 @@ pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> {
|
|||
/// schema to the same version as `target_version`.
|
||||
///
|
||||
/// This may migrate up or down as required.
|
||||
async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Error> {
|
||||
async fn migrate_to_version(db: &PostgresDatabase, target_version: i32) -> Result<(), Error> {
|
||||
let mut version = get_db_version(db).await?;
|
||||
while version != target_version {
|
||||
if version < target_version {
|
||||
|
|
@ -162,7 +162,7 @@ async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Er
|
|||
}
|
||||
|
||||
/// Get the current schema version of the database.
|
||||
pub async fn get_db_version(db: &Database) -> Result<i32, Error> {
|
||||
pub async fn get_db_version(db: &PostgresDatabase) -> Result<i32, Error> {
|
||||
let client = db.connection_pool.get().await?;
|
||||
client
|
||||
.execute(
|
||||
|
|
@ -195,12 +195,12 @@ pub async fn get_db_version(db: &Database) -> Result<i32, Error> {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::Database;
|
||||
use super::super::PostgresDatabase;
|
||||
use super::*;
|
||||
|
||||
async fn test_db() -> Database {
|
||||
async fn test_db() -> PostgresDatabase {
|
||||
let url = std::env::var("LOCALITY_TEST_DATABASE_URL").unwrap();
|
||||
Database::new(&url)
|
||||
PostgresDatabase::new(&url)
|
||||
.unwrap()
|
||||
.connection_pool
|
||||
.get()
|
||||
|
|
@ -217,7 +217,7 @@ mod tests {
|
|||
)
|
||||
.await
|
||||
.unwrap();
|
||||
Database::new(&url).unwrap()
|
||||
PostgresDatabase::new(&url).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ mod migrations;
|
|||
use {
|
||||
deadpool_postgres::{CreatePoolError, Pool, Runtime},
|
||||
serde::{Deserialize, Serialize},
|
||||
std::ops::Deref,
|
||||
std::{future::Future, ops::Deref},
|
||||
tokio_postgres::NoTls,
|
||||
tracing::error,
|
||||
};
|
||||
|
|
@ -84,11 +84,40 @@ impl From<tokio_postgres::Error> for Error {
|
|||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static {
|
||||
/// Run migrations as needed to ensure the database schema version
|
||||
/// match the one used by the current version of the application.
|
||||
fn migrate_to_current_version(&self) -> impl Future<Output = Result<()>> + Send;
|
||||
fn get_client(&self) -> impl Future<Output = Result<deadpool_postgres::Client>> + Send;
|
||||
fn has_admin_users(&self) -> impl Future<Output = Result<bool>> + Send;
|
||||
fn create_user(
|
||||
&self,
|
||||
real_name: &str,
|
||||
email: &str,
|
||||
password: &PasswordHash,
|
||||
) -> impl Future<Output = Result<User>> + Send;
|
||||
fn get_password_for_user(
|
||||
&self,
|
||||
user: &User,
|
||||
) -> impl Future<Output = Result<PasswordHash>> + Send;
|
||||
fn create_first_admin_user(
|
||||
&self,
|
||||
real_name: &str,
|
||||
email: &str,
|
||||
password: &PasswordHash,
|
||||
) -> impl Future<Output = Result<User>> + Send;
|
||||
fn is_user_admin(&self, user: &User) -> impl Future<Output = Result<bool>> + Send;
|
||||
fn get_user_with_id(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> impl Future<Output = Result<Option<User>>> + Send;
|
||||
}
|
||||
|
||||
/// Object that manages the database.
|
||||
///
|
||||
/// All database access happens through this struct.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Database {
|
||||
pub struct PostgresDatabase {
|
||||
connection_pool: Pool,
|
||||
}
|
||||
|
||||
|
|
@ -118,20 +147,22 @@ impl Deref for PasswordHash {
|
|||
}
|
||||
}
|
||||
|
||||
impl Database {
|
||||
impl PostgresDatabase {
|
||||
/// Create a connection pool and return the [Database].
|
||||
pub fn new(connection_url: &str) -> InitialisationResult<Database> {
|
||||
pub fn new(connection_url: &str) -> InitialisationResult<PostgresDatabase> {
|
||||
let mut config = deadpool_postgres::Config::new();
|
||||
config.url = Some(connection_url.to_string());
|
||||
let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?;
|
||||
Ok(Database {
|
||||
Ok(PostgresDatabase {
|
||||
connection_pool: pg_pool,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Database for PostgresDatabase {
|
||||
/// Run migrations as needed to ensure the database schema version
|
||||
/// match the one used by the current version of the application.
|
||||
pub async fn migrate_to_current_version(&self) -> Result<()> {
|
||||
async fn migrate_to_current_version(&self) -> Result<()> {
|
||||
migrations::migrate_to_current_version(self).await
|
||||
}
|
||||
|
||||
|
|
@ -140,7 +171,7 @@ impl Database {
|
|||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn has_admin_users(&self) -> Result<bool> {
|
||||
async fn has_admin_users(&self) -> Result<bool> {
|
||||
let client = self.get_client().await?;
|
||||
Ok(client
|
||||
.query_one("SELECT EXISTS(SELECT 1 FROM admin_users);", &[])
|
||||
|
|
@ -149,7 +180,7 @@ impl Database {
|
|||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn create_user(
|
||||
async fn create_user(
|
||||
&self,
|
||||
real_name: &str,
|
||||
email: &str,
|
||||
|
|
@ -174,7 +205,7 @@ impl Database {
|
|||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
||||
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
||||
let client = self.get_client().await?;
|
||||
let row = client
|
||||
.query_one("SELECT password FROM users WHERE id = $1;", &[&user.id.0])
|
||||
|
|
@ -183,7 +214,7 @@ impl Database {
|
|||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn create_first_admin_user(
|
||||
async fn create_first_admin_user(
|
||||
&self,
|
||||
real_name: &str,
|
||||
email: &str,
|
||||
|
|
@ -204,7 +235,7 @@ impl Database {
|
|||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn is_user_admin(&self, user: &User) -> Result<bool> {
|
||||
async fn is_user_admin(&self, user: &User) -> Result<bool> {
|
||||
Ok(self
|
||||
.get_client()
|
||||
.await?
|
||||
|
|
@ -214,7 +245,7 @@ impl Database {
|
|||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
|
||||
async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
|
||||
Ok(self
|
||||
.get_client()
|
||||
.await?
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ mod db;
|
|||
mod error;
|
||||
|
||||
use config::get_config;
|
||||
use db::Database;
|
||||
use db::{Database, PostgresDatabase};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
|
|
@ -49,13 +49,13 @@ async fn locality_main() -> Result<(), Error> {
|
|||
.finish();
|
||||
tracing::subscriber::set_global_default(subscriber)?;
|
||||
|
||||
let db_pool = Database::new(&config.database_url)?;
|
||||
let db_pool = PostgresDatabase::new(&config.database_url)?;
|
||||
|
||||
db_pool.migrate_to_current_version().await.unwrap();
|
||||
|
||||
let app = app::routes()
|
||||
.nest("/admin", admin::routes())
|
||||
.with_state(db_pool)
|
||||
.with_state(app::AppState { db: db_pool })
|
||||
.nest_service("/static", ServeDir::new(&config.static_file_path))
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue