From c17f93f60046c681dac62bb85a27319c17e3acff Mon Sep 17 00:00:00 2001 From: TotallyNot <44345987+TotallyNot@users.noreply.github.com> Date: Tue, 29 Apr 2025 18:26:00 +0200 Subject: [PATCH] feat: implemented bulk requests --- Cargo.lock | 8 +- torn-api-codegen/Cargo.toml | 2 +- torn-api-codegen/src/model/path.rs | 135 ++++++++- torn-api-codegen/src/model/scope.rs | 32 +- torn-api/Cargo.toml | 8 +- torn-api/src/executor.rs | 271 +++++++++++++---- torn-api/src/request/mod.rs | 29 +- torn-key-pool/Cargo.toml | 7 +- torn-key-pool/src/lib.rs | 446 ++++++++++++++++++++++------ torn-key-pool/src/postgres.rs | 5 + 10 files changed, 767 insertions(+), 176 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1bd2c78..b8edac2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2271,10 +2271,11 @@ dependencies = [ [[package]] name = "torn-api" -version = "1.0.3" +version = "1.1.0" dependencies = [ "bon", "bytes", + "futures", "http", "prettyplease", "proc-macro2", @@ -2290,7 +2291,7 @@ dependencies = [ [[package]] name = "torn-api-codegen" -version = "0.1.5" +version = "0.2.0" dependencies = [ "heck", "indexmap", @@ -2303,7 +2304,7 @@ dependencies = [ [[package]] name = "torn-key-pool" -version = "1.0.1" +version = "1.1.0" dependencies = [ "chrono", "futures", @@ -2315,6 +2316,7 @@ dependencies = [ "sqlx", "thiserror", "tokio", + "tokio-stream", "torn-api", ] diff --git a/torn-api-codegen/Cargo.toml b/torn-api-codegen/Cargo.toml index a57803e..aca08b2 100644 --- a/torn-api-codegen/Cargo.toml +++ b/torn-api-codegen/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "torn-api-codegen" authors = ["Pyrit [2111649]"] -version = "0.1.5" +version = "0.2.0" edition = "2021" description = "Contains the v2 torn API model descriptions and codegen for the bindings" license-file = { workspace = true } diff --git a/torn-api-codegen/src/model/path.rs b/torn-api-codegen/src/model/path.rs index 922622d..0eef06d 100644 --- a/torn-api-codegen/src/model/path.rs +++ b/torn-api-codegen/src/model/path.rs @@ -284,15 +284,18 @@ impl Path { #[allow(unused_parens)] type Discriminant = (#(#discriminant),*); type Response = #response_ty; - fn into_request(self) -> crate::request::ApiRequest { + fn into_request(self) -> (Self::Discriminant, crate::request::ApiRequest) { + let path = format!(#path_fmt_str, #(#fmt_val),*); #[allow(unused_parens)] - crate::request::ApiRequest { - path: format!(#path_fmt_str, #(#fmt_val),*), - parameters: std::iter::empty() - #(#convert_field)* - .collect(), - disriminant: (#(#discriminant_val),*), - } + ( + (#(#discriminant_val),*), + crate::request::ApiRequest { + path, + parameters: std::iter::empty() + #(#convert_field)* + .collect(), + } + ) } } }) @@ -376,7 +379,7 @@ impl Path { Some(quote! { pub async fn #fn_name( - &self, + self, #(#extra_args)* builder: impl FnOnce( #builder_path<#builder_mod_path::Empty> @@ -391,6 +394,120 @@ impl Path { } }) } + + pub fn codegen_bulk_scope_call(&self) -> Option { + let mut disc = Vec::new(); + let mut disc_ty = Vec::new(); + + let snake_name = self.name.to_snake_case(); + + let request_name = format_ident!("{}Request", self.name); + let builder_name = format_ident!("{}RequestBuilder", self.name); + let builder_mod_name = format_ident!("{}_request_builder", snake_name); + let request_mod_name = format_ident!("{snake_name}"); + + let request_path = quote! { crate::request::models::#request_name }; + let builder_path = quote! { crate::request::models::#builder_name }; + let builder_mod_path = quote! { crate::request::models::#builder_mod_name }; + + let tail = snake_name + .split_once('_') + .map_or_else(|| "for_selections".to_owned(), |(_, tail)| tail.to_owned()); + + let fn_name = format_ident!("{tail}"); + + for param in &self.parameters { + let (param, is_inline) = match param { + PathParameter::Inline(param) => (param, true), + PathParameter::Component(param) => (param, false), + }; + + if param.location == ParameterLocation::Path { + let ty = match ¶m.r#type { + ParameterType::I32 { .. } | ParameterType::Enum { .. } => { + let ty_name = format_ident!("{}", param.name); + + if is_inline { + quote! { + crate::request::models::#request_mod_name::#ty_name + } + } else { + quote! { + crate::parameters::#ty_name + } + } + } + ParameterType::String => quote! { String }, + ParameterType::Boolean => quote! { bool }, + ParameterType::Schema { type_name } => { + let ty_name = format_ident!("{}", type_name); + + quote! { + crate::models::#ty_name + } + } + ParameterType::Array { .. } => param.r#type.codegen_type_name(¶m.name), + }; + + let arg_name = format_ident!("{}", param.value.to_snake_case()); + + disc_ty.push(ty); + disc.push(arg_name); + } + } + + if disc.is_empty() { + return None; + } + + let response_ty = match &self.response { + PathResponse::Component { name } => { + let name = format_ident!("{name}"); + quote! { + crate::models::#name + } + } + PathResponse::ArbitraryUnion(union) => { + let name = format_ident!("{}", union.name); + quote! { + crate::request::models::#request_mod_name::#name + } + } + }; + + let disc = if disc.len() > 1 { + quote! { (#(#disc),*) } + } else { + quote! { #(#disc),* } + }; + + let disc_ty = if disc_ty.len() > 1 { + quote! { (#(#disc_ty),*) } + } else { + quote! { #(#disc_ty),* } + }; + + Some(quote! { + pub fn #fn_name( + self, + ids: I, + builder: B + ) -> impl futures::Stream)> + use<'e, E, S, I, B> + where + I: IntoIterator, + S: #builder_mod_path::IsComplete, + B: Fn( + #builder_path<#builder_mod_path::Empty> + ) -> #builder_path, + { + let requests = ids.into_iter() + .map(move |#disc| builder(#request_path::builder(#disc)).build()); + + let executor = self.executor; + executor.fetch_many(requests) + } + }) + } } pub struct PathNamespace<'r> { diff --git a/torn-api-codegen/src/model/scope.rs b/torn-api-codegen/src/model/scope.rs index 2aa57bc..b8fa7b5 100644 --- a/torn-api-codegen/src/model/scope.rs +++ b/torn-api-codegen/src/model/scope.rs @@ -35,30 +35,56 @@ impl Scope { pub fn codegen(&self) -> Option { let name = format_ident!("{}", self.name); + let bulk_name = format_ident!("Bulk{}", self.name); let mut functions = Vec::with_capacity(self.members.len()); + let mut bulk_functions = Vec::with_capacity(self.members.len()); for member in &self.members { if let Some(code) = member.codegen_scope_call() { functions.push(code); } + if let Some(code) = member.codegen_bulk_scope_call() { + bulk_functions.push(code); + } } Some(quote! { - pub struct #name<'e, E>(&'e E) + pub struct #name(E) where E: crate::executor::Executor; - impl<'e, E> #name<'e, E> + impl #name where E: crate::executor::Executor { - pub fn new(executor: &'e E) -> Self { + pub fn new(executor: E) -> Self { Self(executor) } #(#functions)* } + + pub struct #bulk_name<'e, E> where + E: crate::executor::BulkExecutor<'e>, + { + executor: E, + marker: std::marker::PhantomData<&'e E>, + } + + impl<'e, E> #bulk_name<'e, E> + where + E: crate::executor::BulkExecutor<'e> + { + pub fn new(executor: E) -> Self { + Self { + executor, + marker: std::marker::PhantomData, + } + } + + #(#bulk_functions)* + } }) } } diff --git a/torn-api/Cargo.toml b/torn-api/Cargo.toml index 42c80f1..5811980 100644 --- a/torn-api/Cargo.toml +++ b/torn-api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-api" -version = "1.0.3" +version = "1.1.0" edition = "2021" description = "Auto-generated bindings for the v2 torn api" license-file = { workspace = true } @@ -27,12 +27,16 @@ reqwest = { version = "0.12", default-features = false, features = [ "brotli", ] } thiserror = "2" +futures = { version = "0.3", default-features = false, features = [ + "std", + "async-await", +] } [dev-dependencies] tokio = { version = "1", features = ["full"] } [build-dependencies] -torn-api-codegen = { path = "../torn-api-codegen", version = "0.1.5" } +torn-api-codegen = { path = "../torn-api-codegen", version = "0.2" } syn = { workspace = true, features = ["parsing"] } proc-macro2 = { workspace = true } prettyplease = "0.2" diff --git a/torn-api/src/executor.rs b/torn-api/src/executor.rs index dfd75dd..c3d2486 100644 --- a/torn-api/src/executor.rs +++ b/torn-api/src/executor.rs @@ -1,23 +1,27 @@ use std::future::Future; +use futures::{Stream, StreamExt}; use http::{header::AUTHORIZATION, HeaderMap, HeaderValue}; use serde::Deserialize; -use crate::request::{ApiResponse, IntoRequest}; +use crate::request::{ApiRequest, ApiResponse, IntoRequest}; #[cfg(feature = "scopes")] -use crate::scopes::{FactionScope, ForumScope, MarketScope, RacingScope, TornScope, UserScope}; +use crate::scopes::{ + BulkFactionScope, BulkForumScope, BulkMarketScope, BulkRacingScope, BulkTornScope, + BulkUserScope, FactionScope, ForumScope, MarketScope, RacingScope, TornScope, UserScope, +}; -pub trait Executor { +pub trait Executor: Sized { type Error: From + From + Send; fn execute( - &self, + self, request: R, - ) -> impl Future, Self::Error>> + Send + ) -> impl Future)> + Send where R: IntoRequest; - fn fetch(&self, request: R) -> impl Future> + Send + fn fetch(self, request: R) -> impl Future> + Send where R: IntoRequest, { @@ -25,7 +29,7 @@ pub trait Executor { // The future is `Send` but `&self` might not be. let fut = self.execute(request); async { - let resp = fut.await?; + let resp = fut.await.1?; let bytes = resp.body.unwrap(); @@ -52,6 +56,152 @@ pub trait Executor { } } +pub trait BulkExecutor<'e>: 'e + Sized { + type Error: From + From + Send; + + fn execute( + self, + requests: impl IntoIterator, + ) -> impl Stream)> + where + R: IntoRequest; + + fn fetch_many( + self, + requests: impl IntoIterator, + ) -> impl Stream)> + where + R: IntoRequest, + { + self.execute(requests).map(|(d, r)| { + let r = match r { + Ok(r) => r, + Err(why) => return (d, Err(why)), + }; + let bytes = r.body.unwrap(); + + if bytes.starts_with(br#"{"error":{"#) { + #[derive(Deserialize)] + struct ErrorBody<'a> { + code: u16, + error: &'a str, + } + #[derive(Deserialize)] + struct ErrorContainer<'a> { + #[serde(borrow)] + error: ErrorBody<'a>, + } + + let error: ErrorContainer = match serde_json::from_slice(&bytes) { + Ok(error) => error, + Err(why) => return (d, Err(why.into())), + }; + return ( + d, + Err(crate::ApiError::new(error.error.code, error.error.error).into()), + ); + } + + let resp = match serde_json::from_slice(&bytes) { + Ok(resp) => resp, + Err(why) => return (d, Err(why.into())), + }; + + (d, Ok(resp)) + }) + } +} + +#[cfg(feature = "scopes")] +pub trait ExecutorExt: Executor + Sized { + fn user(self) -> UserScope; + + fn faction(self) -> FactionScope; + + fn torn(self) -> TornScope; + + fn market(self) -> MarketScope; + + fn racing(self) -> RacingScope; + + fn forum(self) -> ForumScope; +} + +#[cfg(feature = "scopes")] +impl ExecutorExt for T +where + T: Executor + Sized, +{ + fn user(self) -> UserScope { + UserScope::new(self) + } + + fn faction(self) -> FactionScope { + FactionScope::new(self) + } + + fn torn(self) -> TornScope { + TornScope::new(self) + } + + fn market(self) -> MarketScope { + MarketScope::new(self) + } + + fn racing(self) -> RacingScope { + RacingScope::new(self) + } + + fn forum(self) -> ForumScope { + ForumScope::new(self) + } +} + +#[cfg(feature = "scopes")] +pub trait BulkExecutorExt<'e>: BulkExecutor<'e> + Sized { + fn user_bulk(self) -> BulkUserScope<'e, Self>; + + fn faction_bulk(self) -> BulkFactionScope<'e, Self>; + + fn torn_bulk(self) -> BulkTornScope<'e, Self>; + + fn market_bulk(self) -> BulkMarketScope<'e, Self>; + + fn racing_bulk(self) -> BulkRacingScope<'e, Self>; + + fn forum_bulk(self) -> BulkForumScope<'e, Self>; +} + +#[cfg(feature = "scopes")] +impl<'e, T> BulkExecutorExt<'e> for T +where + T: BulkExecutor<'e> + Sized, +{ + fn user_bulk(self) -> BulkUserScope<'e, Self> { + BulkUserScope::new(self) + } + + fn faction_bulk(self) -> BulkFactionScope<'e, Self> { + BulkFactionScope::new(self) + } + + fn torn_bulk(self) -> BulkTornScope<'e, Self> { + BulkTornScope::new(self) + } + + fn market_bulk(self) -> BulkMarketScope<'e, Self> { + BulkMarketScope::new(self) + } + + fn racing_bulk(self) -> BulkRacingScope<'e, Self> { + BulkRacingScope::new(self) + } + + fn forum_bulk(self) -> BulkForumScope<'e, Self> { + BulkForumScope::new(self) + } +} + pub struct ReqwestClient(reqwest::Client); impl ReqwestClient { @@ -72,70 +222,43 @@ impl ReqwestClient { } } -#[cfg(feature = "scopes")] -pub trait ExecutorExt: Executor + Sized { - fn user(&self) -> UserScope<'_, Self>; - - fn faction(&self) -> FactionScope<'_, Self>; - - fn torn(&self) -> TornScope<'_, Self>; - - fn market(&self) -> MarketScope<'_, Self>; - - fn racing(&self) -> RacingScope<'_, Self>; - - fn forum(&self) -> ForumScope<'_, Self>; -} - -#[cfg(feature = "scopes")] -impl ExecutorExt for T -where - T: Executor + Sized, -{ - fn user(&self) -> UserScope<'_, Self> { - UserScope::new(self) - } - - fn faction(&self) -> FactionScope<'_, Self> { - FactionScope::new(self) - } - - fn torn(&self) -> TornScope<'_, Self> { - TornScope::new(self) - } - - fn market(&self) -> MarketScope<'_, Self> { - MarketScope::new(self) - } - - fn racing(&self) -> RacingScope<'_, Self> { - RacingScope::new(self) - } - - fn forum(&self) -> ForumScope<'_, Self> { - ForumScope::new(self) - } -} - -impl Executor for ReqwestClient { - type Error = crate::Error; - - async fn execute(&self, request: R) -> Result, Self::Error> - where - R: IntoRequest, - { - let request = request.into_request(); +impl ReqwestClient { + async fn execute_api_request(&self, request: ApiRequest) -> Result { let url = request.url(); let response = self.0.get(url).send().await?; let status = response.status(); let body = response.bytes().await.ok(); - Ok(ApiResponse { - discriminant: request.disriminant, - status, - body, - }) + Ok(ApiResponse { status, body }) + } +} + +impl Executor for &ReqwestClient { + type Error = crate::Error; + + async fn execute(self, request: R) -> (R::Discriminant, Result) + where + R: IntoRequest, + { + let (d, request) = request.into_request(); + (d, self.execute_api_request(request).await) + } +} + +impl<'e> BulkExecutor<'e> for &'e ReqwestClient { + type Error = crate::Error; + + fn execute( + self, + requests: impl IntoIterator, + ) -> impl Stream)> + where + R: IntoRequest, + { + futures::stream::iter(requests) + .map(move |r| ::execute(self, r)) + .buffer_unordered(25) } } @@ -157,4 +280,22 @@ mod test { other => panic!("Expected incorrect id entity relation error, got {other:?}"), } } + + #[cfg(feature = "scopes")] + #[tokio::test] + async fn bulk_request() { + let client = test_client().await; + + let stream = client + .faction_bulk() + .basic_for_id(vec![19.into(), 89.into()], |b| b); + + let mut responses: Vec<_> = stream.collect().await; + + let (_id1, basic1) = responses.pop().unwrap(); + basic1.unwrap(); + + let (_id2, basic2) = responses.pop().unwrap(); + basic2.unwrap(); + } } diff --git a/torn-api/src/request/mod.rs b/torn-api/src/request/mod.rs index 3f7bfde..dae42be 100644 --- a/torn-api/src/request/mod.rs +++ b/torn-api/src/request/mod.rs @@ -5,13 +5,12 @@ use http::StatusCode; pub mod models; #[derive(Default)] -pub struct ApiRequest { - pub disriminant: D, +pub struct ApiRequest { pub path: String, pub parameters: Vec<(&'static str, String)>, } -impl ApiRequest { +impl ApiRequest { pub fn url(&self) -> String { let mut url = format!("https://api.torn.com/v2{}?", self.path); @@ -23,8 +22,7 @@ impl ApiRequest { } } -pub struct ApiResponse { - pub discriminant: D, +pub struct ApiResponse { pub body: Option, pub status: StatusCode, } @@ -32,7 +30,26 @@ pub struct ApiResponse { pub trait IntoRequest: Send { type Discriminant: Send; type Response: for<'de> serde::Deserialize<'de> + Send; - fn into_request(self) -> ApiRequest; + fn into_request(self) -> (Self::Discriminant, ApiRequest); +} + +pub(crate) struct WrappedApiRequest +where + R: IntoRequest, +{ + discriminant: R::Discriminant, + request: ApiRequest, +} + +impl IntoRequest for WrappedApiRequest +where + R: IntoRequest, +{ + type Discriminant = R::Discriminant; + type Response = R::Response; + fn into_request(self) -> (Self::Discriminant, ApiRequest) { + (self.discriminant, self.request) + } } #[cfg(test)] diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index 88c3724..94d87e6 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "1.0.1" +version = "1.1.0" edition = "2021" authors = ["Pyrit [2111649]"] license-file = { workspace = true } @@ -11,7 +11,7 @@ description = "A generalised API key pool for torn-api" [features] default = ["postgres", "tokio-runtime"] postgres = ["dep:sqlx", "dep:chrono", "dep:indoc"] -tokio-runtime = ["dep:tokio", "dep:rand"] +tokio-runtime = ["dep:tokio", "dep:rand", "dep:tokio-stream"] [dependencies] torn-api = { path = "../torn-api", default-features = false, version = "1.0.1" } @@ -30,6 +30,9 @@ indoc = { version = "2", optional = true } tokio = { version = "1", optional = true, default-features = false, features = [ "time", ] } +tokio-stream = { version = "0.1", optional = true, default-features = false, features = [ + "time", +] } rand = { version = "0.9", optional = true } futures = "0.3" reqwest = { version = "0.12", default-features = false, features = [ diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index 161cd9a..991ad54 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -5,11 +5,12 @@ pub mod postgres; use std::{collections::HashMap, future::Future, sync::Arc, time::Duration}; -use futures::{future::BoxFuture, FutureExt}; +use futures::{future::BoxFuture, FutureExt, Stream, StreamExt}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use serde::Deserialize; +use tokio_stream::StreamExt as TokioStreamExt; use torn_api::{ - executor::Executor, + executor::{BulkExecutor, Executor}, request::{ApiRequest, ApiResponse}, ApiError, }; @@ -80,6 +81,46 @@ where } } +impl From<&str> for KeySelector +where + K: ApiKey, + D: KeyDomain, +{ + fn from(value: &str) -> Self { + Self::Key(value.to_owned()) + } +} + +impl From for KeySelector +where + K: ApiKey, + D: KeyDomain, +{ + fn from(value: D) -> Self { + Self::Has(vec![value]) + } +} + +impl From<&[D]> for KeySelector +where + K: ApiKey, + D: KeyDomain, +{ + fn from(value: &[D]) -> Self { + Self::Has(value.to_vec()) + } +} + +impl From> for KeySelector +where + K: ApiKey, + D: KeyDomain, +{ + fn from(value: Vec) -> Self { + Self::Has(value) + } +} + pub trait IntoSelector: Send where K: ApiKey, @@ -88,30 +129,35 @@ where fn into_selector(self) -> KeySelector; } -impl IntoSelector for D +impl IntoSelector for T where K: ApiKey, D: KeyDomain, + T: Into> + Send, { fn into_selector(self) -> KeySelector { - KeySelector::Has(vec![self]) + self.into() } } -impl IntoSelector for KeySelector -where - K: ApiKey, - D: KeyDomain, +pub trait KeyPoolError: + From + From + From + From> + Send +{ +} + +impl KeyPoolError for T where + T: From + + From + + From + + From> + + Send { - fn into_selector(self) -> KeySelector { - self - } } pub trait KeyPoolStorage: Send + Sync { type Key: ApiKey; type Domain: KeyDomain; - type Error: From + From + From + Send; + type Error: KeyPoolError; fn acquire_key( &self, @@ -206,65 +252,6 @@ where >, } -pub struct KeyPoolExecutor<'p, S> -where - S: KeyPoolStorage, -{ - pool: &'p KeyPool, - selector: KeySelector, -} - -impl<'p, S> KeyPoolExecutor<'p, S> -where - S: KeyPoolStorage, -{ - pub fn new(pool: &'p KeyPool, selector: KeySelector) -> Self { - Self { pool, selector } - } - - async fn execute_request(&self, request: ApiRequest) -> Result, S::Error> - where - D: Send, - { - let key = self.pool.storage.acquire_key(self.selector.clone()).await?; - - let mut headers = HeaderMap::with_capacity(1); - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(), - ); - - let resp = self - .pool - .client - .get(request.url()) - .headers(headers) - .send() - .await?; - - let status = resp.status(); - - let bytes = resp.bytes().await?; - - if let Some(err) = decode_error(&bytes)? { - if let Some(handler) = self.pool.options.error_hooks.get(&err.code()) { - let retry = (*handler)(&self.pool.storage, &key).await?; - - if retry { - return Box::pin(self.execute_request(request)).await; - } - } - Err(err.into()) - } else { - Ok(ApiResponse { - discriminant: request.disriminant, - body: Some(bytes), - status, - }) - } - } -} - pub struct PoolBuilder where S: KeyPoolStorage, @@ -358,20 +345,137 @@ where pub fn build(self) -> KeyPool { KeyPool { - client: self.client, - storage: self.storage, - options: Arc::new(self.options), + inner: Arc::new(KeyPoolInner { + client: self.client, + storage: self.storage, + options: self.options, + }), } } } +struct KeyPoolInner +where + S: KeyPoolStorage, +{ + pub client: reqwest::Client, + pub storage: S, + pub options: PoolOptions, +} + +impl KeyPoolInner +where + S: KeyPoolStorage, +{ + async fn execute_with_key( + &self, + key: &S::Key, + request: &ApiRequest, + ) -> Result { + let mut headers = HeaderMap::with_capacity(1); + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(), + ); + + let resp = self + .client + .get(request.url()) + .headers(headers) + .send() + .await?; + + let status = resp.status(); + + let bytes = resp.bytes().await?; + + if let Some(err) = decode_error(&bytes)? { + if let Some(handler) = self.options.error_hooks.get(&err.code()) { + let retry = (*handler)(&self.storage, key).await?; + + if retry { + return Ok(RequestResult::Retry); + } + } + Err(err.into()) + } else { + Ok(RequestResult::Response(ApiResponse { + body: Some(bytes), + status, + })) + } + } + + async fn execute_request( + &self, + selector: KeySelector, + request: ApiRequest, + ) -> Result { + loop { + let key = self.storage.acquire_key(selector.clone()).await?; + match self.execute_with_key(&key, &request).await { + Ok(RequestResult::Response(resp)) => return Ok(resp), + Ok(RequestResult::Retry) => (), + Err(why) => return Err(why), + } + } + } + + async fn execute_bulk_requests>( + &self, + selector: KeySelector, + requests: T, + ) -> impl Stream)> + use<'_, D, S, T> { + let requests: Vec<_> = requests.into_iter().collect(); + + let keys: Vec<_> = match self + .storage + .acquire_many_keys(selector.clone(), requests.len() as i64) + .await + { + Ok(keys) => keys.into_iter().map(Ok).collect(), + Err(why) => { + let why = Arc::new(why); + std::iter::repeat_n(why, requests.len()) + .map(|e| Err(S::Error::from(e))) + .collect() + } + }; + + StreamExt::map( + futures::stream::iter(std::iter::zip(requests, keys)), + move |((discriminant, request), mut maybe_key)| { + let selector = selector.clone(); + async move { + loop { + let key = match maybe_key { + Ok(key) => key, + Err(why) => return (discriminant, Err(why)), + }; + match self.execute_with_key(&key, &request).await { + Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)), + Ok(RequestResult::Retry) => (), + Err(why) => return (discriminant, Err(why)), + } + maybe_key = self.storage.acquire_key(selector.clone()).await; + } + } + }, + ) + .buffer_unordered(25) + } +} + pub struct KeyPool where S: KeyPoolStorage, { - pub client: reqwest::Client, - pub storage: S, - pub options: Arc>, + inner: Arc>, +} + +enum RequestResult { + Response(ApiResponse), + Retry, } impl KeyPool @@ -384,6 +488,17 @@ where { KeyPoolExecutor::new(self, selector.into_selector()) } + + pub fn throttled_torn_api( + &self, + selector: I, + distance: Duration, + ) -> ThrottledKeyPoolExecutor + where + I: IntoSelector, + { + ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance) + } } fn decode_error(buf: &[u8]) -> Result, serde_json::Error> { @@ -409,28 +524,145 @@ fn decode_error(buf: &[u8]) -> Result, serde_json::Error> { } } -impl Executor for KeyPoolExecutor<'_, S> +pub struct KeyPoolExecutor<'p, S> where S: KeyPoolStorage, +{ + pool: &'p KeyPoolInner, + selector: KeySelector, +} + +impl<'p, S> KeyPoolExecutor<'p, S> +where + S: KeyPoolStorage, +{ + pub fn new(pool: &'p KeyPool, selector: KeySelector) -> Self { + Self { + pool: &pool.inner, + selector, + } + } +} + +impl Executor for KeyPoolExecutor<'_, S> +where + S: KeyPoolStorage + 'static, { type Error = S::Error; - async fn execute( - &self, - request: R, - ) -> Result, Self::Error> + async fn execute(self, request: R) -> (R::Discriminant, Result) where R: torn_api::request::IntoRequest, { - let request = request.into_request(); + let (d, request) = request.into_request(); - self.execute_request(request).await + (d, self.pool.execute_request(self.selector, request).await) + } +} + +impl<'p, S> BulkExecutor<'p> for KeyPoolExecutor<'p, S> +where + S: KeyPoolStorage + 'static, +{ + type Error = S::Error; + + fn execute( + self, + requests: impl IntoIterator, + ) -> impl futures::Stream)> + where + R: torn_api::request::IntoRequest, + { + self.pool + .execute_bulk_requests( + self.selector.clone(), + requests.into_iter().map(|r| r.into_request()), + ) + .into_stream() + .flatten() + } +} + +pub struct ThrottledKeyPoolExecutor<'p, S> +where + S: KeyPoolStorage, +{ + pool: &'p KeyPoolInner, + selector: KeySelector, + distance: Duration, +} + +impl Clone for ThrottledKeyPoolExecutor<'_, S> +where + S: KeyPoolStorage, +{ + fn clone(&self) -> Self { + Self { + pool: self.pool, + selector: self.selector.clone(), + distance: self.distance, + } + } +} + +impl ThrottledKeyPoolExecutor<'_, S> +where + S: KeyPoolStorage, +{ + async fn execute_request(self, request: ApiRequest) -> Result { + self.pool.execute_request(self.selector, request).await + } +} + +impl<'p, S> ThrottledKeyPoolExecutor<'p, S> +where + S: KeyPoolStorage, +{ + pub fn new( + pool: &'p KeyPool, + selector: KeySelector, + distance: Duration, + ) -> Self { + Self { + pool: &pool.inner, + selector, + distance, + } + } +} + +impl<'p, S> BulkExecutor<'p> for ThrottledKeyPoolExecutor<'p, S> +where + S: KeyPoolStorage + 'static, +{ + type Error = S::Error; + + fn execute( + self, + requests: impl IntoIterator, + ) -> impl futures::Stream)> + where + R: torn_api::request::IntoRequest, + { + StreamExt::map( + futures::stream::iter(requests).throttle(self.distance), + move |r| { + let this = self.clone(); + async move { + let (d, request) = r.into_request(); + let result = this.execute_request(request).await; + (d, result) + } + }, + ) + .buffer_unordered(25) } } #[cfg(test)] +#[cfg(feature = "postgres")] mod test { - use torn_api::executor::ExecutorExt; + use torn_api::executor::{BulkExecutorExt, ExecutorExt}; use crate::postgres; @@ -451,4 +683,48 @@ mod test { .await .unwrap(); } + + #[sqlx::test] + fn bulk(pool: sqlx::PgPool) { + let (storage, _) = postgres::test::setup(pool).await; + + let pool = PoolBuilder::new(storage) + .use_default_hooks() + .comment("test_runner") + .build(); + + let responses = pool + .torn_api(postgres::test::Domain::All) + .faction_bulk() + .basic_for_id(vec![19.into(), 89.into()], |b| b); + let mut responses: Vec<_> = StreamExt::collect(responses).await; + + let (_id1, basic1) = responses.pop().unwrap(); + basic1.unwrap(); + + let (_id2, basic2) = responses.pop().unwrap(); + basic2.unwrap(); + } + + #[sqlx::test] + fn bulk_trottled(pool: sqlx::PgPool) { + let (storage, _) = postgres::test::setup(pool).await; + + let pool = PoolBuilder::new(storage) + .use_default_hooks() + .comment("test_runner") + .build(); + + let responses = pool + .throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500)) + .faction_bulk() + .basic_for_id(vec![19.into(), 89.into()], |b| b); + let mut responses: Vec<_> = StreamExt::collect(responses).await; + + let (_id1, basic1) = responses.pop().unwrap(); + basic1.unwrap(); + + let (_id2, basic2) = responses.pop().unwrap(); + basic2.unwrap(); + } } diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 21651b0..aa142e4 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use futures::future::BoxFuture; use indoc::formatdoc; use sqlx::{FromRow, PgPool, Postgres, QueryBuilder}; @@ -37,6 +39,9 @@ where #[error("Key not found: '{0:?}'")] KeyNotFound(KeySelector, D>), + + #[error("Failed to acquire keys in bulk: {0}")] + Bulk(#[from] Arc), } #[derive(Debug, Clone, FromRow)]