Remove thiserror and continue error handling improvements

This commit is contained in:
Matthew Gordon 2024-02-28 12:07:12 -04:00
parent fbb320507a
commit 242965450a
5 changed files with 202 additions and 61 deletions

View File

@ -20,7 +20,6 @@ http = "1.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
sha2 = "0.10" sha2 = "0.10"
thiserror = "1.0"
tokio = { version = "1", features = ["rt-multi-thread"]} tokio = { version = "1", features = ["rt-multi-thread"]}
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4"] } tokio-postgres = { version = "0.7", features = ["with-chrono-0_4"] }
tower = { version = "0.4", features = ["util"] } tower = { version = "0.4", features = ["util"] }

View File

@ -48,6 +48,7 @@ struct IndexTemplate<'a> {
admin_user_name: &'a str, admin_user_name: &'a str,
} }
#[tracing::instrument]
async fn check_jwt<D: Database>( async fn check_jwt<D: Database>(
db: &D, db: &D,
cookie_jar: &CookieJar, cookie_jar: &CookieJar,
@ -57,16 +58,16 @@ async fn check_jwt<D: Database>(
cookie_jar cookie_jar
.get("jwt") .get("jwt")
.map(|cookie| cookie.value_trimmed()) .map(|cookie| cookie.value_trimmed())
.ok_or(Error::Forbidden)?, .ok_or(Error::new_unauthorized())?,
) )
.await? .await?
{ {
ParsedJwt::Valid(user) => check_if_user_is_admin(db, &user) ParsedJwt::Valid(user) => check_if_user_is_admin(db, &user)
.await? .await?
.ok_or(Error::Forbidden), .ok_or(Error::new_unauthorized()),
ParsedJwt::InvalidSignature => Err(Error::Forbidden), ParsedJwt::InvalidSignature => Err(Error::new_unauthorized()),
ParsedJwt::UserNotFound => Err(Error::Forbidden), ParsedJwt::UserNotFound => Err(Error::new_unauthorized()),
ParsedJwt::Expired(user) => Err(Error::JwtExpired(user)), ParsedJwt::Expired(user) => Err(Error::new_jwt_expired(user)),
} }
} }
@ -114,8 +115,8 @@ async fn post_create_first_admin_user<D: Database>(
.await?; .await?;
let user = authenticate_user_with_password(&db, user, &params.password) let user = authenticate_user_with_password(&db, user, &params.password)
.await? .await?
.ok_or(Error::Unexpected( .ok_or(Error::new_unexpected(
"Could not authenticate newly-created user.".to_string(), "Could not authenticate newly-created user.",
))?; ))?;
Ok(( Ok((
cookie_jar cookie_jar

View File

@ -1,13 +1,24 @@
use {std::env, thiserror::Error}; use std::env;
const ENV_VAR_PREFIX: &str = "LOCALITY_"; const ENV_VAR_PREFIX: &str = "LOCALITY_";
#[derive(Error, Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
#[error("The environment variable \"{0}\" must be set.")]
MissingEnvironmentVariable(String), MissingEnvironmentVariable(String),
} }
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingEnvironmentVariable(v) => {
write!(f, "The environment variable \"{}\" must be set.", v)
}
}
}
}
impl std::error::Error for Error {}
pub struct Config { pub struct Config {
pub database_url: String, pub database_url: String,
pub static_file_path: String, pub static_file_path: String,

View File

@ -1,31 +1,131 @@
//! Handle high-level errors that should be converted into HTTP
//! responses.
//!
//! The [Error] type defined in this module is used by route handlers
//! to convert errors into HTTP responses. Some error types (such as
//! [db::Error]) are automatcially converted into [Error]. [Error]
//! also supplies some constructors, such as [Error::new_unexpected()]
//! and [Error::new_unauthorized] that can be used directly.
//!
//! The intended use is that error types which have [FromError]
//! defined might sometimes be handled at a higher level. If an error
//! should always be handled here then other modules should call the
//! appropriate [Error] constructor directly.
use { use {
crate::{authentication, db}, crate::{authentication, db},
askama::Template, askama::Template,
askama_axum::{IntoResponse, Response}, askama_axum::{IntoResponse, Response},
http::status::StatusCode, http::status::StatusCode,
tracing::{error, warn}, tracing::{debug, error, info},
}; };
#[allow(clippy::enum_variant_names)] #[derive(Debug)]
#[derive(thiserror::Error, Debug)] enum ErrorType {
pub enum Error { InternalServerError,
#[error("Database Error: {}", 0.to_string())] Unauthorized,
DatabaseError(#[from] db::Error),
#[error("Authentication error")]
AuthenticationError(#[from] authentication::AuthenticationError),
#[error("Unexpected error: {}", .0)]
Unexpected(String),
#[error("Forbidden")]
Forbidden,
#[error("JWT Expired")]
JwtExpired(db::User), JwtExpired(db::User),
}
#[error("JWT Error")] impl std::fmt::Display for ErrorType {
JwtError(#[from] authentication::JwtError), fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorType::InternalServerError => write!(f, "InternalServerError"),
ErrorType::Unauthorized => write!(f, "Unauthorized"),
ErrorType::JwtExpired(user) => write!(f, "JWT Expired for {:?}", user),
}
}
}
/// Error which should be converted into a HTTP response
#[derive(Debug)]
pub struct Error {
error_type: ErrorType,
inner: Option<Box<dyn std::error::Error + Send + Sync>>,
}
impl Error {
/// An unexpected error has occurred which prevented the request
/// from being handled properly..
pub fn new_unexpected(message: &str) -> Error {
error!("Unexpected error: {}", message);
Error {
error_type: ErrorType::InternalServerError,
inner: None,
}
}
/// Either authorization failed or the current user is not
/// authorized to perform the requested action
pub fn new_unauthorized() -> Error {
info!("Unauthorized user");
Error {
error_type: ErrorType::Unauthorized,
inner: None,
}
}
/// The user's JWT has expired and the they must log in again
pub fn new_jwt_expired(user: db::User) -> Error {
debug!("Jwt Expired");
Error {
error_type: ErrorType::JwtExpired(user),
inner: None,
}
}
}
impl From<db::Error> for Error {
fn from(value: db::Error) -> Self {
Error {
error_type: ErrorType::InternalServerError,
inner: Some(Box::new(value)),
}
}
}
impl From<authentication::JwtError> for Error {
fn from(value: authentication::JwtError) -> Self {
error!(inner = value.to_string(), "Unhandled JWT error");
Error {
error_type: ErrorType::InternalServerError,
inner: Some(Box::new(value)),
}
}
}
impl From<authentication::AuthenticationError> for Error {
fn from(value: authentication::AuthenticationError) -> Self {
error!(inner = value.to_string(), "Error during authentication");
Error {
error_type: ErrorType::InternalServerError,
inner: Some(Box::new(value)),
}
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error {
error_type,
inner: Some(inner_err),
} => write!(f, "{}: {}", error_type, inner_err),
Error {
error_type,
inner: None,
} => write!(f, "{}", error_type),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.inner {
Some(inner) => Some(inner.as_ref()),
None => None,
}
}
} }
#[derive(Template)] #[derive(Template)]
@ -36,34 +136,21 @@ struct ErrorTemplate<'a> {
impl IntoResponse for Error { impl IntoResponse for Error {
fn into_response(self) -> Response { fn into_response(self) -> Response {
match self { match self.error_type {
Error::DatabaseError(_) => { ErrorType::InternalServerError => {
error!("Uncaught database error producing HTTP 500."); error!(inner = self.inner, "Uncaught error producing HTTP 500.");
( (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
ErrorTemplate { title: "Error" }, ErrorTemplate { title: "Error" },
) )
.into_response() .into_response()
} }
Error::AuthenticationError(_) => { ErrorType::Unauthorized => {
error!("Uncaught authentication error producing HTTP 500."); (StatusCode::UNAUTHORIZED, "User not authorized.").into_response()
(
StatusCode::INTERNAL_SERVER_ERROR,
ErrorTemplate { title: "Error" },
)
.into_response()
} }
Error::Unexpected(_) => { ErrorType::JwtExpired(_) => {
todo!() todo!()
} }
Error::Forbidden => (StatusCode::UNAUTHORIZED, "User not authorized.").into_response(),
Error::JwtExpired(_) => {
todo!()
}
Error::JwtError(jwt_error) => {
warn!(detail = jwt_error.to_string(), "Checking JWT");
(StatusCode::UNAUTHORIZED, ErrorTemplate { title: "Error" }).into_response()
}
} }
} }
} }

View File

@ -1,5 +1,4 @@
use { use {
thiserror::Error,
tower_http::{services::ServeDir, trace::TraceLayer}, tower_http::{services::ServeDir, trace::TraceLayer},
tracing::Level, tracing::Level,
}; };
@ -13,19 +12,63 @@ mod error;
use config::get_config; use config::get_config;
use db::{Database, PostgresDatabase}; use db::{Database, PostgresDatabase};
#[derive(Error, Debug)]
pub enum Error {
#[error("Loading configuration: \"{0}\"")]
MissingConfigError(#[from] config::Error),
#[error("Database error: {0}")] /// An unrecoverable error which requires the server to shut down
DatabaseError(#[from] db::InitialisationError), #[derive(Debug)]
struct FatalError {
message: &'static str,
inner: Option<Box<dyn std::error::Error>>,
}
#[error("{0}")] impl std::fmt::Display for FatalError {
IOError(#[from] std::io::Error), fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.inner {
Some(inner_err) => write!(f, "{}: {}", self.message, inner_err),
None => write!(f, "{}", self.message),
}
}
}
#[error("{0}")] impl std::error::Error for FatalError {
TracingError(#[from] tracing::subscriber::SetGlobalDefaultError), fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.inner.as_deref()
}
}
impl From<config::Error> for FatalError {
fn from(value: config::Error) -> Self {
FatalError {
message: "Loading config",
inner: Some(Box::new(value)),
}
}
}
impl From<tracing::subscriber::SetGlobalDefaultError> for FatalError {
fn from(value: tracing::subscriber::SetGlobalDefaultError) -> Self {
FatalError {
message: "Loading config",
inner: Some(Box::new(value)),
}
}
}
impl From<db::InitialisationError> for FatalError {
fn from(value: db::InitialisationError) -> Self {
FatalError {
message: "initialising database connection",
inner: Some(Box::new(value)),
}
}
}
impl From<std::io::Error> for FatalError {
fn from(value: std::io::Error) -> Self {
FatalError {
message: "Initialising",
inner: Some(Box::new(value)),
}
}
} }
fn main() { fn main() {
@ -39,7 +82,7 @@ fn main() {
}) })
} }
async fn locality_main() -> Result<(), Error> { async fn locality_main() -> Result<(), FatalError> {
let config = get_config()?; let config = get_config()?;
let subscriber = tracing_subscriber::FmtSubscriber::builder() let subscriber = tracing_subscriber::FmtSubscriber::builder()