changed key storage interface

This commit is contained in:
TotallyNot 2023-01-27 17:25:07 +01:00
parent cff93d4c3d
commit 91bfb08652
5 changed files with 257 additions and 131 deletions

View file

@ -1 +1,2 @@
edition = "2021" edition = "2021"
format_strings = true

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.5.2" version = "0.5.3"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license = "MIT" license = "MIT"

View file

@ -30,7 +30,11 @@ where
} }
pub trait ApiKey: Sync + Send { pub trait ApiKey: Sync + Send {
type IdType: PartialEq + Eq + std::hash::Hash;
fn value(&self) -> &str; fn value(&self) -> &str;
fn id(&self) -> Self::IdType;
} }
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync { pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
@ -39,6 +43,15 @@ pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
} }
} }
#[derive(Debug, Clone)]
pub enum KeySelector<K>
where
K: ApiKey,
{
Key(String),
Id(K::IdType),
}
#[async_trait] #[async_trait]
pub trait KeyPoolStorage { pub trait KeyPoolStorage {
type Key: ApiKey; type Key: ApiKey;
@ -62,27 +75,27 @@ pub trait KeyPoolStorage {
domains: Vec<Self::Domain>, domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>; ) -> Result<Self::Key, Self::Error>;
async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error>; async fn read_key(&self, key: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error>;
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error>; async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error>;
async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error>; async fn remove_key(&self, key: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error>;
async fn add_domain_to_key( async fn add_domain_to_key(
&self, &self,
key: String, key: KeySelector<Self::Key>,
domain: Self::Domain, domain: Self::Domain,
) -> Result<Self::Key, Self::Error>; ) -> Result<Self::Key, Self::Error>;
async fn remove_domain_from_key( async fn remove_domain_from_key(
&self, &self,
key: String, key: KeySelector<Self::Key>,
domain: Self::Domain, domain: Self::Domain,
) -> Result<Self::Key, Self::Error>; ) -> Result<Self::Key, Self::Error>;
async fn set_domains_for_key( async fn set_domains_for_key(
&self, &self,
key: String, key: KeySelector<Self::Key>,
domains: Vec<Self::Domain>, domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>; ) -> Result<Self::Key, Self::Error>;
} }

View file

@ -3,7 +3,7 @@ use indoc::indoc;
use sqlx::{FromRow, PgPool}; use sqlx::{FromRow, PgPool};
use thiserror::Error; use thiserror::Error;
use crate::{ApiKey, KeyDomain, KeyPoolStorage}; use crate::{ApiKey, KeyDomain, KeyPoolStorage, KeySelector};
pub trait PgKeyDomain: pub trait PgKeyDomain:
KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
@ -18,7 +18,7 @@ impl<T> PgKeyDomain for T where
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum PgStorageError<D> pub enum PgStorageError<D>
where where
D: std::fmt::Debug, D: PgKeyDomain,
{ {
#[error(transparent)] #[error(transparent)]
Pg(#[from] sqlx::Error), Pg(#[from] sqlx::Error),
@ -26,14 +26,8 @@ where
#[error("No key avalaible for domain {0:?}")] #[error("No key avalaible for domain {0:?}")]
Unavailable(D), Unavailable(D),
#[error("Duplicate key '{0}'")] #[error("Key not found: '{0:?}'")]
DuplicateKey(String), KeyNotFound(KeySelector<PgKey<D>>),
#[error("Duplicate domain '{0:?}'")]
DuplicateDomain(D),
#[error("Key not found: '{0}'")]
KeyNotFound(String),
} }
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow)]
@ -62,9 +56,17 @@ impl<D> ApiKey for PgKey<D>
where where
D: PgKeyDomain, D: PgKeyDomain,
{ {
type IdType = i32;
#[inline(always)]
fn value(&self) -> &str { fn value(&self) -> &str {
&self.key &self.key
} }
#[inline(always)]
fn id(&self) -> Self::IdType {
self.id
}
} }
impl<D> PgKeyPoolStorage<D> impl<D> PgKeyPoolStorage<D>
@ -90,7 +92,7 @@ where
last_used timestamptz not null default now(), last_used timestamptz not null default now(),
flag int2, flag int2,
cooldown timestamptz, cooldown timestamptz,
constraint "uq:api_keys.key+user_id" UNIQUE(user_id, key) constraint "uq:api_keys.key" UNIQUE(key)
)"# )"#
}) })
.execute(&self.pool) .execute(&self.pool)
@ -108,6 +110,28 @@ where
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
sqlx::query(indoc! {r#"
create or replace function __unique_jsonb_array(jsonb) returns jsonb
AS $$
select jsonb_agg(d::jsonb) from (
select distinct jsonb_array_elements_text($1) as d
) t
$$ language sql;
"#})
.execute(&self.pool)
.await?;
sqlx::query(indoc! {r#"
create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb
AS $$
select jsonb_agg(d::jsonb) from (
select distinct jsonb_array_elements_text($1) as d
) t where d<>$2::text
$$ language sql;
"#})
.execute(&self.pool)
.await?;
Ok(()) Ok(())
} }
} }
@ -351,10 +375,13 @@ where
} }
5 => { 5 => {
// too many requests // too many requests
sqlx::query("update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', flag=5 where id=$1") sqlx::query(
.bind(key.id) "update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', \
.execute(&self.pool) flag=5 where id=$1",
.await?; )
.bind(key.id)
.execute(&self.pool)
.await?;
Ok(true) Ok(true)
} }
8 => { 8 => {
@ -373,10 +400,13 @@ where
} }
14 => { 14 => {
// daily read limit reached // daily read limit reached
sqlx::query("update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', flag=14 where id=$1") sqlx::query(
.bind(key.id) "update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', \
.execute(&self.pool) flag=14 where id=$1",
.await?; )
.bind(key.id)
.execute(&self.pool)
.await?;
Ok(true) Ok(true)
} }
_ => Ok(false), _ => Ok(false),
@ -390,30 +420,31 @@ where
domains: Vec<D>, domains: Vec<D>,
) -> Result<Self::Key, Self::Error> { ) -> Result<Self::Key, Self::Error> {
sqlx::query_as( sqlx::query_as(
"insert into api_keys(user_id, key, domains) values ($1, $2, $3) returning *", "insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \
constraint \"uq:api_keys.key\" do update set domains = \
__unique_jsonb_array(excluded.domains || api_keys.domains) returning *",
) )
.bind(user_id) .bind(user_id)
.bind(&key) .bind(&key)
.bind(sqlx::types::Json(domains)) .bind(sqlx::types::Json(domains))
.fetch_one(&self.pool) .fetch_one(&self.pool)
.await .await
.map_err(|why| { .map_err(Into::into)
if let Some(error) = why.as_database_error() {
let pg_error: &sqlx::postgres::PgDatabaseError = error.downcast_ref();
if pg_error.code() == "23505" {
return PgStorageError::DuplicateKey(key);
}
}
PgStorageError::Pg(why)
})
} }
async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error> { async fn read_key(&self, selector: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error> {
sqlx::query_as("select * from api_keys where key=$1") match &selector {
.bind(&key) KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1")
.fetch_optional(&self.pool) .bind(key)
.await? .fetch_optional(&self.pool)
.ok_or_else(|| PgStorageError::KeyNotFound(key)) .await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1")
.bind(id)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
}
} }
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error> { async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error> {
@ -424,66 +455,101 @@ where
.map_err(Into::into) .map_err(Into::into)
} }
async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error> { async fn remove_key(&self, selector: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error> {
sqlx::query_as("delete from api_keys where key=$1 returning *") match &selector {
.bind(&key) KeySelector::Key(key) => {
.fetch_optional(&self.pool) sqlx::query_as("delete from api_keys where key=$1 returning *")
.await? .bind(key)
.ok_or_else(|| PgStorageError::KeyNotFound(key)) .fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
}
KeySelector::Id(id) => sqlx::query_as("delete from api_keys where id=$1 returning *")
.bind(id)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
}
} }
async fn add_domain_to_key(&self, key: String, domain: D) -> Result<Self::Key, Self::Error> { async fn add_domain_to_key(
let mut tx = self.pool.begin().await?; &self,
match sqlx::query_as::<sqlx::Postgres, PgKey<D>>( selector: KeySelector<Self::Key>,
"update api_keys set domains = domains || jsonb_build_array($1) where key=$2 returning *", domain: D,
) ) -> Result<Self::Key, Self::Error> {
.bind(sqlx::types::Json(domain.clone())) match &selector {
.bind(&key) KeySelector::Key(key) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
.fetch_optional(&mut tx) "update api_keys set domains = __unique_jsonb_array(domains || \
.await? jsonb_build_array($1)) where key=$2 returning *",
{ )
None => Err(PgStorageError::KeyNotFound(key)), .bind(sqlx::types::Json(domain))
Some(key) => { .bind(key)
if key.domains.0.iter().filter(|d| **d == domain).count() > 1 { .fetch_optional(&self.pool)
tx.rollback().await?; .await?
return Err(PgStorageError::DuplicateDomain(domain)); .ok_or_else(|| PgStorageError::KeyNotFound(selector)),
} KeySelector::Id(id) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
tx.commit().await?; "update api_keys set domains = __unique_jsonb_array(domains || \
Ok(key) jsonb_build_array($1)) where id=$2 returning *",
} )
.bind(sqlx::types::Json(domain))
.bind(id)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
} }
} }
async fn remove_domain_from_key( async fn remove_domain_from_key(
&self, &self,
key: String, selector: KeySelector<Self::Key>,
domain: D, domain: D,
) -> Result<Self::Key, Self::Error> { ) -> Result<Self::Key, Self::Error> {
// FIX: potential race condition match &selector {
let api_key = self.read_key(key.clone()).await?; KeySelector::Key(key) => sqlx::query_as(
let domains = api_key "update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \
.domains '[]'::jsonb) where key=$2 returning *",
.0 )
.into_iter() .bind(sqlx::types::Json(domain))
.filter(|d| *d != domain) .bind(key)
.collect(); .fetch_optional(&self.pool)
.await?
self.set_domains_for_key(key, domains).await .ok_or_else(|| PgStorageError::KeyNotFound(selector)),
KeySelector::Id(id) => sqlx::query_as(
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \
'[]'::jsonb) where id=$2 returning *",
)
.bind(sqlx::types::Json(domain))
.bind(id)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
}
} }
async fn set_domains_for_key( async fn set_domains_for_key(
&self, &self,
key: String, selector: KeySelector<Self::Key>,
domains: Vec<D>, domains: Vec<D>,
) -> Result<Self::Key, Self::Error> { ) -> Result<Self::Key, Self::Error> {
sqlx::query_as::<sqlx::Postgres, PgKey<D>>( match &selector {
"update api_keys set domains = $1 where key=$2 returning *", KeySelector::Key(key) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
) "update api_keys set domains = $1 where key=$2 returning *",
.bind(sqlx::types::Json(domains)) )
.bind(&key) .bind(sqlx::types::Json(domains))
.fetch_optional(&self.pool) .bind(key)
.await? .fetch_optional(&self.pool)
.ok_or_else(|| PgStorageError::KeyNotFound(key)) .await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
KeySelector::Id(id) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
"update api_keys set domains = $1 where id=$2 returning *",
)
.bind(sqlx::types::Json(domains))
.bind(id)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
}
} }
} }
@ -516,7 +582,7 @@ pub(crate) mod test {
} }
} }
pub(crate) async fn setup() -> PgKeyPoolStorage<Domain> { pub(crate) async fn setup() -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
INIT.call_once(|| { INIT.call_once(|| {
dotenv::dotenv().ok(); dotenv::dotenv().ok();
}); });
@ -533,17 +599,17 @@ pub(crate) mod test {
let storage = PgKeyPoolStorage::new(pool.clone(), 1000); let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
storage.initialise().await.unwrap(); storage.initialise().await.unwrap();
storage let key = storage
.store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All]) .store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All])
.await .await
.unwrap(); .unwrap();
storage (storage, key)
} }
#[test] #[test]
async fn test_initialise() { async fn test_initialise() {
let storage = setup().await; let (storage, _) = setup().await;
if let Err(e) = storage.initialise().await { if let Err(e) = storage.initialise().await {
panic!("Initialising key storage failed: {:?}", e); panic!("Initialising key storage failed: {:?}", e);
@ -551,25 +617,43 @@ pub(crate) mod test {
} }
#[test] #[test]
async fn test_store_duplicate() { async fn test_store_duplicate_key() {
let storage = setup().await; let (storage, key) = setup().await;
match storage let key = storage
.store_key(1, std::env::var("APIKEY").unwrap(), vec![]) .store_key(1, key.key, vec![Domain::User { id: 1 }])
.await .await
.unwrap_err() .unwrap();
{
PgStorageError::DuplicateKey(key) => { assert_eq!(key.domains.0.len(), 2);
assert_eq!(key, std::env::var("APIKEY").unwrap()) }
}
why => panic!("Expected duplicate key error but found '{why}'"), #[test]
}; async fn test_store_duplicate_key_duplicate_domain() {
let (storage, key) = setup().await;
let key = storage
.store_key(1, key.key, vec![Domain::All])
.await
.unwrap();
assert_eq!(key.domains.0.len(), 1);
} }
#[test] #[test]
async fn test_add_domain() { async fn test_add_domain() {
let storage = setup().await; let (storage, key) = setup().await;
let key = storage let key = storage
.add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::User { id: 12345 }) .add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
.await
.unwrap();
assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
}
#[test]
async fn test_add_domain_id() {
let (storage, key) = setup().await;
let key = storage
.add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
.await .await
.unwrap(); .unwrap();
@ -578,22 +662,56 @@ pub(crate) mod test {
#[test] #[test]
async fn test_add_duplicate_domain() { async fn test_add_duplicate_domain() {
let storage = setup().await; let (storage, key) = setup().await;
match storage let key = storage
.add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::All) .add_domain_to_key(KeySelector::Key(key.key), Domain::All)
.await .await
.unwrap_err() .unwrap();
{ assert_eq!(
PgStorageError::DuplicateDomain(d) => assert_eq!(d, Domain::All), key.domains
why => panic!("Expected duplicate domain error but found '{why}'"), .0
}; .into_iter()
.filter(|d| *d == Domain::All)
.count(),
1
);
} }
#[test] #[test]
async fn test_remove_domain() { async fn test_remove_domain() {
let storage = setup().await; let (storage, key) = setup().await;
storage
.add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
.await
.unwrap();
let key = storage let key = storage
.remove_domain_from_key(std::env::var("APIKEY").unwrap(), Domain::All) .remove_domain_from_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
.await
.unwrap();
assert_eq!(key.domains.0, vec![Domain::All]);
}
#[test]
async fn test_remove_domain_id() {
let (storage, key) = setup().await;
storage
.add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
.await
.unwrap();
let key = storage
.remove_domain_from_key(KeySelector::Id(key.id), Domain::User { id: 1 })
.await
.unwrap();
assert_eq!(key.domains.0, vec![Domain::All]);
}
#[test]
async fn test_remove_last_domain() {
let (storage, key) = setup().await;
let key = storage
.remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
.await .await
.unwrap(); .unwrap();
@ -602,7 +720,7 @@ pub(crate) mod test {
#[test] #[test]
async fn test_store_key() { async fn test_store_key() {
let storage = setup().await; let (storage, _) = setup().await;
let key = storage let key = storage
.store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![]) .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
.await .await
@ -612,7 +730,7 @@ pub(crate) mod test {
#[test] #[test]
async fn test_read_user_keys() { async fn test_read_user_keys() {
let storage = setup().await; let (storage, _) = setup().await;
let keys = storage.read_user_keys(1).await.unwrap(); let keys = storage.read_user_keys(1).await.unwrap();
assert_eq!(keys.len(), 1); assert_eq!(keys.len(), 1);
@ -620,7 +738,7 @@ pub(crate) mod test {
#[test] #[test]
async fn acquire_one() { async fn acquire_one() {
let storage = setup().await; let (storage, _) = setup().await;
if let Err(e) = storage.acquire_key(Domain::All).await { if let Err(e) = storage.acquire_key(Domain::All).await {
panic!("Acquiring key failed: {:?}", e); panic!("Acquiring key failed: {:?}", e);
@ -629,11 +747,7 @@ pub(crate) mod test {
#[test] #[test]
async fn test_flag_key_one() { async fn test_flag_key_one() {
let storage = setup().await; let (storage, key) = setup().await;
let key = storage
.read_key(std::env::var("APIKEY").unwrap())
.await
.unwrap();
assert!(storage.flag_key(key, 2).await.unwrap()); assert!(storage.flag_key(key, 2).await.unwrap());
@ -645,11 +759,7 @@ pub(crate) mod test {
#[test] #[test]
async fn test_flag_key_many() { async fn test_flag_key_many() {
let storage = setup().await; let (storage, key) = setup().await;
let key = storage
.read_key(std::env::var("APIKEY").unwrap())
.await
.unwrap();
assert!(storage.flag_key(key, 2).await.unwrap()); assert!(storage.flag_key(key, 2).await.unwrap());
@ -661,7 +771,7 @@ pub(crate) mod test {
#[test] #[test]
async fn acquire_many() { async fn acquire_many() {
let storage = setup().await; let (storage, _) = setup().await;
match storage.acquire_many_keys(Domain::All, 30).await { match storage.acquire_many_keys(Domain::All, 30).await {
Err(e) => panic!("Acquiring key failed: {:?}", e), Err(e) => panic!("Acquiring key failed: {:?}", e),
@ -669,9 +779,10 @@ pub(crate) mod test {
} }
} }
// HACK: this test is time sensitive and will fail if runs at the top of the minute
#[test] #[test]
async fn test_concurrent() { async fn test_concurrent() {
let storage = Arc::new(setup().await); let storage = Arc::new(setup().await.0);
for _ in 0..10 { for _ in 0..10 {
let mut set = tokio::task::JoinSet::new(); let mut set = tokio::task::JoinSet::new();
@ -702,9 +813,10 @@ pub(crate) mod test {
} }
} }
// HACK: this test is time sensitive and will fail if runs at the top of the minute
#[test] #[test]
async fn test_concurrent_many() { async fn test_concurrent_many() {
let storage = Arc::new(setup().await); let storage = Arc::new(setup().await.0);
for _ in 0..10 { for _ in 0..10 {
let mut set = tokio::task::JoinSet::new(); let mut set = tokio::task::JoinSet::new();

View file

@ -179,7 +179,7 @@ mod test {
#[test] #[test]
async fn test_pool_request() { async fn test_pool_request() {
let storage = setup().await; let (storage, _) = setup().await;
let pool = KeyPool::new( let pool = KeyPool::new(
reqwest::Client::default(), reqwest::Client::default(),
storage, storage,
@ -192,7 +192,7 @@ mod test {
#[test] #[test]
async fn test_with_storage_request() { async fn test_with_storage_request() {
let storage = setup().await; let (storage, _) = setup().await;
let response = reqwest::Client::new() let response = reqwest::Client::new()
.with_storage(&storage, Domain::All) .with_storage(&storage, Domain::All)