From 91bfb08652d37d1748a71490cca2b338c62caadb Mon Sep 17 00:00:00 2001 From: TotallyNot <44345987+TotallyNot@users.noreply.github.com> Date: Fri, 27 Jan 2023 17:25:07 +0100 Subject: [PATCH] changed key storage interface --- rustfmt.toml | 1 + torn-key-pool/Cargo.toml | 2 +- torn-key-pool/src/lib.rs | 23 ++- torn-key-pool/src/postgres.rs | 358 ++++++++++++++++++++++------------ torn-key-pool/src/send.rs | 4 +- 5 files changed, 257 insertions(+), 131 deletions(-) diff --git a/rustfmt.toml b/rustfmt.toml index 3a26366..0cb02d8 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1 +1,2 @@ edition = "2021" +format_strings = true diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index bfc726e..4245a6c 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "0.5.2" +version = "0.5.3" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index 379ec72..72ac07c 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -30,7 +30,11 @@ where } pub trait ApiKey: Sync + Send { + type IdType: PartialEq + Eq + std::hash::Hash; + fn value(&self) -> &str; + + fn id(&self) -> Self::IdType; } pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync { @@ -39,6 +43,15 @@ pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync { } } +#[derive(Debug, Clone)] +pub enum KeySelector +where + K: ApiKey, +{ + Key(String), + Id(K::IdType), +} + #[async_trait] pub trait KeyPoolStorage { type Key: ApiKey; @@ -62,27 +75,27 @@ pub trait KeyPoolStorage { domains: Vec, ) -> Result; - async fn read_key(&self, key: String) -> Result; + async fn read_key(&self, key: KeySelector) -> Result; async fn read_user_keys(&self, user_id: i32) -> Result, Self::Error>; - async fn remove_key(&self, key: String) -> Result; + async fn remove_key(&self, key: KeySelector) -> Result; async fn add_domain_to_key( &self, - key: String, + key: KeySelector, domain: Self::Domain, ) -> Result; async fn remove_domain_from_key( &self, - key: String, + key: KeySelector, domain: Self::Domain, ) -> Result; async fn set_domains_for_key( &self, - key: String, + key: KeySelector, domains: Vec, ) -> Result; } diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 4f7b09f..61229d5 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -3,7 +3,7 @@ use indoc::indoc; use sqlx::{FromRow, PgPool}; use thiserror::Error; -use crate::{ApiKey, KeyDomain, KeyPoolStorage}; +use crate::{ApiKey, KeyDomain, KeyPoolStorage, KeySelector}; pub trait PgKeyDomain: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin @@ -18,7 +18,7 @@ impl PgKeyDomain for T where #[derive(Debug, Error)] pub enum PgStorageError where - D: std::fmt::Debug, + D: PgKeyDomain, { #[error(transparent)] Pg(#[from] sqlx::Error), @@ -26,14 +26,8 @@ where #[error("No key avalaible for domain {0:?}")] Unavailable(D), - #[error("Duplicate key '{0}'")] - DuplicateKey(String), - - #[error("Duplicate domain '{0:?}'")] - DuplicateDomain(D), - - #[error("Key not found: '{0}'")] - KeyNotFound(String), + #[error("Key not found: '{0:?}'")] + KeyNotFound(KeySelector>), } #[derive(Debug, Clone, FromRow)] @@ -62,9 +56,17 @@ impl ApiKey for PgKey where D: PgKeyDomain, { + type IdType = i32; + + #[inline(always)] fn value(&self) -> &str { &self.key } + + #[inline(always)] + fn id(&self) -> Self::IdType { + self.id + } } impl PgKeyPoolStorage @@ -90,7 +92,7 @@ where last_used timestamptz not null default now(), flag int2, cooldown timestamptz, - constraint "uq:api_keys.key+user_id" UNIQUE(user_id, key) + constraint "uq:api_keys.key" UNIQUE(key) )"# }) .execute(&self.pool) @@ -108,6 +110,28 @@ where .execute(&self.pool) .await?; + sqlx::query(indoc! {r#" + create or replace function __unique_jsonb_array(jsonb) returns jsonb + AS $$ + select jsonb_agg(d::jsonb) from ( + select distinct jsonb_array_elements_text($1) as d + ) t + $$ language sql; + "#}) + .execute(&self.pool) + .await?; + + sqlx::query(indoc! {r#" + create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb + AS $$ + select jsonb_agg(d::jsonb) from ( + select distinct jsonb_array_elements_text($1) as d + ) t where d<>$2::text + $$ language sql; + "#}) + .execute(&self.pool) + .await?; + Ok(()) } } @@ -351,10 +375,13 @@ where } 5 => { // too many requests - sqlx::query("update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', flag=5 where id=$1") - .bind(key.id) - .execute(&self.pool) - .await?; + sqlx::query( + "update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', \ + flag=5 where id=$1", + ) + .bind(key.id) + .execute(&self.pool) + .await?; Ok(true) } 8 => { @@ -373,10 +400,13 @@ where } 14 => { // daily read limit reached - sqlx::query("update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', flag=14 where id=$1") - .bind(key.id) - .execute(&self.pool) - .await?; + sqlx::query( + "update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', \ + flag=14 where id=$1", + ) + .bind(key.id) + .execute(&self.pool) + .await?; Ok(true) } _ => Ok(false), @@ -390,30 +420,31 @@ where domains: Vec, ) -> Result { sqlx::query_as( - "insert into api_keys(user_id, key, domains) values ($1, $2, $3) returning *", + "insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \ + constraint \"uq:api_keys.key\" do update set domains = \ + __unique_jsonb_array(excluded.domains || api_keys.domains) returning *", ) .bind(user_id) .bind(&key) .bind(sqlx::types::Json(domains)) .fetch_one(&self.pool) .await - .map_err(|why| { - if let Some(error) = why.as_database_error() { - let pg_error: &sqlx::postgres::PgDatabaseError = error.downcast_ref(); - if pg_error.code() == "23505" { - return PgStorageError::DuplicateKey(key); - } - } - PgStorageError::Pg(why) - }) + .map_err(Into::into) } - async fn read_key(&self, key: String) -> Result { - sqlx::query_as("select * from api_keys where key=$1") - .bind(&key) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(key)) + async fn read_key(&self, selector: KeySelector) -> Result { + match &selector { + KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1") + .bind(key) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1") + .bind(id) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + } } async fn read_user_keys(&self, user_id: i32) -> Result, Self::Error> { @@ -424,66 +455,101 @@ where .map_err(Into::into) } - async fn remove_key(&self, key: String) -> Result { - sqlx::query_as("delete from api_keys where key=$1 returning *") - .bind(&key) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(key)) + async fn remove_key(&self, selector: KeySelector) -> Result { + match &selector { + KeySelector::Key(key) => { + sqlx::query_as("delete from api_keys where key=$1 returning *") + .bind(key) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)) + } + KeySelector::Id(id) => sqlx::query_as("delete from api_keys where id=$1 returning *") + .bind(id) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + } } - async fn add_domain_to_key(&self, key: String, domain: D) -> Result { - let mut tx = self.pool.begin().await?; - match sqlx::query_as::>( - "update api_keys set domains = domains || jsonb_build_array($1) where key=$2 returning *", - ) - .bind(sqlx::types::Json(domain.clone())) - .bind(&key) - .fetch_optional(&mut tx) - .await? - { - None => Err(PgStorageError::KeyNotFound(key)), - Some(key) => { - if key.domains.0.iter().filter(|d| **d == domain).count() > 1 { - tx.rollback().await?; - return Err(PgStorageError::DuplicateDomain(domain)); - } - tx.commit().await?; - Ok(key) - } + async fn add_domain_to_key( + &self, + selector: KeySelector, + domain: D, + ) -> Result { + match &selector { + KeySelector::Key(key) => sqlx::query_as::>( + "update api_keys set domains = __unique_jsonb_array(domains || \ + jsonb_build_array($1)) where key=$2 returning *", + ) + .bind(sqlx::types::Json(domain)) + .bind(key) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + KeySelector::Id(id) => sqlx::query_as::>( + "update api_keys set domains = __unique_jsonb_array(domains || \ + jsonb_build_array($1)) where id=$2 returning *", + ) + .bind(sqlx::types::Json(domain)) + .bind(id) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)), } } async fn remove_domain_from_key( &self, - key: String, + selector: KeySelector, domain: D, ) -> Result { - // FIX: potential race condition - let api_key = self.read_key(key.clone()).await?; - let domains = api_key - .domains - .0 - .into_iter() - .filter(|d| *d != domain) - .collect(); - - self.set_domains_for_key(key, domains).await + match &selector { + KeySelector::Key(key) => sqlx::query_as( + "update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \ + '[]'::jsonb) where key=$2 returning *", + ) + .bind(sqlx::types::Json(domain)) + .bind(key) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + KeySelector::Id(id) => sqlx::query_as( + "update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \ + '[]'::jsonb) where id=$2 returning *", + ) + .bind(sqlx::types::Json(domain)) + .bind(id) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + } } async fn set_domains_for_key( &self, - key: String, + selector: KeySelector, domains: Vec, ) -> Result { - sqlx::query_as::>( - "update api_keys set domains = $1 where key=$2 returning *", - ) - .bind(sqlx::types::Json(domains)) - .bind(&key) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(key)) + match &selector { + KeySelector::Key(key) => sqlx::query_as::>( + "update api_keys set domains = $1 where key=$2 returning *", + ) + .bind(sqlx::types::Json(domains)) + .bind(key) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + + KeySelector::Id(id) => sqlx::query_as::>( + "update api_keys set domains = $1 where id=$2 returning *", + ) + .bind(sqlx::types::Json(domains)) + .bind(id) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + } } } @@ -516,7 +582,7 @@ pub(crate) mod test { } } - pub(crate) async fn setup() -> PgKeyPoolStorage { + pub(crate) async fn setup() -> (PgKeyPoolStorage, PgKey) { INIT.call_once(|| { dotenv::dotenv().ok(); }); @@ -533,17 +599,17 @@ pub(crate) mod test { let storage = PgKeyPoolStorage::new(pool.clone(), 1000); storage.initialise().await.unwrap(); - storage + let key = storage .store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All]) .await .unwrap(); - storage + (storage, key) } #[test] async fn test_initialise() { - let storage = setup().await; + let (storage, _) = setup().await; if let Err(e) = storage.initialise().await { panic!("Initialising key storage failed: {:?}", e); @@ -551,25 +617,43 @@ pub(crate) mod test { } #[test] - async fn test_store_duplicate() { - let storage = setup().await; - match storage - .store_key(1, std::env::var("APIKEY").unwrap(), vec![]) + async fn test_store_duplicate_key() { + let (storage, key) = setup().await; + let key = storage + .store_key(1, key.key, vec![Domain::User { id: 1 }]) .await - .unwrap_err() - { - PgStorageError::DuplicateKey(key) => { - assert_eq!(key, std::env::var("APIKEY").unwrap()) - } - why => panic!("Expected duplicate key error but found '{why}'"), - }; + .unwrap(); + + assert_eq!(key.domains.0.len(), 2); + } + + #[test] + async fn test_store_duplicate_key_duplicate_domain() { + let (storage, key) = setup().await; + let key = storage + .store_key(1, key.key, vec![Domain::All]) + .await + .unwrap(); + + assert_eq!(key.domains.0.len(), 1); } #[test] async fn test_add_domain() { - let storage = setup().await; + let (storage, key) = setup().await; let key = storage - .add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::User { id: 12345 }) + .add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 }) + .await + .unwrap(); + + assert!(key.domains.0.contains(&Domain::User { id: 12345 })); + } + + #[test] + async fn test_add_domain_id() { + let (storage, key) = setup().await; + let key = storage + .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 }) .await .unwrap(); @@ -578,22 +662,56 @@ pub(crate) mod test { #[test] async fn test_add_duplicate_domain() { - let storage = setup().await; - match storage - .add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::All) + let (storage, key) = setup().await; + let key = storage + .add_domain_to_key(KeySelector::Key(key.key), Domain::All) .await - .unwrap_err() - { - PgStorageError::DuplicateDomain(d) => assert_eq!(d, Domain::All), - why => panic!("Expected duplicate domain error but found '{why}'"), - }; + .unwrap(); + assert_eq!( + key.domains + .0 + .into_iter() + .filter(|d| *d == Domain::All) + .count(), + 1 + ); } #[test] async fn test_remove_domain() { - let storage = setup().await; + let (storage, key) = setup().await; + storage + .add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 }) + .await + .unwrap(); let key = storage - .remove_domain_from_key(std::env::var("APIKEY").unwrap(), Domain::All) + .remove_domain_from_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 }) + .await + .unwrap(); + + assert_eq!(key.domains.0, vec![Domain::All]); + } + + #[test] + async fn test_remove_domain_id() { + let (storage, key) = setup().await; + storage + .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 }) + .await + .unwrap(); + let key = storage + .remove_domain_from_key(KeySelector::Id(key.id), Domain::User { id: 1 }) + .await + .unwrap(); + + assert_eq!(key.domains.0, vec![Domain::All]); + } + + #[test] + async fn test_remove_last_domain() { + let (storage, key) = setup().await; + let key = storage + .remove_domain_from_key(KeySelector::Key(key.key), Domain::All) .await .unwrap(); @@ -602,7 +720,7 @@ pub(crate) mod test { #[test] async fn test_store_key() { - let storage = setup().await; + let (storage, _) = setup().await; let key = storage .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![]) .await @@ -612,7 +730,7 @@ pub(crate) mod test { #[test] async fn test_read_user_keys() { - let storage = setup().await; + let (storage, _) = setup().await; let keys = storage.read_user_keys(1).await.unwrap(); assert_eq!(keys.len(), 1); @@ -620,7 +738,7 @@ pub(crate) mod test { #[test] async fn acquire_one() { - let storage = setup().await; + let (storage, _) = setup().await; if let Err(e) = storage.acquire_key(Domain::All).await { panic!("Acquiring key failed: {:?}", e); @@ -629,11 +747,7 @@ pub(crate) mod test { #[test] async fn test_flag_key_one() { - let storage = setup().await; - let key = storage - .read_key(std::env::var("APIKEY").unwrap()) - .await - .unwrap(); + let (storage, key) = setup().await; assert!(storage.flag_key(key, 2).await.unwrap()); @@ -645,11 +759,7 @@ pub(crate) mod test { #[test] async fn test_flag_key_many() { - let storage = setup().await; - let key = storage - .read_key(std::env::var("APIKEY").unwrap()) - .await - .unwrap(); + let (storage, key) = setup().await; assert!(storage.flag_key(key, 2).await.unwrap()); @@ -661,7 +771,7 @@ pub(crate) mod test { #[test] async fn acquire_many() { - let storage = setup().await; + let (storage, _) = setup().await; match storage.acquire_many_keys(Domain::All, 30).await { Err(e) => panic!("Acquiring key failed: {:?}", e), @@ -669,9 +779,10 @@ pub(crate) mod test { } } + // HACK: this test is time sensitive and will fail if runs at the top of the minute #[test] async fn test_concurrent() { - let storage = Arc::new(setup().await); + let storage = Arc::new(setup().await.0); for _ in 0..10 { let mut set = tokio::task::JoinSet::new(); @@ -702,9 +813,10 @@ pub(crate) mod test { } } + // HACK: this test is time sensitive and will fail if runs at the top of the minute #[test] async fn test_concurrent_many() { - let storage = Arc::new(setup().await); + let storage = Arc::new(setup().await.0); for _ in 0..10 { let mut set = tokio::task::JoinSet::new(); diff --git a/torn-key-pool/src/send.rs b/torn-key-pool/src/send.rs index aa89a6a..e2cb563 100644 --- a/torn-key-pool/src/send.rs +++ b/torn-key-pool/src/send.rs @@ -179,7 +179,7 @@ mod test { #[test] async fn test_pool_request() { - let storage = setup().await; + let (storage, _) = setup().await; let pool = KeyPool::new( reqwest::Client::default(), storage, @@ -192,7 +192,7 @@ mod test { #[test] async fn test_with_storage_request() { - let storage = setup().await; + let (storage, _) = setup().await; let response = reqwest::Client::new() .with_storage(&storage, Domain::All)