Add migrations and other foundational database stuff

This commit is contained in:
Matthew Gordon 2024-02-19 21:55:15 -04:00
parent 8663a271dc
commit 5ad4f80fc1
5 changed files with 284 additions and 60 deletions

View File

@ -1,7 +1,7 @@
use { use {
crate::db::Database, crate::db::Database,
askama::Template, askama::Template,
axum::{extract::State, routing::get, Router}, axum::{routing::get, Router},
}; };
pub fn routes() -> Router<Database> { pub fn routes() -> Router<Database> {
@ -14,6 +14,6 @@ struct IndexTemplate<'a> {
title: &'a str, title: &'a str,
} }
async fn root<'a>(State(database): State<Database>) -> IndexTemplate<'a> { async fn root<'a>() -> IndexTemplate<'a> {
IndexTemplate { title: "LocalHub" } IndexTemplate { title: "LocalHub" }
} }

View File

@ -1,56 +0,0 @@
use {
deadpool_postgres::{CreatePoolError, Pool, Runtime},
thiserror::Error,
tokio_postgres::NoTls,
};
#[derive(Error, Debug)]
pub enum InitialisationError {
#[error("Could not initialize DB connection pool.")]
ConnectionPoolError(#[from] CreatePoolError),
}
#[derive(Error, Debug)]
pub enum Error {
#[error("")]
PoolError(#[from] deadpool_postgres::PoolError),
#[error("")]
DbError(#[from] tokio_postgres::Error)
}
#[derive(Clone)]
pub struct Database {
connection_pool: Pool,
}
impl Database {
pub fn create_pool(connection_url: &str) -> Result<Database, InitialisationError> {
let mut config = deadpool_postgres::Config::new();
config.url = Some(connection_url.to_string());
let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?;
Ok(Database { connection_pool: pg_pool })
}
pub async fn get_db_version(&self) -> Result<i32, Error> {
let client = self.connection_pool.get().await?;
client.execute(r#"
DO $$BEGIN
IF NOT EXISTS
( SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = 'migration_info'
)
THEN
CREATE TABLE migration_info (
onerow_id bool PRIMARY KEY DEFAULT true,
migration_version INTEGER,
CONSTRAINT onerow_unique CHECK (onerow_id));
INSERT
INTO migration_info (migration_version)
VALUES (-1);
END IF;
END$$;"#, &[]).await?;
Ok(client.query_one(r#"SELECT migration_version FROM migration_info;"#, &[]).await?.get(0))
}
}

226
src/db/migrations.rs Normal file
View File

@ -0,0 +1,226 @@
//! Database schema migrations and the code to manage them.
//!
//! - [MIGRATIONS] stores a list of schema migrations defined in SQL.
//!
//! - [CURRENT_VERSION] is the schema version used by the current app
//! version.
//!
//! - [get_db_version()] returns the schema version of the database
//! itself.
//! - [migrate_to_current_version()] Applies whatever migrations are
//! needed to that the database schema version (as returned by
//! [get_db_version()]) matches [CURRENT_VERSION].
use super::{Database, Error};
/// Defines a database schema migration
#[derive(Debug)]
struct Migration {
/// The schema version that `up` will migrate to.
version: i32,
/// SQL to migrate to [version](Migration::version) from the
/// previous schema version.
///
/// May contain multiple SQL statements: they will be wrapped in a
/// `DO $$BEGIN`...`END$$;` block when executed.
up: &'static str,
/// SQL to migrate from [version](Migration::version) to the
/// previous schema version.
///
/// May contain multiple SQL statements: they will be wrapped in a
/// `DO $$BEGIN`...`END$$;` block when executed.
down: &'static str,
}
/// A list of all the schema migrations.
///
/// The [version](Migration::version) field should always correspond to the
/// [Migration]'s position in the list, starting at 0 and working up
/// sequentially.
///
/// New versions should normally only ever add to the end of this
/// list, never change or add things in the middle or at the
/// beginning.
static MIGRATIONS: &[Migration] = &[Migration {
version: 0,
up: r#"
CREATE TABLE users (
id SERIAL PRIMARY KEY,
real_name TEXT NOT NULL,
email TEXT UNIQUE NOT NULL,
-- 43 characters is enought to base64-encode 32 bits
password_salt CHARACTER(43) NOT NULL,
password_hash CHARACTER(43) NOT NULL);"#,
down: r#"DROP TABLE users;"#,
}];
/// The current schema version. Normally this will be the
/// [version](Migration::version) of the list item in [MIGRATIONS].
///
/// This is the the current version *as specified by the
/// application*. It may not be the actual schema of the database. If
/// it is not the current database schema, then running
/// [migrate_to_current_version()] will migrate the database to this
/// version.
static CURRENT_VERSION: i32 = 0;
/// If the database is not already using the most recent schema, apply
/// one migration to bring it to the next newest version.
///
/// E.g. If [CURRENT_VERSION] is 10 but the database is on version 4,
/// then running this function will apply [Migration] 5 to bring the
/// database up to schema version 5.
async fn migrate_up(db: &Database) -> Result<i32, Error> {
let current_version = dbg!(get_db_version(db).await?);
if let Some(migration) = if current_version < 1 {
MIGRATIONS.first()
} else {
MIGRATIONS.iter().find(|m| m.version == current_version)
} {
let client = db.connection_pool.get().await?;
client
.execute(&format!("DO $$BEGIN\n{}\nEND$$;", migration.up), &[])
.await?;
client
.execute(
"UPDATE migration_info SET migration_version = $1;",
&[&migration.version],
)
.await?;
Ok(dbg!(migration.version))
} else {
eprintln!("ERROR: Attempted to migrate up past last migration.");
Ok(current_version)
}
}
/// Revert back to the previous schema version by running the
/// [down](Migration::down) SQL of the current `Migration`
async fn migrate_down(db: &Database) -> Result<i32, Error> {
let current_version = get_db_version(db).await?;
let mut migration_iter = MIGRATIONS
.iter()
.rev()
.skip_while(|m| m.version != current_version);
if let Some(migration) = migration_iter.next() {
let client = db.connection_pool.get().await?;
client
.execute(&format!("DO $$BEGIN\n{}\nEND$$;", migration.down), &[])
.await?;
let version = migration_iter.next().map_or(-1, |m| m.version);
client
.execute(
"UPDATE migration_info SET migration_version = $1;",
&[&version],
)
.await?;
Ok(version)
} else {
eprintln!("ERROR: Attempted to migrate down past first migration.");
Ok(current_version)
}
}
/// Apply whatever migrations are necessary to bring the database
/// schema to the same version is [CURRENT_VERSION].
pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> {
migrate_to_version(db, CURRENT_VERSION).await
}
/// Apply whatever migrations are necessary to bring the database
/// schema to the same version as `target_version`.
///
/// This may migrate up or down as required.
async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Error> {
let mut version = get_db_version(db).await?;
while version != target_version {
if version < target_version {
version = migrate_up(db).await?;
} else {
version = migrate_down(db).await?;
}
}
Ok(())
}
/// Get the current schema version of the database.
pub async fn get_db_version(db: &Database) -> Result<i32, Error> {
let client = db.connection_pool.get().await?;
client
.execute(
r#"
DO $$BEGIN
IF NOT EXISTS
( SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = 'migration_info'
)
THEN
CREATE TABLE migration_info (
onerow_id BOOL PRIMARY KEY DEFAULT true,
migration_version INTEGER,
CONSTRAINT onerow_unique CHECK (onerow_id));
INSERT
INTO migration_info (migration_version)
VALUES (-1);
END IF;
END$$;"#,
&[],
)
.await?;
Ok(client
.query_one(r#"SELECT migration_version FROM migration_info;"#, &[])
.await?
.get(0))
}
#[cfg(test)]
mod tests {
use super::super::Database;
use super::*;
async fn test_db() -> Database {
let url = std::env::var("LOCALHUB_TEST_DATABASE_URL").unwrap();
Database::new(&url)
.unwrap()
.connection_pool
.get()
.await
.unwrap()
.execute(
r#"
DO $$BEGIN
DROP TABlE IF EXISTS users;
DROP TABLE IF EXISTS migration_info;
END$$;"#,
&[],
)
.await
.unwrap();
Database::new(&url).unwrap()
}
#[test]
fn migrations_have_sequential_versions() {
for i in 0..MIGRATIONS.len() {
assert_eq!(i as i32, MIGRATIONS[i].version);
}
}
#[tokio::test]
async fn migrate_up_and_down_all() {
let db = test_db().await;
assert_eq!(-1, get_db_version(&db).await.unwrap());
migrate_to_version(&db, MIGRATIONS.last().unwrap().version)
.await
.unwrap();
assert_eq!(
MIGRATIONS.last().unwrap().version,
get_db_version(&db).await.unwrap()
);
migrate_to_version(&db, -1).await.unwrap();
assert_eq!(-1, get_db_version(&db).await.unwrap());
}
}

54
src/db/mod.rs Normal file
View File

@ -0,0 +1,54 @@
//! Code for managing the main database.
//!
//! The database schema is defined via a series of migrations in
//! [migrations].
mod migrations;
use {
deadpool_postgres::{CreatePoolError, Pool, Runtime},
thiserror::Error,
tokio_postgres::NoTls,
};
/// Errors that may occur during module initialization
#[derive(Error, Debug)]
pub enum InitialisationError {
#[error("Could not initialize DB connection pool: {}", .0.to_string())]
ConnectionPoolError(#[from] CreatePoolError),
}
/// Errors that may occur during normal app operation
#[derive(Error, Debug)]
pub enum Error {
#[error("{}", .0.to_string())]
PoolError(#[from] deadpool_postgres::PoolError),
#[error("{}", .0.to_string())]
DbError(#[from] tokio_postgres::Error),
}
/// Object that manages the database.
///
/// All database access happens through this struct.
#[derive(Clone, Debug)]
pub struct Database {
connection_pool: Pool,
}
impl Database {
/// Create a connection pool and return the [Database].
pub fn new(connection_url: &str) -> Result<Database, InitialisationError> {
let mut config = deadpool_postgres::Config::new();
config.url = Some(connection_url.to_string());
let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?;
Ok(Database {
connection_pool: pg_pool,
})
}
/// Run migrations as needed to ensure the database schema version
/// match the one used by the current version of the application.
pub async fn migrate_to_current_version(&self) -> Result<(), Error> {
migrations::migrate_to_current_version(self).await
}
}

View File

@ -33,9 +33,9 @@ fn main() {
async fn localhub_main() -> Result<(), Error> { async fn localhub_main() -> Result<(), Error> {
let config = get_config()?; let config = get_config()?;
let db_pool = Database::create_pool(&config.database_url)?; let db_pool = Database::new(&config.database_url)?;
dbg!(db_pool.get_db_version().await.unwrap()); db_pool.migrate_to_current_version().await.unwrap();
let app = app::routes() let app = app::routes()
.with_state(db_pool) .with_state(db_pool)