feat: implemented bulk requests

This commit is contained in:
TotallyNot 2025-04-29 18:26:00 +02:00
parent 4dd4fd37d4
commit c17f93f600
Signed by: pyrite
GPG key ID: 7F1BA9170CD35D15
10 changed files with 767 additions and 176 deletions

8
Cargo.lock generated
View file

@ -2271,10 +2271,11 @@ dependencies = [
[[package]] [[package]]
name = "torn-api" name = "torn-api"
version = "1.0.3" version = "1.1.0"
dependencies = [ dependencies = [
"bon", "bon",
"bytes", "bytes",
"futures",
"http", "http",
"prettyplease", "prettyplease",
"proc-macro2", "proc-macro2",
@ -2290,7 +2291,7 @@ dependencies = [
[[package]] [[package]]
name = "torn-api-codegen" name = "torn-api-codegen"
version = "0.1.5" version = "0.2.0"
dependencies = [ dependencies = [
"heck", "heck",
"indexmap", "indexmap",
@ -2303,7 +2304,7 @@ dependencies = [
[[package]] [[package]]
name = "torn-key-pool" name = "torn-key-pool"
version = "1.0.1" version = "1.1.0"
dependencies = [ dependencies = [
"chrono", "chrono",
"futures", "futures",
@ -2315,6 +2316,7 @@ dependencies = [
"sqlx", "sqlx",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-stream",
"torn-api", "torn-api",
] ]

View file

@ -1,7 +1,7 @@
[package] [package]
name = "torn-api-codegen" name = "torn-api-codegen"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
version = "0.1.5" version = "0.2.0"
edition = "2021" edition = "2021"
description = "Contains the v2 torn API model descriptions and codegen for the bindings" description = "Contains the v2 torn API model descriptions and codegen for the bindings"
license-file = { workspace = true } license-file = { workspace = true }

View file

@ -284,15 +284,18 @@ impl Path {
#[allow(unused_parens)] #[allow(unused_parens)]
type Discriminant = (#(#discriminant),*); type Discriminant = (#(#discriminant),*);
type Response = #response_ty; type Response = #response_ty;
fn into_request(self) -> crate::request::ApiRequest<Self::Discriminant> { fn into_request(self) -> (Self::Discriminant, crate::request::ApiRequest) {
let path = format!(#path_fmt_str, #(#fmt_val),*);
#[allow(unused_parens)] #[allow(unused_parens)]
(
(#(#discriminant_val),*),
crate::request::ApiRequest { crate::request::ApiRequest {
path: format!(#path_fmt_str, #(#fmt_val),*), path,
parameters: std::iter::empty() parameters: std::iter::empty()
#(#convert_field)* #(#convert_field)*
.collect(), .collect(),
disriminant: (#(#discriminant_val),*),
} }
)
} }
} }
}) })
@ -376,7 +379,7 @@ impl Path {
Some(quote! { Some(quote! {
pub async fn #fn_name<S>( pub async fn #fn_name<S>(
&self, self,
#(#extra_args)* #(#extra_args)*
builder: impl FnOnce( builder: impl FnOnce(
#builder_path<#builder_mod_path::Empty> #builder_path<#builder_mod_path::Empty>
@ -391,6 +394,120 @@ impl Path {
} }
}) })
} }
pub fn codegen_bulk_scope_call(&self) -> Option<TokenStream> {
let mut disc = Vec::new();
let mut disc_ty = Vec::new();
let snake_name = self.name.to_snake_case();
let request_name = format_ident!("{}Request", self.name);
let builder_name = format_ident!("{}RequestBuilder", self.name);
let builder_mod_name = format_ident!("{}_request_builder", snake_name);
let request_mod_name = format_ident!("{snake_name}");
let request_path = quote! { crate::request::models::#request_name };
let builder_path = quote! { crate::request::models::#builder_name };
let builder_mod_path = quote! { crate::request::models::#builder_mod_name };
let tail = snake_name
.split_once('_')
.map_or_else(|| "for_selections".to_owned(), |(_, tail)| tail.to_owned());
let fn_name = format_ident!("{tail}");
for param in &self.parameters {
let (param, is_inline) = match param {
PathParameter::Inline(param) => (param, true),
PathParameter::Component(param) => (param, false),
};
if param.location == ParameterLocation::Path {
let ty = match &param.r#type {
ParameterType::I32 { .. } | ParameterType::Enum { .. } => {
let ty_name = format_ident!("{}", param.name);
if is_inline {
quote! {
crate::request::models::#request_mod_name::#ty_name
}
} else {
quote! {
crate::parameters::#ty_name
}
}
}
ParameterType::String => quote! { String },
ParameterType::Boolean => quote! { bool },
ParameterType::Schema { type_name } => {
let ty_name = format_ident!("{}", type_name);
quote! {
crate::models::#ty_name
}
}
ParameterType::Array { .. } => param.r#type.codegen_type_name(&param.name),
};
let arg_name = format_ident!("{}", param.value.to_snake_case());
disc_ty.push(ty);
disc.push(arg_name);
}
}
if disc.is_empty() {
return None;
}
let response_ty = match &self.response {
PathResponse::Component { name } => {
let name = format_ident!("{name}");
quote! {
crate::models::#name
}
}
PathResponse::ArbitraryUnion(union) => {
let name = format_ident!("{}", union.name);
quote! {
crate::request::models::#request_mod_name::#name
}
}
};
let disc = if disc.len() > 1 {
quote! { (#(#disc),*) }
} else {
quote! { #(#disc),* }
};
let disc_ty = if disc_ty.len() > 1 {
quote! { (#(#disc_ty),*) }
} else {
quote! { #(#disc_ty),* }
};
Some(quote! {
pub fn #fn_name<S, I, B>(
self,
ids: I,
builder: B
) -> impl futures::Stream<Item = (#disc_ty, Result<#response_ty, E::Error>)> + use<'e, E, S, I, B>
where
I: IntoIterator<Item = #disc_ty>,
S: #builder_mod_path::IsComplete,
B: Fn(
#builder_path<#builder_mod_path::Empty>
) -> #builder_path<S>,
{
let requests = ids.into_iter()
.map(move |#disc| builder(#request_path::builder(#disc)).build());
let executor = self.executor;
executor.fetch_many(requests)
}
})
}
} }
pub struct PathNamespace<'r> { pub struct PathNamespace<'r> {

View file

@ -35,30 +35,56 @@ impl Scope {
pub fn codegen(&self) -> Option<TokenStream> { pub fn codegen(&self) -> Option<TokenStream> {
let name = format_ident!("{}", self.name); let name = format_ident!("{}", self.name);
let bulk_name = format_ident!("Bulk{}", self.name);
let mut functions = Vec::with_capacity(self.members.len()); let mut functions = Vec::with_capacity(self.members.len());
let mut bulk_functions = Vec::with_capacity(self.members.len());
for member in &self.members { for member in &self.members {
if let Some(code) = member.codegen_scope_call() { if let Some(code) = member.codegen_scope_call() {
functions.push(code); functions.push(code);
} }
if let Some(code) = member.codegen_bulk_scope_call() {
bulk_functions.push(code);
}
} }
Some(quote! { Some(quote! {
pub struct #name<'e, E>(&'e E) pub struct #name<E>(E)
where where
E: crate::executor::Executor; E: crate::executor::Executor;
impl<'e, E> #name<'e, E> impl<E> #name<E>
where where
E: crate::executor::Executor E: crate::executor::Executor
{ {
pub fn new(executor: &'e E) -> Self { pub fn new(executor: E) -> Self {
Self(executor) Self(executor)
} }
#(#functions)* #(#functions)*
} }
pub struct #bulk_name<'e, E> where
E: crate::executor::BulkExecutor<'e>,
{
executor: E,
marker: std::marker::PhantomData<&'e E>,
}
impl<'e, E> #bulk_name<'e, E>
where
E: crate::executor::BulkExecutor<'e>
{
pub fn new(executor: E) -> Self {
Self {
executor,
marker: std::marker::PhantomData,
}
}
#(#bulk_functions)*
}
}) })
} }
} }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-api" name = "torn-api"
version = "1.0.3" version = "1.1.0"
edition = "2021" edition = "2021"
description = "Auto-generated bindings for the v2 torn api" description = "Auto-generated bindings for the v2 torn api"
license-file = { workspace = true } license-file = { workspace = true }
@ -27,12 +27,16 @@ reqwest = { version = "0.12", default-features = false, features = [
"brotli", "brotli",
] } ] }
thiserror = "2" thiserror = "2"
futures = { version = "0.3", default-features = false, features = [
"std",
"async-await",
] }
[dev-dependencies] [dev-dependencies]
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
[build-dependencies] [build-dependencies]
torn-api-codegen = { path = "../torn-api-codegen", version = "0.1.5" } torn-api-codegen = { path = "../torn-api-codegen", version = "0.2" }
syn = { workspace = true, features = ["parsing"] } syn = { workspace = true, features = ["parsing"] }
proc-macro2 = { workspace = true } proc-macro2 = { workspace = true }
prettyplease = "0.2" prettyplease = "0.2"

View file

@ -1,23 +1,27 @@
use std::future::Future; use std::future::Future;
use futures::{Stream, StreamExt};
use http::{header::AUTHORIZATION, HeaderMap, HeaderValue}; use http::{header::AUTHORIZATION, HeaderMap, HeaderValue};
use serde::Deserialize; use serde::Deserialize;
use crate::request::{ApiResponse, IntoRequest}; use crate::request::{ApiRequest, ApiResponse, IntoRequest};
#[cfg(feature = "scopes")] #[cfg(feature = "scopes")]
use crate::scopes::{FactionScope, ForumScope, MarketScope, RacingScope, TornScope, UserScope}; use crate::scopes::{
BulkFactionScope, BulkForumScope, BulkMarketScope, BulkRacingScope, BulkTornScope,
BulkUserScope, FactionScope, ForumScope, MarketScope, RacingScope, TornScope, UserScope,
};
pub trait Executor { pub trait Executor: Sized {
type Error: From<serde_json::Error> + From<crate::ApiError> + Send; type Error: From<serde_json::Error> + From<crate::ApiError> + Send;
fn execute<R>( fn execute<R>(
&self, self,
request: R, request: R,
) -> impl Future<Output = Result<ApiResponse<R::Discriminant>, Self::Error>> + Send ) -> impl Future<Output = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Send
where where
R: IntoRequest; R: IntoRequest;
fn fetch<R>(&self, request: R) -> impl Future<Output = Result<R::Response, Self::Error>> + Send fn fetch<R>(self, request: R) -> impl Future<Output = Result<R::Response, Self::Error>> + Send
where where
R: IntoRequest, R: IntoRequest,
{ {
@ -25,7 +29,7 @@ pub trait Executor {
// The future is `Send` but `&self` might not be. // The future is `Send` but `&self` might not be.
let fut = self.execute(request); let fut = self.execute(request);
async { async {
let resp = fut.await?; let resp = fut.await.1?;
let bytes = resp.body.unwrap(); let bytes = resp.body.unwrap();
@ -52,6 +56,152 @@ pub trait Executor {
} }
} }
pub trait BulkExecutor<'e>: 'e + Sized {
type Error: From<serde_json::Error> + From<crate::ApiError> + Send;
fn execute<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
where
R: IntoRequest;
fn fetch_many<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl Stream<Item = (R::Discriminant, Result<R::Response, Self::Error>)>
where
R: IntoRequest,
{
self.execute(requests).map(|(d, r)| {
let r = match r {
Ok(r) => r,
Err(why) => return (d, Err(why)),
};
let bytes = r.body.unwrap();
if bytes.starts_with(br#"{"error":{"#) {
#[derive(Deserialize)]
struct ErrorBody<'a> {
code: u16,
error: &'a str,
}
#[derive(Deserialize)]
struct ErrorContainer<'a> {
#[serde(borrow)]
error: ErrorBody<'a>,
}
let error: ErrorContainer = match serde_json::from_slice(&bytes) {
Ok(error) => error,
Err(why) => return (d, Err(why.into())),
};
return (
d,
Err(crate::ApiError::new(error.error.code, error.error.error).into()),
);
}
let resp = match serde_json::from_slice(&bytes) {
Ok(resp) => resp,
Err(why) => return (d, Err(why.into())),
};
(d, Ok(resp))
})
}
}
#[cfg(feature = "scopes")]
pub trait ExecutorExt: Executor + Sized {
fn user(self) -> UserScope<Self>;
fn faction(self) -> FactionScope<Self>;
fn torn(self) -> TornScope<Self>;
fn market(self) -> MarketScope<Self>;
fn racing(self) -> RacingScope<Self>;
fn forum(self) -> ForumScope<Self>;
}
#[cfg(feature = "scopes")]
impl<T> ExecutorExt for T
where
T: Executor + Sized,
{
fn user(self) -> UserScope<Self> {
UserScope::new(self)
}
fn faction(self) -> FactionScope<Self> {
FactionScope::new(self)
}
fn torn(self) -> TornScope<Self> {
TornScope::new(self)
}
fn market(self) -> MarketScope<Self> {
MarketScope::new(self)
}
fn racing(self) -> RacingScope<Self> {
RacingScope::new(self)
}
fn forum(self) -> ForumScope<Self> {
ForumScope::new(self)
}
}
#[cfg(feature = "scopes")]
pub trait BulkExecutorExt<'e>: BulkExecutor<'e> + Sized {
fn user_bulk(self) -> BulkUserScope<'e, Self>;
fn faction_bulk(self) -> BulkFactionScope<'e, Self>;
fn torn_bulk(self) -> BulkTornScope<'e, Self>;
fn market_bulk(self) -> BulkMarketScope<'e, Self>;
fn racing_bulk(self) -> BulkRacingScope<'e, Self>;
fn forum_bulk(self) -> BulkForumScope<'e, Self>;
}
#[cfg(feature = "scopes")]
impl<'e, T> BulkExecutorExt<'e> for T
where
T: BulkExecutor<'e> + Sized,
{
fn user_bulk(self) -> BulkUserScope<'e, Self> {
BulkUserScope::new(self)
}
fn faction_bulk(self) -> BulkFactionScope<'e, Self> {
BulkFactionScope::new(self)
}
fn torn_bulk(self) -> BulkTornScope<'e, Self> {
BulkTornScope::new(self)
}
fn market_bulk(self) -> BulkMarketScope<'e, Self> {
BulkMarketScope::new(self)
}
fn racing_bulk(self) -> BulkRacingScope<'e, Self> {
BulkRacingScope::new(self)
}
fn forum_bulk(self) -> BulkForumScope<'e, Self> {
BulkForumScope::new(self)
}
}
pub struct ReqwestClient(reqwest::Client); pub struct ReqwestClient(reqwest::Client);
impl ReqwestClient { impl ReqwestClient {
@ -72,70 +222,43 @@ impl ReqwestClient {
} }
} }
#[cfg(feature = "scopes")] impl ReqwestClient {
pub trait ExecutorExt: Executor + Sized { async fn execute_api_request(&self, request: ApiRequest) -> Result<ApiResponse, crate::Error> {
fn user(&self) -> UserScope<'_, Self>;
fn faction(&self) -> FactionScope<'_, Self>;
fn torn(&self) -> TornScope<'_, Self>;
fn market(&self) -> MarketScope<'_, Self>;
fn racing(&self) -> RacingScope<'_, Self>;
fn forum(&self) -> ForumScope<'_, Self>;
}
#[cfg(feature = "scopes")]
impl<T> ExecutorExt for T
where
T: Executor + Sized,
{
fn user(&self) -> UserScope<'_, Self> {
UserScope::new(self)
}
fn faction(&self) -> FactionScope<'_, Self> {
FactionScope::new(self)
}
fn torn(&self) -> TornScope<'_, Self> {
TornScope::new(self)
}
fn market(&self) -> MarketScope<'_, Self> {
MarketScope::new(self)
}
fn racing(&self) -> RacingScope<'_, Self> {
RacingScope::new(self)
}
fn forum(&self) -> ForumScope<'_, Self> {
ForumScope::new(self)
}
}
impl Executor for ReqwestClient {
type Error = crate::Error;
async fn execute<R>(&self, request: R) -> Result<ApiResponse<R::Discriminant>, Self::Error>
where
R: IntoRequest,
{
let request = request.into_request();
let url = request.url(); let url = request.url();
let response = self.0.get(url).send().await?; let response = self.0.get(url).send().await?;
let status = response.status(); let status = response.status();
let body = response.bytes().await.ok(); let body = response.bytes().await.ok();
Ok(ApiResponse { Ok(ApiResponse { status, body })
discriminant: request.disriminant, }
status, }
body,
}) impl Executor for &ReqwestClient {
type Error = crate::Error;
async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
where
R: IntoRequest,
{
let (d, request) = request.into_request();
(d, self.execute_api_request(request).await)
}
}
impl<'e> BulkExecutor<'e> for &'e ReqwestClient {
type Error = crate::Error;
fn execute<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
where
R: IntoRequest,
{
futures::stream::iter(requests)
.map(move |r| <Self as Executor>::execute(self, r))
.buffer_unordered(25)
} }
} }
@ -157,4 +280,22 @@ mod test {
other => panic!("Expected incorrect id entity relation error, got {other:?}"), other => panic!("Expected incorrect id entity relation error, got {other:?}"),
} }
} }
#[cfg(feature = "scopes")]
#[tokio::test]
async fn bulk_request() {
let client = test_client().await;
let stream = client
.faction_bulk()
.basic_for_id(vec![19.into(), 89.into()], |b| b);
let mut responses: Vec<_> = stream.collect().await;
let (_id1, basic1) = responses.pop().unwrap();
basic1.unwrap();
let (_id2, basic2) = responses.pop().unwrap();
basic2.unwrap();
}
} }

