diff --git a/src/app.rs b/src/app.rs index 3e91995..648ce70 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,7 +1,7 @@ use { crate::db::Database, askama::Template, - axum::{extract::State, routing::get, Router}, + axum::{routing::get, Router}, }; pub fn routes() -> Router { @@ -14,6 +14,6 @@ struct IndexTemplate<'a> { title: &'a str, } -async fn root<'a>(State(database): State) -> IndexTemplate<'a> { +async fn root<'a>() -> IndexTemplate<'a> { IndexTemplate { title: "LocalHub" } } diff --git a/src/db.rs b/src/db.rs deleted file mode 100644 index f03e09e..0000000 --- a/src/db.rs +++ /dev/null @@ -1,56 +0,0 @@ -use { - deadpool_postgres::{CreatePoolError, Pool, Runtime}, - thiserror::Error, - tokio_postgres::NoTls, -}; - -#[derive(Error, Debug)] -pub enum InitialisationError { - #[error("Could not initialize DB connection pool.")] - ConnectionPoolError(#[from] CreatePoolError), -} - -#[derive(Error, Debug)] -pub enum Error { - #[error("")] - PoolError(#[from] deadpool_postgres::PoolError), - #[error("")] - DbError(#[from] tokio_postgres::Error) -} - -#[derive(Clone)] -pub struct Database { - connection_pool: Pool, -} - -impl Database { - pub fn create_pool(connection_url: &str) -> Result { - let mut config = deadpool_postgres::Config::new(); - config.url = Some(connection_url.to_string()); - let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?; - Ok(Database { connection_pool: pg_pool }) - } - - pub async fn get_db_version(&self) -> Result { - let client = self.connection_pool.get().await?; - client.execute(r#" - DO $$BEGIN - IF NOT EXISTS - ( SELECT 1 - FROM information_schema.tables - WHERE table_schema = 'public' - AND table_name = 'migration_info' - ) - THEN - CREATE TABLE migration_info ( - onerow_id bool PRIMARY KEY DEFAULT true, - migration_version INTEGER, - CONSTRAINT onerow_unique CHECK (onerow_id)); - INSERT - INTO migration_info (migration_version) - VALUES (-1); - END IF; - END$$;"#, &[]).await?; - Ok(client.query_one(r#"SELECT migration_version FROM migration_info;"#, &[]).await?.get(0)) - } -} diff --git a/src/db/migrations.rs b/src/db/migrations.rs new file mode 100644 index 0000000..2425bfb --- /dev/null +++ b/src/db/migrations.rs @@ -0,0 +1,226 @@ +//! Database schema migrations and the code to manage them. +//! +//! - [MIGRATIONS] stores a list of schema migrations defined in SQL. +//! +//! - [CURRENT_VERSION] is the schema version used by the current app +//! version. +//! +//! - [get_db_version()] returns the schema version of the database +//! itself. + +//! - [migrate_to_current_version()] Applies whatever migrations are +//! needed to that the database schema version (as returned by +//! [get_db_version()]) matches [CURRENT_VERSION]. + +use super::{Database, Error}; + +/// Defines a database schema migration +#[derive(Debug)] +struct Migration { + /// The schema version that `up` will migrate to. + version: i32, + /// SQL to migrate to [version](Migration::version) from the + /// previous schema version. + /// + /// May contain multiple SQL statements: they will be wrapped in a + /// `DO $$BEGIN`...`END$$;` block when executed. + up: &'static str, + /// SQL to migrate from [version](Migration::version) to the + /// previous schema version. + /// + /// May contain multiple SQL statements: they will be wrapped in a + /// `DO $$BEGIN`...`END$$;` block when executed. + down: &'static str, +} + +/// A list of all the schema migrations. +/// +/// The [version](Migration::version) field should always correspond to the +/// [Migration]'s position in the list, starting at 0 and working up +/// sequentially. +/// +/// New versions should normally only ever add to the end of this +/// list, never change or add things in the middle or at the +/// beginning. +static MIGRATIONS: &[Migration] = &[Migration { + version: 0, + up: r#" + CREATE TABLE users ( + id SERIAL PRIMARY KEY, + real_name TEXT NOT NULL, + email TEXT UNIQUE NOT NULL, + -- 43 characters is enought to base64-encode 32 bits + password_salt CHARACTER(43) NOT NULL, + password_hash CHARACTER(43) NOT NULL);"#, + down: r#"DROP TABLE users;"#, +}]; + +/// The current schema version. Normally this will be the +/// [version](Migration::version) of the list item in [MIGRATIONS]. +/// +/// This is the the current version *as specified by the +/// application*. It may not be the actual schema of the database. If +/// it is not the current database schema, then running +/// [migrate_to_current_version()] will migrate the database to this +/// version. +static CURRENT_VERSION: i32 = 0; + +/// If the database is not already using the most recent schema, apply +/// one migration to bring it to the next newest version. +/// +/// E.g. If [CURRENT_VERSION] is 10 but the database is on version 4, +/// then running this function will apply [Migration] 5 to bring the +/// database up to schema version 5. +async fn migrate_up(db: &Database) -> Result { + let current_version = dbg!(get_db_version(db).await?); + if let Some(migration) = if current_version < 1 { + MIGRATIONS.first() + } else { + MIGRATIONS.iter().find(|m| m.version == current_version) + } { + let client = db.connection_pool.get().await?; + client + .execute(&format!("DO $$BEGIN\n{}\nEND$$;", migration.up), &[]) + .await?; + client + .execute( + "UPDATE migration_info SET migration_version = $1;", + &[&migration.version], + ) + .await?; + Ok(dbg!(migration.version)) + } else { + eprintln!("ERROR: Attempted to migrate up past last migration."); + Ok(current_version) + } +} + +/// Revert back to the previous schema version by running the +/// [down](Migration::down) SQL of the current `Migration` +async fn migrate_down(db: &Database) -> Result { + let current_version = get_db_version(db).await?; + let mut migration_iter = MIGRATIONS + .iter() + .rev() + .skip_while(|m| m.version != current_version); + if let Some(migration) = migration_iter.next() { + let client = db.connection_pool.get().await?; + client + .execute(&format!("DO $$BEGIN\n{}\nEND$$;", migration.down), &[]) + .await?; + let version = migration_iter.next().map_or(-1, |m| m.version); + client + .execute( + "UPDATE migration_info SET migration_version = $1;", + &[&version], + ) + .await?; + Ok(version) + } else { + eprintln!("ERROR: Attempted to migrate down past first migration."); + Ok(current_version) + } +} + +/// Apply whatever migrations are necessary to bring the database +/// schema to the same version is [CURRENT_VERSION]. +pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> { + migrate_to_version(db, CURRENT_VERSION).await +} + +/// Apply whatever migrations are necessary to bring the database +/// schema to the same version as `target_version`. +/// +/// This may migrate up or down as required. +async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Error> { + let mut version = get_db_version(db).await?; + while version != target_version { + if version < target_version { + version = migrate_up(db).await?; + } else { + version = migrate_down(db).await?; + } + } + Ok(()) +} + +/// Get the current schema version of the database. +pub async fn get_db_version(db: &Database) -> Result { + let client = db.connection_pool.get().await?; + client + .execute( + r#" + DO $$BEGIN + IF NOT EXISTS + ( SELECT 1 + FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name = 'migration_info' + ) + THEN + CREATE TABLE migration_info ( + onerow_id BOOL PRIMARY KEY DEFAULT true, + migration_version INTEGER, + CONSTRAINT onerow_unique CHECK (onerow_id)); + INSERT + INTO migration_info (migration_version) + VALUES (-1); + END IF; + END$$;"#, + &[], + ) + .await?; + Ok(client + .query_one(r#"SELECT migration_version FROM migration_info;"#, &[]) + .await? + .get(0)) +} + +#[cfg(test)] +mod tests { + use super::super::Database; + use super::*; + + async fn test_db() -> Database { + let url = std::env::var("LOCALHUB_TEST_DATABASE_URL").unwrap(); + Database::new(&url) + .unwrap() + .connection_pool + .get() + .await + .unwrap() + .execute( + r#" + DO $$BEGIN + DROP TABlE IF EXISTS users; + DROP TABLE IF EXISTS migration_info; + END$$;"#, + &[], + ) + .await + .unwrap(); + Database::new(&url).unwrap() + } + + #[test] + fn migrations_have_sequential_versions() { + for i in 0..MIGRATIONS.len() { + assert_eq!(i as i32, MIGRATIONS[i].version); + } + } + + #[tokio::test] + async fn migrate_up_and_down_all() { + let db = test_db().await; + assert_eq!(-1, get_db_version(&db).await.unwrap()); + migrate_to_version(&db, MIGRATIONS.last().unwrap().version) + .await + .unwrap(); + assert_eq!( + MIGRATIONS.last().unwrap().version, + get_db_version(&db).await.unwrap() + ); + migrate_to_version(&db, -1).await.unwrap(); + assert_eq!(-1, get_db_version(&db).await.unwrap()); + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..a2c056f --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,54 @@ +//! Code for managing the main database. +//! +//! The database schema is defined via a series of migrations in +//! [migrations]. + +mod migrations; + +use { + deadpool_postgres::{CreatePoolError, Pool, Runtime}, + thiserror::Error, + tokio_postgres::NoTls, +}; + +/// Errors that may occur during module initialization +#[derive(Error, Debug)] +pub enum InitialisationError { + #[error("Could not initialize DB connection pool: {}", .0.to_string())] + ConnectionPoolError(#[from] CreatePoolError), +} + +/// Errors that may occur during normal app operation +#[derive(Error, Debug)] +pub enum Error { + #[error("{}", .0.to_string())] + PoolError(#[from] deadpool_postgres::PoolError), + #[error("{}", .0.to_string())] + DbError(#[from] tokio_postgres::Error), +} + +/// Object that manages the database. +/// +/// All database access happens through this struct. +#[derive(Clone, Debug)] +pub struct Database { + connection_pool: Pool, +} + +impl Database { + /// Create a connection pool and return the [Database]. + pub fn new(connection_url: &str) -> Result { + let mut config = deadpool_postgres::Config::new(); + config.url = Some(connection_url.to_string()); + let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?; + Ok(Database { + connection_pool: pg_pool, + }) + } + + /// Run migrations as needed to ensure the database schema version + /// match the one used by the current version of the application. + pub async fn migrate_to_current_version(&self) -> Result<(), Error> { + migrations::migrate_to_current_version(self).await + } +} diff --git a/src/main.rs b/src/main.rs index 55145c8..0a326e7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,9 +33,9 @@ fn main() { async fn localhub_main() -> Result<(), Error> { let config = get_config()?; - let db_pool = Database::create_pool(&config.database_url)?; + let db_pool = Database::new(&config.database_url)?; - dbg!(db_pool.get_db_version().await.unwrap()); + db_pool.migrate_to_current_version().await.unwrap(); let app = app::routes() .with_state(db_pool)