feat: implemented bulk requests

This commit is contained in:
TotallyNot 2025-04-29 18:26:00 +02:00
parent 4dd4fd37d4
commit c17f93f600
Signed by: pyrite
GPG key ID: 7F1BA9170CD35D15
10 changed files with 767 additions and 176 deletions

View file

@ -5,11 +5,12 @@ pub mod postgres;
use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
use futures::{future::BoxFuture, FutureExt};
use futures::{future::BoxFuture, FutureExt, Stream, StreamExt};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde::Deserialize;
use tokio_stream::StreamExt as TokioStreamExt;
use torn_api::{
executor::Executor,
executor::{BulkExecutor, Executor},
request::{ApiRequest, ApiResponse},
ApiError,
};
@ -80,6 +81,46 @@ where
}
}
impl<K, D> From<&str> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: &str) -> Self {
Self::Key(value.to_owned())
}
}
impl<K, D> From<D> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: D) -> Self {
Self::Has(vec![value])
}
}
impl<K, D> From<&[D]> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: &[D]) -> Self {
Self::Has(value.to_vec())
}
}
impl<K, D> From<Vec<D>> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: Vec<D>) -> Self {
Self::Has(value)
}
}
pub trait IntoSelector<K, D>: Send
where
K: ApiKey,
@ -88,30 +129,35 @@ where
fn into_selector(self) -> KeySelector<K, D>;
}
impl<K, D> IntoSelector<K, D> for D
impl<K, D, T> IntoSelector<K, D> for T
where
K: ApiKey,
D: KeyDomain,
T: Into<KeySelector<K, D>> + Send,
{
fn into_selector(self) -> KeySelector<K, D> {
KeySelector::Has(vec![self])
self.into()
}
}
impl<K, D> IntoSelector<K, D> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
pub trait KeyPoolError:
From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + From<Arc<Self>> + Send
{
}
impl<T> KeyPoolError for T where
T: From<reqwest::Error>
+ From<serde_json::Error>
+ From<torn_api::ApiError>
+ From<Arc<Self>>
+ Send
{
fn into_selector(self) -> KeySelector<K, D> {
self
}
}
pub trait KeyPoolStorage: Send + Sync {
type Key: ApiKey;
type Domain: KeyDomain;
type Error: From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + Send;
type Error: KeyPoolError;
fn acquire_key<S>(
&self,
@ -206,65 +252,6 @@ where
>,
}
pub struct KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pool: &'p KeyPool<S>,
selector: KeySelector<S::Key, S::Domain>,
}
impl<'p, S> KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
Self { pool, selector }
}
async fn execute_request<D>(&self, request: ApiRequest<D>) -> Result<ApiResponse<D>, S::Error>
where
D: Send,
{
let key = self.pool.storage.acquire_key(self.selector.clone()).await?;
let mut headers = HeaderMap::with_capacity(1);
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
);
let resp = self
.pool
.client
.get(request.url())
.headers(headers)
.send()
.await?;
let status = resp.status();
let bytes = resp.bytes().await?;
if let Some(err) = decode_error(&bytes)? {
if let Some(handler) = self.pool.options.error_hooks.get(&err.code()) {
let retry = (*handler)(&self.pool.storage, &key).await?;
if retry {
return Box::pin(self.execute_request(request)).await;
}
}
Err(err.into())
} else {
Ok(ApiResponse {
discriminant: request.disriminant,
body: Some(bytes),
status,
})
}
}
}
pub struct PoolBuilder<S>
where
S: KeyPoolStorage,
@ -358,20 +345,137 @@ where
pub fn build(self) -> KeyPool<S> {
KeyPool {
client: self.client,
storage: self.storage,
options: Arc::new(self.options),
inner: Arc::new(KeyPoolInner {
client: self.client,
storage: self.storage,
options: self.options,
}),
}
}
}
struct KeyPoolInner<S>
where
S: KeyPoolStorage,
{
pub client: reqwest::Client,
pub storage: S,
pub options: PoolOptions<S>,
}
impl<S> KeyPoolInner<S>
where
S: KeyPoolStorage,
{
async fn execute_with_key(
&self,
key: &S::Key,
request: &ApiRequest,
) -> Result<RequestResult, S::Error> {
let mut headers = HeaderMap::with_capacity(1);
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
);
let resp = self
.client
.get(request.url())
.headers(headers)
.send()
.await?;
let status = resp.status();
let bytes = resp.bytes().await?;
if let Some(err) = decode_error(&bytes)? {
if let Some(handler) = self.options.error_hooks.get(&err.code()) {
let retry = (*handler)(&self.storage, key).await?;
if retry {
return Ok(RequestResult::Retry);
}
}
Err(err.into())
} else {
Ok(RequestResult::Response(ApiResponse {
body: Some(bytes),
status,
}))
}
}
async fn execute_request(
&self,
selector: KeySelector<S::Key, S::Domain>,
request: ApiRequest,
) -> Result<ApiResponse, S::Error> {
loop {
let key = self.storage.acquire_key(selector.clone()).await?;
match self.execute_with_key(&key, &request).await {
Ok(RequestResult::Response(resp)) => return Ok(resp),
Ok(RequestResult::Retry) => (),
Err(why) => return Err(why),
}
}
}
async fn execute_bulk_requests<D, T: IntoIterator<Item = (D, ApiRequest)>>(
&self,
selector: KeySelector<S::Key, S::Domain>,
requests: T,
) -> impl Stream<Item = (D, Result<ApiResponse, S::Error>)> + use<'_, D, S, T> {
let requests: Vec<_> = requests.into_iter().collect();
let keys: Vec<_> = match self
.storage
.acquire_many_keys(selector.clone(), requests.len() as i64)
.await
{
Ok(keys) => keys.into_iter().map(Ok).collect(),
Err(why) => {
let why = Arc::new(why);
std::iter::repeat_n(why, requests.len())
.map(|e| Err(S::Error::from(e)))
.collect()
}
};
StreamExt::map(
futures::stream::iter(std::iter::zip(requests, keys)),
move |((discriminant, request), mut maybe_key)| {
let selector = selector.clone();
async move {
loop {
let key = match maybe_key {
Ok(key) => key,
Err(why) => return (discriminant, Err(why)),
};
match self.execute_with_key(&key, &request).await {
Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)),
Ok(RequestResult::Retry) => (),
Err(why) => return (discriminant, Err(why)),
}
maybe_key = self.storage.acquire_key(selector.clone()).await;
}
}
},
)
.buffer_unordered(25)
}
}
pub struct KeyPool<S>
where
S: KeyPoolStorage,
{
pub client: reqwest::Client,
pub storage: S,
pub options: Arc<PoolOptions<S>>,
inner: Arc<KeyPoolInner<S>>,
}
enum RequestResult {
Response(ApiResponse),
Retry,
}
impl<S> KeyPool<S>
@ -384,6 +488,17 @@ where
{
KeyPoolExecutor::new(self, selector.into_selector())
}
pub fn throttled_torn_api<I>(
&self,
selector: I,
distance: Duration,
) -> ThrottledKeyPoolExecutor<S>
where
I: IntoSelector<S::Key, S::Domain>,
{
ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance)
}
}
fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
@ -409,28 +524,145 @@ fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
}
}
impl<S> Executor for KeyPoolExecutor<'_, S>
pub struct KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pool: &'p KeyPoolInner<S>,
selector: KeySelector<S::Key, S::Domain>,
}
impl<'p, S> KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
Self {
pool: &pool.inner,
selector,
}
}
}
impl<S> Executor for KeyPoolExecutor<'_, S>
where
S: KeyPoolStorage + 'static,
{
type Error = S::Error;
async fn execute<R>(
&self,
request: R,
) -> Result<torn_api::request::ApiResponse<R::Discriminant>, Self::Error>
async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
where
R: torn_api::request::IntoRequest,
{
let request = request.into_request();
let (d, request) = request.into_request();
self.execute_request(request).await
(d, self.pool.execute_request(self.selector, request).await)
}
}
impl<'p, S> BulkExecutor<'p> for KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage + 'static,
{
type Error = S::Error;
fn execute<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
where
R: torn_api::request::IntoRequest,
{
self.pool
.execute_bulk_requests(
self.selector.clone(),
requests.into_iter().map(|r| r.into_request()),
)
.into_stream()
.flatten()
}
}
pub struct ThrottledKeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pool: &'p KeyPoolInner<S>,
selector: KeySelector<S::Key, S::Domain>,
distance: Duration,
}
impl<S> Clone for ThrottledKeyPoolExecutor<'_, S>
where
S: KeyPoolStorage,
{
fn clone(&self) -> Self {
Self {
pool: self.pool,
selector: self.selector.clone(),
distance: self.distance,
}
}
}
impl<S> ThrottledKeyPoolExecutor<'_, S>
where
S: KeyPoolStorage,
{
async fn execute_request(self, request: ApiRequest) -> Result<ApiResponse, S::Error> {
self.pool.execute_request(self.selector, request).await
}
}
impl<'p, S> ThrottledKeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pub fn new(
pool: &'p KeyPool<S>,
selector: KeySelector<S::Key, S::Domain>,
distance: Duration,
) -> Self {
Self {
pool: &pool.inner,
selector,
distance,
}
}
}
impl<'p, S> BulkExecutor<'p> for ThrottledKeyPoolExecutor<'p, S>
where
S: KeyPoolStorage + 'static,
{
type Error = S::Error;
fn execute<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
where
R: torn_api::request::IntoRequest,
{
StreamExt::map(
futures::stream::iter(requests).throttle(self.distance),
move |r| {
let this = self.clone();
async move {
let (d, request) = r.into_request();
let result = this.execute_request(request).await;
(d, result)
}
},
)
.buffer_unordered(25)
}
}
#[cfg(test)]
#[cfg(feature = "postgres")]
mod test {
use torn_api::executor::ExecutorExt;
use torn_api::executor::{BulkExecutorExt, ExecutorExt};
use crate::postgres;
@ -451,4 +683,48 @@ mod test {
.await
.unwrap();
}
#[sqlx::test]
fn bulk(pool: sqlx::PgPool) {
let (storage, _) = postgres::test::setup(pool).await;
let pool = PoolBuilder::new(storage)
.use_default_hooks()
.comment("test_runner")
.build();
let responses = pool
.torn_api(postgres::test::Domain::All)
.faction_bulk()
.basic_for_id(vec![19.into(), 89.into()], |b| b);
let mut responses: Vec<_> = StreamExt::collect(responses).await;
let (_id1, basic1) = responses.pop().unwrap();
basic1.unwrap();
let (_id2, basic2) = responses.pop().unwrap();
basic2.unwrap();
}
#[sqlx::test]
fn bulk_trottled(pool: sqlx::PgPool) {
let (storage, _) = postgres::test::setup(pool).await;
let pool = PoolBuilder::new(storage)
.use_default_hooks()
.comment("test_runner")
.build();
let responses = pool
.throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500))
.faction_bulk()
.basic_for_id(vec![19.into(), 89.into()], |b| b);
let mut responses: Vec<_> = StreamExt::collect(responses).await;
let (_id1, basic1) = responses.pop().unwrap();
basic1.unwrap();
let (_id2, basic2) = responses.pop().unwrap();
basic2.unwrap();
}
}