Refactor jwt.rs, add tests, fix bugs
Includes small changed to db module.
This commit is contained in:
parent
f9ac8b1e29
commit
10551a9bc9
|
|
@ -15,7 +15,7 @@ use {
|
|||
|
||||
const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1);
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
BadJwt,
|
||||
|
||||
|
|
@ -80,7 +80,7 @@ impl From<digest::MacError> for Error {
|
|||
type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// Result type for [authenticate_user_with_jwt()].
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
pub enum ParsedJwt {
|
||||
/// JWT is a valid JWT, here is the [AuthenticatedUser].
|
||||
Valid(AuthenticatedUser),
|
||||
|
|
@ -92,7 +92,7 @@ pub enum ParsedJwt {
|
|||
UserNotFound,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Header<'a> {
|
||||
#[serde(rename = "alg")]
|
||||
algorithm: &'a str,
|
||||
|
|
@ -100,7 +100,16 @@ struct Header<'a> {
|
|||
token_type: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
impl Header<'static> {
|
||||
fn new() -> Self {
|
||||
Header {
|
||||
algorithm: "HS256",
|
||||
token_type: "JWT",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Payload {
|
||||
#[serde(rename = "sub")]
|
||||
user_id: UserId,
|
||||
|
|
@ -109,68 +118,287 @@ struct Payload {
|
|||
expiry: DateTime<Utc>,
|
||||
}
|
||||
|
||||
fn mac() -> Hmac<Sha256> {
|
||||
Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
|
||||
.expect("HMAC can take key of any size.")
|
||||
impl Payload {
|
||||
fn new(user_id: UserId) -> Self {
|
||||
Self {
|
||||
user_id,
|
||||
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_mac(header: &str, payload: &str) -> Hmac<Sha256> {
|
||||
let mut mac = Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
|
||||
.expect("HMAC can take key of any size.");
|
||||
mac.update(format!("{}.{}", header, payload).as_bytes());
|
||||
mac
|
||||
}
|
||||
|
||||
fn create_signature(header: &str, payload: &str) -> String {
|
||||
let mac = create_mac(header, payload);
|
||||
base64_encoder.encode(mac.finalize().into_bytes())
|
||||
}
|
||||
|
||||
fn check_signature(header: &str, payload: &str, signature: &str) -> Result<bool> {
|
||||
let mac = create_mac(header, payload);
|
||||
Ok(mac
|
||||
.verify_slice(&base64_encoder.decode(signature.as_bytes())?)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
fn sign_jwt(header: &str, payload: &str) -> String {
|
||||
let signature = create_signature(header, payload);
|
||||
format!("{}.{}.{}", header, payload, signature)
|
||||
}
|
||||
|
||||
fn dissassemble_jwt(jwt: &str) -> Result<(&str, &str, &str)> {
|
||||
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
|
||||
Ok((header, payload, signature))
|
||||
} else {
|
||||
warn!("Invalid JWT");
|
||||
Err(Error::BadJwt)
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_jwt_component<T: Serialize>(value: &T) -> Result<String> {
|
||||
Ok(base64_encoder.encode(serde_json::to_string(value)?.as_bytes()))
|
||||
}
|
||||
|
||||
/// Given an [AuthenticatedUser], create a JWT for use as a cookie to
|
||||
/// keep that user logged in.
|
||||
#[tracing::instrument]
|
||||
pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
|
||||
let header = base64_encoder.encode(
|
||||
serde_json::to_string(&Header {
|
||||
algorithm: "HS256",
|
||||
token_type: "JWT",
|
||||
})?
|
||||
.as_bytes(),
|
||||
);
|
||||
let payload = base64_encoder.encode(
|
||||
serde_json::to_string(&Payload {
|
||||
user_id: user.get_id(),
|
||||
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
|
||||
})?
|
||||
.as_bytes(),
|
||||
);
|
||||
let mut mac = mac();
|
||||
mac.update(format!("{}.{}", header, payload).as_bytes());
|
||||
let signature = base64_encoder.encode(mac.finalize().into_bytes());
|
||||
Ok(format!("{}.{}.{}", header, payload, signature))
|
||||
let header = encode_jwt_component(&Header::new())?;
|
||||
let payload = encode_jwt_component(&Payload::new(user.get_id()))?;
|
||||
Ok(sign_jwt(&header, &payload))
|
||||
}
|
||||
|
||||
/// Given JWT string created by [create_jwt_for_user()], check if the
|
||||
/// JWT is valid and return an [AuthenticatedUser] if it is.
|
||||
#[tracing::instrument]
|
||||
pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
|
||||
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
|
||||
let mut mac = mac();
|
||||
mac.update(format!("{}.{}", header, payload).as_bytes());
|
||||
if mac.verify_slice(signature.as_bytes()).is_err() {
|
||||
Ok(ParsedJwt::InvalidSignature)
|
||||
} else {
|
||||
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
|
||||
let header: Header = serde_json::from_str(&header_json)?;
|
||||
if header.algorithm != "HS256" || header.token_type != "JWT" {
|
||||
warn!("JWT does not have expected algorithm or type.");
|
||||
Err(Error::BadJwt)
|
||||
} else {
|
||||
let payload: Payload =
|
||||
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
|
||||
Ok(dbg!(
|
||||
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
|
||||
if payload.expiry < Utc::now() {
|
||||
ParsedJwt::Valid(AuthenticatedUser(user))
|
||||
} else {
|
||||
ParsedJwt::Expired(user)
|
||||
}
|
||||
} else {
|
||||
ParsedJwt::UserNotFound
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
let (header, payload, signature) = dissassemble_jwt(jwt)?;
|
||||
if !check_signature(header, payload, signature)? {
|
||||
Ok(ParsedJwt::InvalidSignature)
|
||||
} else {
|
||||
warn!("Invalid JWT");
|
||||
Err(Error::BadJwt)
|
||||
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
|
||||
let header: Header = serde_json::from_str(&header_json)?;
|
||||
if header.algorithm != "HS256" || header.token_type != "JWT" {
|
||||
warn!("JWT does not have expected algorithm or type.");
|
||||
Err(Error::BadJwt)
|
||||
} else {
|
||||
let payload: Payload =
|
||||
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
|
||||
Ok(
|
||||
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
|
||||
if payload.expiry > Utc::now() {
|
||||
ParsedJwt::Valid(AuthenticatedUser(user))
|
||||
} else {
|
||||
ParsedJwt::Expired(user)
|
||||
}
|
||||
} else {
|
||||
ParsedJwt::UserNotFound
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{db, db::fake::FakeDatabase};
|
||||
|
||||
#[tokio::test]
|
||||
async fn created_jwt_authenticates() {
|
||||
let db = FakeDatabase::new_empty();
|
||||
let user = AuthenticatedUser(
|
||||
db.create_user(
|
||||
"John Smith",
|
||||
"john.smith@example.com",
|
||||
&db::PasswordHash("abc123".to_string()),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
let jwt = create_jwt_for_user(&user).unwrap();
|
||||
assert_eq!(
|
||||
ParsedJwt::Valid(user),
|
||||
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn modified_jwt_does_not_authenticate() {
|
||||
let db = FakeDatabase::new_empty();
|
||||
let user1 = AuthenticatedUser(
|
||||
db.create_user(
|
||||
"John Smith",
|
||||
"john.smith@example.com",
|
||||
&db::PasswordHash("abc123".to_string()),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
let jwt1 = create_jwt_for_user(&user1).unwrap();
|
||||
// Create another user with same name, etc. but different user id
|
||||
let user2 = AuthenticatedUser(
|
||||
db.create_user(
|
||||
"John Smith",
|
||||
"john.smith@example.com",
|
||||
&db::PasswordHash("abc123".to_string()),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
let jwt2 = create_jwt_for_user(&user2).unwrap();
|
||||
// Attach signature 2 to jwt 2
|
||||
let parts1: Vec<_> = jwt1.split('.').collect();
|
||||
assert_eq!(3, parts1.len());
|
||||
let parts2: Vec<_> = jwt2.split('.').collect();
|
||||
assert_eq!(3, parts2.len());
|
||||
let parts3 = vec![parts2[0], parts2[1], parts1[2]];
|
||||
|
||||
//Rebuild all JWTs, so the orignial ones can act as a control
|
||||
let jwt1 = parts1.join(".");
|
||||
let jwt2 = parts2.join(".");
|
||||
let jwt3 = parts3.join(".");
|
||||
assert_eq!(
|
||||
ParsedJwt::Valid(user1),
|
||||
authenticate_user_with_jwt(&db, &jwt1).await.unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
ParsedJwt::Valid(user2),
|
||||
authenticate_user_with_jwt(&db, &jwt2).await.unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
ParsedJwt::InvalidSignature,
|
||||
authenticate_user_with_jwt(&db, &jwt3).await.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deleted_user_cannot_authenticate() {
|
||||
let db = FakeDatabase::new_empty();
|
||||
let user = AuthenticatedUser(
|
||||
db.create_user(
|
||||
"John Smith",
|
||||
"john.smith@example.com",
|
||||
&db::PasswordHash("abc123".to_string()),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
let jwt = create_jwt_for_user(&user).unwrap();
|
||||
assert_eq!(
|
||||
ParsedJwt::Valid(user.clone()),
|
||||
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
||||
);
|
||||
|
||||
db.delete_user(user.get_id()).await.unwrap();
|
||||
assert_eq!(
|
||||
ParsedJwt::UserNotFound,
|
||||
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unsigned_jwt_is_rejected() {
|
||||
let db = FakeDatabase::new_empty();
|
||||
let user = AuthenticatedUser(
|
||||
db.create_user(
|
||||
"John Smith",
|
||||
"john.smith@example.com",
|
||||
&db::PasswordHash("abc123".to_string()),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
let jwt = format!(
|
||||
"{},{}",
|
||||
encode_jwt_component(&Header::new()).unwrap(),
|
||||
encode_jwt_component(&Payload::new(user.get_id())).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
Err(Error::BadJwt),
|
||||
authenticate_user_with_jwt(&db, &jwt).await
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn jwt_with_unrecognized_algorithm_is_rejected() {
|
||||
let db = FakeDatabase::new_empty();
|
||||
let user = AuthenticatedUser(
|
||||
db.create_user(
|
||||
"John Smith",
|
||||
"john.smith@example.com",
|
||||
&db::PasswordHash("abc123".to_string()),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
let jwt = sign_jwt(
|
||||
&encode_jwt_component(&Header {
|
||||
algorithm: "RS256",
|
||||
token_type: "JWT",
|
||||
})
|
||||
.unwrap(),
|
||||
&encode_jwt_component(&Payload::new(user.get_id())).unwrap(),
|
||||
);
|
||||
assert_eq!(
|
||||
Err(Error::BadJwt),
|
||||
authenticate_user_with_jwt(&db, &jwt).await
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn jwt_with_unrecognized_token_type_is_rejected() {
|
||||
let db = FakeDatabase::new_empty();
|
||||
let user = AuthenticatedUser(
|
||||
db.create_user(
|
||||
"John Smith",
|
||||
"john.smith@example.com",
|
||||
&db::PasswordHash("abc123".to_string()),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
let jwt = sign_jwt(
|
||||
&encode_jwt_component(&Header {
|
||||
algorithm: "HS256",
|
||||
token_type: "RWT",
|
||||
})
|
||||
.unwrap(),
|
||||
&encode_jwt_component(&Payload::new(user.get_id())).unwrap(),
|
||||
);
|
||||
assert_eq!(
|
||||
Err(Error::BadJwt),
|
||||
authenticate_user_with_jwt(&db, &jwt).await
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn expired_jwt_is_rejected() {
|
||||
let db = FakeDatabase::new_empty();
|
||||
let user = db
|
||||
.create_user(
|
||||
"John Smith",
|
||||
"john.smith@example.com",
|
||||
&db::PasswordHash("abc123".to_string()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let jwt = sign_jwt(
|
||||
&encode_jwt_component(&Header::new()).unwrap(),
|
||||
&encode_jwt_component(&Payload {
|
||||
user_id: user.get_id(),
|
||||
expiry: Utc::now() - COOKIE_EXPIRY_TIME - Duration::seconds(1),
|
||||
})
|
||||
.unwrap(),
|
||||
);
|
||||
assert_eq!(
|
||||
Ok(ParsedJwt::Expired(user)),
|
||||
authenticate_user_with_jwt(&db, &jwt).await
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ impl From<argon2::password_hash::Error> for AuthenticationError {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct AuthenticatedUser(db::User);
|
||||
|
||||
impl Deref for AuthenticatedUser {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
use super::*;
|
||||
use {
|
||||
std::collections::HashSet,
|
||||
std::sync::{Arc, Mutex},
|
||||
std::collections::{HashMap, HashSet},
|
||||
std::{
|
||||
hash::Hash,
|
||||
sync::{Arc, Mutex},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
@ -11,16 +14,63 @@ struct UserRow {
|
|||
password_hash: String,
|
||||
}
|
||||
|
||||
trait Id: Copy + Eq + Hash {
|
||||
fn first() -> Self;
|
||||
fn increment(&mut self);
|
||||
}
|
||||
|
||||
impl Id for UserId {
|
||||
fn first() -> Self {
|
||||
UserId(0)
|
||||
}
|
||||
|
||||
fn increment(&mut self) {
|
||||
self.0 += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Table<Id, Row> {
|
||||
data: HashMap<Id, Row>,
|
||||
next_id: Id,
|
||||
}
|
||||
|
||||
impl<I, R> Table<I, R>
|
||||
where
|
||||
I: Id,
|
||||
{
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
data: HashMap::new(),
|
||||
next_id: I::first(),
|
||||
}
|
||||
}
|
||||
fn insert(&mut self, row: R) -> I {
|
||||
let id = self.next_id;
|
||||
self.data.insert(id, row);
|
||||
self.next_id.increment();
|
||||
id
|
||||
}
|
||||
|
||||
fn delete(&mut self, id: I) {
|
||||
self.data.remove(&id);
|
||||
}
|
||||
|
||||
fn get(&self, id: &I) -> Option<&R> {
|
||||
self.data.get(id)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FakeDatabase {
|
||||
users: Arc<Mutex<Vec<UserRow>>>,
|
||||
users: Arc<Mutex<Table<UserId, UserRow>>>,
|
||||
admin_users: Arc<Mutex<std::collections::HashSet<usize>>>,
|
||||
}
|
||||
|
||||
impl FakeDatabase {
|
||||
pub fn new_empty() -> Self {
|
||||
FakeDatabase {
|
||||
users: Arc::new(Mutex::new(Vec::new())),
|
||||
users: Arc::new(Mutex::new(Table::new())),
|
||||
admin_users: Arc::new(Mutex::new(HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
|
@ -42,20 +92,26 @@ impl Database for FakeDatabase {
|
|||
password: &PasswordHash,
|
||||
) -> Result<User> {
|
||||
let mut users = self.users.lock().unwrap();
|
||||
users.push(UserRow {
|
||||
let new_id = users.insert(UserRow {
|
||||
real_name: real_name.to_string(),
|
||||
email: email.to_string(),
|
||||
password_hash: password.to_string(),
|
||||
});
|
||||
Ok(User {
|
||||
id: UserId((users.len() - 1) as i32),
|
||||
id: new_id,
|
||||
real_name: real_name.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn delete_user(&self, user_id: UserId) -> Result<()> {
|
||||
let mut users = self.users.lock().unwrap();
|
||||
users.delete(user_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
||||
let users = self.users.lock().unwrap();
|
||||
if let Some(UserRow { password_hash, .. }) = users.get(user.id.0 as usize) {
|
||||
if let Some(UserRow { password_hash, .. }) = users.get(&user.id) {
|
||||
Ok(PasswordHash(password_hash.clone()))
|
||||
} else {
|
||||
Err(Error::Database)
|
||||
|
|
@ -81,12 +137,10 @@ impl Database for FakeDatabase {
|
|||
|
||||
async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
|
||||
let users = self.users.lock().unwrap();
|
||||
Ok(users
|
||||
.get(user_id.0 as usize)
|
||||
.map(|UserRow { real_name, .. }| User {
|
||||
id: user_id,
|
||||
real_name: real_name.clone(),
|
||||
}))
|
||||
Ok(users.get(&user_id).map(|UserRow { real_name, .. }| User {
|
||||
id: user_id,
|
||||
real_name: real_name.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static {
|
|||
email: &str,
|
||||
password: &PasswordHash,
|
||||
) -> impl Future<Output = Result<User>> + Send;
|
||||
fn delete_user(&self, user: UserId) -> impl Future<Output = Result<()>> + Send;
|
||||
fn get_password_for_user(
|
||||
&self,
|
||||
user: &User,
|
||||
|
|
@ -123,10 +124,10 @@ pub struct PostgresDatabase {
|
|||
connection_pool: Pool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
|
||||
pub struct UserId(i32);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct User {
|
||||
id: UserId,
|
||||
pub real_name: String,
|
||||
|
|
@ -206,6 +207,11 @@ impl Database for PostgresDatabase {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn delete_user(&self, user: UserId) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
||||
let client = self.get_client().await?;
|
||||
|
|
|
|||
Loading…
Reference in New Issue