Compare commits
5 Commits
3d4a34611b
...
10551a9bc9
| Author | SHA1 | Date |
|---|---|---|
|
|
10551a9bc9 | |
|
|
f9ac8b1e29 | |
|
|
04e4bc1f55 | |
|
|
fd770124ae | |
|
|
7f7277da2d |
8
dev.py
8
dev.py
|
|
@ -32,12 +32,20 @@ def unit_tests(args):
|
||||||
import_run().unit_tests(args)
|
import_run().unit_tests(args)
|
||||||
|
|
||||||
|
|
||||||
|
def database_tests(args):
|
||||||
|
import_run().database_tests(args)
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
subparsers = parser.add_subparsers(required=True)
|
subparsers = parser.add_subparsers(required=True)
|
||||||
run_parser = subparsers.add_parser(
|
run_parser = subparsers.add_parser(
|
||||||
'run', help='Run a test instance of locality'
|
'run', help='Run a test instance of locality'
|
||||||
)
|
)
|
||||||
run_parser.set_defaults(func=run)
|
run_parser.set_defaults(func=run)
|
||||||
|
run_parser = subparsers.add_parser(
|
||||||
|
'dbtest', help='Run database tests'
|
||||||
|
)
|
||||||
|
run_parser.set_defaults(func=database_tests)
|
||||||
run_parser = subparsers.add_parser(
|
run_parser = subparsers.add_parser(
|
||||||
'unittest', help='Run unit tests'
|
'unittest', help='Run unit tests'
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,24 +8,22 @@ from .postgres_container import PostgresContainer
|
||||||
ROOT_DIR = None
|
ROOT_DIR = None
|
||||||
|
|
||||||
|
|
||||||
def cargo(*args):
|
def cargo(*args, env=None):
|
||||||
global ROOT_DIR
|
global ROOT_DIR
|
||||||
with PostgresContainer() as postgres:
|
if env is None:
|
||||||
locality_env = {
|
env = {
|
||||||
'LOCALITY_DATABASE_URL': postgres.get_url(),
|
'LOCALITY_DATABASE_URL': "",
|
||||||
'LOCALITY_TEST_DATABASE_URL': postgres.get_url(),
|
'LOCALITY_TEST_DATABASE_URL': "",
|
||||||
'LOCALITY_STATIC_FILE_PATH': os.path.join(
|
'LOCALITY_STATIC_FILE_PATH': os.path.join(
|
||||||
ROOT_DIR,
|
ROOT_DIR,
|
||||||
'static'),
|
'static'),
|
||||||
'LOCALITY_HMAC_SECRET': 'iknf4390-8guvmr3'
|
'LOCALITY_HMAC_SECRET': 'iknf4390-8guvmr3'
|
||||||
}
|
}
|
||||||
locality_env = os.environ.copy() | locality_env
|
env = os.environ.copy() | env
|
||||||
|
|
||||||
cargo_bin = shutil.which('cargo')
|
cargo_bin = shutil.which('cargo')
|
||||||
locality_process = subprocess.Popen(
|
locality_process = subprocess.Popen(
|
||||||
[cargo_bin, *args], env=locality_env, cwd=ROOT_DIR
|
[cargo_bin] + list(args), env=env, cwd=ROOT_DIR
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while locality_process.poll() is None:
|
while locality_process.poll() is None:
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
@ -36,9 +34,27 @@ def cargo(*args):
|
||||||
locality_process.terminate()
|
locality_process.terminate()
|
||||||
|
|
||||||
|
|
||||||
def run(args):
|
def cargo_with_db(*args):
|
||||||
cargo('run')
|
global ROOT_DIR
|
||||||
|
with PostgresContainer() as postgres:
|
||||||
|
locality_env = {
|
||||||
|
'LOCALITY_DATABASE_URL': postgres.get_url(),
|
||||||
|
'LOCALITY_TEST_DATABASE_URL': postgres.get_url(),
|
||||||
|
'LOCALITY_STATIC_FILE_PATH': os.path.join(
|
||||||
|
ROOT_DIR,
|
||||||
|
'static'),
|
||||||
|
'LOCALITY_HMAC_SECRET': 'iknf4390-8guvmr3'
|
||||||
|
}
|
||||||
|
cargo(env=locality_env, *args)
|
||||||
|
|
||||||
|
|
||||||
def unit_tests(args):
|
def run(*args):
|
||||||
|
cargo_with_db('run')
|
||||||
|
|
||||||
|
|
||||||
|
def database_tests(*args):
|
||||||
|
cargo_with_db("test", "db::migrations::test", "--", "--include-ignored")
|
||||||
|
|
||||||
|
|
||||||
|
def unit_tests(*args):
|
||||||
cargo("test")
|
cargo("test")
|
||||||
|
|
|
||||||
|
|
@ -115,9 +115,7 @@ async fn post_create_first_admin_user<D: Database>(
|
||||||
.await?;
|
.await?;
|
||||||
let user = authenticate_user_with_password(&db, user, ¶ms.password)
|
let user = authenticate_user_with_password(&db, user, ¶ms.password)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(Error::new_unexpected(
|
.ok_or_else(|| Error::new_unexpected("Could not authenticate newly-created user."))?;
|
||||||
"Could not authenticate newly-created user.",
|
|
||||||
))?;
|
|
||||||
Ok((
|
Ok((
|
||||||
cookie_jar
|
cookie_jar
|
||||||
.add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),
|
.add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,20 @@
|
||||||
use {
|
use {
|
||||||
crate::db::Database,
|
crate::{
|
||||||
|
authentication::{authenticate_user_with_password, create_jwt_for_user, Password},
|
||||||
|
db::Database,
|
||||||
|
error::Error,
|
||||||
|
},
|
||||||
askama::Template,
|
askama::Template,
|
||||||
axum::{routing::get, Router},
|
axum::{
|
||||||
|
extract::State,
|
||||||
|
routing::{get, post},
|
||||||
|
Form, Router,
|
||||||
|
},
|
||||||
|
axum_extra::extract::{
|
||||||
|
cookie::{Cookie, SameSite},
|
||||||
|
CookieJar,
|
||||||
|
},
|
||||||
|
serde::Deserialize,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub mod admin;
|
pub mod admin;
|
||||||
|
|
@ -14,15 +27,59 @@ pub struct AppState<D: Database> {
|
||||||
pub fn routes<D: Database>() -> Router<AppState<D>> {
|
pub fn routes<D: Database>() -> Router<AppState<D>> {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/", get(root))
|
.route("/", get(root))
|
||||||
|
.route("/sign_up", get(sign_up))
|
||||||
|
.route("/create_new_user", post(create_new_user))
|
||||||
.nest("/admin", admin::routes())
|
.nest("/admin", admin::routes())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Template)]
|
#[derive(Template)]
|
||||||
#[template(path = "index.html")]
|
#[template(path = "index.html")]
|
||||||
struct IndexTemplate<'a> {
|
struct IndexTemplate {}
|
||||||
title: &'a str,
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
async fn root() -> IndexTemplate {
|
||||||
|
IndexTemplate {}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn root<'a>() -> IndexTemplate<'a> {
|
#[derive(Template)]
|
||||||
IndexTemplate { title: "Locality" }
|
#[template(path = "sign-up.html")]
|
||||||
|
struct SignUpTemplate {}
|
||||||
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
async fn sign_up() -> SignUpTemplate {
|
||||||
|
SignUpTemplate {}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Template)]
|
||||||
|
#[template(path = "new-user.html")]
|
||||||
|
struct NewUserTemplate {}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct CreateNewUserParameters {
|
||||||
|
real_name: String,
|
||||||
|
email: String,
|
||||||
|
password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
async fn create_new_user<D: Database>(
|
||||||
|
cookie_jar: CookieJar,
|
||||||
|
State(AppState::<D> { db, .. }): State<AppState<D>>,
|
||||||
|
Form(params): Form<CreateNewUserParameters>,
|
||||||
|
) -> Result<(CookieJar, NewUserTemplate), Error> {
|
||||||
|
let user = db
|
||||||
|
.create_user(
|
||||||
|
¶ms.real_name,
|
||||||
|
¶ms.email,
|
||||||
|
&Password::new(¶ms.password)?.into(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
let user = authenticate_user_with_password(&db, user, ¶ms.password)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| Error::new_unexpected("Could not authenticate newly-created user."))?;
|
||||||
|
Ok((
|
||||||
|
cookie_jar
|
||||||
|
.add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),
|
||||||
|
NewUserTemplate {},
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ use {
|
||||||
|
|
||||||
const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1);
|
const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
BadJwt,
|
BadJwt,
|
||||||
|
|
||||||
|
|
@ -80,7 +80,7 @@ impl From<digest::MacError> for Error {
|
||||||
type Result<T> = std::result::Result<T, Error>;
|
type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
/// Result type for [authenticate_user_with_jwt()].
|
/// Result type for [authenticate_user_with_jwt()].
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Eq, PartialEq)]
|
||||||
pub enum ParsedJwt {
|
pub enum ParsedJwt {
|
||||||
/// JWT is a valid JWT, here is the [AuthenticatedUser].
|
/// JWT is a valid JWT, here is the [AuthenticatedUser].
|
||||||
Valid(AuthenticatedUser),
|
Valid(AuthenticatedUser),
|
||||||
|
|
@ -92,7 +92,7 @@ pub enum ParsedJwt {
|
||||||
UserNotFound,
|
UserNotFound,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct Header<'a> {
|
struct Header<'a> {
|
||||||
#[serde(rename = "alg")]
|
#[serde(rename = "alg")]
|
||||||
algorithm: &'a str,
|
algorithm: &'a str,
|
||||||
|
|
@ -100,7 +100,16 @@ struct Header<'a> {
|
||||||
token_type: &'a str,
|
token_type: &'a str,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
impl Header<'static> {
|
||||||
|
fn new() -> Self {
|
||||||
|
Header {
|
||||||
|
algorithm: "HS256",
|
||||||
|
token_type: "JWT",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct Payload {
|
struct Payload {
|
||||||
#[serde(rename = "sub")]
|
#[serde(rename = "sub")]
|
||||||
user_id: UserId,
|
user_id: UserId,
|
||||||
|
|
@ -109,43 +118,67 @@ struct Payload {
|
||||||
expiry: DateTime<Utc>,
|
expiry: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mac() -> Hmac<Sha256> {
|
impl Payload {
|
||||||
Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
|
fn new(user_id: UserId) -> Self {
|
||||||
.expect("HMAC can take key of any size.")
|
Self {
|
||||||
|
user_id,
|
||||||
|
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_mac(header: &str, payload: &str) -> Hmac<Sha256> {
|
||||||
|
let mut mac = Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
|
||||||
|
.expect("HMAC can take key of any size.");
|
||||||
|
mac.update(format!("{}.{}", header, payload).as_bytes());
|
||||||
|
mac
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_signature(header: &str, payload: &str) -> String {
|
||||||
|
let mac = create_mac(header, payload);
|
||||||
|
base64_encoder.encode(mac.finalize().into_bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_signature(header: &str, payload: &str, signature: &str) -> Result<bool> {
|
||||||
|
let mac = create_mac(header, payload);
|
||||||
|
Ok(mac
|
||||||
|
.verify_slice(&base64_encoder.decode(signature.as_bytes())?)
|
||||||
|
.is_ok())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sign_jwt(header: &str, payload: &str) -> String {
|
||||||
|
let signature = create_signature(header, payload);
|
||||||
|
format!("{}.{}.{}", header, payload, signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dissassemble_jwt(jwt: &str) -> Result<(&str, &str, &str)> {
|
||||||
|
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
|
||||||
|
Ok((header, payload, signature))
|
||||||
|
} else {
|
||||||
|
warn!("Invalid JWT");
|
||||||
|
Err(Error::BadJwt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_jwt_component<T: Serialize>(value: &T) -> Result<String> {
|
||||||
|
Ok(base64_encoder.encode(serde_json::to_string(value)?.as_bytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given an [AuthenticatedUser], create a JWT for use as a cookie to
|
/// Given an [AuthenticatedUser], create a JWT for use as a cookie to
|
||||||
/// keep that user logged in.
|
/// keep that user logged in.
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
|
pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
|
||||||
let header = base64_encoder.encode(
|
let header = encode_jwt_component(&Header::new())?;
|
||||||
serde_json::to_string(&Header {
|
let payload = encode_jwt_component(&Payload::new(user.get_id()))?;
|
||||||
algorithm: "HS256",
|
Ok(sign_jwt(&header, &payload))
|
||||||
token_type: "JWT",
|
|
||||||
})?
|
|
||||||
.as_bytes(),
|
|
||||||
);
|
|
||||||
let payload = base64_encoder.encode(
|
|
||||||
serde_json::to_string(&Payload {
|
|
||||||
user_id: user.get_id(),
|
|
||||||
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
|
|
||||||
})?
|
|
||||||
.as_bytes(),
|
|
||||||
);
|
|
||||||
let mut mac = mac();
|
|
||||||
mac.update(format!("{}.{}", header, payload).as_bytes());
|
|
||||||
let signature = base64_encoder.encode(mac.finalize().into_bytes());
|
|
||||||
Ok(format!("{}.{}.{}", header, payload, signature))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given JWT string created by [create_jwt_for_user()], check if the
|
/// Given JWT string created by [create_jwt_for_user()], check if the
|
||||||
/// JWT is valid and return an [AuthenticatedUser] if it is.
|
/// JWT is valid and return an [AuthenticatedUser] if it is.
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
|
pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
|
||||||
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
|
let (header, payload, signature) = dissassemble_jwt(jwt)?;
|
||||||
let mut mac = mac();
|
if !check_signature(header, payload, signature)? {
|
||||||
mac.update(format!("{}.{}", header, payload).as_bytes());
|
|
||||||
if mac.verify_slice(signature.as_bytes()).is_err() {
|
|
||||||
Ok(ParsedJwt::InvalidSignature)
|
Ok(ParsedJwt::InvalidSignature)
|
||||||
} else {
|
} else {
|
||||||
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
|
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
|
||||||
|
|
@ -156,9 +189,9 @@ pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Resul
|
||||||
} else {
|
} else {
|
||||||
let payload: Payload =
|
let payload: Payload =
|
||||||
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
|
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
|
||||||
Ok(dbg!(
|
Ok(
|
||||||
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
|
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
|
||||||
if payload.expiry < Utc::now() {
|
if payload.expiry > Utc::now() {
|
||||||
ParsedJwt::Valid(AuthenticatedUser(user))
|
ParsedJwt::Valid(AuthenticatedUser(user))
|
||||||
} else {
|
} else {
|
||||||
ParsedJwt::Expired(user)
|
ParsedJwt::Expired(user)
|
||||||
|
|
@ -166,11 +199,206 @@ pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Resul
|
||||||
} else {
|
} else {
|
||||||
ParsedJwt::UserNotFound
|
ParsedJwt::UserNotFound
|
||||||
},
|
},
|
||||||
))
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
warn!("Invalid JWT");
|
|
||||||
Err(Error::BadJwt)
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{db, db::fake::FakeDatabase};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn created_jwt_authenticates() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt = create_jwt_for_user(&user).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::Valid(user),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn modified_jwt_does_not_authenticate() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user1 = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt1 = create_jwt_for_user(&user1).unwrap();
|
||||||
|
// Create another user with same name, etc. but different user id
|
||||||
|
let user2 = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt2 = create_jwt_for_user(&user2).unwrap();
|
||||||
|
// Attach signature 2 to jwt 2
|
||||||
|
let parts1: Vec<_> = jwt1.split('.').collect();
|
||||||
|
assert_eq!(3, parts1.len());
|
||||||
|
let parts2: Vec<_> = jwt2.split('.').collect();
|
||||||
|
assert_eq!(3, parts2.len());
|
||||||
|
let parts3 = vec![parts2[0], parts2[1], parts1[2]];
|
||||||
|
|
||||||
|
//Rebuild all JWTs, so the orignial ones can act as a control
|
||||||
|
let jwt1 = parts1.join(".");
|
||||||
|
let jwt2 = parts2.join(".");
|
||||||
|
let jwt3 = parts3.join(".");
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::Valid(user1),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt1).await.unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::Valid(user2),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt2).await.unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::InvalidSignature,
|
||||||
|
authenticate_user_with_jwt(&db, &jwt3).await.unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn deleted_user_cannot_authenticate() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt = create_jwt_for_user(&user).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::Valid(user.clone()),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
db.delete_user(user.get_id()).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
ParsedJwt::UserNotFound,
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn unsigned_jwt_is_rejected() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt = format!(
|
||||||
|
"{},{}",
|
||||||
|
encode_jwt_component(&Header::new()).unwrap(),
|
||||||
|
encode_jwt_component(&Payload::new(user.get_id())).unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Err(Error::BadJwt),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn jwt_with_unrecognized_algorithm_is_rejected() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt = sign_jwt(
|
||||||
|
&encode_jwt_component(&Header {
|
||||||
|
algorithm: "RS256",
|
||||||
|
token_type: "JWT",
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
&encode_jwt_component(&Payload::new(user.get_id())).unwrap(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Err(Error::BadJwt),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn jwt_with_unrecognized_token_type_is_rejected() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = AuthenticatedUser(
|
||||||
|
db.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let jwt = sign_jwt(
|
||||||
|
&encode_jwt_component(&Header {
|
||||||
|
algorithm: "HS256",
|
||||||
|
token_type: "RWT",
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
&encode_jwt_component(&Payload::new(user.get_id())).unwrap(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Err(Error::BadJwt),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn expired_jwt_is_rejected() {
|
||||||
|
let db = FakeDatabase::new_empty();
|
||||||
|
let user = db
|
||||||
|
.create_user(
|
||||||
|
"John Smith",
|
||||||
|
"john.smith@example.com",
|
||||||
|
&db::PasswordHash("abc123".to_string()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let jwt = sign_jwt(
|
||||||
|
&encode_jwt_component(&Header::new()).unwrap(),
|
||||||
|
&encode_jwt_component(&Payload {
|
||||||
|
user_id: user.get_id(),
|
||||||
|
expiry: Utc::now() - COOKIE_EXPIRY_TIME - Duration::seconds(1),
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Ok(ParsedJwt::Expired(user)),
|
||||||
|
authenticate_user_with_jwt(&db, &jwt).await
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ impl From<argon2::password_hash::Error> for AuthenticationError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
pub struct AuthenticatedUser(db::User);
|
pub struct AuthenticatedUser(db::User);
|
||||||
|
|
||||||
impl Deref for AuthenticatedUser {
|
impl Deref for AuthenticatedUser {
|
||||||
|
|
@ -81,6 +81,7 @@ pub struct Password {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Password {
|
impl Password {
|
||||||
|
#[tracing::instrument]
|
||||||
pub fn new(password: &str) -> Result<Password, AuthenticationError> {
|
pub fn new(password: &str) -> Result<Password, AuthenticationError> {
|
||||||
let salt = SaltString::generate(&mut OsRng);
|
let salt = SaltString::generate(&mut OsRng);
|
||||||
let argon2 = Argon2::default();
|
let argon2 = Argon2::default();
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use {
|
use {
|
||||||
std::collections::HashSet,
|
std::collections::{HashMap, HashSet},
|
||||||
std::sync::{Arc, Mutex},
|
std::{
|
||||||
|
hash::Hash,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
@ -11,16 +14,63 @@ struct UserRow {
|
||||||
password_hash: String,
|
password_hash: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait Id: Copy + Eq + Hash {
|
||||||
|
fn first() -> Self;
|
||||||
|
fn increment(&mut self);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Id for UserId {
|
||||||
|
fn first() -> Self {
|
||||||
|
UserId(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn increment(&mut self) {
|
||||||
|
self.0 += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Table<Id, Row> {
|
||||||
|
data: HashMap<Id, Row>,
|
||||||
|
next_id: Id,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I, R> Table<I, R>
|
||||||
|
where
|
||||||
|
I: Id,
|
||||||
|
{
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
data: HashMap::new(),
|
||||||
|
next_id: I::first(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn insert(&mut self, row: R) -> I {
|
||||||
|
let id = self.next_id;
|
||||||
|
self.data.insert(id, row);
|
||||||
|
self.next_id.increment();
|
||||||
|
id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn delete(&mut self, id: I) {
|
||||||
|
self.data.remove(&id);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self, id: &I) -> Option<&R> {
|
||||||
|
self.data.get(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct FakeDatabase {
|
pub struct FakeDatabase {
|
||||||
users: Arc<Mutex<Vec<UserRow>>>,
|
users: Arc<Mutex<Table<UserId, UserRow>>>,
|
||||||
admin_users: Arc<Mutex<std::collections::HashSet<usize>>>,
|
admin_users: Arc<Mutex<std::collections::HashSet<usize>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeDatabase {
|
impl FakeDatabase {
|
||||||
pub fn new_empty() -> Self {
|
pub fn new_empty() -> Self {
|
||||||
FakeDatabase {
|
FakeDatabase {
|
||||||
users: Arc::new(Mutex::new(Vec::new())),
|
users: Arc::new(Mutex::new(Table::new())),
|
||||||
admin_users: Arc::new(Mutex::new(HashSet::new())),
|
admin_users: Arc::new(Mutex::new(HashSet::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -42,20 +92,26 @@ impl Database for FakeDatabase {
|
||||||
password: &PasswordHash,
|
password: &PasswordHash,
|
||||||
) -> Result<User> {
|
) -> Result<User> {
|
||||||
let mut users = self.users.lock().unwrap();
|
let mut users = self.users.lock().unwrap();
|
||||||
users.push(UserRow {
|
let new_id = users.insert(UserRow {
|
||||||
real_name: real_name.to_string(),
|
real_name: real_name.to_string(),
|
||||||
email: email.to_string(),
|
email: email.to_string(),
|
||||||
password_hash: password.to_string(),
|
password_hash: password.to_string(),
|
||||||
});
|
});
|
||||||
Ok(User {
|
Ok(User {
|
||||||
id: UserId((users.len() - 1) as i32),
|
id: new_id,
|
||||||
real_name: real_name.to_string(),
|
real_name: real_name.to_string(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn delete_user(&self, user_id: UserId) -> Result<()> {
|
||||||
|
let mut users = self.users.lock().unwrap();
|
||||||
|
users.delete(user_id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
||||||
let users = self.users.lock().unwrap();
|
let users = self.users.lock().unwrap();
|
||||||
if let Some(UserRow { password_hash, .. }) = users.get(user.id.0 as usize) {
|
if let Some(UserRow { password_hash, .. }) = users.get(&user.id) {
|
||||||
Ok(PasswordHash(password_hash.clone()))
|
Ok(PasswordHash(password_hash.clone()))
|
||||||
} else {
|
} else {
|
||||||
Err(Error::Database)
|
Err(Error::Database)
|
||||||
|
|
@ -81,9 +137,7 @@ impl Database for FakeDatabase {
|
||||||
|
|
||||||
async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
|
async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
|
||||||
let users = self.users.lock().unwrap();
|
let users = self.users.lock().unwrap();
|
||||||
Ok(users
|
Ok(users.get(&user_id).map(|UserRow { real_name, .. }| User {
|
||||||
.get(user_id.0 as usize)
|
|
||||||
.map(|UserRow { real_name, .. }| User {
|
|
||||||
id: user_id,
|
id: user_id,
|
||||||
real_name: real_name.clone(),
|
real_name: real_name.clone(),
|
||||||
}))
|
}))
|
||||||
|
|
|
||||||
|
|
@ -220,6 +220,7 @@ mod tests {
|
||||||
PostgresDatabase::new(&url).unwrap()
|
PostgresDatabase::new(&url).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[ignore]
|
||||||
#[test]
|
#[test]
|
||||||
fn migrations_have_sequential_versions() {
|
fn migrations_have_sequential_versions() {
|
||||||
for i in 0..MIGRATIONS.len() {
|
for i in 0..MIGRATIONS.len() {
|
||||||
|
|
@ -227,6 +228,7 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[ignore]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn migrate_up_and_down_all() {
|
async fn migrate_up_and_down_all() {
|
||||||
let db = test_db().await;
|
let db = test_db().await;
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,7 @@ pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static {
|
||||||
email: &str,
|
email: &str,
|
||||||
password: &PasswordHash,
|
password: &PasswordHash,
|
||||||
) -> impl Future<Output = Result<User>> + Send;
|
) -> impl Future<Output = Result<User>> + Send;
|
||||||
|
fn delete_user(&self, user: UserId) -> impl Future<Output = Result<()>> + Send;
|
||||||
fn get_password_for_user(
|
fn get_password_for_user(
|
||||||
&self,
|
&self,
|
||||||
user: &User,
|
user: &User,
|
||||||
|
|
@ -123,10 +124,10 @@ pub struct PostgresDatabase {
|
||||||
connection_pool: Pool,
|
connection_pool: Pool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
|
||||||
pub struct UserId(i32);
|
pub struct UserId(i32);
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct User {
|
pub struct User {
|
||||||
id: UserId,
|
id: UserId,
|
||||||
pub real_name: String,
|
pub real_name: String,
|
||||||
|
|
@ -206,6 +207,11 @@ impl Database for PostgresDatabase {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
async fn delete_user(&self, user: UserId) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
||||||
let client = self.get_client().await?;
|
let client = self.get_client().await?;
|
||||||
|
|
|
||||||
10
src/error.rs
10
src/error.rs
|
|
@ -130,20 +130,14 @@ impl std::error::Error for Error {
|
||||||
|
|
||||||
#[derive(Template)]
|
#[derive(Template)]
|
||||||
#[template(path = "error.html")]
|
#[template(path = "error.html")]
|
||||||
struct ErrorTemplate<'a> {
|
struct ErrorTemplate {}
|
||||||
title: &'a str,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IntoResponse for Error {
|
impl IntoResponse for Error {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
match self.error_type {
|
match self.error_type {
|
||||||
ErrorType::InternalServerError => {
|
ErrorType::InternalServerError => {
|
||||||
error!(inner = self.inner, "Uncaught error producing HTTP 500.");
|
error!(inner = self.inner, "Uncaught error producing HTTP 500.");
|
||||||
(
|
(StatusCode::INTERNAL_SERVER_ERROR, ErrorTemplate {}).into_response()
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
ErrorTemplate { title: "Error" },
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
}
|
||||||
ErrorType::Unauthorized => {
|
ErrorType::Unauthorized => {
|
||||||
(StatusCode::UNAUTHORIZED, "User not authorized.").into_response()
|
(StatusCode::UNAUTHORIZED, "User not authorized.").into_response()
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
<head>
|
<head>
|
||||||
<meta charset="utf-8" />
|
<meta charset="utf-8" />
|
||||||
<meta name="viewport" content="width=device-width" />
|
<meta name="viewport" content="width=device-width" />
|
||||||
<title>{% block title %}{{title}} - Locality{% endblock %}</title>
|
<title>{% block title %}Locality{% endblock %}</title>
|
||||||
|
|
||||||
<link rel="icon" type="image/png" sizes="32x32" href="/static/favicon-32.png">
|
<link rel="icon" type="image/png" sizes="32x32" href="/static/favicon-32.png">
|
||||||
<link rel="icon" type="image/png" sizes="16x16" href="/static/favicon-16.png">
|
<link rel="icon" type="image/png" sizes="16x16" href="/static/favicon-16.png">
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}New User Welcome - Locality{% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
TBD
|
||||||
|
{% endblock %}
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}New User Sign-Up - Locality{% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
<form action="create_new_user" method="post">>
|
||||||
|
<ul>
|
||||||
|
<li>
|
||||||
|
<label for="real_name">Name:</label>
|
||||||
|
<input type="text" id="real_name" name="real_name" />
|
||||||
|
</li>
|
||||||
|
<li>
|
||||||
|
<label for="email">Email:</label>
|
||||||
|
<input type="email" id="email" name="email" />
|
||||||
|
</li>
|
||||||
|
<li>
|
||||||
|
<label for="password">Password:</label>
|
||||||
|
<input type="password" id="password" name="password" />
|
||||||
|
</li>
|
||||||
|
<li>
|
||||||
|
<button type="submit">Create Account</button>
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
</form>
|
||||||
|
{% endblock %}
|
||||||
Loading…
Reference in New Issue