torn-api.rs/torn-key-pool/src/send.rs

381 lines
11 KiB
Rust

use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use torn_api::{
send::{ApiClient, ApiProvider, RequestExecutor},
ApiRequest, ApiResponse, ApiSelection, ResponseError,
};
use crate::{
ApiKey, IntoSelector, KeyAction, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage,
KeySelector, PoolOptions,
};
#[async_trait]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(
&self,
client: &C,
mut request: ApiRequest<A>,
id: Option<String>,
) -> Result<A::Response, Self::Error>
where
A: ApiSelection,
{
if request.comment.is_none() {
request.comment = self.options.comment.clone();
}
if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) {
let concrete = hook
.downcast_ref::<BeforeHook<A, S::Key, S::Domain>>()
.unwrap();
(concrete.body)(&mut request, &self.selector);
}
loop {
let key = self
.storage
.acquire_key(self.selector.clone())
.await
.map_err(KeyPoolError::Storage)?;
let url = request.url(key.value(), id.as_deref());
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => {
let res = res.into();
if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) {
let concrete = hook
.downcast_ref::<AfterHook<A, S::Key, S::Domain>>()
.unwrap();
match (concrete.body)(&res, &self.selector) {
Err(KeyAction::Delete) => {
self.storage
.remove_key(key.selector())
.await
.map_err(KeyPoolError::Storage)?;
continue;
}
Err(KeyAction::RemoveDomain(domain)) => {
self.storage
.remove_domain_from_key(key.selector(), domain)
.await
.map_err(KeyPoolError::Storage)?;
continue;
}
_ => (),
};
}
return Ok(res);
}
};
}
}
async fn execute_many<A, I>(
&self,
client: &C,
mut request: ApiRequest<A>,
ids: Vec<I>,
) -> HashMap<I, Result<A::Response, Self::Error>>
where
A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
{
let keys = match self
.storage
.acquire_many_keys(self.selector.clone(), ids.len() as i64)
.await
{
Ok(keys) => keys,
Err(why) => {
return ids
.into_iter()
.map(|i| (i, Err(Self::Error::Storage(why.clone()))))
.collect();
}
};
if request.comment.is_none() {
request.comment = self.options.comment.clone();
}
let request_ref = &request;
let tuples =
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
let id_string = id.to_string();
loop {
let url = request_ref.url(key.value(), Some(&id_string));
let value = match client.request(url).await {
Ok(v) => v,
Err(why) => return (id, Err(Self::Error::Client(why))),
};
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
match self.storage.flag_key(key, code).await {
Ok(false) => {
return (
id,
Err(KeyPoolError::Response(ResponseError::Api {
code,
reason,
})),
)
}
Ok(true) => (),
Err(why) => return (id, Err(KeyPoolError::Storage(why))),
}
}
Err(parsing_error) => {
return (id, Err(KeyPoolError::Response(parsing_error)))
}
Ok(res) => return (id, Ok(res.into())),
};
key = match self.storage.acquire_key(self.selector.clone()).await {
Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(why))),
};
}
}))
.await;
HashMap::from_iter(tuples)
}
}
#[allow(clippy::type_complexity)]
pub struct BeforeHook<A, K, D>
where
A: ApiSelection,
K: ApiKey,
D: KeyDomain,
{
body: Box<dyn Fn(&mut ApiRequest<A>, &KeySelector<K, D>) + Send + Sync + 'static>,
}
#[allow(clippy::type_complexity)]
pub struct AfterHook<A, K, D>
where
A: ApiSelection,
K: ApiKey,
D: KeyDomain,
{
body: Box<
dyn Fn(&A::Response, &KeySelector<K, D>) -> Result<(), crate::KeyAction<D>>
+ Send
+ Sync
+ 'static,
>,
}
pub struct PoolBuilder<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: C,
storage: S,
options: crate::PoolOptions,
}
impl<C, S> PoolBuilder<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
pub fn new(client: C, storage: S) -> Self {
Self {
client,
storage,
options: Default::default(),
}
}
pub fn comment(mut self, c: impl ToString) -> Self {
self.options.comment = Some(c.to_string());
self
}
pub fn hook_before<A>(
mut self,
hook: impl Fn(&mut ApiRequest<A>, &KeySelector<S::Key, S::Domain>) + Send + Sync + 'static,
) -> Self
where
A: ApiSelection + 'static,
{
self.options.hooks_before.insert(
std::any::TypeId::of::<A>(),
Box::new(BeforeHook {
body: Box::new(hook),
}),
);
self
}
pub fn hook_after<A>(
mut self,
hook: impl Fn(&A::Response, &KeySelector<S::Key, S::Domain>) -> Result<(), KeyAction<S::Domain>>
+ Send
+ Sync
+ 'static,
) -> Self
where
A: ApiSelection + 'static,
{
self.options.hooks_after.insert(
std::any::TypeId::of::<A>(),
Box::new(AfterHook::<A, S::Key, S::Domain> {
body: Box::new(hook),
}),
);
self
}
pub fn build(self) -> KeyPool<C, S> {
KeyPool {
client: self.client,
storage: self.storage,
options: Arc::new(self.options),
}
}
}
#[derive(Clone, Debug)]
pub struct KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
pub client: C,
pub storage: S,
pub options: Arc<PoolOptions>,
}
impl<C, S> KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
where
I: IntoSelector<S::Key, S::Domain>,
{
ApiProvider::new(
&self.client,
KeyPoolExecutor::new(
&self.storage,
selector.into_selector(),
self.options.clone(),
),
)
}
}
pub trait WithStorage {
fn with_storage<'a, S, I>(
&'a self,
storage: &'a S,
selector: I,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ApiClient + Sized,
S: KeyPoolStorage + Send + Sync + 'static,
I: IntoSelector<S::Key, S::Domain>,
{
ApiProvider::new(
self,
KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()),
)
}
}
#[cfg(feature = "reqwest")]
impl WithStorage for reqwest::Client {}
#[cfg(all(test, feature = "postgres", feature = "reqwest"))]
mod test {
use sqlx::PgPool;
use super::*;
use crate::{
postgres::test::{setup, Domain},
KeySelector,
};
#[sqlx::test]
async fn test_pool_request(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.comment("api.rs")
.build();
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.profile().unwrap();
}
#[sqlx::test]
async fn test_with_storage_request(pool: PgPool) {
let (storage, _) = setup(pool).await;
let response = reqwest::Client::new()
.with_storage(&storage, Domain::All)
.user(|b| b)
.await
.unwrap();
_ = response.profile().unwrap();
}
#[sqlx::test]
async fn before_hook(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_before::<torn_api::user::UserSelection>(|req, _s| {
req.selections.push("crimes");
})
.build();
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.crimes().unwrap();
}
#[sqlx::test]
async fn after_hook(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_after::<torn_api::user::UserSelection>(|_res, _s| Err(KeyAction::Delete))
.build();
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_some());
let response = pool.torn_api(Domain::All).user(|b| b).await;
assert!(matches!(response, Err(KeyPoolError::Storage(_))));
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_none());
}
}