From 0799d6d4754e9ebce0a2e85407e9fe549930bb5e Mon Sep 17 00:00:00 2001 From: TotallyNot <44345987+TotallyNot@users.noreply.github.com> Date: Sun, 11 Jun 2023 15:29:15 +0200 Subject: [PATCH] use new `IntoSelector` trait to identify keys --- torn-key-pool/Cargo.toml | 2 +- torn-key-pool/src/lib.rs | 110 ++++++++-- torn-key-pool/src/postgres.rs | 387 ++++++++++++++++++++-------------- 3 files changed, 313 insertions(+), 186 deletions(-) diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index 2bd1e79..2f76d7d 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "0.5.7" +version = "0.6.0" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index c24be24..748122b 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -30,7 +30,7 @@ where } pub trait ApiKey: Sync + Send { - type IdType: PartialEq + Eq + std::hash::Hash; + type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync; fn value(&self) -> &str; @@ -44,12 +44,65 @@ pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync { } #[derive(Debug, Clone)] -pub enum KeySelector +pub enum KeySelector where K: ApiKey, + D: KeyDomain, { Key(String), Id(K::IdType), + UserId(i32), + Has(D), + OneOf(Vec), +} + +impl KeySelector +where + K: ApiKey, + D: KeyDomain, +{ + pub(crate) fn fallback(&self) -> Option { + match self { + Self::Key(_) | Self::UserId(_) | Self::Id(_) => None, + Self::Has(domain) => domain.fallback().map(Self::Has), + Self::OneOf(domains) => { + let fallbacks: Vec<_> = domains.into_iter().filter_map(|d| d.fallback()).collect(); + if fallbacks.is_empty() { + None + } else { + Some(Self::OneOf(fallbacks)) + } + } + } + } +} + +pub trait IntoSelector: Send + Sync +where + K: ApiKey, + D: KeyDomain, +{ + fn into_selector(self) -> KeySelector; +} + +impl IntoSelector for D +where + K: ApiKey, + D: KeyDomain, +{ + fn into_selector(self) -> KeySelector { + KeySelector::Has(self) + } +} + +impl IntoSelector for KeySelector +where + K: ApiKey, + D: KeyDomain, +{ + fn into_selector(self) -> KeySelector { + self + } } #[async_trait] @@ -58,13 +111,17 @@ pub trait KeyPoolStorage { type Domain: KeyDomain; type Error: std::error::Error + Sync + Send; - async fn acquire_key(&self, domain: Self::Domain) -> Result; + async fn acquire_key(&self, selector: S) -> Result + where + S: IntoSelector; - async fn acquire_many_keys( + async fn acquire_many_keys( &self, - domain: Self::Domain, + selector: S, number: i64, - ) -> Result, Self::Error>; + ) -> Result, Self::Error> + where + S: IntoSelector; async fn flag_key(&self, key: Self::Key, code: u8) -> Result; @@ -75,34 +132,41 @@ pub trait KeyPoolStorage { domains: Vec, ) -> Result; - async fn read_key(&self, key: KeySelector) - -> Result, Self::Error>; + async fn read_key(&self, selector: S) -> Result, Self::Error> + where + S: IntoSelector; - async fn read_user_keys(&self, user_id: i32) -> Result, Self::Error>; + async fn read_keys(&self, selector: S) -> Result, Self::Error> + where + S: IntoSelector; - async fn remove_key(&self, key: KeySelector) -> Result; + async fn remove_key(&self, selector: S) -> Result + where + S: IntoSelector; - async fn query_key(&self, domain: Self::Domain) -> Result, Self::Error>; - - async fn query_all(&self, domain: Self::Domain) -> Result, Self::Error>; - - async fn add_domain_to_key( + async fn add_domain_to_key( &self, - key: KeySelector, + selector: S, domain: Self::Domain, - ) -> Result; + ) -> Result + where + S: IntoSelector; - async fn remove_domain_from_key( + async fn remove_domain_from_key( &self, - key: KeySelector, + selector: S, domain: Self::Domain, - ) -> Result; + ) -> Result + where + S: IntoSelector; - async fn set_domains_for_key( + async fn set_domains_for_key( &self, - key: KeySelector, + selector: S, domains: Vec, - ) -> Result; + ) -> Result + where + S: IntoSelector; } #[derive(Debug, Clone)] diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 561ad3f..c6c461f 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; use indoc::indoc; -use sqlx::{FromRow, PgPool}; +use sqlx::{FromRow, PgPool, Postgres, QueryBuilder}; use thiserror::Error; -use crate::{ApiKey, KeyDomain, KeyPoolStorage, KeySelector}; +use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector}; pub trait PgKeyDomain: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin @@ -24,10 +24,10 @@ where Pg(#[from] sqlx::Error), #[error("No key avalaible for domain {0:?}")] - Unavailable(D), + Unavailable(KeySelector, D>), #[error("Key not found: '{0:?}'")] - KeyNotFound(KeySelector>), + KeyNotFound(KeySelector, D>), } #[derive(Debug, Clone, FromRow)] @@ -42,6 +42,41 @@ where pub domains: sqlx::types::Json>, } +#[inline(always)] +fn build_predicate<'b, D>( + builder: &mut QueryBuilder<'b, Postgres>, + selector: &'b KeySelector, D>, +) where + D: PgKeyDomain, +{ + match selector { + KeySelector::Id(id) => builder.push("id=").push_bind(id), + KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id), + KeySelector::Key(key) => builder.push("key=").push_bind(key), + KeySelector::Has(domain) => builder + .push("domains @> ") + .push_bind(sqlx::types::Json(vec![domain])), + KeySelector::OneOf(domains) => { + if domains.is_empty() { + builder.push("false"); + return; + } + + for (idx, domain) in domains.iter().enumerate() { + if idx == 0 { + builder.push("("); + } else { + builder.push(" or "); + } + builder + .push("domains @> ") + .push_bind(sqlx::types::Json(vec![domain])); + } + builder.push(")") + } + }; +} + #[derive(Debug, Clone, FromRow)] pub struct PgKeyPoolStorage where @@ -160,7 +195,11 @@ where type Error = PgStorageError; - async fn acquire_key(&self, domain: D) -> Result { + async fn acquire_key(&self, selector: S) -> Result + where + S: IntoSelector, + { + let selector = selector.into_selector(); loop { let attempt = async { let mut tx = self.pool.begin().await?; @@ -169,22 +208,33 @@ where .execute(&mut tx) .await?; - // TODO: improve query - let key = sqlx::query_as(&indoc::formatdoc!( + let mut qb = QueryBuilder::new(indoc::indoc! { r#" with key as ( select id, 0::int2 as uses from api_keys where last_used < date_trunc('minute', now()) - and (cooldown is null or now() >= cooldown) - and domains @> $1 - union ( + and (cooldown is null or now() >= cooldown) + and "# + }); + + build_predicate(&mut qb, &selector); + + qb.push(indoc::indoc! { + " + \n union ( select id, uses from api_keys where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) - and domains @> $1 - order by uses asc limit 1 + and " + }); + + build_predicate(&mut qb, &selector); + + qb.push(indoc::indoc! { + " + \n order by uses asc limit 1 ) order by uses asc limit 1 ) @@ -194,19 +244,21 @@ where flag = null, last_used = now() from key where - api_keys.id=key.id and key.uses < $2 - returning + api_keys.id=key.id and key.uses < " + }); + + qb.push_bind(self.limit); + + qb.push(indoc::indoc! { " + \nreturning api_keys.id, api_keys.user_id, api_keys.key, api_keys.uses, - api_keys.domains - "#, - )) - .bind(sqlx::types::Json(vec![&domain])) - .bind(self.limit) - .fetch_optional(&mut tx) - .await?; + api_keys.domains" + }); + + let key = qb.build_query_as().fetch_optional(&mut tx).await?; tx.commit().await?; @@ -219,9 +271,9 @@ where Ok(None) => { return self .acquire_key( - domain + selector .fallback() - .ok_or_else(|| PgStorageError::Unavailable(domain))?, + .ok_or_else(|| PgStorageError::Unavailable(selector))?, ) .await } @@ -241,11 +293,15 @@ where } } - async fn acquire_many_keys( + async fn acquire_many_keys( &self, - domain: D, + selector: S, number: i64, - ) -> Result, Self::Error> { + ) -> Result, Self::Error> + where + S: IntoSelector, + { + let selector = selector.into_selector(); loop { let attempt = async { let mut tx = self.pool.begin().await?; @@ -254,33 +310,36 @@ where .execute(&mut tx) .await?; - let mut keys: Vec = sqlx::query_as(&indoc::formatdoc!( + let mut qb = QueryBuilder::new(indoc::indoc! { r#"select id, user_id, key, 0::int2 as uses, domains - from api_keys where last_used < date_trunc('minute', now()) - and (cooldown is null or now() >= cooldown) - and domains @> $1 - union + from api_keys where last_used < date_trunc('minute', now()) + and (cooldown is null or now() >= cooldown) + and "# + }); + build_predicate(&mut qb, &selector); + qb.push(indoc::indoc! { + " + \nunion select id, user_id, key, uses, domains - from api_keys where last_used >= date_trunc('minute', now()) - and (cooldown is null or now() >= cooldown) - and domains @> $1 - order by uses limit $2 - "#, - )) - .bind(sqlx::types::Json(vec![&domain])) - .bind(number) - .fetch_all(&mut tx) - .await?; + from api_keys where last_used >= date_trunc('minute', now()) + and (cooldown is null or now() >= cooldown) + and " + }); + build_predicate(&mut qb, &selector); + qb.push("\norder by uses limit "); + qb.push_bind(self.limit); + + let mut keys: Vec = qb.build_query_as().fetch_all(&mut tx).await?; if keys.is_empty() { tx.commit().await?; @@ -338,9 +397,9 @@ where Ok(None) => { return self .acquire_many_keys( - domain + selector .fallback() - .ok_or_else(|| Self::Error::Unavailable(domain))?, + .ok_or_else(|| Self::Error::Unavailable(selector))?, number, ) .await @@ -433,143 +492,116 @@ where .map_err(Into::into) } - async fn read_key( - &self, - selector: KeySelector, - ) -> Result, Self::Error> { - match &selector { - KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1") - .bind(key) - .fetch_optional(&self.pool) - .await - .map_err(Into::into), - KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1") - .bind(id) - .fetch_optional(&self.pool) - .await - .map_err(Into::into), - } - } + async fn read_key(&self, selector: S) -> Result, Self::Error> + where + S: IntoSelector, + { + let selector = selector.into_selector(); - async fn query_key(&self, domain: D) -> Result, Self::Error> { - sqlx::query_as("select * from api_keys where domains @> $1 limit 1") - .bind(sqlx::types::Json(vec![domain])) + let mut qb = QueryBuilder::new("select * from api_keys where "); + build_predicate(&mut qb, &selector); + + qb.build_query_as() .fetch_optional(&self.pool) .await .map_err(Into::into) } - async fn query_all(&self, domain: D) -> Result, Self::Error> { - sqlx::query_as("select * from api_keys where domains @> $1") - .bind(sqlx::types::Json(vec![domain])) + async fn read_keys(&self, selector: S) -> Result, Self::Error> + where + S: IntoSelector, + { + let selector = selector.into_selector(); + + let mut qb = QueryBuilder::new("select * from api_keys where "); + build_predicate(&mut qb, &selector); + + qb.build_query_as() .fetch_all(&self.pool) .await .map_err(Into::into) } - async fn read_user_keys(&self, user_id: i32) -> Result, Self::Error> { - sqlx::query_as("select * from api_keys where user_id=$1") - .bind(user_id) - .fetch_all(&self.pool) - .await - .map_err(Into::into) + async fn remove_key(&self, selector: S) -> Result + where + S: IntoSelector, + { + let selector = selector.into_selector(); + + let mut qb = QueryBuilder::new("delete from api_keys where "); + build_predicate(&mut qb, &selector); + qb.push(" returning *"); + + qb.build_query_as() + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)) } - async fn remove_key(&self, selector: KeySelector) -> Result { - match &selector { - KeySelector::Key(key) => { - sqlx::query_as("delete from api_keys where key=$1 returning *") - .bind(key) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)) - } - KeySelector::Id(id) => sqlx::query_as("delete from api_keys where id=$1 returning *") - .bind(id) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - } + async fn add_domain_to_key(&self, selector: S, domain: D) -> Result + where + S: IntoSelector, + { + let selector = selector.into_selector(); + + let mut qb = QueryBuilder::new( + "update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(", + ); + qb.push_bind(sqlx::types::Json(domain)); + qb.push(")) where "); + build_predicate(&mut qb, &selector); + qb.push(" returning *"); + + qb.build_query_as() + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)) } - async fn add_domain_to_key( + async fn remove_domain_from_key( &self, - selector: KeySelector, + selector: S, domain: D, - ) -> Result { - match &selector { - KeySelector::Key(key) => sqlx::query_as::>( - "update api_keys set domains = __unique_jsonb_array(domains || \ - jsonb_build_array($1)) where key=$2 returning *", - ) - .bind(sqlx::types::Json(domain)) - .bind(key) + ) -> Result + where + S: IntoSelector, + { + let selector = selector.into_selector(); + + let mut qb = QueryBuilder::new( + "update api_keys set domains = coalesce(__filter_jsonb_array(domains, ", + ); + qb.push_bind(sqlx::types::Json(domain)); + qb.push("), '[]'::jsonb) where "); + build_predicate(&mut qb, &selector); + qb.push(" returning *"); + + qb.build_query_as() .fetch_optional(&self.pool) .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - KeySelector::Id(id) => sqlx::query_as::>( - "update api_keys set domains = __unique_jsonb_array(domains || \ - jsonb_build_array($1)) where id=$2 returning *", - ) - .bind(sqlx::types::Json(domain)) - .bind(id) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - } + .ok_or_else(|| PgStorageError::KeyNotFound(selector)) } - async fn remove_domain_from_key( + async fn set_domains_for_key( &self, - selector: KeySelector, - domain: D, - ) -> Result { - match &selector { - KeySelector::Key(key) => sqlx::query_as( - "update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \ - '[]'::jsonb) where key=$2 returning *", - ) - .bind(sqlx::types::Json(domain)) - .bind(key) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - KeySelector::Id(id) => sqlx::query_as( - "update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \ - '[]'::jsonb) where id=$2 returning *", - ) - .bind(sqlx::types::Json(domain)) - .bind(id) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - } - } - - async fn set_domains_for_key( - &self, - selector: KeySelector, + selector: S, domains: Vec, - ) -> Result { - match &selector { - KeySelector::Key(key) => sqlx::query_as::>( - "update api_keys set domains = $1 where key=$2 returning *", - ) - .bind(sqlx::types::Json(domains)) - .bind(key) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + ) -> Result + where + S: IntoSelector, + { + let selector = selector.into_selector(); - KeySelector::Id(id) => sqlx::query_as::>( - "update api_keys set domains = $1 where id=$2 returning *", - ) - .bind(sqlx::types::Json(domains)) - .bind(id) + let mut qb = QueryBuilder::new("update api_keys set domains = "); + qb.push_bind(sqlx::types::Json(domains)); + qb.push(" where "); + build_predicate(&mut qb, &selector); + qb.push(" returning *"); + + qb.build_query_as() .fetch_optional(&self.pool) .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - } + .ok_or_else(|| PgStorageError::KeyNotFound(selector)) } } @@ -752,7 +784,7 @@ pub(crate) mod test { async fn test_read_user_keys() { let (storage, _) = setup().await; - let keys = storage.read_user_keys(1).await.unwrap(); + let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap(); assert_eq!(keys.len(), 1); } @@ -777,7 +809,7 @@ pub(crate) mod test { _ = storage.acquire_key(Domain::All).await.unwrap(); } - let keys = storage.read_user_keys(1).await.unwrap(); + let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap(); assert_eq!(keys.len(), 2); for key in keys { assert_eq!(key.uses, 5); @@ -791,7 +823,7 @@ pub(crate) mod test { assert!(storage.flag_key(key, 2).await.unwrap()); match storage.acquire_key(Domain::All).await.unwrap_err() { - PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All), + PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))), why => panic!("Expected domain unavailable error but found '{why}'"), } } @@ -803,7 +835,7 @@ pub(crate) mod test { assert!(storage.flag_key(key, 2).await.unwrap()); match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() { - PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All), + PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))), why => panic!("Expected domain unavailable error but found '{why}'"), } } @@ -877,7 +909,7 @@ pub(crate) mod test { set.join_next().await.unwrap().unwrap(); } - let keys = storage.read_user_keys(1).await.unwrap(); + let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap(); assert_eq!(keys.len(), 25); @@ -952,7 +984,7 @@ pub(crate) mod test { async fn query_key() { let (storage, _) = setup().await; - let key = storage.query_key(Domain::All).await.unwrap(); + let key = storage.read_key(Domain::All).await.unwrap(); assert!(key.is_some()); } @@ -960,7 +992,7 @@ pub(crate) mod test { async fn query_nonexistent_key() { let (storage, _) = setup().await; - let key = storage.query_key(Domain::Guild { id: 0 }).await.unwrap(); + let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap(); assert!(key.is_none()); } @@ -968,7 +1000,38 @@ pub(crate) mod test { async fn query_all() { let (storage, _) = setup().await; - let keys = storage.query_all(Domain::All).await.unwrap(); + let keys = storage.read_keys(Domain::All).await.unwrap(); assert!(keys.len() == 1); } + + #[test] + async fn query_by_id() { + let (storage, _) = setup().await; + let key = storage.read_key(KeySelector::Id(1)).await.unwrap(); + + assert!(key.is_some()); + } + + #[test] + async fn query_by_key() { + let (storage, key) = setup().await; + let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap(); + + assert!(key.is_some()); + } + + #[test] + async fn query_by_set() { + let (storage, _key) = setup().await; + let key = storage + .read_key(KeySelector::OneOf(vec![ + Domain::All, + Domain::Guild { id: 0 }, + Domain::Faction { id: 0 }, + ])) + .await + .unwrap(); + + assert!(key.is_some()); + } }