feat: implemented bulk requests
This commit is contained in:
parent
4dd4fd37d4
commit
c17f93f600
8
Cargo.lock
generated
8
Cargo.lock
generated
|
@ -2271,10 +2271,11 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "torn-api"
|
||||
version = "1.0.3"
|
||||
version = "1.1.0"
|
||||
dependencies = [
|
||||
"bon",
|
||||
"bytes",
|
||||
"futures",
|
||||
"http",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
|
@ -2290,7 +2291,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "torn-api-codegen"
|
||||
version = "0.1.5"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"indexmap",
|
||||
|
@ -2303,7 +2304,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "torn-key-pool"
|
||||
version = "1.0.1"
|
||||
version = "1.1.0"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"futures",
|
||||
|
@ -2315,6 +2316,7 @@ dependencies = [
|
|||
"sqlx",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"torn-api",
|
||||
]
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
name = "torn-api-codegen"
|
||||
authors = ["Pyrit [2111649]"]
|
||||
version = "0.1.5"
|
||||
version = "0.2.0"
|
||||
edition = "2021"
|
||||
description = "Contains the v2 torn API model descriptions and codegen for the bindings"
|
||||
license-file = { workspace = true }
|
||||
|
|
|
@ -284,15 +284,18 @@ impl Path {
|
|||
#[allow(unused_parens)]
|
||||
type Discriminant = (#(#discriminant),*);
|
||||
type Response = #response_ty;
|
||||
fn into_request(self) -> crate::request::ApiRequest<Self::Discriminant> {
|
||||
fn into_request(self) -> (Self::Discriminant, crate::request::ApiRequest) {
|
||||
let path = format!(#path_fmt_str, #(#fmt_val),*);
|
||||
#[allow(unused_parens)]
|
||||
crate::request::ApiRequest {
|
||||
path: format!(#path_fmt_str, #(#fmt_val),*),
|
||||
parameters: std::iter::empty()
|
||||
#(#convert_field)*
|
||||
.collect(),
|
||||
disriminant: (#(#discriminant_val),*),
|
||||
}
|
||||
(
|
||||
(#(#discriminant_val),*),
|
||||
crate::request::ApiRequest {
|
||||
path,
|
||||
parameters: std::iter::empty()
|
||||
#(#convert_field)*
|
||||
.collect(),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -376,7 +379,7 @@ impl Path {
|
|||
|
||||
Some(quote! {
|
||||
pub async fn #fn_name<S>(
|
||||
&self,
|
||||
self,
|
||||
#(#extra_args)*
|
||||
builder: impl FnOnce(
|
||||
#builder_path<#builder_mod_path::Empty>
|
||||
|
@ -391,6 +394,120 @@ impl Path {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn codegen_bulk_scope_call(&self) -> Option<TokenStream> {
|
||||
let mut disc = Vec::new();
|
||||
let mut disc_ty = Vec::new();
|
||||
|
||||
let snake_name = self.name.to_snake_case();
|
||||
|
||||
let request_name = format_ident!("{}Request", self.name);
|
||||
let builder_name = format_ident!("{}RequestBuilder", self.name);
|
||||
let builder_mod_name = format_ident!("{}_request_builder", snake_name);
|
||||
let request_mod_name = format_ident!("{snake_name}");
|
||||
|
||||
let request_path = quote! { crate::request::models::#request_name };
|
||||
let builder_path = quote! { crate::request::models::#builder_name };
|
||||
let builder_mod_path = quote! { crate::request::models::#builder_mod_name };
|
||||
|
||||
let tail = snake_name
|
||||
.split_once('_')
|
||||
.map_or_else(|| "for_selections".to_owned(), |(_, tail)| tail.to_owned());
|
||||
|
||||
let fn_name = format_ident!("{tail}");
|
||||
|
||||
for param in &self.parameters {
|
||||
let (param, is_inline) = match param {
|
||||
PathParameter::Inline(param) => (param, true),
|
||||
PathParameter::Component(param) => (param, false),
|
||||
};
|
||||
|
||||
if param.location == ParameterLocation::Path {
|
||||
let ty = match ¶m.r#type {
|
||||
ParameterType::I32 { .. } | ParameterType::Enum { .. } => {
|
||||
let ty_name = format_ident!("{}", param.name);
|
||||
|
||||
if is_inline {
|
||||
quote! {
|
||||
crate::request::models::#request_mod_name::#ty_name
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
crate::parameters::#ty_name
|
||||
}
|
||||
}
|
||||
}
|
||||
ParameterType::String => quote! { String },
|
||||
ParameterType::Boolean => quote! { bool },
|
||||
ParameterType::Schema { type_name } => {
|
||||
let ty_name = format_ident!("{}", type_name);
|
||||
|
||||
quote! {
|
||||
crate::models::#ty_name
|
||||
}
|
||||
}
|
||||
ParameterType::Array { .. } => param.r#type.codegen_type_name(¶m.name),
|
||||
};
|
||||
|
||||
let arg_name = format_ident!("{}", param.value.to_snake_case());
|
||||
|
||||
disc_ty.push(ty);
|
||||
disc.push(arg_name);
|
||||
}
|
||||
}
|
||||
|
||||
if disc.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let response_ty = match &self.response {
|
||||
PathResponse::Component { name } => {
|
||||
let name = format_ident!("{name}");
|
||||
quote! {
|
||||
crate::models::#name
|
||||
}
|
||||
}
|
||||
PathResponse::ArbitraryUnion(union) => {
|
||||
let name = format_ident!("{}", union.name);
|
||||
quote! {
|
||||
crate::request::models::#request_mod_name::#name
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let disc = if disc.len() > 1 {
|
||||
quote! { (#(#disc),*) }
|
||||
} else {
|
||||
quote! { #(#disc),* }
|
||||
};
|
||||
|
||||
let disc_ty = if disc_ty.len() > 1 {
|
||||
quote! { (#(#disc_ty),*) }
|
||||
} else {
|
||||
quote! { #(#disc_ty),* }
|
||||
};
|
||||
|
||||
Some(quote! {
|
||||
pub fn #fn_name<S, I, B>(
|
||||
self,
|
||||
ids: I,
|
||||
builder: B
|
||||
) -> impl futures::Stream<Item = (#disc_ty, Result<#response_ty, E::Error>)> + use<'e, E, S, I, B>
|
||||
where
|
||||
I: IntoIterator<Item = #disc_ty>,
|
||||
S: #builder_mod_path::IsComplete,
|
||||
B: Fn(
|
||||
#builder_path<#builder_mod_path::Empty>
|
||||
) -> #builder_path<S>,
|
||||
{
|
||||
let requests = ids.into_iter()
|
||||
.map(move |#disc| builder(#request_path::builder(#disc)).build());
|
||||
|
||||
let executor = self.executor;
|
||||
executor.fetch_many(requests)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PathNamespace<'r> {
|
||||
|
|
|
@ -35,30 +35,56 @@ impl Scope {
|
|||
|
||||
pub fn codegen(&self) -> Option<TokenStream> {
|
||||
let name = format_ident!("{}", self.name);
|
||||
let bulk_name = format_ident!("Bulk{}", self.name);
|
||||
|
||||
let mut functions = Vec::with_capacity(self.members.len());
|
||||
let mut bulk_functions = Vec::with_capacity(self.members.len());
|
||||
|
||||
for member in &self.members {
|
||||
if let Some(code) = member.codegen_scope_call() {
|
||||
functions.push(code);
|
||||
}
|
||||
if let Some(code) = member.codegen_bulk_scope_call() {
|
||||
bulk_functions.push(code);
|
||||
}
|
||||
}
|
||||
|
||||
Some(quote! {
|
||||
pub struct #name<'e, E>(&'e E)
|
||||
pub struct #name<E>(E)
|
||||
where
|
||||
E: crate::executor::Executor;
|
||||
|
||||
impl<'e, E> #name<'e, E>
|
||||
impl<E> #name<E>
|
||||
where
|
||||
E: crate::executor::Executor
|
||||
{
|
||||
pub fn new(executor: &'e E) -> Self {
|
||||
pub fn new(executor: E) -> Self {
|
||||
Self(executor)
|
||||
}
|
||||
|
||||
#(#functions)*
|
||||
}
|
||||
|
||||
pub struct #bulk_name<'e, E> where
|
||||
E: crate::executor::BulkExecutor<'e>,
|
||||
{
|
||||
executor: E,
|
||||
marker: std::marker::PhantomData<&'e E>,
|
||||
}
|
||||
|
||||
impl<'e, E> #bulk_name<'e, E>
|
||||
where
|
||||
E: crate::executor::BulkExecutor<'e>
|
||||
{
|
||||
pub fn new(executor: E) -> Self {
|
||||
Self {
|
||||
executor,
|
||||
marker: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#(#bulk_functions)*
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "torn-api"
|
||||
version = "1.0.3"
|
||||
version = "1.1.0"
|
||||
edition = "2021"
|
||||
description = "Auto-generated bindings for the v2 torn api"
|
||||
license-file = { workspace = true }
|
||||
|
@ -27,12 +27,16 @@ reqwest = { version = "0.12", default-features = false, features = [
|
|||
"brotli",
|
||||
] }
|
||||
thiserror = "2"
|
||||
futures = { version = "0.3", default-features = false, features = [
|
||||
"std",
|
||||
"async-await",
|
||||
] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
[build-dependencies]
|
||||
torn-api-codegen = { path = "../torn-api-codegen", version = "0.1.5" }
|
||||
torn-api-codegen = { path = "../torn-api-codegen", version = "0.2" }
|
||||
syn = { workspace = true, features = ["parsing"] }
|
||||
proc-macro2 = { workspace = true }
|
||||
prettyplease = "0.2"
|
||||
|
|
|
@ -1,23 +1,27 @@
|
|||
use std::future::Future;
|
||||
|
||||
use futures::{Stream, StreamExt};
|
||||
use http::{header::AUTHORIZATION, HeaderMap, HeaderValue};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::request::{ApiResponse, IntoRequest};
|
||||
use crate::request::{ApiRequest, ApiResponse, IntoRequest};
|
||||
#[cfg(feature = "scopes")]
|
||||
use crate::scopes::{FactionScope, ForumScope, MarketScope, RacingScope, TornScope, UserScope};
|
||||
use crate::scopes::{
|
||||
BulkFactionScope, BulkForumScope, BulkMarketScope, BulkRacingScope, BulkTornScope,
|
||||
BulkUserScope, FactionScope, ForumScope, MarketScope, RacingScope, TornScope, UserScope,
|
||||
};
|
||||
|
||||
pub trait Executor {
|
||||
pub trait Executor: Sized {
|
||||
type Error: From<serde_json::Error> + From<crate::ApiError> + Send;
|
||||
|
||||
fn execute<R>(
|
||||
&self,
|
||||
self,
|
||||
request: R,
|
||||
) -> impl Future<Output = Result<ApiResponse<R::Discriminant>, Self::Error>> + Send
|
||||
) -> impl Future<Output = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Send
|
||||
where
|
||||
R: IntoRequest;
|
||||
|
||||
fn fetch<R>(&self, request: R) -> impl Future<Output = Result<R::Response, Self::Error>> + Send
|
||||
fn fetch<R>(self, request: R) -> impl Future<Output = Result<R::Response, Self::Error>> + Send
|
||||
where
|
||||
R: IntoRequest,
|
||||
{
|
||||
|
@ -25,7 +29,7 @@ pub trait Executor {
|
|||
// The future is `Send` but `&self` might not be.
|
||||
let fut = self.execute(request);
|
||||
async {
|
||||
let resp = fut.await?;
|
||||
let resp = fut.await.1?;
|
||||
|
||||
let bytes = resp.body.unwrap();
|
||||
|
||||
|
@ -52,6 +56,152 @@ pub trait Executor {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait BulkExecutor<'e>: 'e + Sized {
|
||||
type Error: From<serde_json::Error> + From<crate::ApiError> + Send;
|
||||
|
||||
fn execute<R>(
|
||||
self,
|
||||
requests: impl IntoIterator<Item = R>,
|
||||
) -> impl Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
|
||||
where
|
||||
R: IntoRequest;
|
||||
|
||||
fn fetch_many<R>(
|
||||
self,
|
||||
requests: impl IntoIterator<Item = R>,
|
||||
) -> impl Stream<Item = (R::Discriminant, Result<R::Response, Self::Error>)>
|
||||
where
|
||||
R: IntoRequest,
|
||||
{
|
||||
self.execute(requests).map(|(d, r)| {
|
||||
let r = match r {
|
||||
Ok(r) => r,
|
||||
Err(why) => return (d, Err(why)),
|
||||
};
|
||||
let bytes = r.body.unwrap();
|
||||
|
||||
if bytes.starts_with(br#"{"error":{"#) {
|
||||
#[derive(Deserialize)]
|
||||
struct ErrorBody<'a> {
|
||||
code: u16,
|
||||
error: &'a str,
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
struct ErrorContainer<'a> {
|
||||
#[serde(borrow)]
|
||||
error: ErrorBody<'a>,
|
||||
}
|
||||
|
||||
let error: ErrorContainer = match serde_json::from_slice(&bytes) {
|
||||
Ok(error) => error,
|
||||
Err(why) => return (d, Err(why.into())),
|
||||
};
|
||||
return (
|
||||
d,
|
||||
Err(crate::ApiError::new(error.error.code, error.error.error).into()),
|
||||
);
|
||||
}
|
||||
|
||||
let resp = match serde_json::from_slice(&bytes) {
|
||||
Ok(resp) => resp,
|
||||
Err(why) => return (d, Err(why.into())),
|
||||
};
|
||||
|
||||
(d, Ok(resp))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "scopes")]
|
||||
pub trait ExecutorExt: Executor + Sized {
|
||||
fn user(self) -> UserScope<Self>;
|
||||
|
||||
fn faction(self) -> FactionScope<Self>;
|
||||
|
||||
fn torn(self) -> TornScope<Self>;
|
||||
|
||||
fn market(self) -> MarketScope<Self>;
|
||||
|
||||
fn racing(self) -> RacingScope<Self>;
|
||||
|
||||
fn forum(self) -> ForumScope<Self>;
|
||||
}
|
||||
|
||||
#[cfg(feature = "scopes")]
|
||||
impl<T> ExecutorExt for T
|
||||
where
|
||||
T: Executor + Sized,
|
||||
{
|
||||
fn user(self) -> UserScope<Self> {
|
||||
UserScope::new(self)
|
||||
}
|
||||
|
||||
fn faction(self) -> FactionScope<Self> {
|
||||
FactionScope::new(self)
|
||||
}
|
||||
|
||||
fn torn(self) -> TornScope<Self> {
|
||||
TornScope::new(self)
|
||||
}
|
||||
|
||||
fn market(self) -> MarketScope<Self> {
|
||||
MarketScope::new(self)
|
||||
}
|
||||
|
||||
fn racing(self) -> RacingScope<Self> {
|
||||
RacingScope::new(self)
|
||||
}
|
||||
|
||||
fn forum(self) -> ForumScope<Self> {
|
||||
ForumScope::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "scopes")]
|
||||
pub trait BulkExecutorExt<'e>: BulkExecutor<'e> + Sized {
|
||||
fn user_bulk(self) -> BulkUserScope<'e, Self>;
|
||||
|
||||
fn faction_bulk(self) -> BulkFactionScope<'e, Self>;
|
||||
|
||||
fn torn_bulk(self) -> BulkTornScope<'e, Self>;
|
||||
|
||||
fn market_bulk(self) -> BulkMarketScope<'e, Self>;
|
||||
|
||||
fn racing_bulk(self) -> BulkRacingScope<'e, Self>;
|
||||
|
||||
fn forum_bulk(self) -> BulkForumScope<'e, Self>;
|
||||
}
|
||||
|
||||
#[cfg(feature = "scopes")]
|
||||
impl<'e, T> BulkExecutorExt<'e> for T
|
||||
where
|
||||
T: BulkExecutor<'e> + Sized,
|
||||
{
|
||||
fn user_bulk(self) -> BulkUserScope<'e, Self> {
|
||||
BulkUserScope::new(self)
|
||||
}
|
||||
|
||||
fn faction_bulk(self) -> BulkFactionScope<'e, Self> {
|
||||
BulkFactionScope::new(self)
|
||||
}
|
||||
|
||||
fn torn_bulk(self) -> BulkTornScope<'e, Self> {
|
||||
BulkTornScope::new(self)
|
||||
}
|
||||
|
||||
fn market_bulk(self) -> BulkMarketScope<'e, Self> {
|
||||
BulkMarketScope::new(self)
|
||||
}
|
||||
|
||||
fn racing_bulk(self) -> BulkRacingScope<'e, Self> {
|
||||
BulkRacingScope::new(self)
|
||||
}
|
||||
|
||||
fn forum_bulk(self) -> BulkForumScope<'e, Self> {
|
||||
BulkForumScope::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReqwestClient(reqwest::Client);
|
||||
|
||||
impl ReqwestClient {
|
||||
|
@ -72,70 +222,43 @@ impl ReqwestClient {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "scopes")]
|
||||
pub trait ExecutorExt: Executor + Sized {
|
||||
fn user(&self) -> UserScope<'_, Self>;
|
||||
|
||||
fn faction(&self) -> FactionScope<'_, Self>;
|
||||
|
||||
fn torn(&self) -> TornScope<'_, Self>;
|
||||
|
||||
fn market(&self) -> MarketScope<'_, Self>;
|
||||
|
||||
fn racing(&self) -> RacingScope<'_, Self>;
|
||||
|
||||
fn forum(&self) -> ForumScope<'_, Self>;
|
||||
}
|
||||
|
||||
#[cfg(feature = "scopes")]
|
||||
impl<T> ExecutorExt for T
|
||||
where
|
||||
T: Executor + Sized,
|
||||
{
|
||||
fn user(&self) -> UserScope<'_, Self> {
|
||||
UserScope::new(self)
|
||||
}
|
||||
|
||||
fn faction(&self) -> FactionScope<'_, Self> {
|
||||
FactionScope::new(self)
|
||||
}
|
||||
|
||||
fn torn(&self) -> TornScope<'_, Self> {
|
||||
TornScope::new(self)
|
||||
}
|
||||
|
||||
fn market(&self) -> MarketScope<'_, Self> {
|
||||
MarketScope::new(self)
|
||||
}
|
||||
|
||||
fn racing(&self) -> RacingScope<'_, Self> {
|
||||
RacingScope::new(self)
|
||||
}
|
||||
|
||||
fn forum(&self) -> ForumScope<'_, Self> {
|
||||
ForumScope::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Executor for ReqwestClient {
|
||||
type Error = crate::Error;
|
||||
|
||||
async fn execute<R>(&self, request: R) -> Result<ApiResponse<R::Discriminant>, Self::Error>
|
||||
where
|
||||
R: IntoRequest,
|
||||
{
|
||||
let request = request.into_request();
|
||||
impl ReqwestClient {
|
||||
async fn execute_api_request(&self, request: ApiRequest) -> Result<ApiResponse, crate::Error> {
|
||||
let url = request.url();
|
||||
|
||||
let response = self.0.get(url).send().await?;
|
||||
let status = response.status();
|
||||
let body = response.bytes().await.ok();
|
||||
|
||||
Ok(ApiResponse {
|
||||
discriminant: request.disriminant,
|
||||
status,
|
||||
body,
|
||||
})
|
||||
Ok(ApiResponse { status, body })
|
||||
}
|
||||
}
|
||||
|
||||
impl Executor for &ReqwestClient {
|
||||
type Error = crate::Error;
|
||||
|
||||
async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
|
||||
where
|
||||
R: IntoRequest,
|
||||
{
|
||||
let (d, request) = request.into_request();
|
||||
(d, self.execute_api_request(request).await)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'e> BulkExecutor<'e> for &'e ReqwestClient {
|
||||
type Error = crate::Error;
|
||||
|
||||
fn execute<R>(
|
||||
self,
|
||||
requests: impl IntoIterator<Item = R>,
|
||||
) -> impl Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
|
||||
where
|
||||
R: IntoRequest,
|
||||
{
|
||||
futures::stream::iter(requests)
|
||||
.map(move |r| <Self as Executor>::execute(self, r))
|
||||
.buffer_unordered(25)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -157,4 +280,22 @@ mod test {
|
|||
other => panic!("Expected incorrect id entity relation error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "scopes")]
|
||||
#[tokio::test]
|
||||
async fn bulk_request() {
|
||||
let client = test_client().await;
|
||||
|
||||
let stream = client
|
||||
.faction_bulk()
|
||||
.basic_for_id(vec![19.into(), 89.into()], |b| b);
|
||||
|
||||
let mut responses: Vec<_> = stream.collect().await;
|
||||
|
||||
let (_id1, basic1) = responses.pop().unwrap();
|
||||
basic1.unwrap();
|
||||
|
||||
let (_id2, basic2) = responses.pop().unwrap();
|
||||
basic2.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,13 +5,12 @@ use http::StatusCode;
|
|||
pub mod models;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ApiRequest<D = ()> {
|
||||
pub disriminant: D,
|
||||
pub struct ApiRequest {
|
||||
pub path: String,
|
||||
pub parameters: Vec<(&'static str, String)>,
|
||||
}
|
||||
|
||||
impl<D> ApiRequest<D> {
|
||||
impl ApiRequest {
|
||||
pub fn url(&self) -> String {
|
||||
let mut url = format!("https://api.torn.com/v2{}?", self.path);
|
||||
|
||||
|
@ -23,8 +22,7 @@ impl<D> ApiRequest<D> {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ApiResponse<D = ()> {
|
||||
pub discriminant: D,
|
||||
pub struct ApiResponse {
|
||||
pub body: Option<Bytes>,
|
||||
pub status: StatusCode,
|
||||
}
|
||||
|
@ -32,7 +30,26 @@ pub struct ApiResponse<D = ()> {
|
|||
pub trait IntoRequest: Send {
|
||||
type Discriminant: Send;
|
||||
type Response: for<'de> serde::Deserialize<'de> + Send;
|
||||
fn into_request(self) -> ApiRequest<Self::Discriminant>;
|
||||
fn into_request(self) -> (Self::Discriminant, ApiRequest);
|
||||
}
|
||||
|
||||
pub(crate) struct WrappedApiRequest<R>
|
||||
where
|
||||
R: IntoRequest,
|
||||
{
|
||||
discriminant: R::Discriminant,
|
||||
request: ApiRequest,
|
||||
}
|
||||
|
||||
impl<R> IntoRequest for WrappedApiRequest<R>
|
||||
where
|
||||
R: IntoRequest,
|
||||
{
|
||||
type Discriminant = R::Discriminant;
|
||||
type Response = R::Response;
|
||||
fn into_request(self) -> (Self::Discriminant, ApiRequest) {
|
||||
(self.discriminant, self.request)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "torn-key-pool"
|
||||
version = "1.0.1"
|
||||
version = "1.1.0"
|
||||
edition = "2021"
|
||||
authors = ["Pyrit [2111649]"]
|
||||
license-file = { workspace = true }
|
||||
|
@ -11,7 +11,7 @@ description = "A generalised API key pool for torn-api"
|
|||
[features]
|
||||
default = ["postgres", "tokio-runtime"]
|
||||
postgres = ["dep:sqlx", "dep:chrono", "dep:indoc"]
|
||||
tokio-runtime = ["dep:tokio", "dep:rand"]
|
||||
tokio-runtime = ["dep:tokio", "dep:rand", "dep:tokio-stream"]
|
||||
|
||||
[dependencies]
|
||||
torn-api = { path = "../torn-api", default-features = false, version = "1.0.1" }
|
||||
|
@ -30,6 +30,9 @@ indoc = { version = "2", optional = true }
|
|||
tokio = { version = "1", optional = true, default-features = false, features = [
|
||||
"time",
|
||||
] }
|
||||
tokio-stream = { version = "0.1", optional = true, default-features = false, features = [
|
||||
"time",
|
||||
] }
|
||||
rand = { version = "0.9", optional = true }
|
||||
futures = "0.3"
|
||||
reqwest = { version = "0.12", default-features = false, features = [
|
||||
|
|
|
@ -5,11 +5,12 @@ pub mod postgres;
|
|||
|
||||
use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
|
||||
|
||||
use futures::{future::BoxFuture, FutureExt};
|
||||
use futures::{future::BoxFuture, FutureExt, Stream, StreamExt};
|
||||
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
|
||||
use serde::Deserialize;
|
||||
use tokio_stream::StreamExt as TokioStreamExt;
|
||||
use torn_api::{
|
||||
executor::Executor,
|
||||
executor::{BulkExecutor, Executor},
|
||||
request::{ApiRequest, ApiResponse},
|
||||
ApiError,
|
||||
};
|
||||
|
@ -80,6 +81,46 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<K, D> From<&str> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn from(value: &str) -> Self {
|
||||
Self::Key(value.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, D> From<D> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn from(value: D) -> Self {
|
||||
Self::Has(vec![value])
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, D> From<&[D]> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn from(value: &[D]) -> Self {
|
||||
Self::Has(value.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, D> From<Vec<D>> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn from(value: Vec<D>) -> Self {
|
||||
Self::Has(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IntoSelector<K, D>: Send
|
||||
where
|
||||
K: ApiKey,
|
||||
|
@ -88,30 +129,35 @@ where
|
|||
fn into_selector(self) -> KeySelector<K, D>;
|
||||
}
|
||||
|
||||
impl<K, D> IntoSelector<K, D> for D
|
||||
impl<K, D, T> IntoSelector<K, D> for T
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
T: Into<KeySelector<K, D>> + Send,
|
||||
{
|
||||
fn into_selector(self) -> KeySelector<K, D> {
|
||||
KeySelector::Has(vec![self])
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, D> IntoSelector<K, D> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
pub trait KeyPoolError:
|
||||
From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + From<Arc<Self>> + Send
|
||||
{
|
||||
}
|
||||
|
||||
impl<T> KeyPoolError for T where
|
||||
T: From<reqwest::Error>
|
||||
+ From<serde_json::Error>
|
||||
+ From<torn_api::ApiError>
|
||||
+ From<Arc<Self>>
|
||||
+ Send
|
||||
{
|
||||
fn into_selector(self) -> KeySelector<K, D> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub trait KeyPoolStorage: Send + Sync {
|
||||
type Key: ApiKey;
|
||||
type Domain: KeyDomain;
|
||||
type Error: From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + Send;
|
||||
type Error: KeyPoolError;
|
||||
|
||||
fn acquire_key<S>(
|
||||
&self,
|
||||
|
@ -206,65 +252,6 @@ where
|
|||
>,
|
||||
}
|
||||
|
||||
pub struct KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pool: &'p KeyPool<S>,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
}
|
||||
|
||||
impl<'p, S> KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
|
||||
Self { pool, selector }
|
||||
}
|
||||
|
||||
async fn execute_request<D>(&self, request: ApiRequest<D>) -> Result<ApiResponse<D>, S::Error>
|
||||
where
|
||||
D: Send,
|
||||
{
|
||||
let key = self.pool.storage.acquire_key(self.selector.clone()).await?;
|
||||
|
||||
let mut headers = HeaderMap::with_capacity(1);
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.pool
|
||||
.client
|
||||
.get(request.url())
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
|
||||
let bytes = resp.bytes().await?;
|
||||
|
||||
if let Some(err) = decode_error(&bytes)? {
|
||||
if let Some(handler) = self.pool.options.error_hooks.get(&err.code()) {
|
||||
let retry = (*handler)(&self.pool.storage, &key).await?;
|
||||
|
||||
if retry {
|
||||
return Box::pin(self.execute_request(request)).await;
|
||||
}
|
||||
}
|
||||
Err(err.into())
|
||||
} else {
|
||||
Ok(ApiResponse {
|
||||
discriminant: request.disriminant,
|
||||
body: Some(bytes),
|
||||
status,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PoolBuilder<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
|
@ -358,20 +345,137 @@ where
|
|||
|
||||
pub fn build(self) -> KeyPool<S> {
|
||||
KeyPool {
|
||||
client: self.client,
|
||||
storage: self.storage,
|
||||
options: Arc::new(self.options),
|
||||
inner: Arc::new(KeyPoolInner {
|
||||
client: self.client,
|
||||
storage: self.storage,
|
||||
options: self.options,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct KeyPoolInner<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub client: reqwest::Client,
|
||||
pub storage: S,
|
||||
pub options: PoolOptions<S>,
|
||||
}
|
||||
|
||||
impl<S> KeyPoolInner<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
async fn execute_with_key(
|
||||
&self,
|
||||
key: &S::Key,
|
||||
request: &ApiRequest,
|
||||
) -> Result<RequestResult, S::Error> {
|
||||
let mut headers = HeaderMap::with_capacity(1);
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(request.url())
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
|
||||
let bytes = resp.bytes().await?;
|
||||
|
||||
if let Some(err) = decode_error(&bytes)? {
|
||||
if let Some(handler) = self.options.error_hooks.get(&err.code()) {
|
||||
let retry = (*handler)(&self.storage, key).await?;
|
||||
|
||||
if retry {
|
||||
return Ok(RequestResult::Retry);
|
||||
}
|
||||
}
|
||||
Err(err.into())
|
||||
} else {
|
||||
Ok(RequestResult::Response(ApiResponse {
|
||||
body: Some(bytes),
|
||||
status,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_request(
|
||||
&self,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
request: ApiRequest,
|
||||
) -> Result<ApiResponse, S::Error> {
|
||||
loop {
|
||||
let key = self.storage.acquire_key(selector.clone()).await?;
|
||||
match self.execute_with_key(&key, &request).await {
|
||||
Ok(RequestResult::Response(resp)) => return Ok(resp),
|
||||
Ok(RequestResult::Retry) => (),
|
||||
Err(why) => return Err(why),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_bulk_requests<D, T: IntoIterator<Item = (D, ApiRequest)>>(
|
||||
&self,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
requests: T,
|
||||
) -> impl Stream<Item = (D, Result<ApiResponse, S::Error>)> + use<'_, D, S, T> {
|
||||
let requests: Vec<_> = requests.into_iter().collect();
|
||||
|
||||
let keys: Vec<_> = match self
|
||||
.storage
|
||||
.acquire_many_keys(selector.clone(), requests.len() as i64)
|
||||
.await
|
||||
{
|
||||
Ok(keys) => keys.into_iter().map(Ok).collect(),
|
||||
Err(why) => {
|
||||
let why = Arc::new(why);
|
||||
std::iter::repeat_n(why, requests.len())
|
||||
.map(|e| Err(S::Error::from(e)))
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
StreamExt::map(
|
||||
futures::stream::iter(std::iter::zip(requests, keys)),
|
||||
move |((discriminant, request), mut maybe_key)| {
|
||||
let selector = selector.clone();
|
||||
async move {
|
||||
loop {
|
||||
let key = match maybe_key {
|
||||
Ok(key) => key,
|
||||
Err(why) => return (discriminant, Err(why)),
|
||||
};
|
||||
match self.execute_with_key(&key, &request).await {
|
||||
Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)),
|
||||
Ok(RequestResult::Retry) => (),
|
||||
Err(why) => return (discriminant, Err(why)),
|
||||
}
|
||||
maybe_key = self.storage.acquire_key(selector.clone()).await;
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
.buffer_unordered(25)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KeyPool<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub client: reqwest::Client,
|
||||
pub storage: S,
|
||||
pub options: Arc<PoolOptions<S>>,
|
||||
inner: Arc<KeyPoolInner<S>>,
|
||||
}
|
||||
|
||||
enum RequestResult {
|
||||
Response(ApiResponse),
|
||||
Retry,
|
||||
}
|
||||
|
||||
impl<S> KeyPool<S>
|
||||
|
@ -384,6 +488,17 @@ where
|
|||
{
|
||||
KeyPoolExecutor::new(self, selector.into_selector())
|
||||
}
|
||||
|
||||
pub fn throttled_torn_api<I>(
|
||||
&self,
|
||||
selector: I,
|
||||
distance: Duration,
|
||||
) -> ThrottledKeyPoolExecutor<S>
|
||||
where
|
||||
I: IntoSelector<S::Key, S::Domain>,
|
||||
{
|
||||
ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance)
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
|
||||
|
@ -409,28 +524,145 @@ fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S> Executor for KeyPoolExecutor<'_, S>
|
||||
pub struct KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pool: &'p KeyPoolInner<S>,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
}
|
||||
|
||||
impl<'p, S> KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
|
||||
Self {
|
||||
pool: &pool.inner,
|
||||
selector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Executor for KeyPoolExecutor<'_, S>
|
||||
where
|
||||
S: KeyPoolStorage + 'static,
|
||||
{
|
||||
type Error = S::Error;
|
||||
|
||||
async fn execute<R>(
|
||||
&self,
|
||||
request: R,
|
||||
) -> Result<torn_api::request::ApiResponse<R::Discriminant>, Self::Error>
|
||||
async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
|
||||
where
|
||||
R: torn_api::request::IntoRequest,
|
||||
{
|
||||
let request = request.into_request();
|
||||
let (d, request) = request.into_request();
|
||||
|
||||
self.execute_request(request).await
|
||||
(d, self.pool.execute_request(self.selector, request).await)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'p, S> BulkExecutor<'p> for KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage + 'static,
|
||||
{
|
||||
type Error = S::Error;
|
||||
|
||||
fn execute<R>(
|
||||
self,
|
||||
requests: impl IntoIterator<Item = R>,
|
||||
) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
|
||||
where
|
||||
R: torn_api::request::IntoRequest,
|
||||
{
|
||||
self.pool
|
||||
.execute_bulk_requests(
|
||||
self.selector.clone(),
|
||||
requests.into_iter().map(|r| r.into_request()),
|
||||
)
|
||||
.into_stream()
|
||||
.flatten()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ThrottledKeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pool: &'p KeyPoolInner<S>,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
distance: Duration,
|
||||
}
|
||||
|
||||
impl<S> Clone for ThrottledKeyPoolExecutor<'_, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
pool: self.pool,
|
||||
selector: self.selector.clone(),
|
||||
distance: self.distance,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ThrottledKeyPoolExecutor<'_, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
async fn execute_request(self, request: ApiRequest) -> Result<ApiResponse, S::Error> {
|
||||
self.pool.execute_request(self.selector, request).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<'p, S> ThrottledKeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(
|
||||
pool: &'p KeyPool<S>,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
distance: Duration,
|
||||
) -> Self {
|
||||
Self {
|
||||
pool: &pool.inner,
|
||||
selector,
|
||||
distance,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'p, S> BulkExecutor<'p> for ThrottledKeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage + 'static,
|
||||
{
|
||||
type Error = S::Error;
|
||||
|
||||
fn execute<R>(
|
||||
self,
|
||||
requests: impl IntoIterator<Item = R>,
|
||||
) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
|
||||
where
|
||||
R: torn_api::request::IntoRequest,
|
||||
{
|
||||
StreamExt::map(
|
||||
futures::stream::iter(requests).throttle(self.distance),
|
||||
move |r| {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let (d, request) = r.into_request();
|
||||
let result = this.execute_request(request).await;
|
||||
(d, result)
|
||||
}
|
||||
},
|
||||
)
|
||||
.buffer_unordered(25)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(feature = "postgres")]
|
||||
mod test {
|
||||
use torn_api::executor::ExecutorExt;
|
||||
use torn_api::executor::{BulkExecutorExt, ExecutorExt};
|
||||
|
||||
use crate::postgres;
|
||||
|
||||
|
@ -451,4 +683,48 @@ mod test {
|
|||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
fn bulk(pool: sqlx::PgPool) {
|
||||
let (storage, _) = postgres::test::setup(pool).await;
|
||||
|
||||
let pool = PoolBuilder::new(storage)
|
||||
.use_default_hooks()
|
||||
.comment("test_runner")
|
||||
.build();
|
||||
|
||||
let responses = pool
|
||||
.torn_api(postgres::test::Domain::All)
|
||||
.faction_bulk()
|
||||
.basic_for_id(vec![19.into(), 89.into()], |b| b);
|
||||
let mut responses: Vec<_> = StreamExt::collect(responses).await;
|
||||
|
||||
let (_id1, basic1) = responses.pop().unwrap();
|
||||
basic1.unwrap();
|
||||
|
||||
let (_id2, basic2) = responses.pop().unwrap();
|
||||
basic2.unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
fn bulk_trottled(pool: sqlx::PgPool) {
|
||||
let (storage, _) = postgres::test::setup(pool).await;
|
||||
|
||||
let pool = PoolBuilder::new(storage)
|
||||
.use_default_hooks()
|
||||
.comment("test_runner")
|
||||
.build();
|
||||
|
||||
let responses = pool
|
||||
.throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500))
|
||||
.faction_bulk()
|
||||
.basic_for_id(vec![19.into(), 89.into()], |b| b);
|
||||
let mut responses: Vec<_> = StreamExt::collect(responses).await;
|
||||
|
||||
let (_id1, basic1) = responses.pop().unwrap();
|
||||
basic1.unwrap();
|
||||
|
||||
let (_id2, basic2) = responses.pop().unwrap();
|
||||
basic2.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
use indoc::formatdoc;
|
||||
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
|
||||
|
@ -37,6 +39,9 @@ where
|
|||
|
||||
#[error("Key not found: '{0:?}'")]
|
||||
KeyNotFound(KeySelector<PgKey<D>, D>),
|
||||
|
||||
#[error("Failed to acquire keys in bulk: {0}")]
|
||||
Bulk(#[from] Arc<Self>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow)]
|
||||
|
|
Loading…
Reference in a new issue