changed before/after hook signatures

This commit is contained in:
TotallyNot 2024-04-04 16:19:44 +02:00
parent 8aaf61efb1
commit 1daee85581
3 changed files with 33 additions and 17 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.8.0" version = "0.8.1"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license = "MIT" license = "MIT"

View file

@ -43,7 +43,7 @@ where
} }
} }
pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone { pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone + 'static {
type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone; type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone;
fn value(&self) -> &str; fn value(&self) -> &str;

View file

@ -8,7 +8,8 @@ use torn_api::{
}; };
use crate::{ use crate::{
ApiKey, IntoSelector, KeyAction, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, PoolOptions, ApiKey, IntoSelector, KeyAction, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage,
KeySelector, PoolOptions,
}; };
#[async_trait] #[async_trait]
@ -30,9 +31,11 @@ where
{ {
request.comment = self.options.comment.clone(); request.comment = self.options.comment.clone();
if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) { if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) {
let concrete = hook.downcast_ref::<BeforeHook<A>>().unwrap(); let concrete = hook
.downcast_ref::<BeforeHook<A, S::Key, S::Domain>>()
.unwrap();
(concrete.body)(&mut request); (concrete.body)(&mut request, &self.selector);
} }
loop { loop {
let key = self let key = self
@ -58,9 +61,11 @@ where
Ok(res) => { Ok(res) => {
let res = res.into(); let res = res.into();
if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) { if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) {
let concrete = hook.downcast_ref::<AfterHook<A, S::Domain>>().unwrap(); let concrete = hook
.downcast_ref::<AfterHook<A, S::Key, S::Domain>>()
.unwrap();
match (concrete.body)(&res) { match (concrete.body)(&res, &self.selector) {
Err(KeyAction::Delete) => { Err(KeyAction::Delete) => {
self.storage self.storage
.remove_key(key.selector()) .remove_key(key.selector())
@ -156,20 +161,28 @@ where
} }
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub struct BeforeHook<A> pub struct BeforeHook<A, K, D>
where where
A: ApiSelection, A: ApiSelection,
K: ApiKey,
D: KeyDomain,
{ {
body: Box<dyn Fn(&mut ApiRequest<A>) + Send + Sync + 'static>, body: Box<dyn Fn(&mut ApiRequest<A>, &KeySelector<K, D>) + Send + Sync + 'static>,
} }
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub struct AfterHook<A, D> pub struct AfterHook<A, K, D>
where where
A: ApiSelection, A: ApiSelection,
D: crate::KeyDomain, K: ApiKey,
D: KeyDomain,
{ {
body: Box<dyn Fn(&A::Response) -> Result<(), crate::KeyAction<D>> + Send + Sync + 'static>, body: Box<
dyn Fn(&A::Response, &KeySelector<K, D>) -> Result<(), crate::KeyAction<D>>
+ Send
+ Sync
+ 'static,
>,
} }
pub struct PoolBuilder<C, S> pub struct PoolBuilder<C, S>
@ -202,7 +215,7 @@ where
pub fn hook_before<A>( pub fn hook_before<A>(
mut self, mut self,
hook: impl Fn(&mut ApiRequest<A>) + Send + Sync + 'static, hook: impl Fn(&mut ApiRequest<A>, &KeySelector<S::Key, S::Domain>) + Send + Sync + 'static,
) -> Self ) -> Self
where where
A: ApiSelection + 'static, A: ApiSelection + 'static,
@ -218,14 +231,17 @@ where
pub fn hook_after<A>( pub fn hook_after<A>(
mut self, mut self,
hook: impl Fn(&A::Response) -> Result<(), KeyAction<S::Domain>> + Send + Sync + 'static, hook: impl Fn(&A::Response, &KeySelector<S::Key, S::Domain>) -> Result<(), KeyAction<S::Domain>>
+ Send
+ Sync
+ 'static,
) -> Self ) -> Self
where where
A: ApiSelection + 'static, A: ApiSelection + 'static,
{ {
self.options.hooks_after.insert( self.options.hooks_after.insert(
std::any::TypeId::of::<A>(), std::any::TypeId::of::<A>(),
Box::new(AfterHook::<A, S::Domain> { Box::new(AfterHook::<A, S::Key, S::Domain> {
body: Box::new(hook), body: Box::new(hook),
}), }),
); );
@ -331,7 +347,7 @@ mod test {
let (storage, _) = setup(pool).await; let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage) let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_before::<torn_api::user::UserSelection>(|req| { .hook_before::<torn_api::user::UserSelection>(|req, _s| {
req.selections.push("crimes"); req.selections.push("crimes");
}) })
.build(); .build();
@ -345,7 +361,7 @@ mod test {
let (storage, _) = setup(pool).await; let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage) let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_after::<torn_api::user::UserSelection>(|_res| Err(KeyAction::Delete)) .hook_after::<torn_api::user::UserSelection>(|_res, _s| Err(KeyAction::Delete))
.build(); .build();
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap(); let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();