Compare commits

...

5 Commits

13 changed files with 503 additions and 107 deletions

8
dev.py
View File

@ -32,12 +32,20 @@ def unit_tests(args):
import_run().unit_tests(args)
def database_tests(args):
import_run().database_tests(args)
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(required=True)
run_parser = subparsers.add_parser(
'run', help='Run a test instance of locality'
)
run_parser.set_defaults(func=run)
run_parser = subparsers.add_parser(
'dbtest', help='Run database tests'
)
run_parser.set_defaults(func=database_tests)
run_parser = subparsers.add_parser(
'unittest', help='Run unit tests'
)

View File

@ -8,24 +8,22 @@ from .postgres_container import PostgresContainer
ROOT_DIR = None
def cargo(*args):
def cargo(*args, env=None):
global ROOT_DIR
with PostgresContainer() as postgres:
locality_env = {
'LOCALITY_DATABASE_URL': postgres.get_url(),
'LOCALITY_TEST_DATABASE_URL': postgres.get_url(),
if env is None:
env = {
'LOCALITY_DATABASE_URL': "",
'LOCALITY_TEST_DATABASE_URL': "",
'LOCALITY_STATIC_FILE_PATH': os.path.join(
ROOT_DIR,
'static'),
'LOCALITY_HMAC_SECRET': 'iknf4390-8guvmr3'
}
locality_env = os.environ.copy() | locality_env
env = os.environ.copy() | env
cargo_bin = shutil.which('cargo')
locality_process = subprocess.Popen(
[cargo_bin, *args], env=locality_env, cwd=ROOT_DIR
[cargo_bin] + list(args), env=env, cwd=ROOT_DIR
)
try:
while locality_process.poll() is None:
time.sleep(0.5)
@ -36,9 +34,27 @@ def cargo(*args):
locality_process.terminate()
def run(args):
cargo('run')
def cargo_with_db(*args):
global ROOT_DIR
with PostgresContainer() as postgres:
locality_env = {
'LOCALITY_DATABASE_URL': postgres.get_url(),
'LOCALITY_TEST_DATABASE_URL': postgres.get_url(),
'LOCALITY_STATIC_FILE_PATH': os.path.join(
ROOT_DIR,
'static'),
'LOCALITY_HMAC_SECRET': 'iknf4390-8guvmr3'
}
cargo(env=locality_env, *args)
def unit_tests(args):
def run(*args):
cargo_with_db('run')
def database_tests(*args):
cargo_with_db("test", "db::migrations::test", "--", "--include-ignored")
def unit_tests(*args):
cargo("test")

View File

@ -115,9 +115,7 @@ async fn post_create_first_admin_user<D: Database>(
.await?;
let user = authenticate_user_with_password(&db, user, &params.password)
.await?
.ok_or(Error::new_unexpected(
"Could not authenticate newly-created user.",
))?;
.ok_or_else(|| Error::new_unexpected("Could not authenticate newly-created user."))?;
Ok((
cookie_jar
.add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),

View File

@ -1,7 +1,20 @@
use {
crate::db::Database,
crate::{
authentication::{authenticate_user_with_password, create_jwt_for_user, Password},
db::Database,
error::Error,
},
askama::Template,
axum::{routing::get, Router},
axum::{
extract::State,
routing::{get, post},
Form, Router,
},
axum_extra::extract::{
cookie::{Cookie, SameSite},
CookieJar,
},
serde::Deserialize,
};
pub mod admin;
@ -14,15 +27,59 @@ pub struct AppState<D: Database> {
pub fn routes<D: Database>() -> Router<AppState<D>> {
Router::new()
.route("/", get(root))
.route("/sign_up", get(sign_up))
.route("/create_new_user", post(create_new_user))
.nest("/admin", admin::routes())
}
#[derive(Template)]
#[template(path = "index.html")]
struct IndexTemplate<'a> {
title: &'a str,
struct IndexTemplate {}
#[tracing::instrument]
async fn root() -> IndexTemplate {
IndexTemplate {}
}
async fn root<'a>() -> IndexTemplate<'a> {
IndexTemplate { title: "Locality" }
#[derive(Template)]
#[template(path = "sign-up.html")]
struct SignUpTemplate {}
#[tracing::instrument]
async fn sign_up() -> SignUpTemplate {
SignUpTemplate {}
}
#[derive(Template)]
#[template(path = "new-user.html")]
struct NewUserTemplate {}
#[derive(Deserialize, Debug)]
struct CreateNewUserParameters {
real_name: String,
email: String,
password: String,
}
#[tracing::instrument]
async fn create_new_user<D: Database>(
cookie_jar: CookieJar,
State(AppState::<D> { db, .. }): State<AppState<D>>,
Form(params): Form<CreateNewUserParameters>,
) -> Result<(CookieJar, NewUserTemplate), Error> {
let user = db
.create_user(
&params.real_name,
&params.email,
&Password::new(&params.password)?.into(),
)
.await?;
let user = authenticate_user_with_password(&db, user, &params.password)
.await?
.ok_or_else(|| Error::new_unexpected("Could not authenticate newly-created user."))?;
Ok((
cookie_jar
.add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),
NewUserTemplate {},
))
}

View File

@ -15,7 +15,7 @@ use {
const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1);
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub enum Error {
BadJwt,
@ -80,7 +80,7 @@ impl From<digest::MacError> for Error {
type Result<T> = std::result::Result<T, Error>;
/// Result type for [authenticate_user_with_jwt()].
#[derive(Debug)]
#[derive(Debug, Eq, PartialEq)]
pub enum ParsedJwt {
/// JWT is a valid JWT, here is the [AuthenticatedUser].
Valid(AuthenticatedUser),
@ -92,7 +92,7 @@ pub enum ParsedJwt {
UserNotFound,
}
#[derive(Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
struct Header<'a> {
#[serde(rename = "alg")]
algorithm: &'a str,
@ -100,7 +100,16 @@ struct Header<'a> {
token_type: &'a str,
}
#[derive(Serialize, Deserialize)]
impl Header<'static> {
fn new() -> Self {
Header {
algorithm: "HS256",
token_type: "JWT",
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct Payload {
#[serde(rename = "sub")]
user_id: UserId,
@ -109,43 +118,67 @@ struct Payload {
expiry: DateTime<Utc>,
}
fn mac() -> Hmac<Sha256> {
Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
.expect("HMAC can take key of any size.")
impl Payload {
fn new(user_id: UserId) -> Self {
Self {
user_id,
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
}
}
}
fn create_mac(header: &str, payload: &str) -> Hmac<Sha256> {
let mut mac = Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
.expect("HMAC can take key of any size.");
mac.update(format!("{}.{}", header, payload).as_bytes());
mac
}
fn create_signature(header: &str, payload: &str) -> String {
let mac = create_mac(header, payload);
base64_encoder.encode(mac.finalize().into_bytes())
}
fn check_signature(header: &str, payload: &str, signature: &str) -> Result<bool> {
let mac = create_mac(header, payload);
Ok(mac
.verify_slice(&base64_encoder.decode(signature.as_bytes())?)
.is_ok())
}
fn sign_jwt(header: &str, payload: &str) -> String {
let signature = create_signature(header, payload);
format!("{}.{}.{}", header, payload, signature)
}
fn dissassemble_jwt(jwt: &str) -> Result<(&str, &str, &str)> {
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
Ok((header, payload, signature))
} else {
warn!("Invalid JWT");
Err(Error::BadJwt)
}
}
fn encode_jwt_component<T: Serialize>(value: &T) -> Result<String> {
Ok(base64_encoder.encode(serde_json::to_string(value)?.as_bytes()))
}
/// Given an [AuthenticatedUser], create a JWT for use as a cookie to
/// keep that user logged in.
#[tracing::instrument]
pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
let header = base64_encoder.encode(
serde_json::to_string(&Header {
algorithm: "HS256",
token_type: "JWT",
})?
.as_bytes(),
);
let payload = base64_encoder.encode(
serde_json::to_string(&Payload {
user_id: user.get_id(),
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
})?
.as_bytes(),
);
let mut mac = mac();
mac.update(format!("{}.{}", header, payload).as_bytes());
let signature = base64_encoder.encode(mac.finalize().into_bytes());
Ok(format!("{}.{}.{}", header, payload, signature))
let header = encode_jwt_component(&Header::new())?;
let payload = encode_jwt_component(&Payload::new(user.get_id()))?;
Ok(sign_jwt(&header, &payload))
}
/// Given JWT string created by [create_jwt_for_user()], check if the
/// JWT is valid and return an [AuthenticatedUser] if it is.
#[tracing::instrument]
pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
let mut mac = mac();
mac.update(format!("{}.{}", header, payload).as_bytes());
if mac.verify_slice(signature.as_bytes()).is_err() {
let (header, payload, signature) = dissassemble_jwt(jwt)?;
if !check_signature(header, payload, signature)? {
Ok(ParsedJwt::InvalidSignature)
} else {
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
@ -156,9 +189,9 @@ pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Resul
} else {
let payload: Payload =
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
Ok(dbg!(
Ok(
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
if payload.expiry < Utc::now() {
if payload.expiry > Utc::now() {
ParsedJwt::Valid(AuthenticatedUser(user))
} else {
ParsedJwt::Expired(user)
@ -166,11 +199,206 @@ pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Resul
} else {
ParsedJwt::UserNotFound
},
))
)
}
}
} else {
warn!("Invalid JWT");
Err(Error::BadJwt)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{db, db::fake::FakeDatabase};
#[tokio::test]
async fn created_jwt_authenticates() {
let db = FakeDatabase::new_empty();
let user = AuthenticatedUser(
db.create_user(
"John Smith",
"john.smith@example.com",
&db::PasswordHash("abc123".to_string()),
)
.await
.unwrap(),
);
let jwt = create_jwt_for_user(&user).unwrap();
assert_eq!(
ParsedJwt::Valid(user),
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
);
}
#[tokio::test]
async fn modified_jwt_does_not_authenticate() {
let db = FakeDatabase::new_empty();
let user1 = AuthenticatedUser(
db.create_user(
"John Smith",
"john.smith@example.com",
&db::PasswordHash("abc123".to_string()),
)
.await
.unwrap(),
);
let jwt1 = create_jwt_for_user(&user1).unwrap();
// Create another user with same name, etc. but different user id
let user2 = AuthenticatedUser(
db.create_user(
"John Smith",
"john.smith@example.com",
&db::PasswordHash("abc123".to_string()),
)
.await
.unwrap(),
);
let jwt2 = create_jwt_for_user(&user2).unwrap();
// Attach signature 2 to jwt 2
let parts1: Vec<_> = jwt1.split('.').collect();
assert_eq!(3, parts1.len());
let parts2: Vec<_> = jwt2.split('.').collect();
assert_eq!(3, parts2.len());
let parts3 = vec![parts2[0], parts2[1], parts1[2]];
//Rebuild all JWTs, so the orignial ones can act as a control
let jwt1 = parts1.join(".");
let jwt2 = parts2.join(".");
let jwt3 = parts3.join(".");
assert_eq!(
ParsedJwt::Valid(user1),
authenticate_user_with_jwt(&db, &jwt1).await.unwrap()
);
assert_eq!(
ParsedJwt::Valid(user2),
authenticate_user_with_jwt(&db, &jwt2).await.unwrap()
);
assert_eq!(
ParsedJwt::InvalidSignature,
authenticate_user_with_jwt(&db, &jwt3).await.unwrap()
);
}
#[tokio::test]
async fn deleted_user_cannot_authenticate() {
let db = FakeDatabase::new_empty();
let user = AuthenticatedUser(
db.create_user(
"John Smith",
"john.smith@example.com",
&db::PasswordHash("abc123".to_string()),
)
.await
.unwrap(),
);
let jwt = create_jwt_for_user(&user).unwrap();
assert_eq!(
ParsedJwt::Valid(user.clone()),
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
);
db.delete_user(user.get_id()).await.unwrap();
assert_eq!(
ParsedJwt::UserNotFound,
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
);
}
#[tokio::test]
async fn unsigned_jwt_is_rejected() {
let db = FakeDatabase::new_empty();
let user = AuthenticatedUser(
db.create_user(
"John Smith",
"john.smith@example.com",
&db::PasswordHash("abc123".to_string()),
)
.await
.unwrap(),
);
let jwt = format!(
"{},{}",
encode_jwt_component(&Header::new()).unwrap(),
encode_jwt_component(&Payload::new(user.get_id())).unwrap()
);
assert_eq!(
Err(Error::BadJwt),
authenticate_user_with_jwt(&db, &jwt).await
);
}
#[tokio::test]
async fn jwt_with_unrecognized_algorithm_is_rejected() {
let db = FakeDatabase::new_empty();
let user = AuthenticatedUser(
db.create_user(
"John Smith",
"john.smith@example.com",
&db::PasswordHash("abc123".to_string()),
)
.await
.unwrap(),
);
let jwt = sign_jwt(
&encode_jwt_component(&Header {
algorithm: "RS256",
token_type: "JWT",
})
.unwrap(),
&encode_jwt_component(&Payload::new(user.get_id())).unwrap(),
);
assert_eq!(
Err(Error::BadJwt),
authenticate_user_with_jwt(&db, &jwt).await
);
}
#[tokio::test]
async fn jwt_with_unrecognized_token_type_is_rejected() {
let db = FakeDatabase::new_empty();
let user = AuthenticatedUser(
db.create_user(
"John Smith",
"john.smith@example.com",
&db::PasswordHash("abc123".to_string()),
)
.await
.unwrap(),
);
let jwt = sign_jwt(
&encode_jwt_component(&Header {
algorithm: "HS256",
token_type: "RWT",
})
.unwrap(),
&encode_jwt_component(&Payload::new(user.get_id())).unwrap(),
);
assert_eq!(
Err(Error::BadJwt),
authenticate_user_with_jwt(&db, &jwt).await
);
}
#[tokio::test]
async fn expired_jwt_is_rejected() {
let db = FakeDatabase::new_empty();
let user = db
.create_user(
"John Smith",
"john.smith@example.com",
&db::PasswordHash("abc123".to_string()),
)
.await
.unwrap();
let jwt = sign_jwt(
&encode_jwt_component(&Header::new()).unwrap(),
&encode_jwt_component(&Payload {
user_id: user.get_id(),
expiry: Utc::now() - COOKIE_EXPIRY_TIME - Duration::seconds(1),
})
.unwrap(),
);
assert_eq!(
Ok(ParsedJwt::Expired(user)),
authenticate_user_with_jwt(&db, &jwt).await
);
}
}

View File

@ -54,7 +54,7 @@ impl From<argon2::password_hash::Error> for AuthenticationError {
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct AuthenticatedUser(db::User);
impl Deref for AuthenticatedUser {
@ -81,6 +81,7 @@ pub struct Password {
}
impl Password {
#[tracing::instrument]
pub fn new(password: &str) -> Result<Password, AuthenticationError> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();

View File

@ -1,7 +1,10 @@
use super::*;
use {
std::collections::HashSet,
std::sync::{Arc, Mutex},
std::collections::{HashMap, HashSet},
std::{
hash::Hash,
sync::{Arc, Mutex},
},
};
#[derive(Debug)]
@ -11,16 +14,63 @@ struct UserRow {
password_hash: String,
}
trait Id: Copy + Eq + Hash {
fn first() -> Self;
fn increment(&mut self);
}
impl Id for UserId {
fn first() -> Self {
UserId(0)
}
fn increment(&mut self) {
self.0 += 1;
}
}
#[derive(Debug)]
struct Table<Id, Row> {
data: HashMap<Id, Row>,
next_id: Id,
}
impl<I, R> Table<I, R>
where
I: Id,
{
fn new() -> Self {
Self {
data: HashMap::new(),
next_id: I::first(),
}
}
fn insert(&mut self, row: R) -> I {
let id = self.next_id;
self.data.insert(id, row);
self.next_id.increment();
id
}
fn delete(&mut self, id: I) {
self.data.remove(&id);
}
fn get(&self, id: &I) -> Option<&R> {
self.data.get(id)
}
}
#[derive(Debug, Clone)]
pub struct FakeDatabase {
users: Arc<Mutex<Vec<UserRow>>>,
users: Arc<Mutex<Table<UserId, UserRow>>>,
admin_users: Arc<Mutex<std::collections::HashSet<usize>>>,
}
impl FakeDatabase {
pub fn new_empty() -> Self {
FakeDatabase {
users: Arc::new(Mutex::new(Vec::new())),
users: Arc::new(Mutex::new(Table::new())),
admin_users: Arc::new(Mutex::new(HashSet::new())),
}
}
@ -42,20 +92,26 @@ impl Database for FakeDatabase {
password: &PasswordHash,
) -> Result<User> {
let mut users = self.users.lock().unwrap();
users.push(UserRow {
let new_id = users.insert(UserRow {
real_name: real_name.to_string(),
email: email.to_string(),
password_hash: password.to_string(),
});
Ok(User {
id: UserId((users.len() - 1) as i32),
id: new_id,
real_name: real_name.to_string(),
})
}
async fn delete_user(&self, user_id: UserId) -> Result<()> {
let mut users = self.users.lock().unwrap();
users.delete(user_id);
Ok(())
}
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
let users = self.users.lock().unwrap();
if let Some(UserRow { password_hash, .. }) = users.get(user.id.0 as usize) {
if let Some(UserRow { password_hash, .. }) = users.get(&user.id) {
Ok(PasswordHash(password_hash.clone()))
} else {
Err(Error::Database)
@ -81,9 +137,7 @@ impl Database for FakeDatabase {
async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
let users = self.users.lock().unwrap();
Ok(users
.get(user_id.0 as usize)
.map(|UserRow { real_name, .. }| User {
Ok(users.get(&user_id).map(|UserRow { real_name, .. }| User {
id: user_id,
real_name: real_name.clone(),
}))

View File

@ -220,6 +220,7 @@ mod tests {
PostgresDatabase::new(&url).unwrap()
}
#[ignore]
#[test]
fn migrations_have_sequential_versions() {
for i in 0..MIGRATIONS.len() {
@ -227,6 +228,7 @@ mod tests {
}
}
#[ignore]
#[tokio::test]
async fn migrate_up_and_down_all() {
let db = test_db().await;

View File

@ -98,6 +98,7 @@ pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static {
email: &str,
password: &PasswordHash,
) -> impl Future<Output = Result<User>> + Send;
fn delete_user(&self, user: UserId) -> impl Future<Output = Result<()>> + Send;
fn get_password_for_user(
&self,
user: &User,
@ -123,10 +124,10 @@ pub struct PostgresDatabase {
connection_pool: Pool,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct UserId(i32);
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct User {
id: UserId,
pub real_name: String,
@ -206,6 +207,11 @@ impl Database for PostgresDatabase {
})
}
#[tracing::instrument]
async fn delete_user(&self, user: UserId) -> Result<()> {
todo!()
}
#[tracing::instrument]
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
let client = self.get_client().await?;

View File

@ -130,20 +130,14 @@ impl std::error::Error for Error {
#[derive(Template)]
#[template(path = "error.html")]
struct ErrorTemplate<'a> {
title: &'a str,
}
struct ErrorTemplate {}
impl IntoResponse for Error {
fn into_response(self) -> Response {
match self.error_type {
ErrorType::InternalServerError => {
error!(inner = self.inner, "Uncaught error producing HTTP 500.");
(
StatusCode::INTERNAL_SERVER_ERROR,
ErrorTemplate { title: "Error" },
)
.into_response()
(StatusCode::INTERNAL_SERVER_ERROR, ErrorTemplate {}).into_response()
}
ErrorType::Unauthorized => {
(StatusCode::UNAUTHORIZED, "User not authorized.").into_response()

View File

@ -3,7 +3,7 @@
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width" />
<title>{% block title %}{{title}} - Locality{% endblock %}</title>
<title>{% block title %}Locality{% endblock %}</title>
<link rel="icon" type="image/png" sizes="32x32" href="/static/favicon-32.png">
<link rel="icon" type="image/png" sizes="16x16" href="/static/favicon-16.png">

7
templates/new-user.html Normal file
View File

@ -0,0 +1,7 @@
{% extends "base.html" %}
{% block title %}New User Welcome - Locality{% endblock %}
{% block content %}
TBD
{% endblock %}

25
templates/sign-up.html Normal file
View File

@ -0,0 +1,25 @@
{% extends "base.html" %}
{% block title %}New User Sign-Up - Locality{% endblock %}
{% block content %}
<form action="create_new_user" method="post">>
<ul>
<li>
<label for="real_name">Name:</label>
<input type="text" id="real_name" name="real_name" />
</li>
<li>
<label for="email">Email:</label>
<input type="email" id="email" name="email" />
</li>
<li>
<label for="password">Password:</label>
<input type="password" id="password" name="password" />
</li>
<li>
<button type="submit">Create Account</button>
</li>
</ul>
</form>
{% endblock %}