Add migrations and other foundational database stuff
This commit is contained in:
parent
8663a271dc
commit
5ad4f80fc1
|
|
@ -1,7 +1,7 @@
|
|||
use {
|
||||
crate::db::Database,
|
||||
askama::Template,
|
||||
axum::{extract::State, routing::get, Router},
|
||||
axum::{routing::get, Router},
|
||||
};
|
||||
|
||||
pub fn routes() -> Router<Database> {
|
||||
|
|
@ -14,6 +14,6 @@ struct IndexTemplate<'a> {
|
|||
title: &'a str,
|
||||
}
|
||||
|
||||
async fn root<'a>(State(database): State<Database>) -> IndexTemplate<'a> {
|
||||
async fn root<'a>() -> IndexTemplate<'a> {
|
||||
IndexTemplate { title: "LocalHub" }
|
||||
}
|
||||
|
|
|
|||
56
src/db.rs
56
src/db.rs
|
|
@ -1,56 +0,0 @@
|
|||
use {
|
||||
deadpool_postgres::{CreatePoolError, Pool, Runtime},
|
||||
thiserror::Error,
|
||||
tokio_postgres::NoTls,
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum InitialisationError {
|
||||
#[error("Could not initialize DB connection pool.")]
|
||||
ConnectionPoolError(#[from] CreatePoolError),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("")]
|
||||
PoolError(#[from] deadpool_postgres::PoolError),
|
||||
#[error("")]
|
||||
DbError(#[from] tokio_postgres::Error)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Database {
|
||||
connection_pool: Pool,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
pub fn create_pool(connection_url: &str) -> Result<Database, InitialisationError> {
|
||||
let mut config = deadpool_postgres::Config::new();
|
||||
config.url = Some(connection_url.to_string());
|
||||
let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?;
|
||||
Ok(Database { connection_pool: pg_pool })
|
||||
}
|
||||
|
||||
pub async fn get_db_version(&self) -> Result<i32, Error> {
|
||||
let client = self.connection_pool.get().await?;
|
||||
client.execute(r#"
|
||||
DO $$BEGIN
|
||||
IF NOT EXISTS
|
||||
( SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'migration_info'
|
||||
)
|
||||
THEN
|
||||
CREATE TABLE migration_info (
|
||||
onerow_id bool PRIMARY KEY DEFAULT true,
|
||||
migration_version INTEGER,
|
||||
CONSTRAINT onerow_unique CHECK (onerow_id));
|
||||
INSERT
|
||||
INTO migration_info (migration_version)
|
||||
VALUES (-1);
|
||||
END IF;
|
||||
END$$;"#, &[]).await?;
|
||||
Ok(client.query_one(r#"SELECT migration_version FROM migration_info;"#, &[]).await?.get(0))
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
//! Database schema migrations and the code to manage them.
|
||||
//!
|
||||
//! - [MIGRATIONS] stores a list of schema migrations defined in SQL.
|
||||
//!
|
||||
//! - [CURRENT_VERSION] is the schema version used by the current app
|
||||
//! version.
|
||||
//!
|
||||
//! - [get_db_version()] returns the schema version of the database
|
||||
//! itself.
|
||||
|
||||
//! - [migrate_to_current_version()] Applies whatever migrations are
|
||||
//! needed to that the database schema version (as returned by
|
||||
//! [get_db_version()]) matches [CURRENT_VERSION].
|
||||
|
||||
use super::{Database, Error};
|
||||
|
||||
/// Defines a database schema migration
|
||||
#[derive(Debug)]
|
||||
struct Migration {
|
||||
/// The schema version that `up` will migrate to.
|
||||
version: i32,
|
||||
/// SQL to migrate to [version](Migration::version) from the
|
||||
/// previous schema version.
|
||||
///
|
||||
/// May contain multiple SQL statements: they will be wrapped in a
|
||||
/// `DO $$BEGIN`...`END$$;` block when executed.
|
||||
up: &'static str,
|
||||
/// SQL to migrate from [version](Migration::version) to the
|
||||
/// previous schema version.
|
||||
///
|
||||
/// May contain multiple SQL statements: they will be wrapped in a
|
||||
/// `DO $$BEGIN`...`END$$;` block when executed.
|
||||
down: &'static str,
|
||||
}
|
||||
|
||||
/// A list of all the schema migrations.
|
||||
///
|
||||
/// The [version](Migration::version) field should always correspond to the
|
||||
/// [Migration]'s position in the list, starting at 0 and working up
|
||||
/// sequentially.
|
||||
///
|
||||
/// New versions should normally only ever add to the end of this
|
||||
/// list, never change or add things in the middle or at the
|
||||
/// beginning.
|
||||
static MIGRATIONS: &[Migration] = &[Migration {
|
||||
version: 0,
|
||||
up: r#"
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
real_name TEXT NOT NULL,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
-- 43 characters is enought to base64-encode 32 bits
|
||||
password_salt CHARACTER(43) NOT NULL,
|
||||
password_hash CHARACTER(43) NOT NULL);"#,
|
||||
down: r#"DROP TABLE users;"#,
|
||||
}];
|
||||
|
||||
/// The current schema version. Normally this will be the
|
||||
/// [version](Migration::version) of the list item in [MIGRATIONS].
|
||||
///
|
||||
/// This is the the current version *as specified by the
|
||||
/// application*. It may not be the actual schema of the database. If
|
||||
/// it is not the current database schema, then running
|
||||
/// [migrate_to_current_version()] will migrate the database to this
|
||||
/// version.
|
||||
static CURRENT_VERSION: i32 = 0;
|
||||
|
||||
/// If the database is not already using the most recent schema, apply
|
||||
/// one migration to bring it to the next newest version.
|
||||
///
|
||||
/// E.g. If [CURRENT_VERSION] is 10 but the database is on version 4,
|
||||
/// then running this function will apply [Migration] 5 to bring the
|
||||
/// database up to schema version 5.
|
||||
async fn migrate_up(db: &Database) -> Result<i32, Error> {
|
||||
let current_version = dbg!(get_db_version(db).await?);
|
||||
if let Some(migration) = if current_version < 1 {
|
||||
MIGRATIONS.first()
|
||||
} else {
|
||||
MIGRATIONS.iter().find(|m| m.version == current_version)
|
||||
} {
|
||||
let client = db.connection_pool.get().await?;
|
||||
client
|
||||
.execute(&format!("DO $$BEGIN\n{}\nEND$$;", migration.up), &[])
|
||||
.await?;
|
||||
client
|
||||
.execute(
|
||||
"UPDATE migration_info SET migration_version = $1;",
|
||||
&[&migration.version],
|
||||
)
|
||||
.await?;
|
||||
Ok(dbg!(migration.version))
|
||||
} else {
|
||||
eprintln!("ERROR: Attempted to migrate up past last migration.");
|
||||
Ok(current_version)
|
||||
}
|
||||
}
|
||||
|
||||
/// Revert back to the previous schema version by running the
|
||||
/// [down](Migration::down) SQL of the current `Migration`
|
||||
async fn migrate_down(db: &Database) -> Result<i32, Error> {
|
||||
let current_version = get_db_version(db).await?;
|
||||
let mut migration_iter = MIGRATIONS
|
||||
.iter()
|
||||
.rev()
|
||||
.skip_while(|m| m.version != current_version);
|
||||
if let Some(migration) = migration_iter.next() {
|
||||
let client = db.connection_pool.get().await?;
|
||||
client
|
||||
.execute(&format!("DO $$BEGIN\n{}\nEND$$;", migration.down), &[])
|
||||
.await?;
|
||||
let version = migration_iter.next().map_or(-1, |m| m.version);
|
||||
client
|
||||
.execute(
|
||||
"UPDATE migration_info SET migration_version = $1;",
|
||||
&[&version],
|
||||
)
|
||||
.await?;
|
||||
Ok(version)
|
||||
} else {
|
||||
eprintln!("ERROR: Attempted to migrate down past first migration.");
|
||||
Ok(current_version)
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply whatever migrations are necessary to bring the database
|
||||
/// schema to the same version is [CURRENT_VERSION].
|
||||
pub async fn migrate_to_current_version(db: &Database) -> Result<(), Error> {
|
||||
migrate_to_version(db, CURRENT_VERSION).await
|
||||
}
|
||||
|
||||
/// Apply whatever migrations are necessary to bring the database
|
||||
/// schema to the same version as `target_version`.
|
||||
///
|
||||
/// This may migrate up or down as required.
|
||||
async fn migrate_to_version(db: &Database, target_version: i32) -> Result<(), Error> {
|
||||
let mut version = get_db_version(db).await?;
|
||||
while version != target_version {
|
||||
if version < target_version {
|
||||
version = migrate_up(db).await?;
|
||||
} else {
|
||||
version = migrate_down(db).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the current schema version of the database.
|
||||
pub async fn get_db_version(db: &Database) -> Result<i32, Error> {
|
||||
let client = db.connection_pool.get().await?;
|
||||
client
|
||||
.execute(
|
||||
r#"
|
||||
DO $$BEGIN
|
||||
IF NOT EXISTS
|
||||
( SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'migration_info'
|
||||
)
|
||||
THEN
|
||||
CREATE TABLE migration_info (
|
||||
onerow_id BOOL PRIMARY KEY DEFAULT true,
|
||||
migration_version INTEGER,
|
||||
CONSTRAINT onerow_unique CHECK (onerow_id));
|
||||
INSERT
|
||||
INTO migration_info (migration_version)
|
||||
VALUES (-1);
|
||||
END IF;
|
||||
END$$;"#,
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
Ok(client
|
||||
.query_one(r#"SELECT migration_version FROM migration_info;"#, &[])
|
||||
.await?
|
||||
.get(0))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::Database;
|
||||
use super::*;
|
||||
|
||||
async fn test_db() -> Database {
|
||||
let url = std::env::var("LOCALHUB_TEST_DATABASE_URL").unwrap();
|
||||
Database::new(&url)
|
||||
.unwrap()
|
||||
.connection_pool
|
||||
.get()
|
||||
.await
|
||||
.unwrap()
|
||||
.execute(
|
||||
r#"
|
||||
DO $$BEGIN
|
||||
DROP TABlE IF EXISTS users;
|
||||
DROP TABLE IF EXISTS migration_info;
|
||||
END$$;"#,
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
Database::new(&url).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn migrations_have_sequential_versions() {
|
||||
for i in 0..MIGRATIONS.len() {
|
||||
assert_eq!(i as i32, MIGRATIONS[i].version);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn migrate_up_and_down_all() {
|
||||
let db = test_db().await;
|
||||
assert_eq!(-1, get_db_version(&db).await.unwrap());
|
||||
migrate_to_version(&db, MIGRATIONS.last().unwrap().version)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
MIGRATIONS.last().unwrap().version,
|
||||
get_db_version(&db).await.unwrap()
|
||||
);
|
||||
migrate_to_version(&db, -1).await.unwrap();
|
||||
assert_eq!(-1, get_db_version(&db).await.unwrap());
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
//! Code for managing the main database.
|
||||
//!
|
||||
//! The database schema is defined via a series of migrations in
|
||||
//! [migrations].
|
||||
|
||||
mod migrations;
|
||||
|
||||
use {
|
||||
deadpool_postgres::{CreatePoolError, Pool, Runtime},
|
||||
thiserror::Error,
|
||||
tokio_postgres::NoTls,
|
||||
};
|
||||
|
||||
/// Errors that may occur during module initialization
|
||||
#[derive(Error, Debug)]
|
||||
pub enum InitialisationError {
|
||||
#[error("Could not initialize DB connection pool: {}", .0.to_string())]
|
||||
ConnectionPoolError(#[from] CreatePoolError),
|
||||
}
|
||||
|
||||
/// Errors that may occur during normal app operation
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("{}", .0.to_string())]
|
||||
PoolError(#[from] deadpool_postgres::PoolError),
|
||||
#[error("{}", .0.to_string())]
|
||||
DbError(#[from] tokio_postgres::Error),
|
||||
}
|
||||
|
||||
/// Object that manages the database.
|
||||
///
|
||||
/// All database access happens through this struct.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Database {
|
||||
connection_pool: Pool,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Create a connection pool and return the [Database].
|
||||
pub fn new(connection_url: &str) -> Result<Database, InitialisationError> {
|
||||
let mut config = deadpool_postgres::Config::new();
|
||||
config.url = Some(connection_url.to_string());
|
||||
let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?;
|
||||
Ok(Database {
|
||||
connection_pool: pg_pool,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run migrations as needed to ensure the database schema version
|
||||
/// match the one used by the current version of the application.
|
||||
pub async fn migrate_to_current_version(&self) -> Result<(), Error> {
|
||||
migrations::migrate_to_current_version(self).await
|
||||
}
|
||||
}
|
||||
|
|
@ -33,9 +33,9 @@ fn main() {
|
|||
async fn localhub_main() -> Result<(), Error> {
|
||||
let config = get_config()?;
|
||||
|
||||
let db_pool = Database::create_pool(&config.database_url)?;
|
||||
let db_pool = Database::new(&config.database_url)?;
|
||||
|
||||
dbg!(db_pool.get_db_version().await.unwrap());
|
||||
db_pool.migrate_to_current_version().await.unwrap();
|
||||
|
||||
let app = app::routes()
|
||||
.with_state(db_pool)
|
||||
|
|
|
|||
Loading…
Reference in New Issue