From 75fc19d0f78b5a1701f7ee6c67d5ba9dd932262e Mon Sep 17 00:00:00 2001 From: TotallyNot <44345987+TotallyNot@users.noreply.github.com> Date: Thu, 4 Apr 2024 15:59:10 +0200 Subject: [PATCH] major refactoring --- torn-api-macros/Cargo.toml | 2 +- torn-api-macros/src/lib.rs | 10 +- torn-api/Cargo.toml | 4 +- torn-api/src/lib.rs | 10 +- torn-api/src/local.rs | 37 ++--- torn-api/src/send.rs | 37 ++--- torn-key-pool/Cargo.toml | 4 +- torn-key-pool/src/lib.rs | 56 ++++++-- torn-key-pool/src/postgres.rs | 252 ++++++++++++++++++++-------------- torn-key-pool/src/send.rs | 214 ++++++++++++++++++++++++----- 10 files changed, 404 insertions(+), 222 deletions(-) diff --git a/torn-api-macros/Cargo.toml b/torn-api-macros/Cargo.toml index e860b4b..f05c5ac 100644 --- a/torn-api-macros/Cargo.toml +++ b/torn-api-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-api-macros" -version = "0.2.0" +version = "0.3.0" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" diff --git a/torn-api-macros/src/lib.rs b/torn-api-macros/src/lib.rs index 320d941..c2d70a1 100644 --- a/torn-api-macros/src/lib.rs +++ b/torn-api-macros/src/lib.rs @@ -147,15 +147,15 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream { #(#accessors)* } - impl crate::ApiCategoryResponse for Response { - type Selection = #name; - - fn from_response(response: crate::ApiResponse) -> Self { - Self(response) + impl From for Response { + fn from(value: crate::ApiResponse) -> Self { + Self(value) } } impl crate::ApiSelection for #name { + type Response = Response; + fn raw_value(self) -> &'static str { match self { #(#raw_values,)* diff --git a/torn-api/Cargo.toml b/torn-api/Cargo.toml index 822ed35..95dc93f 100644 --- a/torn-api/Cargo.toml +++ b/torn-api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-api" -version = "0.6.7" +version = "0.7.0" edition = "2021" rust-version = "1.75.0" authors = ["Pyrit [2111649]"] @@ -39,7 +39,7 @@ reqwest = { version = "0.11", default-features = false, features = [ "json" ], o awc = { version = "3", default-features = false, optional = true } rust_decimal = { version = "1", default-features = false, optional = true, features = [ "serde" ] } -torn-api-macros = { path = "../torn-api-macros", version = "0.2" } +torn-api-macros = { path = "../torn-api-macros", version = "0.3" } [dev-dependencies] actix-rt = { version = "2.7.0" } diff --git a/torn-api/src/lib.rs b/torn-api/src/lib.rs index 3690feb..35bceba 100644 --- a/torn-api/src/lib.rs +++ b/torn-api/src/lib.rs @@ -111,18 +111,14 @@ impl ApiResponse { } } -pub trait ApiSelection: Send + Sync { +pub trait ApiSelection: Send + Sync + 'static { + type Response: From + Send + Sync; + fn raw_value(self) -> &'static str; fn category() -> &'static str; } -pub trait ApiCategoryResponse: Send + Sync { - type Selection: ApiSelection; - - fn from_response(response: ApiResponse) -> Self; -} - pub struct DirectExecutor { key: String, _marker: std::marker::PhantomData, diff --git a/torn-api/src/local.rs b/torn-api/src/local.rs index 80e63a7..772e8fc 100644 --- a/torn-api/src/local.rs +++ b/torn-api/src/local.rs @@ -2,9 +2,7 @@ use std::collections::HashMap; use async_trait::async_trait; -use crate::{ - ApiCategoryResponse, ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor, -}; +use crate::{ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor}; pub struct ApiProvider<'a, C, E> where @@ -39,7 +37,6 @@ where self.executor .execute(self.client, builder.request, builder.id) .await - .map(crate::user::Response::from_response) } #[cfg(feature = "user")] @@ -61,9 +58,6 @@ where self.executor .execute_many(self.client, builder.request, Vec::from_iter(ids)) .await - .into_iter() - .map(|(k, v)| (k, v.map(crate::user::Response::from_response))) - .collect() } #[cfg(feature = "faction")] @@ -79,7 +73,6 @@ where self.executor .execute(self.client, builder.request, builder.id) .await - .map(crate::faction::Response::from_response) } #[cfg(feature = "faction")] @@ -101,9 +94,6 @@ where self.executor .execute_many(self.client, builder.request, Vec::from_iter(ids)) .await - .into_iter() - .map(|(k, v)| (k, v.map(crate::faction::Response::from_response))) - .collect() } #[cfg(feature = "market")] @@ -119,7 +109,6 @@ where self.executor .execute(self.client, builder.request, builder.id) .await - .map(crate::market::Response::from_response) } #[cfg(feature = "market")] @@ -141,9 +130,6 @@ where self.executor .execute_many(self.client, builder.request, Vec::from_iter(ids)) .await - .into_iter() - .map(|(k, v)| (k, v.map(crate::market::Response::from_response))) - .collect() } #[cfg(feature = "torn")] @@ -159,7 +145,6 @@ where self.executor .execute(self.client, builder.request, builder.id) .await - .map(crate::torn::Response::from_response) } #[cfg(feature = "torn")] @@ -181,9 +166,6 @@ where self.executor .execute_many(self.client, builder.request, Vec::from_iter(ids)) .await - .into_iter() - .map(|(k, v)| (k, v.map(crate::torn::Response::from_response))) - .collect() } #[cfg(feature = "key")] @@ -199,7 +181,6 @@ where self.executor .execute(self.client, builder.request, builder.id) .await - .map(crate::key::Response::from_response) } } @@ -215,7 +196,7 @@ where client: &C, request: ApiRequest, id: Option, - ) -> Result + ) -> Result where A: ApiSelection; @@ -224,7 +205,7 @@ where client: &C, request: ApiRequest, ids: Vec, - ) -> HashMap> + ) -> HashMap> where A: ApiSelection, I: ToString + std::hash::Hash + std::cmp::Eq; @@ -242,7 +223,7 @@ where client: &C, request: ApiRequest, id: Option, - ) -> Result + ) -> Result where A: ApiSelection, { @@ -250,7 +231,7 @@ where let value = client.request(url).await.map_err(ApiClientError::Client)?; - Ok(ApiResponse::from_value(value)?) + Ok(ApiResponse::from_value(value)?.into()) } async fn execute_many( @@ -258,7 +239,7 @@ where client: &C, request: ApiRequest, ids: Vec, - ) -> HashMap> + ) -> HashMap> where A: ApiSelection, I: ToString + std::hash::Hash + std::cmp::Eq, @@ -272,7 +253,11 @@ where ( i, - value.and_then(|v| ApiResponse::from_value(v).map_err(Into::into)), + value.and_then(|v| { + ApiResponse::from_value(v) + .map(Into::into) + .map_err(Into::into) + }), ) })) .await; diff --git a/torn-api/src/send.rs b/torn-api/src/send.rs index 706f8b6..ca50eb1 100644 --- a/torn-api/src/send.rs +++ b/torn-api/src/send.rs @@ -2,9 +2,7 @@ use std::collections::HashMap; use async_trait::async_trait; -use crate::{ - ApiCategoryResponse, ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor, -}; +use crate::{ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor}; pub struct ApiProvider<'a, C, E> where @@ -37,7 +35,6 @@ where self.executor .execute(self.client, builder.request, builder.id) .await - .map(crate::user::Response::from_response) } #[cfg(feature = "user")] @@ -59,9 +56,6 @@ where self.executor .execute_many(self.client, builder.request, Vec::from_iter(ids)) .await - .into_iter() - .map(|(k, v)| (k, v.map(crate::user::Response::from_response))) - .collect() } #[cfg(feature = "faction")] @@ -77,7 +71,6 @@ where self.executor .execute(self.client, builder.request, builder.id) .await - .map(crate::faction::Response::from_response) } #[cfg(feature = "faction")] @@ -99,9 +92,6 @@ where self.executor .execute_many(self.client, builder.request, Vec::from_iter(ids)) .await - .into_iter() - .map(|(k, v)| (k, v.map(crate::faction::Response::from_response))) - .collect() } #[cfg(feature = "market")] @@ -117,7 +107,6 @@ where self.executor .execute(self.client, builder.request, builder.id) .await - .map(crate::market::Response::from_response) } #[cfg(feature = "market")] @@ -139,9 +128,6 @@ where self.executor .execute_many(self.client, builder.request, Vec::from_iter(ids)) .await - .into_iter() - .map(|(k, v)| (k, v.map(crate::market::Response::from_response))) - .collect() } #[cfg(feature = "torn")] @@ -157,7 +143,6 @@ where self.executor .execute(self.client, builder.request, builder.id) .await - .map(crate::torn::Response::from_response) } #[cfg(feature = "torn")] @@ -179,9 +164,6 @@ where self.executor .execute_many(self.client, builder.request, Vec::from_iter(ids)) .await - .into_iter() - .map(|(k, v)| (k, v.map(crate::torn::Response::from_response))) - .collect() } #[cfg(feature = "key")] @@ -197,7 +179,6 @@ where self.executor .execute(self.client, builder.request, builder.id) .await - .map(crate::key::Response::from_response) } } @@ -213,7 +194,7 @@ where client: &C, request: ApiRequest, id: Option, - ) -> Result + ) -> Result where A: ApiSelection; @@ -222,7 +203,7 @@ where client: &C, request: ApiRequest, ids: Vec, - ) -> HashMap> + ) -> HashMap> where A: ApiSelection, I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync; @@ -240,7 +221,7 @@ where client: &C, request: ApiRequest, id: Option, - ) -> Result + ) -> Result where A: ApiSelection, { @@ -248,7 +229,7 @@ where let value = client.request(url).await.map_err(ApiClientError::Client)?; - Ok(ApiResponse::from_value(value)?) + Ok(ApiResponse::from_value(value)?.into()) } async fn execute_many( @@ -256,7 +237,7 @@ where client: &C, request: ApiRequest, ids: Vec, - ) -> HashMap> + ) -> HashMap> where A: ApiSelection, I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync, @@ -270,7 +251,11 @@ where ( i, - value.and_then(|v| ApiResponse::from_value(v).map_err(Into::into)), + value.and_then(|v| { + ApiResponse::from_value(v) + .map(Into::into) + .map_err(Into::into) + }), ) })) .await; diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index 41ac733..1612f1d 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "0.7.0" +version = "0.8.0" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" @@ -17,7 +17,7 @@ tokio-runtime = [ "dep:tokio", "dep:rand" ] actix-runtime = [ "dep:actix-rt", "dep:rand" ] [dependencies] -torn-api = { path = "../torn-api", default-features = false, version = "0.6" } +torn-api = { path = "../torn-api", default-features = false, version = "0.7" } async-trait = "0.1" thiserror = "1" diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index 8a8d09b..7b9c961 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -3,7 +3,7 @@ #[cfg(feature = "postgres")] pub mod postgres; -pub mod local; +// pub mod local; pub mod send; use std::sync::Arc; @@ -16,11 +16,11 @@ use torn_api::ResponseError; #[derive(Debug, Error)] pub enum KeyPoolError where - S: std::error::Error, + S: std::error::Error + Clone, C: std::error::Error, { #[error("Key pool storage driver error: {0:?}")] - Storage(#[source] Arc), + Storage(#[source] S), #[error(transparent)] Client(#[from] C), @@ -31,7 +31,7 @@ where impl KeyPoolError where - S: std::error::Error, + S: std::error::Error + Clone, C: std::error::Error, { #[inline(always)] @@ -49,9 +49,16 @@ pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone { fn value(&self) -> &str; fn id(&self) -> Self::IdType; + + fn selector(&self) -> KeySelector + where + D: KeyDomain, + { + KeySelector::Id(self.id()) + } } -pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync { +pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static { fn fallback(&self) -> Option { None } @@ -66,7 +73,7 @@ where Key(String), Id(K::IdType), UserId(i32), - Has(D), + Has(Vec), OneOf(Vec), } @@ -78,7 +85,14 @@ where pub(crate) fn fallback(&self) -> Option { match self { Self::Key(_) | Self::UserId(_) | Self::Id(_) => None, - Self::Has(domain) => domain.fallback().map(Self::Has), + Self::Has(domains) => { + let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect(); + if fallbacks.is_empty() { + None + } else { + Some(Self::Has(fallbacks)) + } + } Self::OneOf(domains) => { let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect(); if fallbacks.is_empty() { @@ -105,7 +119,7 @@ where D: KeyDomain, { fn into_selector(self) -> KeySelector { - KeySelector::Has(self) + KeySelector::Has(vec![self]) } } @@ -119,11 +133,20 @@ where } } +pub enum KeyAction +where + D: KeyDomain, +{ + Delete, + RemoveDomain(D), + Timeout(chrono::Duration), +} + #[async_trait] pub trait KeyPoolStorage { type Key: ApiKey; type Domain: KeyDomain; - type Error: std::error::Error + Sync + Send; + type Error: std::error::Error + Sync + Send + Clone; async fn acquire_key(&self, selector: S) -> Result where @@ -183,13 +206,20 @@ pub trait KeyPoolStorage { S: IntoSelector; } +#[derive(Debug, Default)] +struct PoolOptions { + comment: Option, + hooks_before: std::collections::HashMap>, + hooks_after: std::collections::HashMap>, +} + #[derive(Debug, Clone)] pub struct KeyPoolExecutor<'a, C, S> where S: KeyPoolStorage, { storage: &'a S, - comment: Option<&'a str>, + options: Arc, selector: KeySelector, _marker: std::marker::PhantomData, } @@ -198,15 +228,15 @@ impl<'a, C, S> KeyPoolExecutor<'a, C, S> where S: KeyPoolStorage, { - pub fn new( + fn new( storage: &'a S, selector: KeySelector, - comment: Option<&'a str>, + options: Arc, ) -> Self { Self { storage, selector, - comment, + options, _marker: std::marker::PhantomData, } } diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 8899259..99ebe01 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use indoc::indoc; use sqlx::{FromRow, PgPool, Postgres, QueryBuilder}; @@ -15,13 +17,13 @@ impl PgKeyDomain for T where { } -#[derive(Debug, Error)] +#[derive(Debug, Error, Clone)] pub enum PgStorageError where D: PgKeyDomain, { #[error(transparent)] - Pg(#[from] sqlx::Error), + Pg(Arc), #[error("No key avalaible for domain {0:?}")] Unavailable(KeySelector, D>), @@ -30,6 +32,15 @@ where KeyNotFound(KeySelector, D>), } +impl From for PgStorageError +where + D: PgKeyDomain, +{ + fn from(value: sqlx::Error) -> Self { + Self::Pg(Arc::new(value)) + } +} + #[derive(Debug, Clone, FromRow)] pub struct PgKey where @@ -53,9 +64,9 @@ fn build_predicate<'b, D>( KeySelector::Id(id) => builder.push("id=").push_bind(id), KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id), KeySelector::Key(key) => builder.push("key=").push_bind(key), - KeySelector::Has(domain) => builder + KeySelector::Has(domains) => builder .push("domains @> ") - .push_bind(sqlx::types::Json(vec![domain])), + .push_bind(sqlx::types::Json(domains)), KeySelector::OneOf(domains) => { if domains.is_empty() { builder.push("false"); @@ -607,15 +618,12 @@ where #[cfg(test)] pub(crate) mod test { - use std::sync::{Arc, Once}; + use std::sync::Arc; use sqlx::Row; - use tokio::test; use super::*; - static INIT: Once = Once::new(); - #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub(crate) enum Domain { @@ -634,15 +642,7 @@ pub(crate) mod test { } } - pub(crate) async fn setup() -> (PgKeyPoolStorage, PgKey) { - INIT.call_once(|| { - dotenv::dotenv().ok(); - }); - - let pool = PgPool::connect(&std::env::var("DATABASE_URL").unwrap()) - .await - .unwrap(); - + pub(crate) async fn setup(pool: PgPool) -> (PgKeyPoolStorage, PgKey) { sqlx::query("DROP TABLE IF EXISTS api_keys") .execute(&pool) .await @@ -659,18 +659,18 @@ pub(crate) mod test { (storage, key) } - #[test] - async fn test_initialise() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn test_initialise(pool: PgPool) { + let (storage, _) = setup(pool).await; if let Err(e) = storage.initialise().await { panic!("Initialising key storage failed: {:?}", e); } } - #[test] - async fn test_store_duplicate_key() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn test_store_duplicate_key(pool: PgPool) { + let (storage, key) = setup(pool).await; let key = storage .store_key(1, key.key, vec![Domain::User { id: 1 }]) .await @@ -679,9 +679,9 @@ pub(crate) mod test { assert_eq!(key.domains.0.len(), 2); } - #[test] - async fn test_store_duplicate_key_duplicate_domain() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn test_store_duplicate_key_duplicate_domain(pool: PgPool) { + let (storage, key) = setup(pool).await; let key = storage .store_key(1, key.key, vec![Domain::All]) .await @@ -690,9 +690,9 @@ pub(crate) mod test { assert_eq!(key.domains.0.len(), 1); } - #[test] - async fn test_add_domain() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn test_add_domain(pool: PgPool) { + let (storage, key) = setup(pool).await; let key = storage .add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 }) .await @@ -701,9 +701,9 @@ pub(crate) mod test { assert!(key.domains.0.contains(&Domain::User { id: 12345 })); } - #[test] - async fn test_add_domain_id() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn test_add_domain_id(pool: PgPool) { + let (storage, key) = setup(pool).await; let key = storage .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 }) .await @@ -712,9 +712,9 @@ pub(crate) mod test { assert!(key.domains.0.contains(&Domain::User { id: 12345 })); } - #[test] - async fn test_add_duplicate_domain() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn test_add_duplicate_domain(pool: PgPool) { + let (storage, key) = setup(pool).await; let key = storage .add_domain_to_key(KeySelector::Key(key.key), Domain::All) .await @@ -729,9 +729,9 @@ pub(crate) mod test { ); } - #[test] - async fn test_remove_domain() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn test_remove_domain(pool: PgPool) { + let (storage, key) = setup(pool).await; storage .add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 }) .await @@ -744,9 +744,9 @@ pub(crate) mod test { assert_eq!(key.domains.0, vec![Domain::All]); } - #[test] - async fn test_remove_domain_id() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn test_remove_domain_id(pool: PgPool) { + let (storage, key) = setup(pool).await; storage .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 }) .await @@ -759,9 +759,9 @@ pub(crate) mod test { assert_eq!(key.domains.0, vec![Domain::All]); } - #[test] - async fn test_remove_last_domain() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn test_remove_last_domain(pool: PgPool) { + let (storage, key) = setup(pool).await; let key = storage .remove_domain_from_key(KeySelector::Key(key.key), Domain::All) .await @@ -770,9 +770,9 @@ pub(crate) mod test { assert!(key.domains.0.is_empty()); } - #[test] - async fn test_store_key() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn test_store_key(pool: PgPool) { + let (storage, _) = setup(pool).await; let key = storage .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![]) .await @@ -780,26 +780,26 @@ pub(crate) mod test { assert_eq!(key.value(), "ABCDABCDABCDABCD"); } - #[test] - async fn test_read_user_keys() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn test_read_user_keys(pool: PgPool) { + let (storage, _) = setup(pool).await; let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap(); assert_eq!(keys.len(), 1); } - #[test] - async fn acquire_one() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn acquire_one(pool: PgPool) { + let (storage, _) = setup(pool).await; if let Err(e) = storage.acquire_key(Domain::All).await { panic!("Acquiring key failed: {:?}", e); } } - #[test] - async fn uses_spread() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn uses_spread(pool: PgPool) { + let (storage, _) = setup(pool).await; storage .store_key(1, "ABC".to_owned(), vec![Domain::All]) .await @@ -816,33 +816,37 @@ pub(crate) mod test { } } - #[test] - async fn test_flag_key_one() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn test_flag_key_one(pool: PgPool) { + let (storage, key) = setup(pool).await; assert!(storage.flag_key(key, 2).await.unwrap()); match storage.acquire_key(Domain::All).await.unwrap_err() { - PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))), + PgStorageError::Unavailable(KeySelector::Has(domains)) => { + assert_eq!(domains, vec![Domain::All]) + } why => panic!("Expected domain unavailable error but found '{why}'"), } } - #[test] - async fn test_flag_key_many() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn test_flag_key_many(pool: PgPool) { + let (storage, key) = setup(pool).await; assert!(storage.flag_key(key, 2).await.unwrap()); match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() { - PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))), + PgStorageError::Unavailable(KeySelector::Has(domains)) => { + assert_eq!(domains, vec![Domain::All]) + } why => panic!("Expected domain unavailable error but found '{why}'"), } } - #[test] - async fn acquire_many() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn acquire_many(pool: PgPool) { + let (storage, _) = setup(pool).await; match storage.acquire_many_keys(Domain::All, 30).await { Err(e) => panic!("Acquiring key failed: {:?}", e), @@ -851,9 +855,9 @@ pub(crate) mod test { } // HACK: this test is time sensitive and will fail if runs at the top of the minute - #[test] - async fn test_concurrent() { - let storage = Arc::new(setup().await.0); + #[sqlx::test] + async fn test_concurrent(pool: PgPool) { + let storage = Arc::new(setup(pool).await.0); for _ in 0..10 { let mut set = tokio::task::JoinSet::new(); @@ -884,9 +888,9 @@ pub(crate) mod test { } } - #[test] - async fn test_concurrent_spread() { - let storage = Arc::new(setup().await.0); + #[sqlx::test] + async fn test_concurrent_spread(pool: PgPool) { + let storage = Arc::new(setup(pool).await.0); for i in 0..24 { storage @@ -923,10 +927,11 @@ pub(crate) mod test { .unwrap(); } } + // HACK: this test is time sensitive and will fail if runs at the top of the minute - #[test] - async fn test_concurrent_many() { - let storage = Arc::new(setup().await.0); + #[sqlx::test] + async fn test_concurrent_many(pool: PgPool) { + let storage = Arc::new(setup(pool).await.0); for _ in 0..10 { let mut set = tokio::task::JoinSet::new(); @@ -956,73 +961,73 @@ pub(crate) mod test { } } - #[test] - async fn read_key() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn read_key(pool: PgPool) { + let (storage, key) = setup(pool).await; let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap(); assert!(key.is_some()); } - #[test] - async fn read_key_id() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn read_key_id(pool: PgPool) { + let (storage, key) = setup(pool).await; let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap(); assert!(key.is_some()); } - #[test] - async fn read_nonexistent_key() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn read_nonexistent_key(pool: PgPool) { + let (storage, _) = setup(pool).await; let key = storage.read_key(KeySelector::Id(-1)).await.unwrap(); assert!(key.is_none()); } - #[test] - async fn query_key() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn query_key(pool: PgPool) { + let (storage, _) = setup(pool).await; let key = storage.read_key(Domain::All).await.unwrap(); assert!(key.is_some()); } - #[test] - async fn query_nonexistent_key() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn query_nonexistent_key(pool: PgPool) { + let (storage, _) = setup(pool).await; let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap(); assert!(key.is_none()); } - #[test] - async fn query_all() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn query_all(pool: PgPool) { + let (storage, _) = setup(pool).await; let keys = storage.read_keys(Domain::All).await.unwrap(); assert!(keys.len() == 1); } - #[test] - async fn query_by_id() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn query_by_id(pool: PgPool) { + let (storage, _) = setup(pool).await; let key = storage.read_key(KeySelector::Id(1)).await.unwrap(); assert!(key.is_some()); } - #[test] - async fn query_by_key() { - let (storage, key) = setup().await; + #[sqlx::test] + async fn query_by_key(pool: PgPool) { + let (storage, key) = setup(pool).await; let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap(); assert!(key.is_some()); } - #[test] - async fn query_by_set() { - let (storage, _key) = setup().await; + #[sqlx::test] + async fn query_by_set(pool: PgPool) { + let (storage, _key) = setup(pool).await; let key = storage .read_key(KeySelector::OneOf(vec![ Domain::All, @@ -1034,4 +1039,45 @@ pub(crate) mod test { assert!(key.is_some()); } + + #[sqlx::test] + async fn all_selector(pool: PgPool) { + let (storage, key) = setup(pool).await; + + storage + .add_domain_to_key(key.selector(), Domain::Faction { id: 1 }) + .await + .unwrap(); + + let key = storage + .read_key(KeySelector::Has(vec![ + Domain::Faction { id: 1 }, + Domain::All, + ])) + .await + .unwrap(); + + assert!(key.is_some()); + + let key = storage + .read_key(KeySelector::Has(vec![ + Domain::All, + Domain::Faction { id: 1 }, + ])) + .await + .unwrap(); + + assert!(key.is_some()); + + let key = storage + .read_key(KeySelector::Has(vec![ + Domain::All, + Domain::Faction { id: 2 }, + Domain::Faction { id: 1 }, + ])) + .await + .unwrap(); + + assert!(key.is_none()); + } } diff --git a/torn-key-pool/src/send.rs b/torn-key-pool/src/send.rs index d1f95b5..5b528d8 100644 --- a/torn-key-pool/src/send.rs +++ b/torn-key-pool/src/send.rs @@ -7,7 +7,9 @@ use torn_api::{ ApiRequest, ApiResponse, ApiSelection, ResponseError, }; -use crate::{ApiKey, IntoSelector, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; +use crate::{ + ApiKey, IntoSelector, KeyAction, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, PoolOptions, +}; #[async_trait] impl<'client, C, S> RequestExecutor for KeyPoolExecutor<'client, C, S> @@ -22,17 +24,22 @@ where client: &C, mut request: ApiRequest, id: Option, - ) -> Result + ) -> Result where A: ApiSelection, { - request.comment = self.comment.map(ToOwned::to_owned); + request.comment = self.options.comment.clone(); + if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::()) { + let concrete = hook.downcast_ref::>().unwrap(); + + (concrete.body)(&mut request); + } loop { let key = self .storage .acquire_key(self.selector.clone()) .await - .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; + .map_err(KeyPoolError::Storage)?; let url = request.url(key.value(), id.as_deref()); let value = client.request(url).await?; @@ -42,14 +49,37 @@ where .storage .flag_key(key, code) .await - .map_err(Arc::new) .map_err(KeyPoolError::Storage)? { return Err(KeyPoolError::Response(ResponseError::Api { code, reason })); } } Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)), - Ok(res) => return Ok(res), + Ok(res) => { + let res = res.into(); + if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::()) { + let concrete = hook.downcast_ref::>().unwrap(); + + match (concrete.body)(&res) { + Err(KeyAction::Delete) => { + self.storage + .remove_key(key.selector()) + .await + .map_err(KeyPoolError::Storage)?; + continue; + } + Err(KeyAction::RemoveDomain(domain)) => { + self.storage + .remove_domain_from_key(key.selector(), domain) + .await + .map_err(KeyPoolError::Storage)?; + continue; + } + _ => (), + }; + } + return Ok(res); + } }; } } @@ -59,7 +89,7 @@ where client: &C, mut request: ApiRequest, ids: Vec, - ) -> HashMap> + ) -> HashMap> where A: ApiSelection, I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync, @@ -71,15 +101,14 @@ where { Ok(keys) => keys, Err(why) => { - let shared = Arc::new(why); return ids .into_iter() - .map(|i| (i, Err(Self::Error::Storage(shared.clone())))) + .map(|i| (i, Err(Self::Error::Storage(why.clone())))) .collect(); } }; - request.comment = self.comment.map(ToOwned::to_owned); + request.comment = self.options.comment.clone(); let request_ref = &request; let tuples = @@ -105,18 +134,18 @@ where ) } Ok(true) => (), - Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))), + Err(why) => return (id, Err(KeyPoolError::Storage(why))), } } Err(parsing_error) => { return (id, Err(KeyPoolError::Response(parsing_error))) } - Ok(res) => return (id, Ok(res)), + Ok(res) => return (id, Ok(res.into())), }; key = match self.storage.acquire_key(self.selector.clone()).await { Ok(k) => k, - Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), + Err(why) => return (id, Err(Self::Error::Storage(why))), }; } })) @@ -126,6 +155,92 @@ where } } +#[allow(clippy::type_complexity)] +pub struct BeforeHook +where + A: ApiSelection, +{ + body: Box) + Send + Sync + 'static>, +} + +#[allow(clippy::type_complexity)] +pub struct AfterHook +where + A: ApiSelection, + D: crate::KeyDomain, +{ + body: Box Result<(), crate::KeyAction> + Send + Sync + 'static>, +} + +pub struct PoolBuilder +where + C: ApiClient, + S: KeyPoolStorage, +{ + client: C, + storage: S, + options: crate::PoolOptions, +} + +impl PoolBuilder +where + C: ApiClient, + S: KeyPoolStorage, +{ + pub fn new(client: C, storage: S) -> Self { + Self { + client, + storage, + options: Default::default(), + } + } + + pub fn comment(mut self, c: impl ToString) -> Self { + self.options.comment = Some(c.to_string()); + self + } + + pub fn hook_before( + mut self, + hook: impl Fn(&mut ApiRequest) + Send + Sync + 'static, + ) -> Self + where + A: ApiSelection + 'static, + { + self.options.hooks_before.insert( + std::any::TypeId::of::(), + Box::new(BeforeHook { + body: Box::new(hook), + }), + ); + self + } + + pub fn hook_after( + mut self, + hook: impl Fn(&A::Response) -> Result<(), KeyAction> + Send + Sync + 'static, + ) -> Self + where + A: ApiSelection + 'static, + { + self.options.hooks_after.insert( + std::any::TypeId::of::(), + Box::new(AfterHook:: { + body: Box::new(hook), + }), + ); + self + } + + pub fn build(self) -> KeyPool { + KeyPool { + client: self.client, + storage: self.storage, + options: Arc::new(self.options), + } + } +} + #[derive(Clone, Debug)] pub struct KeyPool where @@ -134,7 +249,7 @@ where { client: C, pub storage: S, - comment: Option, + options: Arc, } impl KeyPool @@ -142,14 +257,6 @@ where C: ApiClient, S: KeyPoolStorage + Send + Sync + 'static, { - pub fn new(client: C, storage: S, comment: Option) -> Self { - Self { - client, - storage, - comment, - } - } - pub fn torn_api(&self, selector: I) -> ApiProvider> where I: IntoSelector, @@ -159,7 +266,7 @@ where KeyPoolExecutor::new( &self.storage, selector.into_selector(), - self.comment.as_deref(), + self.options.clone(), ), ) } @@ -178,7 +285,7 @@ pub trait WithStorage { { ApiProvider::new( self, - KeyPoolExecutor::new(storage, selector.into_selector(), None), + KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()), ) } } @@ -188,27 +295,28 @@ impl WithStorage for reqwest::Client {} #[cfg(all(test, feature = "postgres", feature = "reqwest"))] mod test { - use tokio::test; + use sqlx::PgPool; use super::*; - use crate::postgres::test::{setup, Domain}; + use crate::{ + postgres::test::{setup, Domain}, + KeySelector, + }; - #[test] - async fn test_pool_request() { - let (storage, _) = setup().await; - let pool = KeyPool::new( - reqwest::Client::default(), - storage, - Some("api.rs".to_owned()), - ); + #[sqlx::test] + async fn test_pool_request(pool: PgPool) { + let (storage, _) = setup(pool).await; + let pool = PoolBuilder::new(reqwest::Client::default(), storage) + .comment("api.rs") + .build(); let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap(); _ = response.profile().unwrap(); } - #[test] - async fn test_with_storage_request() { - let (storage, _) = setup().await; + #[sqlx::test] + async fn test_with_storage_request(pool: PgPool) { + let (storage, _) = setup(pool).await; let response = reqwest::Client::new() .with_storage(&storage, Domain::All) @@ -217,4 +325,36 @@ mod test { .unwrap(); _ = response.profile().unwrap(); } + + #[sqlx::test] + async fn before_hook(pool: PgPool) { + let (storage, _) = setup(pool).await; + + let pool = PoolBuilder::new(reqwest::Client::default(), storage) + .hook_before::(|req| { + req.selections.push("crimes"); + }) + .build(); + + let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap(); + _ = response.crimes().unwrap(); + } + + #[sqlx::test] + async fn after_hook(pool: PgPool) { + let (storage, _) = setup(pool).await; + + let pool = PoolBuilder::new(reqwest::Client::default(), storage) + .hook_after::(|_res| Err(KeyAction::Delete)) + .build(); + + let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap(); + assert!(key.is_some()); + + let response = pool.torn_api(Domain::All).user(|b| b).await; + assert!(matches!(response, Err(KeyPoolError::Storage(_)))); + + let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap(); + assert!(key.is_none()); + } }