use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use torn_api::{ send::{ApiClient, ApiProvider, RequestExecutor}, ApiRequest, ApiResponse, ApiSelection, ResponseError, }; use crate::{ ApiKey, IntoSelector, KeyAction, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, KeySelector, PoolOptions, }; #[async_trait] impl<'client, C, S> RequestExecutor for KeyPoolExecutor<'client, C, S> where C: ApiClient, S: KeyPoolStorage + Send + Sync + 'static, { type Error = KeyPoolError; async fn execute( &self, client: &C, mut request: ApiRequest, id: Option, ) -> Result where A: ApiSelection, { if request.comment.is_none() { request.comment = self.options.comment.clone(); } if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::()) { let concrete = hook .downcast_ref::>() .unwrap(); (concrete.body)(&mut request, &self.selector); } loop { let key = self .storage .acquire_key(self.selector.clone()) .await .map_err(KeyPoolError::Storage)?; let url = request.url(key.value(), id.as_deref()); let value = client.request(url).await?; match ApiResponse::from_value(value) { Err(ResponseError::Api { code, reason }) => { if !self .storage .flag_key(key, code) .await .map_err(KeyPoolError::Storage)? { return Err(KeyPoolError::Response(ResponseError::Api { code, reason })); } } Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)), Ok(res) => { let res = res.into(); if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::()) { let concrete = hook .downcast_ref::>() .unwrap(); match (concrete.body)(&res, &self.selector) { Err(KeyAction::Delete) => { self.storage .remove_key(key.selector()) .await .map_err(KeyPoolError::Storage)?; continue; } Err(KeyAction::RemoveDomain(domain)) => { self.storage .remove_domain_from_key(key.selector(), domain) .await .map_err(KeyPoolError::Storage)?; continue; } _ => (), }; } return Ok(res); } }; } } async fn execute_many( &self, client: &C, mut request: ApiRequest, ids: Vec, ) -> HashMap> where A: ApiSelection, I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync, { let keys = match self .storage .acquire_many_keys(self.selector.clone(), ids.len() as i64) .await { Ok(keys) => keys, Err(why) => { return ids .into_iter() .map(|i| (i, Err(Self::Error::Storage(why.clone())))) .collect(); } }; if request.comment.is_none() { request.comment = self.options.comment.clone(); } let request_ref = &request; let tuples = futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move { let id_string = id.to_string(); loop { let url = request_ref.url(key.value(), Some(&id_string)); let value = match client.request(url).await { Ok(v) => v, Err(why) => return (id, Err(Self::Error::Client(why))), }; match ApiResponse::from_value(value) { Err(ResponseError::Api { code, reason }) => { match self.storage.flag_key(key, code).await { Ok(false) => { return ( id, Err(KeyPoolError::Response(ResponseError::Api { code, reason, })), ) } Ok(true) => (), Err(why) => return (id, Err(KeyPoolError::Storage(why))), } } Err(parsing_error) => { return (id, Err(KeyPoolError::Response(parsing_error))) } Ok(res) => return (id, Ok(res.into())), }; key = match self.storage.acquire_key(self.selector.clone()).await { Ok(k) => k, Err(why) => return (id, Err(Self::Error::Storage(why))), }; } })) .await; HashMap::from_iter(tuples) } } #[allow(clippy::type_complexity)] pub struct BeforeHook where A: ApiSelection, K: ApiKey, D: KeyDomain, { body: Box, &KeySelector) + Send + Sync + 'static>, } #[allow(clippy::type_complexity)] pub struct AfterHook where A: ApiSelection, K: ApiKey, D: KeyDomain, { body: Box< dyn Fn(&A::Response, &KeySelector) -> Result<(), crate::KeyAction> + Send + Sync + 'static, >, } pub struct PoolBuilder where C: ApiClient, S: KeyPoolStorage, { client: C, storage: S, options: crate::PoolOptions, } impl PoolBuilder where C: ApiClient, S: KeyPoolStorage, { pub fn new(client: C, storage: S) -> Self { Self { client, storage, options: Default::default(), } } pub fn comment(mut self, c: impl ToString) -> Self { self.options.comment = Some(c.to_string()); self } pub fn hook_before( mut self, hook: impl Fn(&mut ApiRequest, &KeySelector) + Send + Sync + 'static, ) -> Self where A: ApiSelection + 'static, { self.options.hooks_before.insert( std::any::TypeId::of::(), Box::new(BeforeHook { body: Box::new(hook), }), ); self } pub fn hook_after( mut self, hook: impl Fn(&A::Response, &KeySelector) -> Result<(), KeyAction> + Send + Sync + 'static, ) -> Self where A: ApiSelection + 'static, { self.options.hooks_after.insert( std::any::TypeId::of::(), Box::new(AfterHook:: { body: Box::new(hook), }), ); self } pub fn build(self) -> KeyPool { KeyPool { client: self.client, storage: self.storage, options: Arc::new(self.options), } } } #[derive(Clone, Debug)] pub struct KeyPool where C: ApiClient, S: KeyPoolStorage, { pub client: C, pub storage: S, pub options: Arc, } impl KeyPool where C: ApiClient, S: KeyPoolStorage + Send + Sync + 'static, { pub fn torn_api(&self, selector: I) -> ApiProvider> where I: IntoSelector, { ApiProvider::new( &self.client, KeyPoolExecutor::new( &self.storage, selector.into_selector(), self.options.clone(), ), ) } } pub trait WithStorage { fn with_storage<'a, S, I>( &'a self, storage: &'a S, selector: I, ) -> ApiProvider> where Self: ApiClient + Sized, S: KeyPoolStorage + Send + Sync + 'static, I: IntoSelector, { ApiProvider::new( self, KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()), ) } } #[cfg(feature = "reqwest")] impl WithStorage for reqwest::Client {} #[cfg(all(test, feature = "postgres", feature = "reqwest"))] mod test { use sqlx::PgPool; use super::*; use crate::{ postgres::test::{setup, Domain}, KeySelector, }; #[sqlx::test] async fn test_pool_request(pool: PgPool) { let (storage, _) = setup(pool).await; let pool = PoolBuilder::new(reqwest::Client::default(), storage) .comment("api.rs") .build(); let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap(); _ = response.profile().unwrap(); } #[sqlx::test] async fn test_with_storage_request(pool: PgPool) { let (storage, _) = setup(pool).await; let response = reqwest::Client::new() .with_storage(&storage, Domain::All) .user(|b| b) .await .unwrap(); _ = response.profile().unwrap(); } #[sqlx::test] async fn before_hook(pool: PgPool) { let (storage, _) = setup(pool).await; let pool = PoolBuilder::new(reqwest::Client::default(), storage) .hook_before::(|req, _s| { req.selections.push("crimes"); }) .build(); let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap(); _ = response.crimes().unwrap(); } #[sqlx::test] async fn after_hook(pool: PgPool) { let (storage, _) = setup(pool).await; let pool = PoolBuilder::new(reqwest::Client::default(), storage) .hook_after::(|_res, _s| Err(KeyAction::Delete)) .build(); let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap(); assert!(key.is_some()); let response = pool.torn_api(Domain::All).user(|b| b).await; assert!(matches!(response, Err(KeyPoolError::Storage(_)))); let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap(); assert!(key.is_none()); } }