//! Code for managing the main database. //! //! The database schema is defined via a series of migrations in //! [migrations]. mod migrations; #[cfg(test)] pub mod fake; use { deadpool_postgres::{CreatePoolError, Pool, Runtime}, serde::{Deserialize, Serialize}, std::{future::Future, ops::Deref}, tokio_postgres::NoTls, tracing::error, }; /// Errors that may occur during module initialization #[derive(Debug)] pub enum InitialisationError { ConnectionPoolError(CreatePoolError), } impl std::fmt::Display for InitialisationError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { InitialisationError::ConnectionPoolError(e) => { write!(f, "Could not initialise DB connection pool: {}", e) } } } } impl std::error::Error for InitialisationError {} impl From for InitialisationError { fn from(value: CreatePoolError) -> Self { InitialisationError::ConnectionPoolError(value) } } pub type InitialisationResult = std::result::Result; /// Errors that may occur during normal app operation #[derive(Debug)] pub enum Error { Pool(deadpool_postgres::PoolError), Database, NotAllowed, } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Error::Pool(e) => e.fmt(f), Error::Database => write!(f, "Database Error"), Error::NotAllowed => write!(f, "Not Allowed"), } } } impl std::error::Error for Error {} impl From for Error { fn from(value: deadpool_postgres::PoolError) -> Self { error!( details = value.to_string(), "Error with deadpool_postgress connection pool" ); Self::Pool(value) } } impl From for Error { fn from(value: tokio_postgres::Error) -> Self { error!( details = value .as_db_error() .and_then(|db_error| db_error.detail()) .unwrap_or(&value.to_string()), "PostgreSQL error" ); Error::Database } } pub type Result = std::result::Result; pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static { /// Run migrations as needed to ensure the database schema version /// match the one used by the current version of the application. fn migrate_to_current_version(&self) -> impl Future> + Send; fn has_admin_users(&self) -> impl Future> + Send; fn create_user( &self, real_name: &str, email: &str, password: &PasswordHash, ) -> impl Future> + Send; fn get_password_for_user( &self, user: &User, ) -> impl Future> + Send; fn create_first_admin_user( &self, real_name: &str, email: &str, password: &PasswordHash, ) -> impl Future> + Send; fn is_user_admin(&self, user: &User) -> impl Future> + Send; fn get_user_with_id( &self, user_id: UserId, ) -> impl Future>> + Send; } /// Object that manages the database. /// /// All database access happens through this struct. #[derive(Clone, Debug)] pub struct PostgresDatabase { connection_pool: Pool, } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] pub struct UserId(i32); #[derive(Debug, Clone)] pub struct User { id: UserId, pub real_name: String, } impl User { pub fn get_id(&self) -> UserId { self.id } } #[derive(Debug)] pub struct PasswordHash(pub String); impl Deref for PasswordHash { type Target = str; fn deref(&self) -> &str { &self.0 } } impl PostgresDatabase { /// Create a connection pool and return the [Database]. pub fn new(connection_url: &str) -> InitialisationResult { let mut config = deadpool_postgres::Config::new(); config.url = Some(connection_url.to_string()); let pg_pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?; Ok(PostgresDatabase { connection_pool: pg_pool, }) } async fn get_client(&self) -> Result { Ok(self.connection_pool.get().await?) } } impl Database for PostgresDatabase { /// Run migrations as needed to ensure the database schema version /// match the one used by the current version of the application. async fn migrate_to_current_version(&self) -> Result<()> { migrations::migrate_to_current_version(self).await } #[tracing::instrument] async fn has_admin_users(&self) -> Result { let client = self.get_client().await?; Ok(client .query_one("SELECT EXISTS(SELECT 1 FROM admin_users);", &[]) .await? .get(0)) } #[tracing::instrument] async fn create_user( &self, real_name: &str, email: &str, password: &PasswordHash, ) -> Result { let client = self.get_client().await?; let id = client .query_one( r#" INSERT INTO users (real_name, email, password) VALUES ($1, $2, $3) RETURNING id;"#, &[&real_name, &email, &password.0], ) .await? .get(0); Ok(User { id: UserId(id), real_name: real_name.to_string(), }) } #[tracing::instrument] async fn get_password_for_user(&self, user: &User) -> Result { let client = self.get_client().await?; let row = client .query_one("SELECT password FROM users WHERE id = $1;", &[&user.id.0]) .await?; Ok(PasswordHash(row.get(0))) } #[tracing::instrument] async fn create_first_admin_user( &self, real_name: &str, email: &str, password: &PasswordHash, ) -> Result { if self.has_admin_users().await? { return Err(Error::NotAllowed); } let user = self.create_user(real_name, email, password).await?; let client = self.get_client().await?; client .execute("INSERT INTO admin_users (id) VALUES ($1)", &[&user.id.0]) .await?; Ok(User { id: user.id, real_name: user.real_name, }) } #[tracing::instrument] async fn is_user_admin(&self, user: &User) -> Result { Ok(self .get_client() .await? .query_opt("SELECT 1 FROM admin_users WHERE id = $1;", &[&user.id.0]) .await? .is_some()) } #[tracing::instrument] async fn get_user_with_id(&self, user_id: UserId) -> Result> { Ok(self .get_client() .await? .query_opt("SELECT real_name FROM users WHERE id = $1;", &[&user_id.0]) .await? .map(|row| User { id: user_id, real_name: row.get(0), })) } }