Refactor jwt.rs, add tests, fix bugs
Includes small changed to db module.
This commit is contained in:
parent
f9ac8b1e29
commit
10551a9bc9
|
|
@ -15,7 +15,7 @@ use {
|
||||||
|
|
||||||
const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1);
|
const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
BadJwt,
|
BadJwt,
|
||||||
|
|
||||||
|
|
@ -80,7 +80,7 @@ impl From<digest::MacError> for Error {
|
||||||
type Result<T> = std::result::Result<T, Error>;
|
type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
/// Result type for [authenticate_user_with_jwt()].
|
/// Result type for [authenticate_user_with_jwt()].
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Eq, PartialEq)]
|
||||||
pub enum ParsedJwt {
|
pub enum ParsedJwt {
|
||||||
/// JWT is a valid JWT, here is the [AuthenticatedUser].
|
/// JWT is a valid JWT, here is the [AuthenticatedUser].
|
||||||
Valid(AuthenticatedUser),
|
Valid(AuthenticatedUser),
|
||||||
|
|
@ -92,7 +92,7 @@ pub enum ParsedJwt {
|
||||||
UserNotFound,
|
UserNotFound,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct Header<'a> {
|
struct Header<'a> {
|
||||||
#[serde(rename = "alg")]
|
#[serde(rename = "alg")]
|
||||||
algorithm: &'a str,
|
algorithm: &'a str,
|
||||||
|
|
@ -100,7 +100,16 @@ struct Header<'a> {
|
||||||
token_type: &'a str,
|
token_type: &'a str,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
impl Header<'static> {
|
||||||
|
fn new() -> Self {
|
||||||
|
Header {
|
||||||
|
algorithm: "HS256",
|
||||||
|
token_type: "JWT",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct Payload {
|
struct Payload {
|
||||||
#[serde(rename = "sub")]
|
#[serde(rename = "sub")]
|
||||||
user_id: UserId,
|
user_id: UserId,
|
||||||
|
|
@ -109,68 +118,287 @@ struct Payload {
|
||||||
expiry: DateTime<Utc>,
|
expiry: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mac() -> Hmac<Sha256> {
|
impl Payload {
|
||||||
Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
|
fn new(user_id: UserId) -> Self {
|
||||||
.expect("HMAC can take key of any size.")
|
Self {
|
||||||
|
user_id,
|
||||||
|
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_mac(header: &str, payload: &str) -> Hmac<Sha256> {
|
||||||
|
let mut mac = Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
|
||||||
|
.expect("HMAC can take key of any size.");
|
||||||
|
mac.update(format!("{}.{}", header, payload).as_bytes());
|
||||||
|
mac
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_signature(header: &str, payload: &str) -> String {
|
||||||
|
let mac = create_mac(header, payload);
|
||||||
|
base64_encoder.encode(mac.finalize().into_bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_signature(header: &str, payload: &str, signature: &str) -> Result<bool> {
|
||||||
|
let mac = create_mac(header, payload);
|
||||||
|
Ok(mac
|
||||||
|
.verify_slice(&base64_encoder.decode(signature.as_bytes())?)
|
||||||
|
.is_ok())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sign_jwt(header: &str, payload: &str) -> String {
|
||||||
|
let signature = create_signature(header, payload);
|
||||||
|
format!("{}.{}.{}", header, payload, signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dissassemble_jwt(jwt: &str) -> Result<(&str, &str, &str)> {
|
||||||
|
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
|
||||||
|
Ok((header, payload, signature))
|
||||||
|
} else {
|
||||||
|
warn!("Invalid JWT");
|
||||||
|
Err(Error::BadJwt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_jwt_component<T: Serialize>(value: &T) -> Result<String> {
|
||||||
|
Ok(base64_encoder.encode(serde_json::to_string(value)?.as_bytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given an [AuthenticatedUser], create a JWT for use as a cookie to
|
/// Given an [AuthenticatedUser], create a JWT for use as a cookie to
|
||||||
/// keep that user logged in.
|
/// keep that user logged in.
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
|
pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
|
||||||
let header = base64_encoder.encode(
|
let header = encode_jwt_component(&Header::new())?;
|
||||||
serde_json::to_string(&Header {
|
let payload = encode_jwt_component(&Payload::new(user.get_id()))?;
|
||||||
algorithm: "HS256",
|
Ok(sign_jwt(&header, &payload))
|
||||||
token_type: "JWT",
|
|
||||||
})?
|
|
||||||
.as_bytes(),
|
|
||||||
);
|
|
||||||
let payload = base64_encoder.encode(
|
|
||||||
serde_json::to_string(&Payload {
|
|
||||||
user_id: user.get_id(),
|
|
||||||
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
|
|
||||||
})?
|
|
||||||
.as_bytes(),
|
|
||||||
);
|
|
||||||
let mut mac = mac();
|
|
||||||
mac.update(format!("{}.{}", header, payload).as_bytes());
|
|
||||||
let signature = base64_encoder.encode(mac.finalize().into_bytes());
|
|
||||||
Ok(format!("{}.{}.{}", header, payload, signature))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given JWT string created by [create_jwt_for_user()], check if the
|
/// Given JWT string created by [create_jwt_for_user()], check if the
|
||||||
/// JWT is valid and return an [AuthenticatedUser] if it is.
|
/// JWT is valid and return an [AuthenticatedUser] if it is.
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
|
pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
|
||||||
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
|
let (header, payload, signature) = dissassemble_jwt(jwt)?;
|
||||||
let mut mac = mac();
|
if !check_signature(header, payload, signature)? {
|
||||||
mac.update(format!("{}.{}", header, payload).as_bytes());
|
Ok(ParsedJwt::InvalidSignature)
|
||||||
if mac.verify_slice(signature.as_bytes()).is_err() {
|
|
||||||
Ok(ParsedJwt::InvalidSignature)
|
|
||||||
} else {
|
|
||||||
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
|
|
||||||
let header: Header = serde_json::from_str(&header_json)?;
|
|
||||||
if header.algorithm != "HS256" || header.token_type != "JWT" {
|
|
||||||
warn!("JWT does not have expected algorithm or type.");
|
|
||||||
Err(Error::BadJwt)
|
|
||||||
} else {
|
|
||||||
let payload: Payload =
|
|
||||||
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
|
|
||||||
Ok(dbg!(
|
|
||||||
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
|
|
||||||
if payload.expiry < Utc::now() {
|
|
||||||
ParsedJwt::Valid(AuthenticatedUser(user))
|
|
||||||
} else {
|
|
||||||
ParsedJwt::Expired(user)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ParsedJwt::UserNotFound
|
|
||||||
},
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
warn!("Invalid JWT");
|
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
|
||||||
Err(Error::BadJwt)
|
let header: Header = serde_json::from_str(&header_json)?;
|
||||||
|
if header.algorithm != "HS256" || header.token_type != "JWT" {
|
||||||
|
warn!("JWT does not have expected algorithm or type.");
|
||||||
|
Err(Error::BadJwt)
|
||||||
|
} else {
|
||||||
|
let payload: Payload =
|
||||||
|
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
|
||||||
|
Ok(
|
||||||
|
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
|
||||||
|
if payload.expiry > Utc::now() {
|
||||||
|
ParsedJwt::Valid(AuthenticatedUser(user))
|
||||||
|
} else {
|
||||||
|
ParsedJwt::Expired(user)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ParsedJwt::UserNotFound
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{db, db::fake::FakeDatabase};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn created_jwt_authenticates() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt = create_jwt_for_user(&user).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::Valid(user),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn modified_jwt_does_not_authenticate() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user1 = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt1 = create_jwt_for_user(&user1).unwrap();
|
||||||
|
// Create another user with same name, etc. but different user id
|
||||||
|
let user2 = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt2 = create_jwt_for_user(&user2).unwrap();
|
||||||
|
// Attach signature 2 to jwt 2
|
||||||
|
let parts1: Vec<_> = jwt1.split('.').collect();
|
||||||
|
assert_eq!(3, parts1.len());
|
||||||
|
let parts2: Vec<_> = jwt2.split('.').collect();
|
||||||
|
assert_eq!(3, parts2.len());
|
||||||
|
let parts3 = vec![parts2[0], parts2[1], parts1[2]];
|
||||||
|
|
||||||
|
//Rebuild all JWTs, so the orignial ones can act as a control
|
||||||
|
let jwt1 = parts1.join(".");
|
||||||
|
let jwt2 = parts2.join(".");
|
||||||
|
let jwt3 = parts3.join(".");
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::Valid(user1),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt1).await.unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::Valid(user2),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt2).await.unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::InvalidSignature,
|
||||||
|
authenticate_user_with_jwt(&db, &jwt3).await.unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn deleted_user_cannot_authenticate() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt = create_jwt_for_user(&user).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::Valid(user.clone()),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
db.delete_user(user.get_id()).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::UserNotFound,
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn unsigned_jwt_is_rejected() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt = format!(
|
||||||
|
"{},{}",
|
||||||
|
encode_jwt_component(&Header::new()).unwrap(),
|
||||||
|
encode_jwt_component(&Payload::new(user.get_id())).unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Err(Error::BadJwt),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn jwt_with_unrecognized_algorithm_is_rejected() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt = sign_jwt(
|
||||||
|
&encode_jwt_component(&Header {
|
||||||
|
algorithm: "RS256",
|
||||||
|
token_type: "JWT",
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
&encode_jwt_component(&Payload::new(user.get_id())).unwrap(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Err(Error::BadJwt),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn jwt_with_unrecognized_token_type_is_rejected() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt = sign_jwt(
|
||||||
|
&encode_jwt_component(&Header {
|
||||||
|
algorithm: "HS256",
|
||||||
|
token_type: "RWT",
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
&encode_jwt_component(&Payload::new(user.get_id())).unwrap(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Err(Error::BadJwt),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn expired_jwt_is_rejected() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = db
|
||||||
|
.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let jwt = sign_jwt(
|
||||||
|
&encode_jwt_component(&Header::new()).unwrap(),
|
||||||
|
&encode_jwt_component(&Payload {
|
||||||
|
user_id: user.get_id(),
|
||||||
|
expiry: Utc::now() - COOKIE_EXPIRY_TIME - Duration::seconds(1),
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Ok(ParsedJwt::Expired(user)),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ impl From<argon2::password_hash::Error> for AuthenticationError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
pub struct AuthenticatedUser(db::User);
|
pub struct AuthenticatedUser(db::User);
|
||||||
|
|
||||||
impl Deref for AuthenticatedUser {
|
impl Deref for AuthenticatedUser {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use {
|
use {
|
||||||
std::collections::HashSet,
|
std::collections::{HashMap, HashSet},
|
||||||
std::sync::{Arc, Mutex},
|
std::{
|
||||||
|
hash::Hash,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
@ -11,16 +14,63 @@ struct UserRow {
|
||||||
password_hash: String,
|
password_hash: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait Id: Copy + Eq + Hash {
|
||||||
|
fn first() -> Self;
|
||||||
|
fn increment(&mut self);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Id for UserId {
|
||||||
|
fn first() -> Self {
|
||||||
|
UserId(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn increment(&mut self) {
|
||||||
|
self.0 += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Table<Id, Row> {
|
||||||
|
data: HashMap<Id, Row>,
|
||||||
|
next_id: Id,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I, R> Table<I, R>
|
||||||
|
where
|
||||||
|
I: Id,
|
||||||
|
{
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
data: HashMap::new(),
|
||||||
|
next_id: I::first(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn insert(&mut self, row: R) -> I {
|
||||||
|
let id = self.next_id;
|
||||||
|
self.data.insert(id, row);
|
||||||
|
self.next_id.increment();
|
||||||
|
id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn delete(&mut self, id: I) {
|
||||||
|
self.data.remove(&id);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self, id: &I) -> Option<&R> {
|
||||||
|
self.data.get(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct FakeDatabase {
|
pub struct FakeDatabase {
|
||||||
users: Arc<Mutex<Vec<UserRow>>>,
|
users: Arc<Mutex<Table<UserId, UserRow>>>,
|
||||||
admin_users: Arc<Mutex<std::collections::HashSet<usize>>>,
|
admin_users: Arc<Mutex<std::collections::HashSet<usize>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeDatabase {
|
impl FakeDatabase {
|
||||||
pub fn new_empty() -> Self {
|
pub fn new_empty() -> Self {
|
||||||
FakeDatabase {
|
FakeDatabase {
|
||||||
users: Arc::new(Mutex::new(Vec::new())),
|
users: Arc::new(Mutex::new(Table::new())),
|
||||||
admin_users: Arc::new(Mutex::new(HashSet::new())),
|
admin_users: Arc::new(Mutex::new(HashSet::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -42,20 +92,26 @@ impl Database for FakeDatabase {
|
||||||
password: &PasswordHash,
|
password: &PasswordHash,
|
||||||
) -> Result<User> {
|
) -> Result<User> {
|
||||||
let mut users = self.users.lock().unwrap();
|
let mut users = self.users.lock().unwrap();
|
||||||
users.push(UserRow {
|
let new_id = users.insert(UserRow {
|
||||||
real_name: real_name.to_string(),
|
real_name: real_name.to_string(),
|
||||||
email: email.to_string(),
|
email: email.to_string(),
|
||||||
password_hash: password.to_string(),
|
password_hash: password.to_string(),
|
||||||
});
|
});
|
||||||
Ok(User {
|
Ok(User {
|
||||||
id: UserId((users.len() - 1) as i32),
|
id: new_id,
|
||||||
real_name: real_name.to_string(),
|
real_name: real_name.to_string(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn delete_user(&self, user_id: UserId) -> Result<()> {
|
||||||
|
let mut users = self.users.lock().unwrap();
|
||||||
|
users.delete(user_id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
||||||
let users = self.users.lock().unwrap();
|
let users = self.users.lock().unwrap();
|
||||||
if let Some(UserRow { password_hash, .. }) = users.get(user.id.0 as usize) {
|
if let Some(UserRow { password_hash, .. }) = users.get(&user.id) {
|
||||||
Ok(PasswordHash(password_hash.clone()))
|
Ok(PasswordHash(password_hash.clone()))
|
||||||
} else {
|
} else {
|
||||||
Err(Error::Database)
|
Err(Error::Database)
|
||||||
|
|
@ -81,12 +137,10 @@ impl Database for FakeDatabase {
|
||||||
|
|
||||||
async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
|
async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
|
||||||
let users = self.users.lock().unwrap();
|
let users = self.users.lock().unwrap();
|
||||||
Ok(users
|
Ok(users.get(&user_id).map(|UserRow { real_name, .. }| User {
|
||||||
.get(user_id.0 as usize)
|
id: user_id,
|
||||||
.map(|UserRow { real_name, .. }| User {
|
real_name: real_name.clone(),
|
||||||
id: user_id,
|
}))
|
||||||
real_name: real_name.clone(),
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,7 @@ pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static {
|
||||||
email: &str,
|
email: &str,
|
||||||
password: &PasswordHash,
|
password: &PasswordHash,
|
||||||
) -> impl Future<Output = Result<User>> + Send;
|
) -> impl Future<Output = Result<User>> + Send;
|
||||||
|
fn delete_user(&self, user: UserId) -> impl Future<Output = Result<()>> + Send;
|
||||||
fn get_password_for_user(
|
fn get_password_for_user(
|
||||||
&self,
|
&self,
|
||||||
user: &User,
|
user: &User,
|
||||||
|
|
@ -123,10 +124,10 @@ pub struct PostgresDatabase {
|
||||||
connection_pool: Pool,
|
connection_pool: Pool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
|
||||||
pub struct UserId(i32);
|
pub struct UserId(i32);
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct User {
|
pub struct User {
|
||||||
id: UserId,
|
id: UserId,
|
||||||
pub real_name: String,
|
pub real_name: String,
|
||||||
|
|
@ -206,6 +207,11 @@ impl Database for PostgresDatabase {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
async fn delete_user(&self, user: UserId) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
||||||
let client = self.get_client().await?;
|
let client = self.get_client().await?;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue