diff --git a/Cargo.toml b/Cargo.toml index 4f63f6a..d98ade1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,10 @@ sha2 = "0.10" thiserror = "1.0" tokio = { version = "1", features = ["rt-multi-thread"]} tokio-postgres = { version = "0.7", features = ["with-chrono-0_4"] } +tower = { version = "0.4", features = ["util"] } tower-http = { version = "0.5", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", default_features = false, features = ["std", "fmt", "ansi"] } + +[dev-dependencies] +scraper = "0.18" \ No newline at end of file diff --git a/src/app/admin.rs b/src/app/admin.rs index 0984ac0..9246aa9 100644 --- a/src/app/admin.rs +++ b/src/app/admin.rs @@ -10,7 +10,7 @@ use { askama::Template, askama_axum::{IntoResponse, Response}, axum::{ - extract::State, + extract::{NestedPath, State}, response::Redirect, routing::{get, post}, Form, Router, @@ -74,9 +74,10 @@ async fn check_jwt( async fn root( cookie_jar: CookieJar, State(AppState { db, .. }): State>, + path: NestedPath, ) -> Result { Ok(if !db.has_admin_users().await? { - Redirect::temporary("admin/create_first_admin_user").into_response() + Redirect::temporary(&format!("{}/create_first_admin_user", path.as_str())).into_response() } else { let admin_user = check_jwt(&db, &cookie_jar).await?; IndexTemplate { @@ -122,3 +123,116 @@ async fn post_create_first_admin_user( FirstLoginTemplate {}, )) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{app::AppState, db::fake::FakeDatabase}; + use { + axum::{ + body, + body::Body, + http::{Request, StatusCode}, + }, + scraper::{Html, Selector}, + tower::{Service, ServiceExt}, + }; + + #[tokio::test] + async fn root_redirects_when_no_admin_users() { + let app = Router::new() + .nest("/test_admin", routes()) + .with_state(AppState { + db: FakeDatabase::new_empty(), + }); + + let response = app + .oneshot( + Request::builder() + .method(http::Method::GET) + .uri("/test_admin") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::TEMPORARY_REDIRECT); + assert!(response.headers().contains_key("location")); + assert_eq!( + "/test_admin/create_first_admin_user", + response.headers()["location"] + ); + } + + #[tokio::test] + async fn create_first_admin_user() { + let mut app = Router::new() + .nest("/test_admin", routes()) + .with_state(AppState { + db: FakeDatabase::new_empty(), + }) + .into_service(); + + let request = Request::get("/test_admin/create_first_admin_user") + .body(Body::empty()) + .unwrap(); + let response = ServiceExt::>::ready(&mut app) + .await + .unwrap() + .call(request) + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::OK); + let body = body::to_bytes(response.into_body(), 10000).await.unwrap(); + let html = Html::parse_document(&String::from_utf8(body.into()).unwrap()); + let form_selector = Selector::parse("form").unwrap(); + let mut form_elements = html.select(&form_selector); + let form_element = form_elements.next().unwrap(); + assert_eq!(0, form_elements.count()); + assert_eq!(Some("create_first_admin_user"), form_element.attr("action")); + assert_eq!(Some("post"), form_element.attr("method")); + let input_selector = Selector::parse("input").unwrap(); + let inputs: Vec<_> = form_element.select(&input_selector).collect(); + assert_eq!( + 1, + inputs + .iter() + .filter(|elem| elem.attr("name") == Some("real_name")) + .count() + ); + assert_eq!( + 1, + inputs + .iter() + .filter(|elem| elem.attr("name") == Some("email")) + .count() + ); + assert_eq!( + 1, + inputs + .iter() + .filter(|elem| elem.attr("name") == Some("password")) + .filter(|elem| elem.attr("type") == Some("password")) + .count() + ); + + let request = Request::post("/test_admin/create_first_admin_user") + .header( + http::header::CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .body(Body::from( + "real_name=Joe%20User&email=joe%40user.com&password=abc123", + )) + .unwrap(); + let response = ServiceExt::>::ready(&mut app) + .await + .unwrap() + .call(request) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + } +} diff --git a/src/authentication/jwt.rs b/src/authentication/jwt.rs index a631dda..3a47910 100644 --- a/src/authentication/jwt.rs +++ b/src/authentication/jwt.rs @@ -141,7 +141,7 @@ pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result { /// Given JWT string created by [create_jwt_for_user()], check if the /// JWT is valid and return an [AuthenticatedUser] if it is. #[tracing::instrument] -pub async fn authenticate_user_with_jwt(db: &D, jwt: &str) -> Result { +pub async fn authenticate_user_with_jwt(db: &D, jwt: &str) -> Result { if let [header, payload, signature] = jwt.split('.').collect::>().as_slice() { let mut mac = mac(); mac.update(format!("{}.{}", header, payload).as_bytes()); diff --git a/src/authentication/mod.rs b/src/authentication/mod.rs index 6f9155d..b806161 100644 --- a/src/authentication/mod.rs +++ b/src/authentication/mod.rs @@ -7,19 +7,51 @@ use { Argon2, }, std::ops::Deref, + tracing::{error, warn}, }; mod jwt; pub use jwt::{authenticate_user_with_jwt, create_jwt_for_user, Error as JwtError, ParsedJwt}; -#[derive(thiserror::Error, Debug)] +#[derive(Debug)] pub enum AuthenticationError { - #[error("Could not get password hash from database: {}", .0.to_string())] - DatabaseError(#[from] db::Error), + DatabaseError(db::Error), + HashError(argon2::password_hash::Error), +} - #[error("{}", .0.to_string())] - HashError(#[from] argon2::password_hash::Error), +impl std::fmt::Display for AuthenticationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AuthenticationError::HashError(e) => write!(f, "{}", e), + AuthenticationError::DatabaseError(e) => { + write!(f, "Could not get password hash from database: {}", e) + } + } + } +} + +impl std::error::Error for AuthenticationError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + AuthenticationError::HashError(_) => None, + AuthenticationError::DatabaseError(e) => Some(e), + } + } +} + +impl From for AuthenticationError { + fn from(value: db::Error) -> Self { + warn!(details = value.to_string(), "Database error"); + AuthenticationError::DatabaseError(value) + } +} + +impl From for AuthenticationError { + fn from(value: argon2::password_hash::Error) -> Self { + error!(details = value.to_string(), "Error hashing password."); + AuthenticationError::HashError(value) + } } #[derive(Debug, Clone)] @@ -81,6 +113,7 @@ impl From for db::PasswordHash { } } +#[tracing::instrument] pub async fn authenticate_user_with_password( db: &D, user: db::User, @@ -94,6 +127,7 @@ pub async fn authenticate_user_with_password( }) } +#[tracing::instrument] pub async fn check_if_user_is_admin( db: &D, user: &AuthenticatedUser, diff --git a/src/db/fake.rs b/src/db/fake.rs new file mode 100644 index 0000000..6b52ac8 --- /dev/null +++ b/src/db/fake.rs @@ -0,0 +1,128 @@ +use super::*; +use { + std::collections::HashSet, + std::sync::{Arc, Mutex}, +}; + +#[derive(Debug)] +struct UserRow { + real_name: String, + email: String, + password_hash: String, +} + +#[derive(Debug, Clone)] +pub struct FakeDatabase { + users: Arc>>, + admin_users: Arc>>, +} + +impl FakeDatabase { + pub fn new_empty() -> Self { + FakeDatabase { + users: Arc::new(Mutex::new(Vec::new())), + admin_users: Arc::new(Mutex::new(HashSet::new())), + } + } +} + +impl Database for FakeDatabase { + async fn migrate_to_current_version(&self) -> Result<()> { + Ok(()) + } + + async fn has_admin_users(&self) -> Result { + Ok(self.admin_users.lock().unwrap().len() > 0) + } + + async fn create_user( + &self, + real_name: &str, + email: &str, + password: &PasswordHash, + ) -> Result { + let mut users = self.users.lock().unwrap(); + users.push(UserRow { + real_name: real_name.to_string(), + email: email.to_string(), + password_hash: password.to_string(), + }); + Ok(User { + id: UserId((users.len() - 1) as i32), + real_name: real_name.to_string(), + }) + } + + async fn get_password_for_user(&self, user: &User) -> Result { + let users = self.users.lock().unwrap(); + if let Some(UserRow { password_hash, .. }) = users.get(user.id.0 as usize) { + Ok(PasswordHash(password_hash.clone())) + } else { + Err(Error::Database) + } + } + + async fn create_first_admin_user( + &self, + real_name: &str, + email: &str, + password: &PasswordHash, + ) -> Result { + let user = self.create_user(real_name, email, password).await?; + let mut admin_users = self.admin_users.lock().unwrap(); + admin_users.insert(user.id.0 as usize); + Ok(user) + } + + async fn is_user_admin(&self, user: &User) -> Result { + let admin_users = self.admin_users.lock().unwrap(); + Ok(admin_users.contains(&(user.id.0 as usize))) + } + + async fn get_user_with_id(&self, user_id: UserId) -> Result> { + let users = self.users.lock().unwrap(); + Ok(users + .get(user_id.0 as usize) + .map(|UserRow { real_name, .. }| User { + id: user_id, + real_name: real_name.clone(), + })) + } +} + +mod tests { + use super::*; + + #[tokio::test] + async fn store_user() { + let target = FakeDatabase::new_empty(); + + let user = target + .create_user( + "Jane Doe", + "jane.doe@example.com", + &PasswordHash("iamjane!".to_string()), + ) + .await + .unwrap(); + let saved_user = target.get_user_with_id(user.id).await.unwrap().unwrap(); + assert_eq!(user.id, saved_user.id); + assert_eq!(user.real_name, saved_user.real_name); + } + + #[tokio::test] + async fn store_user_password() { + let target = FakeDatabase::new_empty(); + + let user = target + .create_user( + "Jane Doe", + "jane.doe@example.com", + &PasswordHash("iamjane!".to_string()), + ) + .await + .unwrap(); + let saved_password = target.get_password_for_user(&user).await.unwrap(); + assert_eq!("iamjane!", saved_password.0); + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs index 6266a64..394e91e 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -5,6 +5,9 @@ mod migrations; +#[cfg(test)] +pub mod fake; + use { deadpool_postgres::{CreatePoolError, Pool, Runtime}, serde::{Deserialize, Serialize}, @@ -43,7 +46,7 @@ pub type InitialisationResult = std::result::Result; #[derive(Debug)] pub enum Error { Pool(deadpool_postgres::PoolError), - Postgres(tokio_postgres::Error), + Database, NotAllowed, } @@ -51,7 +54,7 @@ impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Error::Pool(e) => e.fmt(f), - Error::Postgres(e) => e.fmt(f), + Error::Database => write!(f, "Database Error"), Error::NotAllowed => write!(f, "Not Allowed"), } } @@ -78,7 +81,7 @@ impl From for Error { .unwrap_or(&value.to_string()), "PostgreSQL error" ); - Error::Postgres(value) + Error::Database } } @@ -88,7 +91,6 @@ pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static { /// Run migrations as needed to ensure the database schema version /// match the one used by the current version of the application. fn migrate_to_current_version(&self) -> impl Future> + Send; - fn get_client(&self) -> impl Future> + Send; fn has_admin_users(&self) -> impl Future> + Send; fn create_user( &self, @@ -121,7 +123,7 @@ pub struct PostgresDatabase { connection_pool: Pool, } -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] pub struct UserId(i32); #[derive(Debug, Clone)] @@ -157,6 +159,10 @@ impl PostgresDatabase { connection_pool: pg_pool, }) } + + async fn get_client(&self) -> Result { + Ok(self.connection_pool.get().await?) + } } impl Database for PostgresDatabase { @@ -166,10 +172,6 @@ impl Database for PostgresDatabase { migrations::migrate_to_current_version(self).await } - async fn get_client(&self) -> Result { - Ok(self.connection_pool.get().await?) - } - #[tracing::instrument] async fn has_admin_users(&self) -> Result { let client = self.get_client().await?; diff --git a/src/error.rs b/src/error.rs index 4d840b4..00161b6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -46,14 +46,17 @@ impl IntoResponse for Error { .into_response() } Error::AuthenticationError(_) => { - todo!() + error!("Uncaught authentication error producing HTTP 500."); + ( + StatusCode::INTERNAL_SERVER_ERROR, + ErrorTemplate { title: "Error" }, + ) + .into_response() } Error::Unexpected(_) => { todo!() } - Error::Forbidden => { - (StatusCode::UNAUTHORIZED, "User not authorized.").into_response() - } + Error::Forbidden => (StatusCode::UNAUTHORIZED, "User not authorized.").into_response(), Error::JwtExpired(_) => { todo!() }