locality/src/authentication/jwt.rs

177 lines
5.1 KiB
Rust

use {
base64::engine::{general_purpose::STANDARD as base64_encoder, Engine as _},
chrono::{DateTime, Duration, Utc},
hmac::{Hmac, Mac},
serde::{Deserialize, Serialize},
sha2::Sha256,
tracing::{error, warn},
};
use {
super::AuthenticatedUser,
crate::config::get_config,
crate::db::{Database, User, UserId},
};
const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1);
#[derive(Debug)]
pub enum Error {
BadJwt,
#[allow(clippy::enum_variant_names)]
DatabaseError,
#[allow(clippy::enum_variant_names)]
HmacError,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Error::BadJwt => "BadJwt",
Error::DatabaseError => "DatabaseError",
Error::HmacError => "HmacError",
}
)
}
}
impl std::error::Error for Error {}
impl From<serde_json::Error> for Error {
fn from(value: serde_json::Error) -> Self {
warn!(details = value.to_string(), "JWT contained invalid JSON");
Error::BadJwt
}
}
impl From<base64::DecodeError> for Error {
fn from(value: base64::DecodeError) -> Self {
warn!(details = value.to_string(), "JWT contained invalid BASE64");
Error::BadJwt
}
}
impl From<std::string::FromUtf8Error> for Error {
fn from(value: std::string::FromUtf8Error) -> Self {
warn!(details = value.to_string(), "JWT contained invalid UTF-8");
Error::BadJwt
}
}
impl From<crate::db::Error> for Error {
fn from(value: crate::db::Error) -> Self {
error!(details = value.to_string(), "Database error");
Error::DatabaseError
}
}
impl From<digest::MacError> for Error {
fn from(_value: digest::MacError) -> Self {
error!("Bug in JWT HMAC code.");
Error::HmacError
}
}
type Result<T> = std::result::Result<T, Error>;
/// Result type for [authenticate_user_with_jwt()].
#[derive(Debug)]
pub enum ParsedJwt {
/// JWT is a valid JWT, here is the [AuthenticatedUser].
Valid(AuthenticatedUser),
/// JWT was valid but is expired
Expired(User),
/// JWT signature does not match contents
InvalidSignature,
/// JWT is valid, but the user id is not in the database
UserNotFound,
}
#[derive(Serialize, Deserialize)]
struct Header<'a> {
#[serde(rename = "alg")]
algorithm: &'a str,
#[serde(rename = "typ")]
token_type: &'a str,
}
#[derive(Serialize, Deserialize)]
struct Payload {
#[serde(rename = "sub")]
user_id: UserId,
#[serde(rename = "exp")]
expiry: DateTime<Utc>,
}
fn mac() -> Hmac<Sha256> {
Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
.expect("HMAC can take key of any size.")
}
/// Given an [AuthenticatedUser], create a JWT for use as a cookie to
/// keep that user logged in.
#[tracing::instrument]
pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
let header = base64_encoder.encode(
serde_json::to_string(&Header {
algorithm: "HS256",
token_type: "JWT",
})?
.as_bytes(),
);
let payload = base64_encoder.encode(
serde_json::to_string(&Payload {
user_id: user.get_id(),
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
})?
.as_bytes(),
);
let mut mac = mac();
mac.update(format!("{}.{}", header, payload).as_bytes());
let signature = base64_encoder.encode(mac.finalize().into_bytes());
Ok(format!("{}.{}.{}", header, payload, signature))
}
/// Given JWT string created by [create_jwt_for_user()], check if the
/// JWT is valid and return an [AuthenticatedUser] if it is.
#[tracing::instrument]
pub async fn authenticate_user_with_jwt<D:Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
let mut mac = mac();
mac.update(format!("{}.{}", header, payload).as_bytes());
if mac.verify_slice(signature.as_bytes()).is_err() {
Ok(ParsedJwt::InvalidSignature)
} else {
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
let header: Header = serde_json::from_str(&header_json)?;
if header.algorithm != "HS256" || header.token_type != "JWT" {
warn!("JWT does not have expected algorithm or type.");
Err(Error::BadJwt)
} else {
let payload: Payload =
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
Ok(dbg!(
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
if payload.expiry < Utc::now() {
ParsedJwt::Valid(AuthenticatedUser(user))
} else {
ParsedJwt::Expired(user)
}
} else {
ParsedJwt::UserNotFound
},
))
}
}
} else {
warn!("Invalid JWT");
Err(Error::BadJwt)
}
}