diff --git a/src/authentication/jwt.rs b/src/authentication/jwt.rs index 3a47910..4711bf4 100644 --- a/src/authentication/jwt.rs +++ b/src/authentication/jwt.rs @@ -15,7 +15,7 @@ use { const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1); -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum Error { BadJwt, @@ -80,7 +80,7 @@ impl From for Error { type Result = std::result::Result; /// Result type for [authenticate_user_with_jwt()]. -#[derive(Debug)] +#[derive(Debug, Eq, PartialEq)] pub enum ParsedJwt { /// JWT is a valid JWT, here is the [AuthenticatedUser]. Valid(AuthenticatedUser), @@ -92,7 +92,7 @@ pub enum ParsedJwt { UserNotFound, } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] struct Header<'a> { #[serde(rename = "alg")] algorithm: &'a str, @@ -100,7 +100,16 @@ struct Header<'a> { token_type: &'a str, } -#[derive(Serialize, Deserialize)] +impl Header<'static> { + fn new() -> Self { + Header { + algorithm: "HS256", + token_type: "JWT", + } + } +} + +#[derive(Debug, Serialize, Deserialize)] struct Payload { #[serde(rename = "sub")] user_id: UserId, @@ -109,68 +118,287 @@ struct Payload { expiry: DateTime, } -fn mac() -> Hmac { - Hmac::new_from_slice(&get_config().unwrap().hmac_secret) - .expect("HMAC can take key of any size.") +impl Payload { + fn new(user_id: UserId) -> Self { + Self { + user_id, + expiry: Utc::now() + COOKIE_EXPIRY_TIME, + } + } +} + +fn create_mac(header: &str, payload: &str) -> Hmac { + let mut mac = Hmac::new_from_slice(&get_config().unwrap().hmac_secret) + .expect("HMAC can take key of any size."); + mac.update(format!("{}.{}", header, payload).as_bytes()); + mac +} + +fn create_signature(header: &str, payload: &str) -> String { + let mac = create_mac(header, payload); + base64_encoder.encode(mac.finalize().into_bytes()) +} + +fn check_signature(header: &str, payload: &str, signature: &str) -> Result { + let mac = create_mac(header, payload); + Ok(mac + .verify_slice(&base64_encoder.decode(signature.as_bytes())?) + .is_ok()) +} + +fn sign_jwt(header: &str, payload: &str) -> String { + let signature = create_signature(header, payload); + format!("{}.{}.{}", header, payload, signature) +} + +fn dissassemble_jwt(jwt: &str) -> Result<(&str, &str, &str)> { + if let [header, payload, signature] = jwt.split('.').collect::>().as_slice() { + Ok((header, payload, signature)) + } else { + warn!("Invalid JWT"); + Err(Error::BadJwt) + } +} + +fn encode_jwt_component(value: &T) -> Result { + Ok(base64_encoder.encode(serde_json::to_string(value)?.as_bytes())) } /// Given an [AuthenticatedUser], create a JWT for use as a cookie to /// keep that user logged in. #[tracing::instrument] pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result { - let header = base64_encoder.encode( - serde_json::to_string(&Header { - algorithm: "HS256", - token_type: "JWT", - })? - .as_bytes(), - ); - let payload = base64_encoder.encode( - serde_json::to_string(&Payload { - user_id: user.get_id(), - expiry: Utc::now() + COOKIE_EXPIRY_TIME, - })? - .as_bytes(), - ); - let mut mac = mac(); - mac.update(format!("{}.{}", header, payload).as_bytes()); - let signature = base64_encoder.encode(mac.finalize().into_bytes()); - Ok(format!("{}.{}.{}", header, payload, signature)) + let header = encode_jwt_component(&Header::new())?; + let payload = encode_jwt_component(&Payload::new(user.get_id()))?; + Ok(sign_jwt(&header, &payload)) } /// Given JWT string created by [create_jwt_for_user()], check if the /// JWT is valid and return an [AuthenticatedUser] if it is. #[tracing::instrument] pub async fn authenticate_user_with_jwt(db: &D, jwt: &str) -> Result { - if let [header, payload, signature] = jwt.split('.').collect::>().as_slice() { - let mut mac = mac(); - mac.update(format!("{}.{}", header, payload).as_bytes()); - if mac.verify_slice(signature.as_bytes()).is_err() { - Ok(ParsedJwt::InvalidSignature) - } else { - let header_json = String::from_utf8(base64_encoder.decode(header)?)?; - let header: Header = serde_json::from_str(&header_json)?; - if header.algorithm != "HS256" || header.token_type != "JWT" { - warn!("JWT does not have expected algorithm or type."); - Err(Error::BadJwt) - } else { - let payload: Payload = - serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?; - Ok(dbg!( - if let Some(user) = db.get_user_with_id(payload.user_id).await? { - if payload.expiry < Utc::now() { - ParsedJwt::Valid(AuthenticatedUser(user)) - } else { - ParsedJwt::Expired(user) - } - } else { - ParsedJwt::UserNotFound - }, - )) - } - } + let (header, payload, signature) = dissassemble_jwt(jwt)?; + if !check_signature(header, payload, signature)? { + Ok(ParsedJwt::InvalidSignature) } else { - warn!("Invalid JWT"); - Err(Error::BadJwt) + let header_json = String::from_utf8(base64_encoder.decode(header)?)?; + let header: Header = serde_json::from_str(&header_json)?; + if header.algorithm != "HS256" || header.token_type != "JWT" { + warn!("JWT does not have expected algorithm or type."); + Err(Error::BadJwt) + } else { + let payload: Payload = + serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?; + Ok( + if let Some(user) = db.get_user_with_id(payload.user_id).await? { + if payload.expiry > Utc::now() { + ParsedJwt::Valid(AuthenticatedUser(user)) + } else { + ParsedJwt::Expired(user) + } + } else { + ParsedJwt::UserNotFound + }, + ) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{db, db::fake::FakeDatabase}; + + #[tokio::test] + async fn created_jwt_authenticates() { + let db = FakeDatabase::new_empty(); + let user = AuthenticatedUser( + db.create_user( + "John Smith", + "john.smith@example.com", + &db::PasswordHash("abc123".to_string()), + ) + .await + .unwrap(), + ); + let jwt = create_jwt_for_user(&user).unwrap(); + assert_eq!( + ParsedJwt::Valid(user), + authenticate_user_with_jwt(&db, &jwt).await.unwrap() + ); + } + + #[tokio::test] + async fn modified_jwt_does_not_authenticate() { + let db = FakeDatabase::new_empty(); + let user1 = AuthenticatedUser( + db.create_user( + "John Smith", + "john.smith@example.com", + &db::PasswordHash("abc123".to_string()), + ) + .await + .unwrap(), + ); + let jwt1 = create_jwt_for_user(&user1).unwrap(); + // Create another user with same name, etc. but different user id + let user2 = AuthenticatedUser( + db.create_user( + "John Smith", + "john.smith@example.com", + &db::PasswordHash("abc123".to_string()), + ) + .await + .unwrap(), + ); + let jwt2 = create_jwt_for_user(&user2).unwrap(); + // Attach signature 2 to jwt 2 + let parts1: Vec<_> = jwt1.split('.').collect(); + assert_eq!(3, parts1.len()); + let parts2: Vec<_> = jwt2.split('.').collect(); + assert_eq!(3, parts2.len()); + let parts3 = vec![parts2[0], parts2[1], parts1[2]]; + + //Rebuild all JWTs, so the orignial ones can act as a control + let jwt1 = parts1.join("."); + let jwt2 = parts2.join("."); + let jwt3 = parts3.join("."); + assert_eq!( + ParsedJwt::Valid(user1), + authenticate_user_with_jwt(&db, &jwt1).await.unwrap() + ); + assert_eq!( + ParsedJwt::Valid(user2), + authenticate_user_with_jwt(&db, &jwt2).await.unwrap() + ); + assert_eq!( + ParsedJwt::InvalidSignature, + authenticate_user_with_jwt(&db, &jwt3).await.unwrap() + ); + } + + #[tokio::test] + async fn deleted_user_cannot_authenticate() { + let db = FakeDatabase::new_empty(); + let user = AuthenticatedUser( + db.create_user( + "John Smith", + "john.smith@example.com", + &db::PasswordHash("abc123".to_string()), + ) + .await + .unwrap(), + ); + let jwt = create_jwt_for_user(&user).unwrap(); + assert_eq!( + ParsedJwt::Valid(user.clone()), + authenticate_user_with_jwt(&db, &jwt).await.unwrap() + ); + + db.delete_user(user.get_id()).await.unwrap(); + assert_eq!( + ParsedJwt::UserNotFound, + authenticate_user_with_jwt(&db, &jwt).await.unwrap() + ); + } + + #[tokio::test] + async fn unsigned_jwt_is_rejected() { + let db = FakeDatabase::new_empty(); + let user = AuthenticatedUser( + db.create_user( + "John Smith", + "john.smith@example.com", + &db::PasswordHash("abc123".to_string()), + ) + .await + .unwrap(), + ); + let jwt = format!( + "{},{}", + encode_jwt_component(&Header::new()).unwrap(), + encode_jwt_component(&Payload::new(user.get_id())).unwrap() + ); + assert_eq!( + Err(Error::BadJwt), + authenticate_user_with_jwt(&db, &jwt).await + ); + } + + #[tokio::test] + async fn jwt_with_unrecognized_algorithm_is_rejected() { + let db = FakeDatabase::new_empty(); + let user = AuthenticatedUser( + db.create_user( + "John Smith", + "john.smith@example.com", + &db::PasswordHash("abc123".to_string()), + ) + .await + .unwrap(), + ); + let jwt = sign_jwt( + &encode_jwt_component(&Header { + algorithm: "RS256", + token_type: "JWT", + }) + .unwrap(), + &encode_jwt_component(&Payload::new(user.get_id())).unwrap(), + ); + assert_eq!( + Err(Error::BadJwt), + authenticate_user_with_jwt(&db, &jwt).await + ); + } + + #[tokio::test] + async fn jwt_with_unrecognized_token_type_is_rejected() { + let db = FakeDatabase::new_empty(); + let user = AuthenticatedUser( + db.create_user( + "John Smith", + "john.smith@example.com", + &db::PasswordHash("abc123".to_string()), + ) + .await + .unwrap(), + ); + let jwt = sign_jwt( + &encode_jwt_component(&Header { + algorithm: "HS256", + token_type: "RWT", + }) + .unwrap(), + &encode_jwt_component(&Payload::new(user.get_id())).unwrap(), + ); + assert_eq!( + Err(Error::BadJwt), + authenticate_user_with_jwt(&db, &jwt).await + ); + } + + #[tokio::test] + async fn expired_jwt_is_rejected() { + let db = FakeDatabase::new_empty(); + let user = db + .create_user( + "John Smith", + "john.smith@example.com", + &db::PasswordHash("abc123".to_string()), + ) + .await + .unwrap(); + let jwt = sign_jwt( + &encode_jwt_component(&Header::new()).unwrap(), + &encode_jwt_component(&Payload { + user_id: user.get_id(), + expiry: Utc::now() - COOKIE_EXPIRY_TIME - Duration::seconds(1), + }) + .unwrap(), + ); + assert_eq!( + Ok(ParsedJwt::Expired(user)), + authenticate_user_with_jwt(&db, &jwt).await + ); } } diff --git a/src/authentication/mod.rs b/src/authentication/mod.rs index a078640..7ce9b3d 100644 --- a/src/authentication/mod.rs +++ b/src/authentication/mod.rs @@ -54,7 +54,7 @@ impl From for AuthenticationError { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct AuthenticatedUser(db::User); impl Deref for AuthenticatedUser { diff --git a/src/db/fake.rs b/src/db/fake.rs index 6b52ac8..31b807f 100644 --- a/src/db/fake.rs +++ b/src/db/fake.rs @@ -1,7 +1,10 @@ use super::*; use { - std::collections::HashSet, - std::sync::{Arc, Mutex}, + std::collections::{HashMap, HashSet}, + std::{ + hash::Hash, + sync::{Arc, Mutex}, + }, }; #[derive(Debug)] @@ -11,16 +14,63 @@ struct UserRow { password_hash: String, } +trait Id: Copy + Eq + Hash { + fn first() -> Self; + fn increment(&mut self); +} + +impl Id for UserId { + fn first() -> Self { + UserId(0) + } + + fn increment(&mut self) { + self.0 += 1; + } +} + +#[derive(Debug)] +struct Table { + data: HashMap, + next_id: Id, +} + +impl Table +where + I: Id, +{ + fn new() -> Self { + Self { + data: HashMap::new(), + next_id: I::first(), + } + } + fn insert(&mut self, row: R) -> I { + let id = self.next_id; + self.data.insert(id, row); + self.next_id.increment(); + id + } + + fn delete(&mut self, id: I) { + self.data.remove(&id); + } + + fn get(&self, id: &I) -> Option<&R> { + self.data.get(id) + } +} + #[derive(Debug, Clone)] pub struct FakeDatabase { - users: Arc>>, + users: Arc>>, admin_users: Arc>>, } impl FakeDatabase { pub fn new_empty() -> Self { FakeDatabase { - users: Arc::new(Mutex::new(Vec::new())), + users: Arc::new(Mutex::new(Table::new())), admin_users: Arc::new(Mutex::new(HashSet::new())), } } @@ -42,20 +92,26 @@ impl Database for FakeDatabase { password: &PasswordHash, ) -> Result { let mut users = self.users.lock().unwrap(); - users.push(UserRow { + let new_id = users.insert(UserRow { real_name: real_name.to_string(), email: email.to_string(), password_hash: password.to_string(), }); Ok(User { - id: UserId((users.len() - 1) as i32), + id: new_id, real_name: real_name.to_string(), }) } + async fn delete_user(&self, user_id: UserId) -> Result<()> { + let mut users = self.users.lock().unwrap(); + users.delete(user_id); + Ok(()) + } + async fn get_password_for_user(&self, user: &User) -> Result { let users = self.users.lock().unwrap(); - if let Some(UserRow { password_hash, .. }) = users.get(user.id.0 as usize) { + if let Some(UserRow { password_hash, .. }) = users.get(&user.id) { Ok(PasswordHash(password_hash.clone())) } else { Err(Error::Database) @@ -81,12 +137,10 @@ impl Database for FakeDatabase { async fn get_user_with_id(&self, user_id: UserId) -> Result> { let users = self.users.lock().unwrap(); - Ok(users - .get(user_id.0 as usize) - .map(|UserRow { real_name, .. }| User { - id: user_id, - real_name: real_name.clone(), - })) + Ok(users.get(&user_id).map(|UserRow { real_name, .. }| User { + id: user_id, + real_name: real_name.clone(), + })) } } diff --git a/src/db/mod.rs b/src/db/mod.rs index 394e91e..ba71781 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -98,6 +98,7 @@ pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static { email: &str, password: &PasswordHash, ) -> impl Future> + Send; + fn delete_user(&self, user: UserId) -> impl Future> + Send; fn get_password_for_user( &self, user: &User, @@ -123,10 +124,10 @@ pub struct PostgresDatabase { connection_pool: Pool, } -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct UserId(i32); -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct User { id: UserId, pub real_name: String, @@ -206,6 +207,11 @@ impl Database for PostgresDatabase { }) } + #[tracing::instrument] + async fn delete_user(&self, user: UserId) -> Result<()> { + todo!() + } + #[tracing::instrument] async fn get_password_for_user(&self, user: &User) -> Result { let client = self.get_client().await?;