feat: implemented bulk requests
This commit is contained in:
parent
4dd4fd37d4
commit
c17f93f600
10 changed files with 767 additions and 176 deletions
|
|
@ -5,11 +5,12 @@ pub mod postgres;
|
|||
|
||||
use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
|
||||
|
||||
use futures::{future::BoxFuture, FutureExt};
|
||||
use futures::{future::BoxFuture, FutureExt, Stream, StreamExt};
|
||||
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
|
||||
use serde::Deserialize;
|
||||
use tokio_stream::StreamExt as TokioStreamExt;
|
||||
use torn_api::{
|
||||
executor::Executor,
|
||||
executor::{BulkExecutor, Executor},
|
||||
request::{ApiRequest, ApiResponse},
|
||||
ApiError,
|
||||
};
|
||||
|
|
@ -80,6 +81,46 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<K, D> From<&str> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn from(value: &str) -> Self {
|
||||
Self::Key(value.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, D> From<D> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn from(value: D) -> Self {
|
||||
Self::Has(vec![value])
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, D> From<&[D]> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn from(value: &[D]) -> Self {
|
||||
Self::Has(value.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, D> From<Vec<D>> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn from(value: Vec<D>) -> Self {
|
||||
Self::Has(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IntoSelector<K, D>: Send
|
||||
where
|
||||
K: ApiKey,
|
||||
|
|
@ -88,30 +129,35 @@ where
|
|||
fn into_selector(self) -> KeySelector<K, D>;
|
||||
}
|
||||
|
||||
impl<K, D> IntoSelector<K, D> for D
|
||||
impl<K, D, T> IntoSelector<K, D> for T
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
T: Into<KeySelector<K, D>> + Send,
|
||||
{
|
||||
fn into_selector(self) -> KeySelector<K, D> {
|
||||
KeySelector::Has(vec![self])
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, D> IntoSelector<K, D> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
pub trait KeyPoolError:
|
||||
From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + From<Arc<Self>> + Send
|
||||
{
|
||||
}
|
||||
|
||||
impl<T> KeyPoolError for T where
|
||||
T: From<reqwest::Error>
|
||||
+ From<serde_json::Error>
|
||||
+ From<torn_api::ApiError>
|
||||
+ From<Arc<Self>>
|
||||
+ Send
|
||||
{
|
||||
fn into_selector(self) -> KeySelector<K, D> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub trait KeyPoolStorage: Send + Sync {
|
||||
type Key: ApiKey;
|
||||
type Domain: KeyDomain;
|
||||
type Error: From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + Send;
|
||||
type Error: KeyPoolError;
|
||||
|
||||
fn acquire_key<S>(
|
||||
&self,
|
||||
|
|
@ -206,65 +252,6 @@ where
|
|||
>,
|
||||
}
|
||||
|
||||
pub struct KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pool: &'p KeyPool<S>,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
}
|
||||
|
||||
impl<'p, S> KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
|
||||
Self { pool, selector }
|
||||
}
|
||||
|
||||
async fn execute_request<D>(&self, request: ApiRequest<D>) -> Result<ApiResponse<D>, S::Error>
|
||||
where
|
||||
D: Send,
|
||||
{
|
||||
let key = self.pool.storage.acquire_key(self.selector.clone()).await?;
|
||||
|
||||
let mut headers = HeaderMap::with_capacity(1);
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.pool
|
||||
.client
|
||||
.get(request.url())
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
|
||||
let bytes = resp.bytes().await?;
|
||||
|
||||
if let Some(err) = decode_error(&bytes)? {
|
||||
if let Some(handler) = self.pool.options.error_hooks.get(&err.code()) {
|
||||
let retry = (*handler)(&self.pool.storage, &key).await?;
|
||||
|
||||
if retry {
|
||||
return Box::pin(self.execute_request(request)).await;
|
||||
}
|
||||
}
|
||||
Err(err.into())
|
||||
} else {
|
||||
Ok(ApiResponse {
|
||||
discriminant: request.disriminant,
|
||||
body: Some(bytes),
|
||||
status,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PoolBuilder<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
|
|
@ -358,20 +345,137 @@ where
|
|||
|
||||
pub fn build(self) -> KeyPool<S> {
|
||||
KeyPool {
|
||||
client: self.client,
|
||||
storage: self.storage,
|
||||
options: Arc::new(self.options),
|
||||
inner: Arc::new(KeyPoolInner {
|
||||
client: self.client,
|
||||
storage: self.storage,
|
||||
options: self.options,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct KeyPoolInner<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub client: reqwest::Client,
|
||||
pub storage: S,
|
||||
pub options: PoolOptions<S>,
|
||||
}
|
||||
|
||||
impl<S> KeyPoolInner<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
async fn execute_with_key(
|
||||
&self,
|
||||
key: &S::Key,
|
||||
request: &ApiRequest,
|
||||
) -> Result<RequestResult, S::Error> {
|
||||
let mut headers = HeaderMap::with_capacity(1);
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(request.url())
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
|
||||
let bytes = resp.bytes().await?;
|
||||
|
||||
if let Some(err) = decode_error(&bytes)? {
|
||||
if let Some(handler) = self.options.error_hooks.get(&err.code()) {
|
||||
let retry = (*handler)(&self.storage, key).await?;
|
||||
|
||||
if retry {
|
||||
return Ok(RequestResult::Retry);
|
||||
}
|
||||
}
|
||||
Err(err.into())
|
||||
} else {
|
||||
Ok(RequestResult::Response(ApiResponse {
|
||||
body: Some(bytes),
|
||||
status,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_request(
|
||||
&self,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
request: ApiRequest,
|
||||
) -> Result<ApiResponse, S::Error> {
|
||||
loop {
|
||||
let key = self.storage.acquire_key(selector.clone()).await?;
|
||||
match self.execute_with_key(&key, &request).await {
|
||||
Ok(RequestResult::Response(resp)) => return Ok(resp),
|
||||
Ok(RequestResult::Retry) => (),
|
||||
Err(why) => return Err(why),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_bulk_requests<D, T: IntoIterator<Item = (D, ApiRequest)>>(
|
||||
&self,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
requests: T,
|
||||
) -> impl Stream<Item = (D, Result<ApiResponse, S::Error>)> + use<'_, D, S, T> {
|
||||
let requests: Vec<_> = requests.into_iter().collect();
|
||||
|
||||
let keys: Vec<_> = match self
|
||||
.storage
|
||||
.acquire_many_keys(selector.clone(), requests.len() as i64)
|
||||
.await
|
||||
{
|
||||
Ok(keys) => keys.into_iter().map(Ok).collect(),
|
||||
Err(why) => {
|
||||
let why = Arc::new(why);
|
||||
std::iter::repeat_n(why, requests.len())
|
||||
.map(|e| Err(S::Error::from(e)))
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
StreamExt::map(
|
||||
futures::stream::iter(std::iter::zip(requests, keys)),
|
||||
move |((discriminant, request), mut maybe_key)| {
|
||||
let selector = selector.clone();
|
||||
async move {
|
||||
loop {
|
||||
let key = match maybe_key {
|
||||
Ok(key) => key,
|
||||
Err(why) => return (discriminant, Err(why)),
|
||||
};
|
||||
match self.execute_with_key(&key, &request).await {
|
||||
Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)),
|
||||
Ok(RequestResult::Retry) => (),
|
||||
Err(why) => return (discriminant, Err(why)),
|
||||
}
|
||||
maybe_key = self.storage.acquire_key(selector.clone()).await;
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
.buffer_unordered(25)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KeyPool<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub client: reqwest::Client,
|
||||
pub storage: S,
|
||||
pub options: Arc<PoolOptions<S>>,
|
||||
inner: Arc<KeyPoolInner<S>>,
|
||||
}
|
||||
|
||||
enum RequestResult {
|
||||
Response(ApiResponse),
|
||||
Retry,
|
||||
}
|
||||
|
||||
impl<S> KeyPool<S>
|
||||
|
|
@ -384,6 +488,17 @@ where
|
|||
{
|
||||
KeyPoolExecutor::new(self, selector.into_selector())
|
||||
}
|
||||
|
||||
pub fn throttled_torn_api<I>(
|
||||
&self,
|
||||
selector: I,
|
||||
distance: Duration,
|
||||
) -> ThrottledKeyPoolExecutor<S>
|
||||
where
|
||||
I: IntoSelector<S::Key, S::Domain>,
|
||||
{
|
||||
ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance)
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
|
||||
|
|
@ -409,28 +524,145 @@ fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S> Executor for KeyPoolExecutor<'_, S>
|
||||
pub struct KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pool: &'p KeyPoolInner<S>,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
}
|
||||
|
||||
impl<'p, S> KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
|
||||
Self {
|
||||
pool: &pool.inner,
|
||||
selector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Executor for KeyPoolExecutor<'_, S>
|
||||
where
|
||||
S: KeyPoolStorage + 'static,
|
||||
{
|
||||
type Error = S::Error;
|
||||
|
||||
async fn execute<R>(
|
||||
&self,
|
||||
request: R,
|
||||
) -> Result<torn_api::request::ApiResponse<R::Discriminant>, Self::Error>
|
||||
async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
|
||||
where
|
||||
R: torn_api::request::IntoRequest,
|
||||
{
|
||||
let request = request.into_request();
|
||||
let (d, request) = request.into_request();
|
||||
|
||||
self.execute_request(request).await
|
||||
(d, self.pool.execute_request(self.selector, request).await)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'p, S> BulkExecutor<'p> for KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage + 'static,
|
||||
{
|
||||
type Error = S::Error;
|
||||
|
||||
fn execute<R>(
|
||||
self,
|
||||
requests: impl IntoIterator<Item = R>,
|
||||
) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
|
||||
where
|
||||
R: torn_api::request::IntoRequest,
|
||||
{
|
||||
self.pool
|
||||
.execute_bulk_requests(
|
||||
self.selector.clone(),
|
||||
requests.into_iter().map(|r| r.into_request()),
|
||||
)
|
||||
.into_stream()
|
||||
.flatten()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ThrottledKeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pool: &'p KeyPoolInner<S>,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
distance: Duration,
|
||||
}
|
||||
|
||||
impl<S> Clone for ThrottledKeyPoolExecutor<'_, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
pool: self.pool,
|
||||
selector: self.selector.clone(),
|
||||
distance: self.distance,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ThrottledKeyPoolExecutor<'_, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
async fn execute_request(self, request: ApiRequest) -> Result<ApiResponse, S::Error> {
|
||||
self.pool.execute_request(self.selector, request).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<'p, S> ThrottledKeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(
|
||||
pool: &'p KeyPool<S>,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
distance: Duration,
|
||||
) -> Self {
|
||||
Self {
|
||||
pool: &pool.inner,
|
||||
selector,
|
||||
distance,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'p, S> BulkExecutor<'p> for ThrottledKeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage + 'static,
|
||||
{
|
||||
type Error = S::Error;
|
||||
|
||||
fn execute<R>(
|
||||
self,
|
||||
requests: impl IntoIterator<Item = R>,
|
||||
) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
|
||||
where
|
||||
R: torn_api::request::IntoRequest,
|
||||
{
|
||||
StreamExt::map(
|
||||
futures::stream::iter(requests).throttle(self.distance),
|
||||
move |r| {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let (d, request) = r.into_request();
|
||||
let result = this.execute_request(request).await;
|
||||
(d, result)
|
||||
}
|
||||
},
|
||||
)
|
||||
.buffer_unordered(25)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(feature = "postgres")]
|
||||
mod test {
|
||||
use torn_api::executor::ExecutorExt;
|
||||
use torn_api::executor::{BulkExecutorExt, ExecutorExt};
|
||||
|
||||
use crate::postgres;
|
||||
|
||||
|
|
@ -451,4 +683,48 @@ mod test {
|
|||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
fn bulk(pool: sqlx::PgPool) {
|
||||
let (storage, _) = postgres::test::setup(pool).await;
|
||||
|
||||
let pool = PoolBuilder::new(storage)
|
||||
.use_default_hooks()
|
||||
.comment("test_runner")
|
||||
.build();
|
||||
|
||||
let responses = pool
|
||||
.torn_api(postgres::test::Domain::All)
|
||||
.faction_bulk()
|
||||
.basic_for_id(vec![19.into(), 89.into()], |b| b);
|
||||
let mut responses: Vec<_> = StreamExt::collect(responses).await;
|
||||
|
||||
let (_id1, basic1) = responses.pop().unwrap();
|
||||
basic1.unwrap();
|
||||
|
||||
let (_id2, basic2) = responses.pop().unwrap();
|
||||
basic2.unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
fn bulk_trottled(pool: sqlx::PgPool) {
|
||||
let (storage, _) = postgres::test::setup(pool).await;
|
||||
|
||||
let pool = PoolBuilder::new(storage)
|
||||
.use_default_hooks()
|
||||
.comment("test_runner")
|
||||
.build();
|
||||
|
||||
let responses = pool
|
||||
.throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500))
|
||||
.faction_bulk()
|
||||
.basic_for_id(vec![19.into(), 89.into()], |b| b);
|
||||
let mut responses: Vec<_> = StreamExt::collect(responses).await;
|
||||
|
||||
let (_id1, basic1) = responses.pop().unwrap();
|
||||
basic1.unwrap();
|
||||
|
||||
let (_id2, basic2) = responses.pop().unwrap();
|
||||
basic2.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
use indoc::formatdoc;
|
||||
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
|
||||
|
|
@ -37,6 +39,9 @@ where
|
|||
|
||||
#[error("Key not found: '{0:?}'")]
|
||||
KeyNotFound(KeySelector<PgKey<D>, D>),
|
||||
|
||||
#[error("Failed to acquire keys in bulk: {0}")]
|
||||
Bulk(#[from] Arc<Self>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue