major refactoring

This commit is contained in:
TotallyNot 2024-04-04 15:59:10 +02:00
parent e8a8b5976b
commit 8aaf61efb1
10 changed files with 404 additions and 222 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-api-macros" name = "torn-api-macros"
version = "0.2.0" version = "0.3.0"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license = "MIT" license = "MIT"

View file

@ -147,15 +147,15 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream {
#(#accessors)* #(#accessors)*
} }
impl crate::ApiCategoryResponse for Response { impl From<crate::ApiResponse> for Response {
type Selection = #name; fn from(value: crate::ApiResponse) -> Self {
Self(value)
fn from_response(response: crate::ApiResponse) -> Self {
Self(response)
} }
} }
impl crate::ApiSelection for #name { impl crate::ApiSelection for #name {
type Response = Response;
fn raw_value(self) -> &'static str { fn raw_value(self) -> &'static str {
match self { match self {
#(#raw_values,)* #(#raw_values,)*

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-api" name = "torn-api"
version = "0.6.7" version = "0.7.0"
edition = "2021" edition = "2021"
rust-version = "1.75.0" rust-version = "1.75.0"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
@ -39,7 +39,7 @@ reqwest = { version = "0.11", default-features = false, features = [ "json" ], o
awc = { version = "3", default-features = false, optional = true } awc = { version = "3", default-features = false, optional = true }
rust_decimal = { version = "1", default-features = false, optional = true, features = [ "serde" ] } rust_decimal = { version = "1", default-features = false, optional = true, features = [ "serde" ] }
torn-api-macros = { path = "../torn-api-macros", version = "0.2" } torn-api-macros = { path = "../torn-api-macros", version = "0.3" }
[dev-dependencies] [dev-dependencies]
actix-rt = { version = "2.7.0" } actix-rt = { version = "2.7.0" }

View file

@ -111,18 +111,14 @@ impl ApiResponse {
} }
} }
pub trait ApiSelection: Send + Sync { pub trait ApiSelection: Send + Sync + 'static {
type Response: From<ApiResponse> + Send + Sync;
fn raw_value(self) -> &'static str; fn raw_value(self) -> &'static str;
fn category() -> &'static str; fn category() -> &'static str;
} }
pub trait ApiCategoryResponse: Send + Sync {
type Selection: ApiSelection;
fn from_response(response: ApiResponse) -> Self;
}
pub struct DirectExecutor<C> { pub struct DirectExecutor<C> {
key: String, key: String,
_marker: std::marker::PhantomData<C>, _marker: std::marker::PhantomData<C>,

View file

@ -2,9 +2,7 @@ use std::collections::HashMap;
use async_trait::async_trait; use async_trait::async_trait;
use crate::{ use crate::{ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor};
ApiCategoryResponse, ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor,
};
pub struct ApiProvider<'a, C, E> pub struct ApiProvider<'a, C, E>
where where
@ -39,7 +37,6 @@ where
self.executor self.executor
.execute(self.client, builder.request, builder.id) .execute(self.client, builder.request, builder.id)
.await .await
.map(crate::user::Response::from_response)
} }
#[cfg(feature = "user")] #[cfg(feature = "user")]
@ -61,9 +58,6 @@ where
self.executor self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids)) .execute_many(self.client, builder.request, Vec::from_iter(ids))
.await .await
.into_iter()
.map(|(k, v)| (k, v.map(crate::user::Response::from_response)))
.collect()
} }
#[cfg(feature = "faction")] #[cfg(feature = "faction")]
@ -79,7 +73,6 @@ where
self.executor self.executor
.execute(self.client, builder.request, builder.id) .execute(self.client, builder.request, builder.id)
.await .await
.map(crate::faction::Response::from_response)
} }
#[cfg(feature = "faction")] #[cfg(feature = "faction")]
@ -101,9 +94,6 @@ where
self.executor self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids)) .execute_many(self.client, builder.request, Vec::from_iter(ids))
.await .await
.into_iter()
.map(|(k, v)| (k, v.map(crate::faction::Response::from_response)))
.collect()
} }
#[cfg(feature = "market")] #[cfg(feature = "market")]
@ -119,7 +109,6 @@ where
self.executor self.executor
.execute(self.client, builder.request, builder.id) .execute(self.client, builder.request, builder.id)
.await .await
.map(crate::market::Response::from_response)
} }
#[cfg(feature = "market")] #[cfg(feature = "market")]
@ -141,9 +130,6 @@ where
self.executor self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids)) .execute_many(self.client, builder.request, Vec::from_iter(ids))
.await .await
.into_iter()
.map(|(k, v)| (k, v.map(crate::market::Response::from_response)))
.collect()
} }
#[cfg(feature = "torn")] #[cfg(feature = "torn")]
@ -159,7 +145,6 @@ where
self.executor self.executor
.execute(self.client, builder.request, builder.id) .execute(self.client, builder.request, builder.id)
.await .await
.map(crate::torn::Response::from_response)
} }
#[cfg(feature = "torn")] #[cfg(feature = "torn")]
@ -181,9 +166,6 @@ where
self.executor self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids)) .execute_many(self.client, builder.request, Vec::from_iter(ids))
.await .await
.into_iter()
.map(|(k, v)| (k, v.map(crate::torn::Response::from_response)))
.collect()
} }
#[cfg(feature = "key")] #[cfg(feature = "key")]
@ -199,7 +181,6 @@ where
self.executor self.executor
.execute(self.client, builder.request, builder.id) .execute(self.client, builder.request, builder.id)
.await .await
.map(crate::key::Response::from_response)
} }
} }
@ -215,7 +196,7 @@ where
client: &C, client: &C,
request: ApiRequest<A>, request: ApiRequest<A>,
id: Option<String>, id: Option<String>,
) -> Result<ApiResponse, Self::Error> ) -> Result<A::Response, Self::Error>
where where
A: ApiSelection; A: ApiSelection;
@ -224,7 +205,7 @@ where
client: &C, client: &C,
request: ApiRequest<A>, request: ApiRequest<A>,
ids: Vec<I>, ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>> ) -> HashMap<I, Result<A::Response, Self::Error>>
where where
A: ApiSelection, A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq; I: ToString + std::hash::Hash + std::cmp::Eq;
@ -242,7 +223,7 @@ where
client: &C, client: &C,
request: ApiRequest<A>, request: ApiRequest<A>,
id: Option<String>, id: Option<String>,
) -> Result<ApiResponse, Self::Error> ) -> Result<A::Response, Self::Error>
where where
A: ApiSelection, A: ApiSelection,
{ {
@ -250,7 +231,7 @@ where
let value = client.request(url).await.map_err(ApiClientError::Client)?; let value = client.request(url).await.map_err(ApiClientError::Client)?;
Ok(ApiResponse::from_value(value)?) Ok(ApiResponse::from_value(value)?.into())
} }
async fn execute_many<A, I>( async fn execute_many<A, I>(
@ -258,7 +239,7 @@ where
client: &C, client: &C,
request: ApiRequest<A>, request: ApiRequest<A>,
ids: Vec<I>, ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>> ) -> HashMap<I, Result<A::Response, Self::Error>>
where where
A: ApiSelection, A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq, I: ToString + std::hash::Hash + std::cmp::Eq,
@ -272,7 +253,11 @@ where
( (
i, i,
value.and_then(|v| ApiResponse::from_value(v).map_err(Into::into)), value.and_then(|v| {
ApiResponse::from_value(v)
.map(Into::into)
.map_err(Into::into)
}),
) )
})) }))
.await; .await;

View file

@ -2,9 +2,7 @@ use std::collections::HashMap;
use async_trait::async_trait; use async_trait::async_trait;
use crate::{ use crate::{ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor};
ApiCategoryResponse, ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor,
};
pub struct ApiProvider<'a, C, E> pub struct ApiProvider<'a, C, E>
where where
@ -37,7 +35,6 @@ where
self.executor self.executor
.execute(self.client, builder.request, builder.id) .execute(self.client, builder.request, builder.id)
.await .await
.map(crate::user::Response::from_response)
} }
#[cfg(feature = "user")] #[cfg(feature = "user")]
@ -59,9 +56,6 @@ where
self.executor self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids)) .execute_many(self.client, builder.request, Vec::from_iter(ids))
.await .await
.into_iter()
.map(|(k, v)| (k, v.map(crate::user::Response::from_response)))
.collect()
} }
#[cfg(feature = "faction")] #[cfg(feature = "faction")]
@ -77,7 +71,6 @@ where
self.executor self.executor
.execute(self.client, builder.request, builder.id) .execute(self.client, builder.request, builder.id)
.await .await
.map(crate::faction::Response::from_response)
} }
#[cfg(feature = "faction")] #[cfg(feature = "faction")]
@ -99,9 +92,6 @@ where
self.executor self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids)) .execute_many(self.client, builder.request, Vec::from_iter(ids))
.await .await
.into_iter()
.map(|(k, v)| (k, v.map(crate::faction::Response::from_response)))
.collect()
} }
#[cfg(feature = "market")] #[cfg(feature = "market")]
@ -117,7 +107,6 @@ where
self.executor self.executor
.execute(self.client, builder.request, builder.id) .execute(self.client, builder.request, builder.id)
.await .await
.map(crate::market::Response::from_response)
} }
#[cfg(feature = "market")] #[cfg(feature = "market")]
@ -139,9 +128,6 @@ where
self.executor self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids)) .execute_many(self.client, builder.request, Vec::from_iter(ids))
.await .await
.into_iter()
.map(|(k, v)| (k, v.map(crate::market::Response::from_response)))
.collect()
} }
#[cfg(feature = "torn")] #[cfg(feature = "torn")]
@ -157,7 +143,6 @@ where
self.executor self.executor
.execute(self.client, builder.request, builder.id) .execute(self.client, builder.request, builder.id)
.await .await
.map(crate::torn::Response::from_response)
} }
#[cfg(feature = "torn")] #[cfg(feature = "torn")]
@ -179,9 +164,6 @@ where
self.executor self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids)) .execute_many(self.client, builder.request, Vec::from_iter(ids))
.await .await
.into_iter()
.map(|(k, v)| (k, v.map(crate::torn::Response::from_response)))
.collect()
} }
#[cfg(feature = "key")] #[cfg(feature = "key")]
@ -197,7 +179,6 @@ where
self.executor self.executor
.execute(self.client, builder.request, builder.id) .execute(self.client, builder.request, builder.id)
.await .await
.map(crate::key::Response::from_response)
} }
} }
@ -213,7 +194,7 @@ where
client: &C, client: &C,
request: ApiRequest<A>, request: ApiRequest<A>,
id: Option<String>, id: Option<String>,
) -> Result<ApiResponse, Self::Error> ) -> Result<A::Response, Self::Error>
where where
A: ApiSelection; A: ApiSelection;
@ -222,7 +203,7 @@ where
client: &C, client: &C,
request: ApiRequest<A>, request: ApiRequest<A>,
ids: Vec<I>, ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>> ) -> HashMap<I, Result<A::Response, Self::Error>>
where where
A: ApiSelection, A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync; I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync;
@ -240,7 +221,7 @@ where
client: &C, client: &C,
request: ApiRequest<A>, request: ApiRequest<A>,
id: Option<String>, id: Option<String>,
) -> Result<ApiResponse, Self::Error> ) -> Result<A::Response, Self::Error>
where where
A: ApiSelection, A: ApiSelection,
{ {
@ -248,7 +229,7 @@ where
let value = client.request(url).await.map_err(ApiClientError::Client)?; let value = client.request(url).await.map_err(ApiClientError::Client)?;
Ok(ApiResponse::from_value(value)?) Ok(ApiResponse::from_value(value)?.into())
} }
async fn execute_many<A, I>( async fn execute_many<A, I>(
@ -256,7 +237,7 @@ where
client: &C, client: &C,
request: ApiRequest<A>, request: ApiRequest<A>,
ids: Vec<I>, ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>> ) -> HashMap<I, Result<A::Response, Self::Error>>
where where
A: ApiSelection, A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync, I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
@ -270,7 +251,11 @@ where
( (
i, i,
value.and_then(|v| ApiResponse::from_value(v).map_err(Into::into)), value.and_then(|v| {
ApiResponse::from_value(v)
.map(Into::into)
.map_err(Into::into)
}),
) )
})) }))
.await; .await;

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.7.0" version = "0.8.0"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license = "MIT" license = "MIT"
@ -17,7 +17,7 @@ tokio-runtime = [ "dep:tokio", "dep:rand" ]
actix-runtime = [ "dep:actix-rt", "dep:rand" ] actix-runtime = [ "dep:actix-rt", "dep:rand" ]
[dependencies] [dependencies]
torn-api = { path = "../torn-api", default-features = false, version = "0.6" } torn-api = { path = "../torn-api", default-features = false, version = "0.7" }
async-trait = "0.1" async-trait = "0.1"
thiserror = "1" thiserror = "1"

