diff --git a/Cargo.toml b/Cargo.toml index d98ade1..495935c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ http = "1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" sha2 = "0.10" -thiserror = "1.0" tokio = { version = "1", features = ["rt-multi-thread"]} tokio-postgres = { version = "0.7", features = ["with-chrono-0_4"] } tower = { version = "0.4", features = ["util"] } diff --git a/src/app/admin.rs b/src/app/admin.rs index 9246aa9..c14fd38 100644 --- a/src/app/admin.rs +++ b/src/app/admin.rs @@ -48,6 +48,7 @@ struct IndexTemplate<'a> { admin_user_name: &'a str, } +#[tracing::instrument] async fn check_jwt( db: &D, cookie_jar: &CookieJar, @@ -57,16 +58,16 @@ async fn check_jwt( cookie_jar .get("jwt") .map(|cookie| cookie.value_trimmed()) - .ok_or(Error::Forbidden)?, + .ok_or(Error::new_unauthorized())?, ) .await? { ParsedJwt::Valid(user) => check_if_user_is_admin(db, &user) .await? - .ok_or(Error::Forbidden), - ParsedJwt::InvalidSignature => Err(Error::Forbidden), - ParsedJwt::UserNotFound => Err(Error::Forbidden), - ParsedJwt::Expired(user) => Err(Error::JwtExpired(user)), + .ok_or(Error::new_unauthorized()), + ParsedJwt::InvalidSignature => Err(Error::new_unauthorized()), + ParsedJwt::UserNotFound => Err(Error::new_unauthorized()), + ParsedJwt::Expired(user) => Err(Error::new_jwt_expired(user)), } } @@ -114,8 +115,8 @@ async fn post_create_first_admin_user( .await?; let user = authenticate_user_with_password(&db, user, ¶ms.password) .await? - .ok_or(Error::Unexpected( - "Could not authenticate newly-created user.".to_string(), + .ok_or(Error::new_unexpected( + "Could not authenticate newly-created user.", ))?; Ok(( cookie_jar diff --git a/src/config.rs b/src/config.rs index 322bd0b..23c2234 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,13 +1,24 @@ -use {std::env, thiserror::Error}; +use std::env; const ENV_VAR_PREFIX: &str = "LOCALITY_"; -#[derive(Error, Debug)] +#[derive(Debug)] pub enum Error { - #[error("The environment variable \"{0}\" must be set.")] MissingEnvironmentVariable(String), } +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::MissingEnvironmentVariable(v) => { + write!(f, "The environment variable \"{}\" must be set.", v) + } + } + } +} + +impl std::error::Error for Error {} + pub struct Config { pub database_url: String, pub static_file_path: String, diff --git a/src/error.rs b/src/error.rs index 00161b6..577cefe 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,31 +1,131 @@ +//! Handle high-level errors that should be converted into HTTP +//! responses. +//! +//! The [Error] type defined in this module is used by route handlers +//! to convert errors into HTTP responses. Some error types (such as +//! [db::Error]) are automatcially converted into [Error]. [Error] +//! also supplies some constructors, such as [Error::new_unexpected()] +//! and [Error::new_unauthorized] that can be used directly. +//! +//! The intended use is that error types which have [FromError] +//! defined might sometimes be handled at a higher level. If an error +//! should always be handled here then other modules should call the +//! appropriate [Error] constructor directly. + use { crate::{authentication, db}, askama::Template, askama_axum::{IntoResponse, Response}, http::status::StatusCode, - tracing::{error, warn}, + tracing::{debug, error, info}, }; -#[allow(clippy::enum_variant_names)] -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Database Error: {}", 0.to_string())] - DatabaseError(#[from] db::Error), - - #[error("Authentication error")] - AuthenticationError(#[from] authentication::AuthenticationError), - - #[error("Unexpected error: {}", .0)] - Unexpected(String), - - #[error("Forbidden")] - Forbidden, - - #[error("JWT Expired")] +#[derive(Debug)] +enum ErrorType { + InternalServerError, + Unauthorized, JwtExpired(db::User), +} - #[error("JWT Error")] - JwtError(#[from] authentication::JwtError), +impl std::fmt::Display for ErrorType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ErrorType::InternalServerError => write!(f, "InternalServerError"), + ErrorType::Unauthorized => write!(f, "Unauthorized"), + ErrorType::JwtExpired(user) => write!(f, "JWT Expired for {:?}", user), + } + } +} + +/// Error which should be converted into a HTTP response +#[derive(Debug)] +pub struct Error { + error_type: ErrorType, + inner: Option>, +} + +impl Error { + /// An unexpected error has occurred which prevented the request + /// from being handled properly.. + pub fn new_unexpected(message: &str) -> Error { + error!("Unexpected error: {}", message); + Error { + error_type: ErrorType::InternalServerError, + inner: None, + } + } + + /// Either authorization failed or the current user is not + /// authorized to perform the requested action + pub fn new_unauthorized() -> Error { + info!("Unauthorized user"); + Error { + error_type: ErrorType::Unauthorized, + inner: None, + } + } + + /// The user's JWT has expired and the they must log in again + pub fn new_jwt_expired(user: db::User) -> Error { + debug!("Jwt Expired"); + Error { + error_type: ErrorType::JwtExpired(user), + inner: None, + } + } +} + +impl From for Error { + fn from(value: db::Error) -> Self { + Error { + error_type: ErrorType::InternalServerError, + inner: Some(Box::new(value)), + } + } +} + +impl From for Error { + fn from(value: authentication::JwtError) -> Self { + error!(inner = value.to_string(), "Unhandled JWT error"); + Error { + error_type: ErrorType::InternalServerError, + inner: Some(Box::new(value)), + } + } +} + +impl From for Error { + fn from(value: authentication::AuthenticationError) -> Self { + error!(inner = value.to_string(), "Error during authentication"); + Error { + error_type: ErrorType::InternalServerError, + inner: Some(Box::new(value)), + } + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error { + error_type, + inner: Some(inner_err), + } => write!(f, "{}: {}", error_type, inner_err), + Error { + error_type, + inner: None, + } => write!(f, "{}", error_type), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.inner { + Some(inner) => Some(inner.as_ref()), + None => None, + } + } } #[derive(Template)] @@ -36,34 +136,21 @@ struct ErrorTemplate<'a> { impl IntoResponse for Error { fn into_response(self) -> Response { - match self { - Error::DatabaseError(_) => { - error!("Uncaught database error producing HTTP 500."); + match self.error_type { + ErrorType::InternalServerError => { + error!(inner = self.inner, "Uncaught error producing HTTP 500."); ( StatusCode::INTERNAL_SERVER_ERROR, ErrorTemplate { title: "Error" }, ) .into_response() } - Error::AuthenticationError(_) => { - error!("Uncaught authentication error producing HTTP 500."); - ( - StatusCode::INTERNAL_SERVER_ERROR, - ErrorTemplate { title: "Error" }, - ) - .into_response() + ErrorType::Unauthorized => { + (StatusCode::UNAUTHORIZED, "User not authorized.").into_response() } - Error::Unexpected(_) => { + ErrorType::JwtExpired(_) => { todo!() } - Error::Forbidden => (StatusCode::UNAUTHORIZED, "User not authorized.").into_response(), - Error::JwtExpired(_) => { - todo!() - } - Error::JwtError(jwt_error) => { - warn!(detail = jwt_error.to_string(), "Checking JWT"); - (StatusCode::UNAUTHORIZED, ErrorTemplate { title: "Error" }).into_response() - } } } } diff --git a/src/main.rs b/src/main.rs index 60af8b1..e72b579 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,4 @@ use { - thiserror::Error, tower_http::{services::ServeDir, trace::TraceLayer}, tracing::Level, }; @@ -13,19 +12,63 @@ mod error; use config::get_config; use db::{Database, PostgresDatabase}; -#[derive(Error, Debug)] -pub enum Error { - #[error("Loading configuration: \"{0}\"")] - MissingConfigError(#[from] config::Error), - #[error("Database error: {0}")] - DatabaseError(#[from] db::InitialisationError), +/// An unrecoverable error which requires the server to shut down +#[derive(Debug)] +struct FatalError { + message: &'static str, + inner: Option>, +} - #[error("{0}")] - IOError(#[from] std::io::Error), +impl std::fmt::Display for FatalError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.inner { + Some(inner_err) => write!(f, "{}: {}", self.message, inner_err), + None => write!(f, "{}", self.message), + } + } +} - #[error("{0}")] - TracingError(#[from] tracing::subscriber::SetGlobalDefaultError), +impl std::error::Error for FatalError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.inner.as_deref() + } +} + +impl From for FatalError { + fn from(value: config::Error) -> Self { + FatalError { + message: "Loading config", + inner: Some(Box::new(value)), + } + } +} + +impl From for FatalError { + fn from(value: tracing::subscriber::SetGlobalDefaultError) -> Self { + FatalError { + message: "Loading config", + inner: Some(Box::new(value)), + } + } +} + +impl From for FatalError { + fn from(value: db::InitialisationError) -> Self { + FatalError { + message: "initialising database connection", + inner: Some(Box::new(value)), + } + } +} + +impl From for FatalError { + fn from(value: std::io::Error) -> Self { + FatalError { + message: "Initialising", + inner: Some(Box::new(value)), + } + } } fn main() { @@ -39,7 +82,7 @@ fn main() { }) } -async fn locality_main() -> Result<(), Error> { +async fn locality_main() -> Result<(), FatalError> { let config = get_config()?; let subscriber = tracing_subscriber::FmtSubscriber::builder()