177 lines
5.1 KiB
Rust
177 lines
5.1 KiB
Rust
use {
|
|
base64::engine::{general_purpose::STANDARD as base64_encoder, Engine as _},
|
|
chrono::{DateTime, Duration, Utc},
|
|
hmac::{Hmac, Mac},
|
|
serde::{Deserialize, Serialize},
|
|
sha2::Sha256,
|
|
tracing::{error, warn},
|
|
};
|
|
|
|
use {
|
|
super::AuthenticatedUser,
|
|
crate::config::get_config,
|
|
crate::db::{Database, User, UserId},
|
|
};
|
|
|
|
const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1);
|
|
|
|
#[derive(Debug)]
|
|
pub enum Error {
|
|
BadJwt,
|
|
|
|
#[allow(clippy::enum_variant_names)]
|
|
DatabaseError,
|
|
|
|
#[allow(clippy::enum_variant_names)]
|
|
HmacError,
|
|
}
|
|
|
|
impl std::fmt::Display for Error {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(
|
|
f,
|
|
"{}",
|
|
match self {
|
|
Error::BadJwt => "BadJwt",
|
|
Error::DatabaseError => "DatabaseError",
|
|
Error::HmacError => "HmacError",
|
|
}
|
|
)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for Error {}
|
|
|
|
impl From<serde_json::Error> for Error {
|
|
fn from(value: serde_json::Error) -> Self {
|
|
warn!(details = value.to_string(), "JWT contained invalid JSON");
|
|
Error::BadJwt
|
|
}
|
|
}
|
|
|
|
impl From<base64::DecodeError> for Error {
|
|
fn from(value: base64::DecodeError) -> Self {
|
|
warn!(details = value.to_string(), "JWT contained invalid BASE64");
|
|
Error::BadJwt
|
|
}
|
|
}
|
|
|
|
impl From<std::string::FromUtf8Error> for Error {
|
|
fn from(value: std::string::FromUtf8Error) -> Self {
|
|
warn!(details = value.to_string(), "JWT contained invalid UTF-8");
|
|
Error::BadJwt
|
|
}
|
|
}
|
|
|
|
impl From<crate::db::Error> for Error {
|
|
fn from(value: crate::db::Error) -> Self {
|
|
error!(details = value.to_string(), "Database error");
|
|
Error::DatabaseError
|
|
}
|
|
}
|
|
|
|
impl From<digest::MacError> for Error {
|
|
fn from(_value: digest::MacError) -> Self {
|
|
error!("Bug in JWT HMAC code.");
|
|
Error::HmacError
|
|
}
|
|
}
|
|
|
|
type Result<T> = std::result::Result<T, Error>;
|
|
|
|
/// Result type for [authenticate_user_with_jwt()].
|
|
#[derive(Debug)]
|
|
pub enum ParsedJwt {
|
|
/// JWT is a valid JWT, here is the [AuthenticatedUser].
|
|
Valid(AuthenticatedUser),
|
|
/// JWT was valid but is expired
|
|
Expired(User),
|
|
/// JWT signature does not match contents
|
|
InvalidSignature,
|
|
/// JWT is valid, but the user id is not in the database
|
|
UserNotFound,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
struct Header<'a> {
|
|
#[serde(rename = "alg")]
|
|
algorithm: &'a str,
|
|
#[serde(rename = "typ")]
|
|
token_type: &'a str,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
struct Payload {
|
|
#[serde(rename = "sub")]
|
|
user_id: UserId,
|
|
|
|
#[serde(rename = "exp")]
|
|
expiry: DateTime<Utc>,
|
|
}
|
|
|
|
fn mac() -> Hmac<Sha256> {
|
|
Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
|
|
.expect("HMAC can take key of any size.")
|
|
}
|
|
|
|
/// Given an [AuthenticatedUser], create a JWT for use as a cookie to
|
|
/// keep that user logged in.
|
|
#[tracing::instrument]
|
|
pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
|
|
let header = base64_encoder.encode(
|
|
serde_json::to_string(&Header {
|
|
algorithm: "HS256",
|
|
token_type: "JWT",
|
|
})?
|
|
.as_bytes(),
|
|
);
|
|
let payload = base64_encoder.encode(
|
|
serde_json::to_string(&Payload {
|
|
user_id: user.get_id(),
|
|
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
|
|
})?
|
|
.as_bytes(),
|
|
);
|
|
let mut mac = mac();
|
|
mac.update(format!("{}.{}", header, payload).as_bytes());
|
|
let signature = base64_encoder.encode(mac.finalize().into_bytes());
|
|
Ok(format!("{}.{}.{}", header, payload, signature))
|
|
}
|
|
|
|
/// Given JWT string created by [create_jwt_for_user()], check if the
|
|
/// JWT is valid and return an [AuthenticatedUser] if it is.
|
|
#[tracing::instrument]
|
|
pub async fn authenticate_user_with_jwt<D:Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
|
|
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
|
|
let mut mac = mac();
|
|
mac.update(format!("{}.{}", header, payload).as_bytes());
|
|
if mac.verify_slice(signature.as_bytes()).is_err() {
|
|
Ok(ParsedJwt::InvalidSignature)
|
|
} else {
|
|
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
|
|
let header: Header = serde_json::from_str(&header_json)?;
|
|
if header.algorithm != "HS256" || header.token_type != "JWT" {
|
|
warn!("JWT does not have expected algorithm or type.");
|
|
Err(Error::BadJwt)
|
|
} else {
|
|
let payload: Payload =
|
|
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
|
|
Ok(dbg!(
|
|
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
|
|
if payload.expiry < Utc::now() {
|
|
ParsedJwt::Valid(AuthenticatedUser(user))
|
|
} else {
|
|
ParsedJwt::Expired(user)
|
|
}
|
|
} else {
|
|
ParsedJwt::UserNotFound
|
|
},
|
|
))
|
|
}
|
|
}
|
|
} else {
|
|
warn!("Invalid JWT");
|
|
Err(Error::BadJwt)
|
|
}
|
|
}
|