Compare commits
No commits in common. "10551a9bc9b23de10a20ed0a8cf1f5296746253b" and "3d4a34611b8b077321f8df53e08831c351370372" have entirely different histories.
10551a9bc9
...
3d4a34611b
8
dev.py
8
dev.py
|
|
@ -32,20 +32,12 @@ def unit_tests(args):
|
||||||
import_run().unit_tests(args)
|
import_run().unit_tests(args)
|
||||||
|
|
||||||
|
|
||||||
def database_tests(args):
|
|
||||||
import_run().database_tests(args)
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
subparsers = parser.add_subparsers(required=True)
|
subparsers = parser.add_subparsers(required=True)
|
||||||
run_parser = subparsers.add_parser(
|
run_parser = subparsers.add_parser(
|
||||||
'run', help='Run a test instance of locality'
|
'run', help='Run a test instance of locality'
|
||||||
)
|
)
|
||||||
run_parser.set_defaults(func=run)
|
run_parser.set_defaults(func=run)
|
||||||
run_parser = subparsers.add_parser(
|
|
||||||
'dbtest', help='Run database tests'
|
|
||||||
)
|
|
||||||
run_parser.set_defaults(func=database_tests)
|
|
||||||
run_parser = subparsers.add_parser(
|
run_parser = subparsers.add_parser(
|
||||||
'unittest', help='Run unit tests'
|
'unittest', help='Run unit tests'
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,33 +8,7 @@ from .postgres_container import PostgresContainer
|
||||||
ROOT_DIR = None
|
ROOT_DIR = None
|
||||||
|
|
||||||
|
|
||||||
def cargo(*args, env=None):
|
def cargo(*args):
|
||||||
global ROOT_DIR
|
|
||||||
if env is None:
|
|
||||||
env = {
|
|
||||||
'LOCALITY_DATABASE_URL': "",
|
|
||||||
'LOCALITY_TEST_DATABASE_URL': "",
|
|
||||||
'LOCALITY_STATIC_FILE_PATH': os.path.join(
|
|
||||||
ROOT_DIR,
|
|
||||||
'static'),
|
|
||||||
'LOCALITY_HMAC_SECRET': 'iknf4390-8guvmr3'
|
|
||||||
}
|
|
||||||
env = os.environ.copy() | env
|
|
||||||
cargo_bin = shutil.which('cargo')
|
|
||||||
locality_process = subprocess.Popen(
|
|
||||||
[cargo_bin] + list(args), env=env, cwd=ROOT_DIR
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
while locality_process.poll() is None:
|
|
||||||
time.sleep(0.5)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
if locality_process.poll() is None:
|
|
||||||
locality_process.terminate()
|
|
||||||
|
|
||||||
|
|
||||||
def cargo_with_db(*args):
|
|
||||||
global ROOT_DIR
|
global ROOT_DIR
|
||||||
with PostgresContainer() as postgres:
|
with PostgresContainer() as postgres:
|
||||||
locality_env = {
|
locality_env = {
|
||||||
|
|
@ -45,16 +19,26 @@ def cargo_with_db(*args):
|
||||||
'static'),
|
'static'),
|
||||||
'LOCALITY_HMAC_SECRET': 'iknf4390-8guvmr3'
|
'LOCALITY_HMAC_SECRET': 'iknf4390-8guvmr3'
|
||||||
}
|
}
|
||||||
cargo(env=locality_env, *args)
|
locality_env = os.environ.copy() | locality_env
|
||||||
|
|
||||||
|
cargo_bin = shutil.which('cargo')
|
||||||
|
locality_process = subprocess.Popen(
|
||||||
|
[cargo_bin, *args], env=locality_env, cwd=ROOT_DIR
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while locality_process.poll() is None:
|
||||||
|
time.sleep(0.5)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
if locality_process.poll() is None:
|
||||||
|
locality_process.terminate()
|
||||||
|
|
||||||
|
|
||||||
def run(*args):
|
def run(args):
|
||||||
cargo_with_db('run')
|
cargo('run')
|
||||||
|
|
||||||
|
|
||||||
def database_tests(*args):
|
def unit_tests(args):
|
||||||
cargo_with_db("test", "db::migrations::test", "--", "--include-ignored")
|
|
||||||
|
|
||||||
|
|
||||||
def unit_tests(*args):
|
|
||||||
cargo("test")
|
cargo("test")
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,9 @@ async fn post_create_first_admin_user<D: Database>(
|
||||||
.await?;
|
.await?;
|
||||||
let user = authenticate_user_with_password(&db, user, ¶ms.password)
|
let user = authenticate_user_with_password(&db, user, ¶ms.password)
|
||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| Error::new_unexpected("Could not authenticate newly-created user."))?;
|
.ok_or(Error::new_unexpected(
|
||||||
|
"Could not authenticate newly-created user.",
|
||||||
|
))?;
|
||||||
Ok((
|
Ok((
|
||||||
cookie_jar
|
cookie_jar
|
||||||
.add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),
|
.add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,7 @@
|
||||||
use {
|
use {
|
||||||
crate::{
|
crate::db::Database,
|
||||||
authentication::{authenticate_user_with_password, create_jwt_for_user, Password},
|
|
||||||
db::Database,
|
|
||||||
error::Error,
|
|
||||||
},
|
|
||||||
askama::Template,
|
askama::Template,
|
||||||
axum::{
|
axum::{routing::get, Router},
|
||||||
extract::State,
|
|
||||||
routing::{get, post},
|
|
||||||
Form, Router,
|
|
||||||
},
|
|
||||||
axum_extra::extract::{
|
|
||||||
cookie::{Cookie, SameSite},
|
|
||||||
CookieJar,
|
|
||||||
},
|
|
||||||
serde::Deserialize,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub mod admin;
|
pub mod admin;
|
||||||
|
|
@ -27,59 +14,15 @@ pub struct AppState<D: Database> {
|
||||||
pub fn routes<D: Database>() -> Router<AppState<D>> {
|
pub fn routes<D: Database>() -> Router<AppState<D>> {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/", get(root))
|
.route("/", get(root))
|
||||||
.route("/sign_up", get(sign_up))
|
|
||||||
.route("/create_new_user", post(create_new_user))
|
|
||||||
.nest("/admin", admin::routes())
|
.nest("/admin", admin::routes())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Template)]
|
#[derive(Template)]
|
||||||
#[template(path = "index.html")]
|
#[template(path = "index.html")]
|
||||||
struct IndexTemplate {}
|
struct IndexTemplate<'a> {
|
||||||
|
title: &'a str,
|
||||||
#[tracing::instrument]
|
|
||||||
async fn root() -> IndexTemplate {
|
|
||||||
IndexTemplate {}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Template)]
|
async fn root<'a>() -> IndexTemplate<'a> {
|
||||||
#[template(path = "sign-up.html")]
|
IndexTemplate { title: "Locality" }
|
||||||
struct SignUpTemplate {}
|
|
||||||
|
|
||||||
#[tracing::instrument]
|
|
||||||
async fn sign_up() -> SignUpTemplate {
|
|
||||||
SignUpTemplate {}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Template)]
|
|
||||||
#[template(path = "new-user.html")]
|
|
||||||
struct NewUserTemplate {}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
struct CreateNewUserParameters {
|
|
||||||
real_name: String,
|
|
||||||
email: String,
|
|
||||||
password: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument]
|
|
||||||
async fn create_new_user<D: Database>(
|
|
||||||
cookie_jar: CookieJar,
|
|
||||||
State(AppState::<D> { db, .. }): State<AppState<D>>,
|
|
||||||
Form(params): Form<CreateNewUserParameters>,
|
|
||||||
) -> Result<(CookieJar, NewUserTemplate), Error> {
|
|
||||||
let user = db
|
|
||||||
.create_user(
|
|
||||||
¶ms.real_name,
|
|
||||||
¶ms.email,
|
|
||||||
&Password::new(¶ms.password)?.into(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
let user = authenticate_user_with_password(&db, user, ¶ms.password)
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| Error::new_unexpected("Could not authenticate newly-created user."))?;
|
|
||||||
Ok((
|
|
||||||
cookie_jar
|
|
||||||
.add(Cookie::build(("jwt", create_jwt_for_user(&user)?)).same_site(SameSite::Strict)),
|
|
||||||
NewUserTemplate {},
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ use {
|
||||||
|
|
||||||
const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1);
|
const COOKIE_EXPIRY_TIME: Duration = Duration::weeks(1);
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
BadJwt,
|
BadJwt,
|
||||||
|
|
||||||
|
|
@ -80,7 +80,7 @@ impl From<digest::MacError> for Error {
|
||||||
type Result<T> = std::result::Result<T, Error>;
|
type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
/// Result type for [authenticate_user_with_jwt()].
|
/// Result type for [authenticate_user_with_jwt()].
|
||||||
#[derive(Debug, Eq, PartialEq)]
|
#[derive(Debug)]
|
||||||
pub enum ParsedJwt {
|
pub enum ParsedJwt {
|
||||||
/// JWT is a valid JWT, here is the [AuthenticatedUser].
|
/// JWT is a valid JWT, here is the [AuthenticatedUser].
|
||||||
Valid(AuthenticatedUser),
|
Valid(AuthenticatedUser),
|
||||||
|
|
@ -92,7 +92,7 @@ pub enum ParsedJwt {
|
||||||
UserNotFound,
|
UserNotFound,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
struct Header<'a> {
|
struct Header<'a> {
|
||||||
#[serde(rename = "alg")]
|
#[serde(rename = "alg")]
|
||||||
algorithm: &'a str,
|
algorithm: &'a str,
|
||||||
|
|
@ -100,16 +100,7 @@ struct Header<'a> {
|
||||||
token_type: &'a str,
|
token_type: &'a str,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Header<'static> {
|
#[derive(Serialize, Deserialize)]
|
||||||
fn new() -> Self {
|
|
||||||
Header {
|
|
||||||
algorithm: "HS256",
|
|
||||||
token_type: "JWT",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct Payload {
|
struct Payload {
|
||||||
#[serde(rename = "sub")]
|
#[serde(rename = "sub")]
|
||||||
user_id: UserId,
|
user_id: UserId,
|
||||||
|
|
@ -118,287 +109,68 @@ struct Payload {
|
||||||
expiry: DateTime<Utc>,
|
expiry: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Payload {
|
fn mac() -> Hmac<Sha256> {
|
||||||
fn new(user_id: UserId) -> Self {
|
Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
|
||||||
Self {
|
.expect("HMAC can take key of any size.")
|
||||||
user_id,
|
|
||||||
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_mac(header: &str, payload: &str) -> Hmac<Sha256> {
|
|
||||||
let mut mac = Hmac::new_from_slice(&get_config().unwrap().hmac_secret)
|
|
||||||
.expect("HMAC can take key of any size.");
|
|
||||||
mac.update(format!("{}.{}", header, payload).as_bytes());
|
|
||||||
mac
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_signature(header: &str, payload: &str) -> String {
|
|
||||||
let mac = create_mac(header, payload);
|
|
||||||
base64_encoder.encode(mac.finalize().into_bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_signature(header: &str, payload: &str, signature: &str) -> Result<bool> {
|
|
||||||
let mac = create_mac(header, payload);
|
|
||||||
Ok(mac
|
|
||||||
.verify_slice(&base64_encoder.decode(signature.as_bytes())?)
|
|
||||||
.is_ok())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sign_jwt(header: &str, payload: &str) -> String {
|
|
||||||
let signature = create_signature(header, payload);
|
|
||||||
format!("{}.{}.{}", header, payload, signature)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dissassemble_jwt(jwt: &str) -> Result<(&str, &str, &str)> {
|
|
||||||
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
|
|
||||||
Ok((header, payload, signature))
|
|
||||||
} else {
|
|
||||||
warn!("Invalid JWT");
|
|
||||||
Err(Error::BadJwt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn encode_jwt_component<T: Serialize>(value: &T) -> Result<String> {
|
|
||||||
Ok(base64_encoder.encode(serde_json::to_string(value)?.as_bytes()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given an [AuthenticatedUser], create a JWT for use as a cookie to
|
/// Given an [AuthenticatedUser], create a JWT for use as a cookie to
|
||||||
/// keep that user logged in.
|
/// keep that user logged in.
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
|
pub fn create_jwt_for_user(user: &AuthenticatedUser) -> Result<String> {
|
||||||
let header = encode_jwt_component(&Header::new())?;
|
let header = base64_encoder.encode(
|
||||||
let payload = encode_jwt_component(&Payload::new(user.get_id()))?;
|
serde_json::to_string(&Header {
|
||||||
Ok(sign_jwt(&header, &payload))
|
algorithm: "HS256",
|
||||||
|
token_type: "JWT",
|
||||||
|
})?
|
||||||
|
.as_bytes(),
|
||||||
|
);
|
||||||
|
let payload = base64_encoder.encode(
|
||||||
|
serde_json::to_string(&Payload {
|
||||||
|
user_id: user.get_id(),
|
||||||
|
expiry: Utc::now() + COOKIE_EXPIRY_TIME,
|
||||||
|
})?
|
||||||
|
.as_bytes(),
|
||||||
|
);
|
||||||
|
let mut mac = mac();
|
||||||
|
mac.update(format!("{}.{}", header, payload).as_bytes());
|
||||||
|
let signature = base64_encoder.encode(mac.finalize().into_bytes());
|
||||||
|
Ok(format!("{}.{}.{}", header, payload, signature))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given JWT string created by [create_jwt_for_user()], check if the
|
/// Given JWT string created by [create_jwt_for_user()], check if the
|
||||||
/// JWT is valid and return an [AuthenticatedUser] if it is.
|
/// JWT is valid and return an [AuthenticatedUser] if it is.
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
|
pub async fn authenticate_user_with_jwt<D: Database>(db: &D, jwt: &str) -> Result<ParsedJwt> {
|
||||||
let (header, payload, signature) = dissassemble_jwt(jwt)?;
|
if let [header, payload, signature] = jwt.split('.').collect::<Vec<_>>().as_slice() {
|
||||||
if !check_signature(header, payload, signature)? {
|
let mut mac = mac();
|
||||||
Ok(ParsedJwt::InvalidSignature)
|
mac.update(format!("{}.{}", header, payload).as_bytes());
|
||||||
} else {
|
if mac.verify_slice(signature.as_bytes()).is_err() {
|
||||||
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
|
Ok(ParsedJwt::InvalidSignature)
|
||||||
let header: Header = serde_json::from_str(&header_json)?;
|
|
||||||
if header.algorithm != "HS256" || header.token_type != "JWT" {
|
|
||||||
warn!("JWT does not have expected algorithm or type.");
|
|
||||||
Err(Error::BadJwt)
|
|
||||||
} else {
|
} else {
|
||||||
let payload: Payload =
|
let header_json = String::from_utf8(base64_encoder.decode(header)?)?;
|
||||||
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
|
let header: Header = serde_json::from_str(&header_json)?;
|
||||||
Ok(
|
if header.algorithm != "HS256" || header.token_type != "JWT" {
|
||||||
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
|
warn!("JWT does not have expected algorithm or type.");
|
||||||
if payload.expiry > Utc::now() {
|
Err(Error::BadJwt)
|
||||||
ParsedJwt::Valid(AuthenticatedUser(user))
|
} else {
|
||||||
|
let payload: Payload =
|
||||||
|
serde_json::from_str(&String::from_utf8(base64_encoder.decode(payload)?)?)?;
|
||||||
|
Ok(dbg!(
|
||||||
|
if let Some(user) = db.get_user_with_id(payload.user_id).await? {
|
||||||
|
if payload.expiry < Utc::now() {
|
||||||
|
ParsedJwt::Valid(AuthenticatedUser(user))
|
||||||
|
} else {
|
||||||
|
ParsedJwt::Expired(user)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
ParsedJwt::Expired(user)
|
ParsedJwt::UserNotFound
|
||||||
}
|
},
|
||||||
} else {
|
))
|
||||||
ParsedJwt::UserNotFound
|
}
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
}
|
warn!("Invalid JWT");
|
||||||
|
Err(Error::BadJwt)
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::{db, db::fake::FakeDatabase};
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn created_jwt_authenticates() {
|
|
||||||
let db = FakeDatabase::new_empty();
|
|
||||||
let user = AuthenticatedUser(
|
|
||||||
db.create_user(
|
|
||||||
"John Smith",
|
|
||||||
"john.smith@example.com",
|
|
||||||
&db::PasswordHash("abc123".to_string()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
let jwt = create_jwt_for_user(&user).unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
ParsedJwt::Valid(user),
|
|
||||||
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn modified_jwt_does_not_authenticate() {
|
|
||||||
let db = FakeDatabase::new_empty();
|
|
||||||
let user1 = AuthenticatedUser(
|
|
||||||
db.create_user(
|
|
||||||
"John Smith",
|
|
||||||
"john.smith@example.com",
|
|
||||||
&db::PasswordHash("abc123".to_string()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
let jwt1 = create_jwt_for_user(&user1).unwrap();
|
|
||||||
// Create another user with same name, etc. but different user id
|
|
||||||
let user2 = AuthenticatedUser(
|
|
||||||
db.create_user(
|
|
||||||
"John Smith",
|
|
||||||
"john.smith@example.com",
|
|
||||||
&db::PasswordHash("abc123".to_string()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
let jwt2 = create_jwt_for_user(&user2).unwrap();
|
|
||||||
// Attach signature 2 to jwt 2
|
|
||||||
let parts1: Vec<_> = jwt1.split('.').collect();
|
|
||||||
assert_eq!(3, parts1.len());
|
|
||||||
let parts2: Vec<_> = jwt2.split('.').collect();
|
|
||||||
assert_eq!(3, parts2.len());
|
|
||||||
let parts3 = vec![parts2[0], parts2[1], parts1[2]];
|
|
||||||
|
|
||||||
//Rebuild all JWTs, so the orignial ones can act as a control
|
|
||||||
let jwt1 = parts1.join(".");
|
|
||||||
let jwt2 = parts2.join(".");
|
|
||||||
let jwt3 = parts3.join(".");
|
|
||||||
assert_eq!(
|
|
||||||
ParsedJwt::Valid(user1),
|
|
||||||
authenticate_user_with_jwt(&db, &jwt1).await.unwrap()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
ParsedJwt::Valid(user2),
|
|
||||||
authenticate_user_with_jwt(&db, &jwt2).await.unwrap()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
ParsedJwt::InvalidSignature,
|
|
||||||
authenticate_user_with_jwt(&db, &jwt3).await.unwrap()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn deleted_user_cannot_authenticate() {
|
|
||||||
let db = FakeDatabase::new_empty();
|
|
||||||
let user = AuthenticatedUser(
|
|
||||||
db.create_user(
|
|
||||||
"John Smith",
|
|
||||||
"john.smith@example.com",
|
|
||||||
&db::PasswordHash("abc123".to_string()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
let jwt = create_jwt_for_user(&user).unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
ParsedJwt::Valid(user.clone()),
|
|
||||||
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
db.delete_user(user.get_id()).await.unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
ParsedJwt::UserNotFound,
|
|
||||||
authenticate_user_with_jwt(&db, &jwt).await.unwrap()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn unsigned_jwt_is_rejected() {
|
|
||||||
let db = FakeDatabase::new_empty();
|
|
||||||
let user = AuthenticatedUser(
|
|
||||||
db.create_user(
|
|
||||||
"John Smith",
|
|
||||||
"john.smith@example.com",
|
|
||||||
&db::PasswordHash("abc123".to_string()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
let jwt = format!(
|
|
||||||
"{},{}",
|
|
||||||
encode_jwt_component(&Header::new()).unwrap(),
|
|
||||||
encode_jwt_component(&Payload::new(user.get_id())).unwrap()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
Err(Error::BadJwt),
|
|
||||||
authenticate_user_with_jwt(&db, &jwt).await
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn jwt_with_unrecognized_algorithm_is_rejected() {
|
|
||||||
let db = FakeDatabase::new_empty();
|
|
||||||
let user = AuthenticatedUser(
|
|
||||||
db.create_user(
|
|
||||||
"John Smith",
|
|
||||||
"john.smith@example.com",
|
|
||||||
&db::PasswordHash("abc123".to_string()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
let jwt = sign_jwt(
|
|
||||||
&encode_jwt_component(&Header {
|
|
||||||
algorithm: "RS256",
|
|
||||||
token_type: "JWT",
|
|
||||||
})
|
|
||||||
.unwrap(),
|
|
||||||
&encode_jwt_component(&Payload::new(user.get_id())).unwrap(),
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
Err(Error::BadJwt),
|
|
||||||
authenticate_user_with_jwt(&db, &jwt).await
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn jwt_with_unrecognized_token_type_is_rejected() {
|
|
||||||
let db = FakeDatabase::new_empty();
|
|
||||||
let user = AuthenticatedUser(
|
|
||||||
db.create_user(
|
|
||||||
"John Smith",
|
|
||||||
"john.smith@example.com",
|
|
||||||
&db::PasswordHash("abc123".to_string()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
let jwt = sign_jwt(
|
|
||||||
&encode_jwt_component(&Header {
|
|
||||||
algorithm: "HS256",
|
|
||||||
token_type: "RWT",
|
|
||||||
})
|
|
||||||
.unwrap(),
|
|
||||||
&encode_jwt_component(&Payload::new(user.get_id())).unwrap(),
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
Err(Error::BadJwt),
|
|
||||||
authenticate_user_with_jwt(&db, &jwt).await
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn expired_jwt_is_rejected() {
|
|
||||||
let db = FakeDatabase::new_empty();
|
|
||||||
let user = db
|
|
||||||
.create_user(
|
|
||||||
"John Smith",
|
|
||||||
"john.smith@example.com",
|
|
||||||
&db::PasswordHash("abc123".to_string()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let jwt = sign_jwt(
|
|
||||||
&encode_jwt_component(&Header::new()).unwrap(),
|
|
||||||
&encode_jwt_component(&Payload {
|
|
||||||
user_id: user.get_id(),
|
|
||||||
expiry: Utc::now() - COOKIE_EXPIRY_TIME - Duration::seconds(1),
|
|
||||||
})
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
Ok(ParsedJwt::Expired(user)),
|
|
||||||
authenticate_user_with_jwt(&db, &jwt).await
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ impl From<argon2::password_hash::Error> for AuthenticationError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct AuthenticatedUser(db::User);
|
pub struct AuthenticatedUser(db::User);
|
||||||
|
|
||||||
impl Deref for AuthenticatedUser {
|
impl Deref for AuthenticatedUser {
|
||||||
|
|
@ -81,7 +81,6 @@ pub struct Password {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Password {
|
impl Password {
|
||||||
#[tracing::instrument]
|
|
||||||
pub fn new(password: &str) -> Result<Password, AuthenticationError> {
|
pub fn new(password: &str) -> Result<Password, AuthenticationError> {
|
||||||
let salt = SaltString::generate(&mut OsRng);
|
let salt = SaltString::generate(&mut OsRng);
|
||||||
let argon2 = Argon2::default();
|
let argon2 = Argon2::default();
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,7 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use {
|
use {
|
||||||
std::collections::{HashMap, HashSet},
|
std::collections::HashSet,
|
||||||
std::{
|
std::sync::{Arc, Mutex},
|
||||||
hash::Hash,
|
|
||||||
sync::{Arc, Mutex},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
@ -14,63 +11,16 @@ struct UserRow {
|
||||||
password_hash: String,
|
password_hash: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
trait Id: Copy + Eq + Hash {
|
|
||||||
fn first() -> Self;
|
|
||||||
fn increment(&mut self);
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Id for UserId {
|
|
||||||
fn first() -> Self {
|
|
||||||
UserId(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn increment(&mut self) {
|
|
||||||
self.0 += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Table<Id, Row> {
|
|
||||||
data: HashMap<Id, Row>,
|
|
||||||
next_id: Id,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<I, R> Table<I, R>
|
|
||||||
where
|
|
||||||
I: Id,
|
|
||||||
{
|
|
||||||
fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
data: HashMap::new(),
|
|
||||||
next_id: I::first(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn insert(&mut self, row: R) -> I {
|
|
||||||
let id = self.next_id;
|
|
||||||
self.data.insert(id, row);
|
|
||||||
self.next_id.increment();
|
|
||||||
id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn delete(&mut self, id: I) {
|
|
||||||
self.data.remove(&id);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get(&self, id: &I) -> Option<&R> {
|
|
||||||
self.data.get(id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct FakeDatabase {
|
pub struct FakeDatabase {
|
||||||
users: Arc<Mutex<Table<UserId, UserRow>>>,
|
users: Arc<Mutex<Vec<UserRow>>>,
|
||||||
admin_users: Arc<Mutex<std::collections::HashSet<usize>>>,
|
admin_users: Arc<Mutex<std::collections::HashSet<usize>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeDatabase {
|
impl FakeDatabase {
|
||||||
pub fn new_empty() -> Self {
|
pub fn new_empty() -> Self {
|
||||||
FakeDatabase {
|
FakeDatabase {
|
||||||
users: Arc::new(Mutex::new(Table::new())),
|
users: Arc::new(Mutex::new(Vec::new())),
|
||||||
admin_users: Arc::new(Mutex::new(HashSet::new())),
|
admin_users: Arc::new(Mutex::new(HashSet::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -92,26 +42,20 @@ impl Database for FakeDatabase {
|
||||||
password: &PasswordHash,
|
password: &PasswordHash,
|
||||||
) -> Result<User> {
|
) -> Result<User> {
|
||||||
let mut users = self.users.lock().unwrap();
|
let mut users = self.users.lock().unwrap();
|
||||||
let new_id = users.insert(UserRow {
|
users.push(UserRow {
|
||||||
real_name: real_name.to_string(),
|
real_name: real_name.to_string(),
|
||||||
email: email.to_string(),
|
email: email.to_string(),
|
||||||
password_hash: password.to_string(),
|
password_hash: password.to_string(),
|
||||||
});
|
});
|
||||||
Ok(User {
|
Ok(User {
|
||||||
id: new_id,
|
id: UserId((users.len() - 1) as i32),
|
||||||
real_name: real_name.to_string(),
|
real_name: real_name.to_string(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn delete_user(&self, user_id: UserId) -> Result<()> {
|
|
||||||
let mut users = self.users.lock().unwrap();
|
|
||||||
users.delete(user_id);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
||||||
let users = self.users.lock().unwrap();
|
let users = self.users.lock().unwrap();
|
||||||
if let Some(UserRow { password_hash, .. }) = users.get(&user.id) {
|
if let Some(UserRow { password_hash, .. }) = users.get(user.id.0 as usize) {
|
||||||
Ok(PasswordHash(password_hash.clone()))
|
Ok(PasswordHash(password_hash.clone()))
|
||||||
} else {
|
} else {
|
||||||
Err(Error::Database)
|
Err(Error::Database)
|
||||||
|
|
@ -137,10 +81,12 @@ impl Database for FakeDatabase {
|
||||||
|
|
||||||
async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
|
async fn get_user_with_id(&self, user_id: UserId) -> Result<Option<User>> {
|
||||||
let users = self.users.lock().unwrap();
|
let users = self.users.lock().unwrap();
|
||||||
Ok(users.get(&user_id).map(|UserRow { real_name, .. }| User {
|
Ok(users
|
||||||
id: user_id,
|
.get(user_id.0 as usize)
|
||||||
real_name: real_name.clone(),
|
.map(|UserRow { real_name, .. }| User {
|
||||||
}))
|
id: user_id,
|
||||||
|
real_name: real_name.clone(),
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -220,7 +220,6 @@ mod tests {
|
||||||
PostgresDatabase::new(&url).unwrap()
|
PostgresDatabase::new(&url).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[ignore]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn migrations_have_sequential_versions() {
|
fn migrations_have_sequential_versions() {
|
||||||
for i in 0..MIGRATIONS.len() {
|
for i in 0..MIGRATIONS.len() {
|
||||||
|
|
@ -228,7 +227,6 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[ignore]
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn migrate_up_and_down_all() {
|
async fn migrate_up_and_down_all() {
|
||||||
let db = test_db().await;
|
let db = test_db().await;
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,6 @@ pub trait Database: std::fmt::Debug + Clone + Send + Sync + 'static {
|
||||||
email: &str,
|
email: &str,
|
||||||
password: &PasswordHash,
|
password: &PasswordHash,
|
||||||
) -> impl Future<Output = Result<User>> + Send;
|
) -> impl Future<Output = Result<User>> + Send;
|
||||||
fn delete_user(&self, user: UserId) -> impl Future<Output = Result<()>> + Send;
|
|
||||||
fn get_password_for_user(
|
fn get_password_for_user(
|
||||||
&self,
|
&self,
|
||||||
user: &User,
|
user: &User,
|
||||||
|
|
@ -124,10 +123,10 @@ pub struct PostgresDatabase {
|
||||||
connection_pool: Pool,
|
connection_pool: Pool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
pub struct UserId(i32);
|
pub struct UserId(i32);
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct User {
|
pub struct User {
|
||||||
id: UserId,
|
id: UserId,
|
||||||
pub real_name: String,
|
pub real_name: String,
|
||||||
|
|
@ -207,11 +206,6 @@ impl Database for PostgresDatabase {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
|
||||||
async fn delete_user(&self, user: UserId) -> Result<()> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
async fn get_password_for_user(&self, user: &User) -> Result<PasswordHash> {
|
||||||
let client = self.get_client().await?;
|
let client = self.get_client().await?;
|
||||||
|
|
|
||||||
10
src/error.rs
10
src/error.rs
|
|
@ -130,14 +130,20 @@ impl std::error::Error for Error {
|
||||||
|
|
||||||
#[derive(Template)]
|
#[derive(Template)]
|
||||||
#[template(path = "error.html")]
|
#[template(path = "error.html")]
|
||||||
struct ErrorTemplate {}
|
struct ErrorTemplate<'a> {
|
||||||
|
title: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
impl IntoResponse for Error {
|
impl IntoResponse for Error {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
match self.error_type {
|
match self.error_type {
|
||||||
ErrorType::InternalServerError => {
|
ErrorType::InternalServerError => {
|
||||||
error!(inner = self.inner, "Uncaught error producing HTTP 500.");
|
error!(inner = self.inner, "Uncaught error producing HTTP 500.");
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, ErrorTemplate {}).into_response()
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
ErrorTemplate { title: "Error" },
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
ErrorType::Unauthorized => {
|
ErrorType::Unauthorized => {
|
||||||
(StatusCode::UNAUTHORIZED, "User not authorized.").into_response()
|
(StatusCode::UNAUTHORIZED, "User not authorized.").into_response()
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
<head>
|
<head>
|
||||||
<meta charset="utf-8" />
|
<meta charset="utf-8" />
|
||||||
<meta name="viewport" content="width=device-width" />
|
<meta name="viewport" content="width=device-width" />
|
||||||
<title>{% block title %}Locality{% endblock %}</title>
|
<title>{% block title %}{{title}} - Locality{% endblock %}</title>
|
||||||
|
|
||||||
<link rel="icon" type="image/png" sizes="32x32" href="/static/favicon-32.png">
|
<link rel="icon" type="image/png" sizes="32x32" href="/static/favicon-32.png">
|
||||||
<link rel="icon" type="image/png" sizes="16x16" href="/static/favicon-16.png">
|
<link rel="icon" type="image/png" sizes="16x16" href="/static/favicon-16.png">
|
||||||
|
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
{% extends "base.html" %}
|
|
||||||
|
|
||||||
{% block title %}New User Welcome - Locality{% endblock %}
|
|
||||||
|
|
||||||
{% block content %}
|
|
||||||
TBD
|
|
||||||
{% endblock %}
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
{% extends "base.html" %}
|
|
||||||
|
|
||||||
{% block title %}New User Sign-Up - Locality{% endblock %}
|
|
||||||
|
|
||||||
{% block content %}
|
|
||||||
<form action="create_new_user" method="post">>
|
|
||||||
<ul>
|
|
||||||
<li>
|
|
||||||
<label for="real_name">Name:</label>
|
|
||||||
<input type="text" id="real_name" name="real_name" />
|
|
||||||
</li>
|
|
||||||
<li>
|
|
||||||
<label for="email">Email:</label>
|
|
||||||
<input type="email" id="email" name="email" />
|
|
||||||
</li>
|
|
||||||
<li>
|
|
||||||
<label for="password">Password:</label>
|
|
||||||
<input type="password" id="password" name="password" />
|
|
||||||
</li>
|
|
||||||
<li>
|
|
||||||
<button type="submit">Create Account</button>
|
|
||||||
</li>
|
|
||||||
</ul>
|
|
||||||
</form>
|
|
||||||
{% endblock %}
|
|
||||||
Loading…
Reference in New Issue