use { base64::engine::{general_purpose::STANDARD as base64_encoder, Engine as _}, chrono::{DateTime, Duration, Utc}, hmac::{Hmac, Mac}, serde::{Deserialize, Serialize}, sha2::Sha256, tracing::{error, warn}, }; use { super::AuthenticatedUser, crate::config::get_config, crate::db::{Database, User, UserId}, }; const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1); #[derive(Debug)] pub enum Error { BadJwt, #[allow(clippy::enum_variant_names)] DatabaseError, #[allow(clippy::enum_variant_names)] HmacError, } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "{}", match self { Error::BadJwt => "BadJwt", Error::DatabaseError => "DatabaseError", Error::HmacError => "HmacError", } ) } } impl std::error::Error for Error {} impl From for Error { fn from(value: serde_json::Error) -> Self { warn!(details = value.to_string(), "JWT contained invalid JSON"); Error::BadJwt } } impl From for Error { fn from(value: base64::DecodeError) -> Self { warn!(details = value.to_string(), "JWT contained invalid BASE64"); Error::BadJwt } } impl From for Error { fn from(value: std::string::FromUtf8Error) -> Self { warn!(details = value.to_string(), "JWT contained invalid UTF-8"); Error::BadJwt } } impl From for Error { fn from(value: crate::db::Error) -> Self { error!(details = value.to_string(), "Database error"); Error::DatabaseError } } impl From for Error { fn from(_value: digest::MacError) -> Self { error!("Bug in JWT HMAC code."); Error::HmacError } } type Result = std::result::Result; /// Result type for [authenticate_user_with_jwt()]. #[derive(Debug)] pub enum ParsedJwt { /// JWT is a valid JWT, here is the [AuthenticatedUser]. Valid(AuthenticatedUser), /// JWT was valid but is expired Expired(User), /// JWT signature does not match contents InvalidSignature, /// JWT is valid, but the user id is not in the database UserNotFound, } #[derive(Serialize, Deserialize)] struct Header<'a> { #[serde(rename = "alg")] algorithm: &'a str, #[serde(rename = "typ")] token_type: &'a str, } #[derive(Serialize, Deserialize)] struct Payload { #[serde(rename = "sub")] user_id: UserId, #[serde(rename = "exp")] expiry: DateTime, } fn mac() -> Hmac { Hmac::new_from_slice(&get_config().unwrap().hmac_secret) .expect("HMAC can take key of any size.") } /// Given an [AuthenticatedUser], create a JWT for use as a cookie to /// keep that user logged in. #[tracing::instrument] pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result { let header = base64_encoder.encode( serde_json::to_string(&Header { algorithm: "HS256", token_type: "JWT", })? .as_bytes(), ); let payload = base64_encoder.encode( serde_json::to_string(&Payload { user_id: user.get_id(), expiry: Utc::now() + COOKIE_EXPIRY_TIME, })? .as_bytes(), ); let mut mac = mac(); mac.update(format!("{}.{}", header, payload).as_bytes()); let signature = base64_encoder.encode(mac.finalize().into_bytes()); Ok(format!("{}.{}.{}", header, payload, signature)) } /// Given JWT string created by [create_jwt_for_user()], check if the /// JWT is valid and return an [AuthenticatedUser] if it is. #[tracing::instrument] pub async fn authenticate_user_with_jwt(db: &Database, jwt: &str) -> Result { if let [header, payload, signature] = jwt.split('.').collect::>().as_slice() { let mut mac = mac(); mac.update(format!("{}.{}", header, payload).as_bytes()); if mac.verify_slice(signature.as_bytes()).is_err() { Ok(ParsedJwt::InvalidSignature) } else { let header_json = String::from_utf8(base64_encoder.decode(header)?)?; let header: Header = serde_json::from_str(&header_json)?; if header.algorithm != "HS256" || header.token_type != "JWT" { warn!("JWT does not have expected algorithm or type."); Err(Error::BadJwt) } else { let payload: Payload = serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?; Ok(dbg!( if let Some(user) = db.get_user_with_id(payload.user_id).await? { if payload.expiry < Utc::now() { ParsedJwt::Valid(AuthenticatedUser(user)) } else { ParsedJwt::Expired(user) } } else { ParsedJwt::UserNotFound }, )) } } } else { warn!("Invalid JWT"); Err(Error::BadJwt) } }