refactor(key-pool): change error handler signature

This commit is contained in:
pyrite 2025-08-10 18:53:57 +02:00
parent 8a8b34506a
commit 44c5df9a7f
Signed by: pyrite
GPG key ID: 7F1BA9170CD35D15
4 changed files with 27 additions and 21 deletions

2
Cargo.lock generated
View file

@ -2329,7 +2329,7 @@ dependencies = [
[[package]] [[package]]
name = "torn-key-pool" name = "torn-key-pool"
version = "1.1.3" version = "1.2.0"
dependencies = [ dependencies = [
"chrono", "chrono",
"futures", "futures",

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "1.1.3" version = "1.2.0"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license = { workspace = true } license = { workspace = true }

View file

@ -245,7 +245,11 @@ where
error_hooks: HashMap< error_hooks: HashMap<
u16, u16,
Box< Box<
dyn for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>> dyn for<'a> Fn(
&'a S,
&'a S::Key,
&'a ApiRequest,
) -> BoxFuture<'a, Result<bool, S::Error>>
+ Send + Send
+ Sync, + Sync,
>, >,
@ -287,27 +291,29 @@ where
self self
} }
pub fn error_hook<F>(mut self, code: u16, handler: F) -> Self pub fn error_hook<F>(mut self, error: ApiError, handler: F) -> Self
where where
F: for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>> F: for<'a> Fn(&'a S, &'a S::Key, &'a ApiRequest) -> BoxFuture<'a, Result<bool, S::Error>>
+ Send + Send
+ Sync + Sync
+ 'static, + 'static,
{ {
self.options.error_hooks.insert(code, Box::new(handler)); self.options
.error_hooks
.insert(error.code(), Box::new(handler));
self self
} }
pub fn use_default_hooks(self) -> Self { pub fn use_default_hooks(self) -> Self {
self.error_hook(2, |storage, key| { self.error_hook(ApiError::IncorrectKey, |storage, key, _| {
async move { async move {
storage.remove_key(KeySelector::Id(key.id())).await?; storage.remove_key(KeySelector::Id(key.id())).await?;
Ok(true) Ok(true)
} }
.boxed() .boxed()
}) })
.error_hook(5, |storage, key| { .error_hook(ApiError::TooManyRequest, |storage, key, _| {
async move { async move {
storage storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60)) .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
@ -316,14 +322,14 @@ where
} }
.boxed() .boxed()
}) })
.error_hook(10, |storage, key| { .error_hook(ApiError::KeyOwnerInFederalJail, |storage, key, _| {
async move { async move {
storage.remove_key(KeySelector::Id(key.id())).await?; storage.remove_key(KeySelector::Id(key.id())).await?;
Ok(true) Ok(true)
} }
.boxed() .boxed()
}) })
.error_hook(13, |storage, key| { .error_hook(ApiError::TemporaryInactivity, |storage, key, _| {
async move { async move {
storage storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600)) .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
@ -332,7 +338,7 @@ where
} }
.boxed() .boxed()
}) })
.error_hook(18, |storage, key| { .error_hook(ApiError::Paused, |storage, key, _| {
async move { async move {
storage storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600)) .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
@ -391,7 +397,7 @@ where
if let Some(err) = decode_error(&bytes)? { if let Some(err) = decode_error(&bytes)? {
if let Some(handler) = self.options.error_hooks.get(&err.code()) { if let Some(handler) = self.options.error_hooks.get(&err.code()) {
let retry = (*handler)(&self.storage, key).await?; let retry = (*handler)(&self.storage, key, request).await?;
if retry { if retry {
return Ok(RequestResult::Retry); return Ok(RequestResult::Retry);
@ -492,7 +498,7 @@ impl<S> KeyPool<S>
where where
S: KeyPoolStorage + Send + Sync + 'static, S: KeyPoolStorage + Send + Sync + 'static,
{ {
pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<S> pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<'_, S>
where where
I: IntoSelector<S::Key, S::Domain>, I: IntoSelector<S::Key, S::Domain>,
{ {
@ -503,7 +509,7 @@ where
&self, &self,
selector: I, selector: I,
distance: Duration, distance: Duration,
) -> ThrottledKeyPoolExecutor<S> ) -> ThrottledKeyPoolExecutor<'_, S>
where where
I: IntoSelector<S::Key, S::Domain>, I: IntoSelector<S::Key, S::Domain>,
{ {

View file

@ -155,7 +155,7 @@ where
pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> { pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
if let Some(schema) = self.schema.as_ref() { if let Some(schema) = self.schema.as_ref() {
sqlx::query(&format!("create schema if not exists {}", schema)) sqlx::query(&format!("create schema if not exists {schema}"))
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
} }
@ -306,7 +306,7 @@ where
fn recurse<D>( fn recurse<D>(
storage: &PgKeyPoolStorage<D>, storage: &PgKeyPoolStorage<D>,
selector: KeySelector<PgKey<D>, D>, selector: KeySelector<PgKey<D>, D>,
) -> BoxFuture<Result<PgKey<D>, PgKeyPoolError<D>>> ) -> BoxFuture<'_, Result<PgKey<D>, PgKeyPoolError<D>>>
where where
D: PgKeyDomain, D: PgKeyDomain,
{ {
@ -445,7 +445,7 @@ where
storage: &PgKeyPoolStorage<D>, storage: &PgKeyPoolStorage<D>,
selector: KeySelector<PgKey<D>, D>, selector: KeySelector<PgKey<D>, D>,
number: i64, number: i64,
) -> BoxFuture<Result<Vec<PgKey<D>>, PgKeyPoolError<D>>> ) -> BoxFuture<'_, Result<Vec<PgKey<D>>, PgKeyPoolError<D>>>
where where
D: PgKeyDomain, D: PgKeyDomain,
{ {
@ -686,7 +686,7 @@ pub(crate) mod test {
let (storage, _) = setup(pool).await; let (storage, _) = setup(pool).await;
if let Err(e) = storage.initialise().await { if let Err(e) = storage.initialise().await {
panic!("Initialising key storage failed: {:?}", e); panic!("Initialising key storage failed: {e:?}");
} }
} }
@ -815,7 +815,7 @@ pub(crate) mod test {
let (storage, _) = setup(pool).await; let (storage, _) = setup(pool).await;
if let Err(e) = storage.acquire_key(Domain::All).await { if let Err(e) = storage.acquire_key(Domain::All).await {
panic!("Acquiring key failed: {:?}", e); panic!("Acquiring key failed: {e:?}");
} }
} }
@ -843,7 +843,7 @@ pub(crate) mod test {
let (storage, _) = setup(pool).await; let (storage, _) = setup(pool).await;
match storage.acquire_many_keys(Domain::All, 30).await { match storage.acquire_many_keys(Domain::All, 30).await {
Err(e) => panic!("Acquiring key failed: {:?}", e), Err(e) => panic!("Acquiring key failed: {e:?}"),
Ok(keys) => assert_eq!(keys.len(), 30), Ok(keys) => assert_eq!(keys.len(), 30),
} }
} }
@ -888,7 +888,7 @@ pub(crate) mod test {
for i in 0..24 { for i in 0..24 {
storage storage
.store_key(1, format!("{}", i), vec![Domain::All]) .store_key(1, format!("{i}"), vec![Domain::All])
.await .await
.unwrap(); .unwrap();
} }