View file

@ -3,7 +3,7 @@
#[cfg(feature = "postgres")] #[cfg(feature = "postgres")]
pub mod postgres; pub mod postgres;
pub mod local; // pub mod local;
pub mod send; pub mod send;
use std::sync::Arc; use std::sync::Arc;
@ -16,11 +16,11 @@ use torn_api::ResponseError;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum KeyPoolError<S, C> pub enum KeyPoolError<S, C>
where where
S: std::error::Error, S: std::error::Error + Clone,
C: std::error::Error, C: std::error::Error,
{ {
#[error("Key pool storage driver error: {0:?}")] #[error("Key pool storage driver error: {0:?}")]
Storage(#[source] Arc<S>), Storage(#[source] S),
#[error(transparent)] #[error(transparent)]
Client(#[from] C), Client(#[from] C),
@ -31,7 +31,7 @@ where
impl<S, C> KeyPoolError<S, C> impl<S, C> KeyPoolError<S, C>
where where
S: std::error::Error, S: std::error::Error + Clone,
C: std::error::Error, C: std::error::Error,
{ {
#[inline(always)] #[inline(always)]
@ -49,9 +49,16 @@ pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone {
fn value(&self) -> &str; fn value(&self) -> &str;
fn id(&self) -> Self::IdType; fn id(&self) -> Self::IdType;
fn selector<D>(&self) -> KeySelector<Self, D>
where
D: KeyDomain,
{
KeySelector::Id(self.id())
}
} }
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync { pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
fn fallback(&self) -> Option<Self> { fn fallback(&self) -> Option<Self> {
None None
} }
@ -66,7 +73,7 @@ where
Key(String), Key(String),
Id(K::IdType), Id(K::IdType),
UserId(i32), UserId(i32),
Has(D), Has(Vec<D>),
OneOf(Vec<D>), OneOf(Vec<D>),
} }
@ -78,7 +85,14 @@ where
pub(crate) fn fallback(&self) -> Option<Self> { pub(crate) fn fallback(&self) -> Option<Self> {
match self { match self {
Self::Key(_) | Self::UserId(_) | Self::Id(_) => None, Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
Self::Has(domain) => domain.fallback().map(Self::Has), Self::Has(domains) => {
let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
if fallbacks.is_empty() {
None
} else {
Some(Self::Has(fallbacks))
}
}
Self::OneOf(domains) => { Self::OneOf(domains) => {
let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect(); let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
if fallbacks.is_empty() { if fallbacks.is_empty() {
@ -105,7 +119,7 @@ where
D: KeyDomain, D: KeyDomain,
{ {
fn into_selector(self) -> KeySelector<K, D> { fn into_selector(self) -> KeySelector<K, D> {
KeySelector::Has(self) KeySelector::Has(vec![self])
} }
} }
@ -119,11 +133,20 @@ where
} }
} }
pub enum KeyAction<D>
where
D: KeyDomain,
{
Delete,
RemoveDomain(D),
Timeout(chrono::Duration),
}
#[async_trait] #[async_trait]
pub trait KeyPoolStorage { pub trait KeyPoolStorage {
type Key: ApiKey; type Key: ApiKey;
type Domain: KeyDomain; type Domain: KeyDomain;
type Error: std::error::Error + Sync + Send; type Error: std::error::Error + Sync + Send + Clone;
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error> async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
where where
@ -183,13 +206,20 @@ pub trait KeyPoolStorage {
S: IntoSelector<Self::Key, Self::Domain>; S: IntoSelector<Self::Key, Self::Domain>;
} }
#[derive(Debug, Default)]
struct PoolOptions {
comment: Option<String>,
hooks_before: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
hooks_after: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct KeyPoolExecutor<'a, C, S> pub struct KeyPoolExecutor<'a, C, S>
where where
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
storage: &'a S, storage: &'a S,
comment: Option<&'a str>, options: Arc<PoolOptions>,
selector: KeySelector<S::Key, S::Domain>, selector: KeySelector<S::Key, S::Domain>,
_marker: std::marker::PhantomData<C>, _marker: std::marker::PhantomData<C>,
} }
@ -198,15 +228,15 @@ impl<'a, C, S> KeyPoolExecutor<'a, C, S>
where where
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
pub fn new( fn new(
storage: &'a S, storage: &'a S,
selector: KeySelector<S::Key, S::Domain>, selector: KeySelector<S::Key, S::Domain>,
comment: Option<&'a str>, options: Arc<PoolOptions>,
) -> Self { ) -> Self {
Self { Self {
storage, storage,
selector, selector,
comment, options,
_marker: std::marker::PhantomData, _marker: std::marker::PhantomData,
} }
} }

View file

@ -1,3 +1,5 @@
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use indoc::indoc; use indoc::indoc;
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder}; use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
@ -15,13 +17,13 @@ impl<T> PgKeyDomain for T where
{ {
} }
#[derive(Debug, Error)] #[derive(Debug, Error, Clone)]
pub enum PgStorageError<D> pub enum PgStorageError<D>
where where
D: PgKeyDomain, D: PgKeyDomain,
{ {
#[error(transparent)] #[error(transparent)]
Pg(#[from] sqlx::Error), Pg(Arc<sqlx::Error>),
#[error("No key avalaible for domain {0:?}")] #[error("No key avalaible for domain {0:?}")]
Unavailable(KeySelector<PgKey<D>, D>), Unavailable(KeySelector<PgKey<D>, D>),
@ -30,6 +32,15 @@ where
KeyNotFound(KeySelector<PgKey<D>, D>), KeyNotFound(KeySelector<PgKey<D>, D>),
} }
impl<D> From<sqlx::Error> for PgStorageError<D>
where
D: PgKeyDomain,
{
fn from(value: sqlx::Error) -> Self {
Self::Pg(Arc::new(value))
}
}
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow)]
pub struct PgKey<D> pub struct PgKey<D>
where where
@ -53,9 +64,9 @@ fn build_predicate<'b, D>(
KeySelector::Id(id) => builder.push("id=").push_bind(id), KeySelector::Id(id) => builder.push("id=").push_bind(id),
KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id), KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
KeySelector::Key(key) => builder.push("key=").push_bind(key), KeySelector::Key(key) => builder.push("key=").push_bind(key),
KeySelector::Has(domain) => builder KeySelector::Has(domains) => builder
.push("domains @> ") .push("domains @> ")
.push_bind(sqlx::types::Json(vec![domain])), .push_bind(sqlx::types::Json(domains)),
KeySelector::OneOf(domains) => { KeySelector::OneOf(domains) => {
if domains.is_empty() { if domains.is_empty() {
builder.push("false"); builder.push("false");
@ -607,15 +618,12 @@ where
#[cfg(test)] #[cfg(test)]
pub(crate) mod test { pub(crate) mod test {
use std::sync::{Arc, Once}; use std::sync::Arc;
use sqlx::Row; use sqlx::Row;
use tokio::test;
use super::*; use super::*;
static INIT: Once = Once::new();
#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum Domain { pub(crate) enum Domain {
@ -634,15 +642,7 @@ pub(crate) mod test {
} }
} }
pub(crate) async fn setup() -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) { pub(crate) async fn setup(pool: PgPool) -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
INIT.call_once(|| {
dotenv::dotenv().ok();
});
let pool = PgPool::connect(&std::env::var("DATABASE_URL").unwrap())
.await
.unwrap();
sqlx::query("DROP TABLE IF EXISTS api_keys") sqlx::query("DROP TABLE IF EXISTS api_keys")
.execute(&pool) .execute(&pool)
.await .await
@ -659,18 +659,18 @@ pub(crate) mod test {
(storage, key) (storage, key)
} }
#[test] #[sqlx::test]
async fn test_initialise() { async fn test_initialise(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
if let Err(e) = storage.initialise().await { if let Err(e) = storage.initialise().await {
panic!("Initialising key storage failed: {:?}", e); panic!("Initialising key storage failed: {:?}", e);
} }
} }
#[test] #[sqlx::test]
async fn test_store_duplicate_key() { async fn test_store_duplicate_key(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
let key = storage let key = storage
.store_key(1, key.key, vec![Domain::User { id: 1 }]) .store_key(1, key.key, vec![Domain::User { id: 1 }])
.await .await
@ -679,9 +679,9 @@ pub(crate) mod test {
assert_eq!(key.domains.0.len(), 2); assert_eq!(key.domains.0.len(), 2);
} }
#[test] #[sqlx::test]
async fn test_store_duplicate_key_duplicate_domain() { async fn test_store_duplicate_key_duplicate_domain(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
let key = storage let key = storage
.store_key(1, key.key, vec![Domain::All]) .store_key(1, key.key, vec![Domain::All])
.await .await
@ -690,9 +690,9 @@ pub(crate) mod test {
assert_eq!(key.domains.0.len(), 1); assert_eq!(key.domains.0.len(), 1);
} }
#[test] #[sqlx::test]
async fn test_add_domain() { async fn test_add_domain(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
let key = storage let key = storage
.add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 }) .add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
.await .await
@ -701,9 +701,9 @@ pub(crate) mod test {
assert!(key.domains.0.contains(&Domain::User { id: 12345 })); assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
} }
#[test] #[sqlx::test]
async fn test_add_domain_id() { async fn test_add_domain_id(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
let key = storage let key = storage
.add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 }) .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
.await .await
@ -712,9 +712,9 @@ pub(crate) mod test {
assert!(key.domains.0.contains(&Domain::User { id: 12345 })); assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
} }
#[test] #[sqlx::test]
async fn test_add_duplicate_domain() { async fn test_add_duplicate_domain(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
let key = storage let key = storage
.add_domain_to_key(KeySelector::Key(key.key), Domain::All) .add_domain_to_key(KeySelector::Key(key.key), Domain::All)
.await .await
@ -729,9 +729,9 @@ pub(crate) mod test {
); );
} }
#[test] #[sqlx::test]
async fn test_remove_domain() { async fn test_remove_domain(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
storage storage
.add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 }) .add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
.await .await
@ -744,9 +744,9 @@ pub(crate) mod test {
assert_eq!(key.domains.0, vec![Domain::All]); assert_eq!(key.domains.0, vec![Domain::All]);
} }
#[test] #[sqlx::test]
async fn test_remove_domain_id() { async fn test_remove_domain_id(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
storage storage
.add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 }) .add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
.await .await
@ -759,9 +759,9 @@ pub(crate) mod test {
assert_eq!(key.domains.0, vec![Domain::All]); assert_eq!(key.domains.0, vec![Domain::All]);
} }
#[test] #[sqlx::test]
async fn test_remove_last_domain() { async fn test_remove_last_domain(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
let key = storage let key = storage
.remove_domain_from_key(KeySelector::Key(key.key), Domain::All) .remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
.await .await
@ -770,9 +770,9 @@ pub(crate) mod test {
assert!(key.domains.0.is_empty()); assert!(key.domains.0.is_empty());
} }
#[test] #[sqlx::test]
async fn test_store_key() { async fn test_store_key(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
let key = storage let key = storage
.store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![]) .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
.await .await
@ -780,26 +780,26 @@ pub(crate) mod test {
assert_eq!(key.value(), "ABCDABCDABCDABCD"); assert_eq!(key.value(), "ABCDABCDABCDABCD");
} }
#[test] #[sqlx::test]
async fn test_read_user_keys() { async fn test_read_user_keys(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap(); let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
assert_eq!(keys.len(), 1); assert_eq!(keys.len(), 1);
} }
#[test] #[sqlx::test]
async fn acquire_one() { async fn acquire_one(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
if let Err(e) = storage.acquire_key(Domain::All).await { if let Err(e) = storage.acquire_key(Domain::All).await {
panic!("Acquiring key failed: {:?}", e); panic!("Acquiring key failed: {:?}", e);
} }
} }
#[test] #[sqlx::test]
async fn uses_spread() { async fn uses_spread(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
storage storage
.store_key(1, "ABC".to_owned(), vec![Domain::All]) .store_key(1, "ABC".to_owned(), vec![Domain::All])
.await .await
@ -816,33 +816,37 @@ pub(crate) mod test {
} }
} }
#[test] #[sqlx::test]
async fn test_flag_key_one() { async fn test_flag_key_one(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
assert!(storage.flag_key(key, 2).await.unwrap()); assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_key(Domain::All).await.unwrap_err() { match storage.acquire_key(Domain::All).await.unwrap_err() {
PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))), PgStorageError::Unavailable(KeySelector::Has(domains)) => {
assert_eq!(domains, vec![Domain::All])
}
why => panic!("Expected domain unavailable error but found '{why}'"), why => panic!("Expected domain unavailable error but found '{why}'"),
} }
} }
#[test] #[sqlx::test]
async fn test_flag_key_many() { async fn test_flag_key_many(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
assert!(storage.flag_key(key, 2).await.unwrap()); assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() { match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))), PgStorageError::Unavailable(KeySelector::Has(domains)) => {
assert_eq!(domains, vec![Domain::All])
}
why => panic!("Expected domain unavailable error but found '{why}'"), why => panic!("Expected domain unavailable error but found '{why}'"),
} }
} }
#[test] #[sqlx::test]
async fn acquire_many() { async fn acquire_many(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
match storage.acquire_many_keys(Domain::All, 30).await { match storage.acquire_many_keys(Domain::All, 30).await {
Err(e) => panic!("Acquiring key failed: {:?}", e), Err(e) => panic!("Acquiring key failed: {:?}", e),
@ -851,9 +855,9 @@ pub(crate) mod test {
} }
// HACK: this test is time sensitive and will fail if runs at the top of the minute // HACK: this test is time sensitive and will fail if runs at the top of the minute
#[test] #[sqlx::test]
async fn test_concurrent() { async fn test_concurrent(pool: PgPool) {
let storage = Arc::new(setup().await.0); let storage = Arc::new(setup(pool).await.0);
for _ in 0..10 { for _ in 0..10 {
let mut set = tokio::task::JoinSet::new(); let mut set = tokio::task::JoinSet::new();
@ -884,9 +888,9 @@ pub(crate) mod test {
} }
} }
#[test] #[sqlx::test]
async fn test_concurrent_spread() { async fn test_concurrent_spread(pool: PgPool) {
let storage = Arc::new(setup().await.0); let storage = Arc::new(setup(pool).await.0);
for i in 0..24 { for i in 0..24 {
storage storage
@ -923,10 +927,11 @@ pub(crate) mod test {
.unwrap(); .unwrap();
} }
} }
// HACK: this test is time sensitive and will fail if runs at the top of the minute // HACK: this test is time sensitive and will fail if runs at the top of the minute
#[test] #[sqlx::test]
async fn test_concurrent_many() { async fn test_concurrent_many(pool: PgPool) {
let storage = Arc::new(setup().await.0); let storage = Arc::new(setup(pool).await.0);
for _ in 0..10 { for _ in 0..10 {
let mut set = tokio::task::JoinSet::new(); let mut set = tokio::task::JoinSet::new();
@ -956,73 +961,73 @@ pub(crate) mod test {
} }
} }
#[test] #[sqlx::test]
async fn read_key() { async fn read_key(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap(); let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
assert!(key.is_some()); assert!(key.is_some());
} }
#[test] #[sqlx::test]
async fn read_key_id() { async fn read_key_id(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap(); let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap();
assert!(key.is_some()); assert!(key.is_some());
} }
#[test] #[sqlx::test]
async fn read_nonexistent_key() { async fn read_nonexistent_key(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
let key = storage.read_key(KeySelector::Id(-1)).await.unwrap(); let key = storage.read_key(KeySelector::Id(-1)).await.unwrap();
assert!(key.is_none()); assert!(key.is_none());
} }
#[test] #[sqlx::test]
async fn query_key() { async fn query_key(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
let key = storage.read_key(Domain::All).await.unwrap(); let key = storage.read_key(Domain::All).await.unwrap();
assert!(key.is_some()); assert!(key.is_some());
} }
#[test] #[sqlx::test]
async fn query_nonexistent_key() { async fn query_nonexistent_key(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap(); let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
assert!(key.is_none()); assert!(key.is_none());
} }
#[test] #[sqlx::test]
async fn query_all() { async fn query_all(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
let keys = storage.read_keys(Domain::All).await.unwrap(); let keys = storage.read_keys(Domain::All).await.unwrap();
assert!(keys.len() == 1); assert!(keys.len() == 1);
} }
#[test] #[sqlx::test]
async fn query_by_id() { async fn query_by_id(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
let key = storage.read_key(KeySelector::Id(1)).await.unwrap(); let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_some()); assert!(key.is_some());
} }
#[test] #[sqlx::test]
async fn query_by_key() { async fn query_by_key(pool: PgPool) {
let (storage, key) = setup().await; let (storage, key) = setup(pool).await;
let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap(); let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
assert!(key.is_some()); assert!(key.is_some());
} }
#[test] #[sqlx::test]
async fn query_by_set() { async fn query_by_set(pool: PgPool) {
let (storage, _key) = setup().await; let (storage, _key) = setup(pool).await;
let key = storage let key = storage
.read_key(KeySelector::OneOf(vec![ .read_key(KeySelector::OneOf(vec![
Domain::All, Domain::All,
@ -1034,4 +1039,45 @@ pub(crate) mod test {
assert!(key.is_some()); assert!(key.is_some());
} }
#[sqlx::test]
async fn all_selector(pool: PgPool) {
let (storage, key) = setup(pool).await;
storage
.add_domain_to_key(key.selector(), Domain::Faction { id: 1 })
.await
.unwrap();
let key = storage
.read_key(KeySelector::Has(vec![
Domain::Faction { id: 1 },
Domain::All,
]))
.await
.unwrap();
assert!(key.is_some());
let key = storage
.read_key(KeySelector::Has(vec![
Domain::All,
Domain::Faction { id: 1 },
]))
.await
.unwrap();
assert!(key.is_some());
let key = storage
.read_key(KeySelector::Has(vec![
Domain::All,
Domain::Faction { id: 2 },
Domain::Faction { id: 1 },
]))
.await
.unwrap();
assert!(key.is_none());
}
} }

View file

@ -7,7 +7,9 @@ use torn_api::{
ApiRequest, ApiResponse, ApiSelection, ResponseError, ApiRequest, ApiResponse, ApiSelection, ResponseError,
}; };
use crate::{ApiKey, IntoSelector, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; use crate::{
ApiKey, IntoSelector, KeyAction, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, PoolOptions,
};
#[async_trait] #[async_trait]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S> impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
@ -22,17 +24,22 @@ where
client: &C, client: &C,
mut request: ApiRequest<A>, mut request: ApiRequest<A>,
id: Option<String>, id: Option<String>,
) -> Result<ApiResponse, Self::Error> ) -> Result<A::Response, Self::Error>
where where
A: ApiSelection, A: ApiSelection,
{ {
request.comment = self.comment.map(ToOwned::to_owned); request.comment = self.options.comment.clone();
if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) {
let concrete = hook.downcast_ref::<BeforeHook<A>>().unwrap();
(concrete.body)(&mut request);
}
loop { loop {
let key = self let key = self
.storage .storage
.acquire_key(self.selector.clone()) .acquire_key(self.selector.clone())
.await .await
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; .map_err(KeyPoolError::Storage)?;
let url = request.url(key.value(), id.as_deref()); let url = request.url(key.value(), id.as_deref());
let value = client.request(url).await?; let value = client.request(url).await?;
@ -42,14 +49,37 @@ where
.storage .storage
.flag_key(key, code) .flag_key(key, code)
.await .await
.map_err(Arc::new)
.map_err(KeyPoolError::Storage)? .map_err(KeyPoolError::Storage)?
{ {
return Err(KeyPoolError::Response(ResponseError::Api { code, reason })); return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
} }
} }
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)), Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(res), Ok(res) => {
let res = res.into();
if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) {
let concrete = hook.downcast_ref::<AfterHook<A, S::Domain>>().unwrap();
match (concrete.body)(&res) {
Err(KeyAction::Delete) => {
self.storage
.remove_key(key.selector())
.await
.map_err(KeyPoolError::Storage)?;
continue;
}
Err(KeyAction::RemoveDomain(domain)) => {
self.storage
.remove_domain_from_key(key.selector(), domain)
.await
.map_err(KeyPoolError::Storage)?;
continue;
}
_ => (),
};
}
return Ok(res);
}
}; };
} }
} }
@ -59,7 +89,7 @@ where
client: &C, client: &C,
mut request: ApiRequest<A>, mut request: ApiRequest<A>,
ids: Vec<I>, ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>> ) -> HashMap<I, Result<A::Response, Self::Error>>
where where
A: ApiSelection, A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync, I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
@ -71,15 +101,14 @@ where
{ {
Ok(keys) => keys, Ok(keys) => keys,
Err(why) => { Err(why) => {
let shared = Arc::new(why);
return ids return ids
.into_iter() .into_iter()
.map(|i| (i, Err(Self::Error::Storage(shared.clone())))) .map(|i| (i, Err(Self::Error::Storage(why.clone()))))
.collect(); .collect();
} }
}; };
request.comment = self.comment.map(ToOwned::to_owned); request.comment = self.options.comment.clone();
let request_ref = &request; let request_ref = &request;
let tuples = let tuples =
@ -105,18 +134,18 @@ where
) )
} }
Ok(true) => (), Ok(true) => (),
Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))), Err(why) => return (id, Err(KeyPoolError::Storage(why))),
} }
} }
Err(parsing_error) => { Err(parsing_error) => {
return (id, Err(KeyPoolError::Response(parsing_error))) return (id, Err(KeyPoolError::Response(parsing_error)))
} }
Ok(res) => return (id, Ok(res)), Ok(res) => return (id, Ok(res.into())),
}; };
key = match self.storage.acquire_key(self.selector.clone()).await { key = match self.storage.acquire_key(self.selector.clone()).await {
Ok(k) => k, Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), Err(why) => return (id, Err(Self::Error::Storage(why))),
}; };
} }
})) }))
@ -126,6 +155,92 @@ where
} }
} }
#[allow(clippy::type_complexity)]
pub struct BeforeHook<A>
where
A: ApiSelection,
{
body: Box<dyn Fn(&mut ApiRequest<A>) + Send + Sync + 'static>,
}
#[allow(clippy::type_complexity)]
pub struct AfterHook<A, D>
where
A: ApiSelection,
D: crate::KeyDomain,
{
body: Box<dyn Fn(&A::Response) -> Result<(), crate::KeyAction<D>> + Send + Sync + 'static>,
}
pub struct PoolBuilder<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: C,
storage: S,
options: crate::PoolOptions,
}
impl<C, S> PoolBuilder<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
pub fn new(client: C, storage: S) -> Self {
Self {
client,
storage,
options: Default::default(),
}
}
pub fn comment(mut self, c: impl ToString) -> Self {
self.options.comment = Some(c.to_string());
self
}
pub fn hook_before<A>(
mut self,
hook: impl Fn(&mut ApiRequest<A>) + Send + Sync + 'static,
) -> Self
where
A: ApiSelection + 'static,
{
self.options.hooks_before.insert(
std::any::TypeId::of::<A>(),
Box::new(BeforeHook {
body: Box::new(hook),
}),
);
self
}
pub fn hook_after<A>(
mut self,
hook: impl Fn(&A::Response) -> Result<(), KeyAction<S::Domain>> + Send + Sync + 'static,
) -> Self
where
A: ApiSelection + 'static,
{
self.options.hooks_after.insert(
std::any::TypeId::of::<A>(),
Box::new(AfterHook::<A, S::Domain> {
body: Box::new(hook),
}),
);
self
}
pub fn build(self) -> KeyPool<C, S> {
KeyPool {
client: self.client,
storage: self.storage,
options: Arc::new(self.options),
}
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct KeyPool<C, S> pub struct KeyPool<C, S>
where where
@ -134,7 +249,7 @@ where
{ {
client: C, client: C,
pub storage: S, pub storage: S,
comment: Option<String>, options: Arc<PoolOptions>,
} }
impl<C, S> KeyPool<C, S> impl<C, S> KeyPool<C, S>
@ -142,14 +257,6 @@ where
C: ApiClient, C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static, S: KeyPoolStorage + Send + Sync + 'static,
{ {
pub fn new(client: C, storage: S, comment: Option<String>) -> Self {
Self {
client,
storage,
comment,
}
}
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>> pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
where where
I: IntoSelector<S::Key, S::Domain>, I: IntoSelector<S::Key, S::Domain>,
@ -159,7 +266,7 @@ where
KeyPoolExecutor::new( KeyPoolExecutor::new(
&self.storage, &self.storage,
selector.into_selector(), selector.into_selector(),
self.comment.as_deref(), self.options.clone(),
), ),
) )
} }
@ -178,7 +285,7 @@ pub trait WithStorage {
{ {
ApiProvider::new( ApiProvider::new(
self, self,
KeyPoolExecutor::new(storage, selector.into_selector(), None), KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()),
) )
} }
} }
@ -188,27 +295,28 @@ impl WithStorage for reqwest::Client {}
#[cfg(all(test, feature = "postgres", feature = "reqwest"))] #[cfg(all(test, feature = "postgres", feature = "reqwest"))]
mod test { mod test {
use tokio::test; use sqlx::PgPool;
use super::*; use super::*;
use crate::postgres::test::{setup, Domain}; use crate::{
postgres::test::{setup, Domain},
KeySelector,
};
#[test] #[sqlx::test]
async fn test_pool_request() { async fn test_pool_request(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
let pool = KeyPool::new( let pool = PoolBuilder::new(reqwest::Client::default(), storage)
reqwest::Client::default(), .comment("api.rs")
storage, .build();
Some("api.rs".to_owned()),
);
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap(); let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.profile().unwrap(); _ = response.profile().unwrap();
} }
#[test] #[sqlx::test]
async fn test_with_storage_request() { async fn test_with_storage_request(pool: PgPool) {
let (storage, _) = setup().await; let (storage, _) = setup(pool).await;
let response = reqwest::Client::new() let response = reqwest::Client::new()
.with_storage(&storage, Domain::All) .with_storage(&storage, Domain::All)
@ -217,4 +325,36 @@ mod test {
.unwrap(); .unwrap();
_ = response.profile().unwrap(); _ = response.profile().unwrap();
} }
#[sqlx::test]
async fn before_hook(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_before::<torn_api::user::UserSelection>(|req| {
req.selections.push("crimes");
})
.build();
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.crimes().unwrap();
}
#[sqlx::test]
async fn after_hook(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_after::<torn_api::user::UserSelection>(|_res| Err(KeyAction::Delete))
.build();
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_some());
let response = pool.torn_api(Domain::All).user(|b| b).await;
assert!(matches!(response, Err(KeyPoolError::Storage(_))));
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_none());
}
} }