View file

@ -5,13 +5,12 @@ use http::StatusCode;
pub mod models; pub mod models;
#[derive(Default)] #[derive(Default)]
pub struct ApiRequest<D = ()> { pub struct ApiRequest {
pub disriminant: D,
pub path: String, pub path: String,
pub parameters: Vec<(&'static str, String)>, pub parameters: Vec<(&'static str, String)>,
} }
impl<D> ApiRequest<D> { impl ApiRequest {
pub fn url(&self) -> String { pub fn url(&self) -> String {
let mut url = format!("https://api.torn.com/v2{}?", self.path); let mut url = format!("https://api.torn.com/v2{}?", self.path);
@ -23,8 +22,7 @@ impl<D> ApiRequest<D> {
} }
} }
pub struct ApiResponse<D = ()> { pub struct ApiResponse {
pub discriminant: D,
pub body: Option<Bytes>, pub body: Option<Bytes>,
pub status: StatusCode, pub status: StatusCode,
} }
@ -32,7 +30,26 @@ pub struct ApiResponse<D = ()> {
pub trait IntoRequest: Send { pub trait IntoRequest: Send {
type Discriminant: Send; type Discriminant: Send;
type Response: for<'de> serde::Deserialize<'de> + Send; type Response: for<'de> serde::Deserialize<'de> + Send;
fn into_request(self) -> ApiRequest<Self::Discriminant>; fn into_request(self) -> (Self::Discriminant, ApiRequest);
}
pub(crate) struct WrappedApiRequest<R>
where
R: IntoRequest,
{
discriminant: R::Discriminant,
request: ApiRequest,
}
impl<R> IntoRequest for WrappedApiRequest<R>
where
R: IntoRequest,
{
type Discriminant = R::Discriminant;
type Response = R::Response;
fn into_request(self) -> (Self::Discriminant, ApiRequest) {
(self.discriminant, self.request)
}
} }
#[cfg(test)] #[cfg(test)]

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "1.0.1" version = "1.1.0"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license-file = { workspace = true } license-file = { workspace = true }
@ -11,7 +11,7 @@ description = "A generalised API key pool for torn-api"
[features] [features]
default = ["postgres", "tokio-runtime"] default = ["postgres", "tokio-runtime"]
postgres = ["dep:sqlx", "dep:chrono", "dep:indoc"] postgres = ["dep:sqlx", "dep:chrono", "dep:indoc"]
tokio-runtime = ["dep:tokio", "dep:rand"] tokio-runtime = ["dep:tokio", "dep:rand", "dep:tokio-stream"]
[dependencies] [dependencies]
torn-api = { path = "../torn-api", default-features = false, version = "1.0.1" } torn-api = { path = "../torn-api", default-features = false, version = "1.0.1" }
@ -30,6 +30,9 @@ indoc = { version = "2", optional = true }
tokio = { version = "1", optional = true, default-features = false, features = [ tokio = { version = "1", optional = true, default-features = false, features = [
"time", "time",
] } ] }
tokio-stream = { version = "0.1", optional = true, default-features = false, features = [
"time",
] }
rand = { version = "0.9", optional = true } rand = { version = "0.9", optional = true }
futures = "0.3" futures = "0.3"
reqwest = { version = "0.12", default-features = false, features = [ reqwest = { version = "0.12", default-features = false, features = [

View file

@ -5,11 +5,12 @@ pub mod postgres;
use std::{collections::HashMap, future::Future, sync::Arc, time::Duration}; use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
use futures::{future::BoxFuture, FutureExt}; use futures::{future::BoxFuture, FutureExt, Stream, StreamExt};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde::Deserialize; use serde::Deserialize;
use tokio_stream::StreamExt as TokioStreamExt;
use torn_api::{ use torn_api::{
executor::Executor, executor::{BulkExecutor, Executor},
request::{ApiRequest, ApiResponse}, request::{ApiRequest, ApiResponse},
ApiError, ApiError,
}; };
@ -80,6 +81,46 @@ where
} }
} }
impl<K, D> From<&str> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: &str) -> Self {
Self::Key(value.to_owned())
}
}
impl<K, D> From<D> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: D) -> Self {
Self::Has(vec![value])
}
}
impl<K, D> From<&[D]> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: &[D]) -> Self {
Self::Has(value.to_vec())
}
}
impl<K, D> From<Vec<D>> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: Vec<D>) -> Self {
Self::Has(value)
}
}
pub trait IntoSelector<K, D>: Send pub trait IntoSelector<K, D>: Send
where where
K: ApiKey, K: ApiKey,
@ -88,30 +129,35 @@ where
fn into_selector(self) -> KeySelector<K, D>; fn into_selector(self) -> KeySelector<K, D>;
} }
impl<K, D> IntoSelector<K, D> for D impl<K, D, T> IntoSelector<K, D> for T
where where
K: ApiKey, K: ApiKey,
D: KeyDomain, D: KeyDomain,
T: Into<KeySelector<K, D>> + Send,
{ {
fn into_selector(self) -> KeySelector<K, D> { fn into_selector(self) -> KeySelector<K, D> {
KeySelector::Has(vec![self]) self.into()
} }
} }
impl<K, D> IntoSelector<K, D> for KeySelector<K, D> pub trait KeyPoolError:
where From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + From<Arc<Self>> + Send
K: ApiKey,
D: KeyDomain,
{ {
fn into_selector(self) -> KeySelector<K, D> {
self
} }
impl<T> KeyPoolError for T where
T: From<reqwest::Error>
+ From<serde_json::Error>
+ From<torn_api::ApiError>
+ From<Arc<Self>>
+ Send
{
} }
pub trait KeyPoolStorage: Send + Sync { pub trait KeyPoolStorage: Send + Sync {
type Key: ApiKey; type Key: ApiKey;
type Domain: KeyDomain; type Domain: KeyDomain;
type Error: From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + Send; type Error: KeyPoolError;
fn acquire_key<S>( fn acquire_key<S>(
&self, &self,
@ -206,65 +252,6 @@ where
>, >,
} }
pub struct KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pool: &'p KeyPool<S>,
selector: KeySelector<S::Key, S::Domain>,
}
impl<'p, S> KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
Self { pool, selector }
}
async fn execute_request<D>(&self, request: ApiRequest<D>) -> Result<ApiResponse<D>, S::Error>
where
D: Send,
{
let key = self.pool.storage.acquire_key(self.selector.clone()).await?;
let mut headers = HeaderMap::with_capacity(1);
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
);
let resp = self
.pool
.client
.get(request.url())
.headers(headers)
.send()
.await?;
let status = resp.status();
let bytes = resp.bytes().await?;
if let Some(err) = decode_error(&bytes)? {
if let Some(handler) = self.pool.options.error_hooks.get(&err.code()) {
let retry = (*handler)(&self.pool.storage, &key).await?;
if retry {
return Box::pin(self.execute_request(request)).await;
}
}
Err(err.into())
} else {
Ok(ApiResponse {
discriminant: request.disriminant,
body: Some(bytes),
status,
})
}
}
}
pub struct PoolBuilder<S> pub struct PoolBuilder<S>
where where
S: KeyPoolStorage, S: KeyPoolStorage,
@ -358,20 +345,137 @@ where
pub fn build(self) -> KeyPool<S> { pub fn build(self) -> KeyPool<S> {
KeyPool { KeyPool {
inner: Arc::new(KeyPoolInner {
client: self.client, client: self.client,
storage: self.storage, storage: self.storage,
options: Arc::new(self.options), options: self.options,
}),
} }
} }
} }
struct KeyPoolInner<S>
where
S: KeyPoolStorage,
{
pub client: reqwest::Client,
pub storage: S,
pub options: PoolOptions<S>,
}
impl<S> KeyPoolInner<S>
where
S: KeyPoolStorage,
{
async fn execute_with_key(
&self,
key: &S::Key,
request: &ApiRequest,
) -> Result<RequestResult, S::Error> {
let mut headers = HeaderMap::with_capacity(1);
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
);
let resp = self
.client
.get(request.url())
.headers(headers)
.send()
.await?;
let status = resp.status();
let bytes = resp.bytes().await?;
if let Some(err) = decode_error(&bytes)? {
if let Some(handler) = self.options.error_hooks.get(&err.code()) {
let retry = (*handler)(&self.storage, key).await?;
if retry {
return Ok(RequestResult::Retry);
}
}
Err(err.into())
} else {
Ok(RequestResult::Response(ApiResponse {
body: Some(bytes),
status,
}))
}
}
async fn execute_request(
&self,
selector: KeySelector<S::Key, S::Domain>,
request: ApiRequest,
) -> Result<ApiResponse, S::Error> {
loop {
let key = self.storage.acquire_key(selector.clone()).await?;
match self.execute_with_key(&key, &request).await {
Ok(RequestResult::Response(resp)) => return Ok(resp),
Ok(RequestResult::Retry) => (),
Err(why) => return Err(why),
}
}
}
async fn execute_bulk_requests<D, T: IntoIterator<Item = (D, ApiRequest)>>(
&self,
selector: KeySelector<S::Key, S::Domain>,
requests: T,
) -> impl Stream<Item = (D, Result<ApiResponse, S::Error>)> + use<'_, D, S, T> {
let requests: Vec<_> = requests.into_iter().collect();
let keys: Vec<_> = match self
.storage
.acquire_many_keys(selector.clone(), requests.len() as i64)
.await
{
Ok(keys) => keys.into_iter().map(Ok).collect(),
Err(why) => {
let why = Arc::new(why);
std::iter::repeat_n(why, requests.len())
.map(|e| Err(S::Error::from(e)))
.collect()
}
};
StreamExt::map(
futures::stream::iter(std::iter::zip(requests, keys)),
move |((discriminant, request), mut maybe_key)| {
let selector = selector.clone();
async move {
loop {
let key = match maybe_key {
Ok(key) => key,
Err(why) => return (discriminant, Err(why)),
};
match self.execute_with_key(&key, &request).await {
Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)),
Ok(RequestResult::Retry) => (),
Err(why) => return (discriminant, Err(why)),
}
maybe_key = self.storage.acquire_key(selector.clone()).await;
}
}
},
)
.buffer_unordered(25)
}
}
pub struct KeyPool<S> pub struct KeyPool<S>
where where
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
pub client: reqwest::Client, inner: Arc<KeyPoolInner<S>>,
pub storage: S, }
pub options: Arc<PoolOptions<S>>,
enum RequestResult {
Response(ApiResponse),
Retry,
} }
impl<S> KeyPool<S> impl<S> KeyPool<S>
@ -384,6 +488,17 @@ where
{ {
KeyPoolExecutor::new(self, selector.into_selector()) KeyPoolExecutor::new(self, selector.into_selector())
} }
pub fn throttled_torn_api<I>(
&self,
selector: I,
distance: Duration,
) -> ThrottledKeyPoolExecutor<S>
where
I: IntoSelector<S::Key, S::Domain>,
{
ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance)
}
} }
fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> { fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
@ -409,28 +524,145 @@ fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
} }
} }
impl<S> Executor for KeyPoolExecutor<'_, S> pub struct KeyPoolExecutor<'p, S>
where where
S: KeyPoolStorage, S: KeyPoolStorage,
{
pool: &'p KeyPoolInner<S>,
selector: KeySelector<S::Key, S::Domain>,
}
impl<'p, S> KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
Self {
pool: &pool.inner,
selector,
}
}
}
impl<S> Executor for KeyPoolExecutor<'_, S>
where
S: KeyPoolStorage + 'static,
{ {
type Error = S::Error; type Error = S::Error;
async fn execute<R>( async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
&self,
request: R,
) -> Result<torn_api::request::ApiResponse<R::Discriminant>, Self::Error>
where where
R: torn_api::request::IntoRequest, R: torn_api::request::IntoRequest,
{ {
let request = request.into_request(); let (d, request) = request.into_request();
self.execute_request(request).await (d, self.pool.execute_request(self.selector, request).await)
}
}
impl<'p, S> BulkExecutor<'p> for KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage + 'static,
{
type Error = S::Error;
fn execute<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
where
R: torn_api::request::IntoRequest,
{
self.pool
.execute_bulk_requests(
self.selector.clone(),
requests.into_iter().map(|r| r.into_request()),
)
.into_stream()
.flatten()
}
}
pub struct ThrottledKeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pool: &'p KeyPoolInner<S>,
selector: KeySelector<S::Key, S::Domain>,
distance: Duration,
}
impl<S> Clone for ThrottledKeyPoolExecutor<'_, S>
where
S: KeyPoolStorage,
{
fn clone(&self) -> Self {
Self {
pool: self.pool,
selector: self.selector.clone(),
distance: self.distance,
}
}
}
impl<S> ThrottledKeyPoolExecutor<'_, S>
where
S: KeyPoolStorage,
{
async fn execute_request(self, request: ApiRequest) -> Result<ApiResponse, S::Error> {
self.pool.execute_request(self.selector, request).await
}
}
impl<'p, S> ThrottledKeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pub fn new(
pool: &'p KeyPool<S>,
selector: KeySelector<S::Key, S::Domain>,
distance: Duration,
) -> Self {
Self {
pool: &pool.inner,
selector,
distance,
}
}
}
impl<'p, S> BulkExecutor<'p> for ThrottledKeyPoolExecutor<'p, S>
where
S: KeyPoolStorage + 'static,
{
type Error = S::Error;
fn execute<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
where
R: torn_api::request::IntoRequest,
{
StreamExt::map(
futures::stream::iter(requests).throttle(self.distance),
move |r| {
let this = self.clone();
async move {
let (d, request) = r.into_request();
let result = this.execute_request(request).await;
(d, result)
}
},
)
.buffer_unordered(25)
} }
} }
#[cfg(test)] #[cfg(test)]
#[cfg(feature = "postgres")]
mod test { mod test {
use torn_api::executor::ExecutorExt; use torn_api::executor::{BulkExecutorExt, ExecutorExt};
use crate::postgres; use crate::postgres;
@ -451,4 +683,48 @@ mod test {
.await .await
.unwrap(); .unwrap();
} }
#[sqlx::test]
fn bulk(pool: sqlx::PgPool) {
let (storage, _) = postgres::test::setup(pool).await;
let pool = PoolBuilder::new(storage)
.use_default_hooks()
.comment("test_runner")
.build();
let responses = pool
.torn_api(postgres::test::Domain::All)
.faction_bulk()
.basic_for_id(vec![19.into(), 89.into()], |b| b);
let mut responses: Vec<_> = StreamExt::collect(responses).await;
let (_id1, basic1) = responses.pop().unwrap();
basic1.unwrap();
let (_id2, basic2) = responses.pop().unwrap();
basic2.unwrap();
}
#[sqlx::test]
fn bulk_trottled(pool: sqlx::PgPool) {
let (storage, _) = postgres::test::setup(pool).await;
let pool = PoolBuilder::new(storage)
.use_default_hooks()
.comment("test_runner")
.build();
let responses = pool
.throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500))
.faction_bulk()
.basic_for_id(vec![19.into(), 89.into()], |b| b);
let mut responses: Vec<_> = StreamExt::collect(responses).await;
let (_id1, basic1) = responses.pop().unwrap();
basic1.unwrap();
let (_id2, basic2) = responses.pop().unwrap();
basic2.unwrap();
}
} }

View file

@ -1,3 +1,5 @@
use std::sync::Arc;
use futures::future::BoxFuture; use futures::future::BoxFuture;
use indoc::formatdoc; use indoc::formatdoc;
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder}; use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
@ -37,6 +39,9 @@ where
#[error("Key not found: '{0:?}'")] #[error("Key not found: '{0:?}'")]
KeyNotFound(KeySelector<PgKey<D>, D>), KeyNotFound(KeySelector<PgKey<D>, D>),
#[error("Failed to acquire keys in bulk: {0}")]
Bulk(#[from] Arc<Self>),
} }
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow)]