#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)] #[cfg(feature = "postgres")] pub mod postgres; use std::{collections::HashMap, future::Future, ops::Deref, sync::Arc, time::Duration}; use futures::{future::BoxFuture, FutureExt, Stream, StreamExt}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use serde::Deserialize; use tokio_stream::StreamExt as TokioStreamExt; use torn_api::{ executor::{BulkExecutor, Executor}, request::{ApiRequest, ApiResponse}, ApiError, }; pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {} impl ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {} pub trait ApiKey: Send + Sync + Clone + 'static { type IdType: ApiKeyId; fn value(&self) -> &str; fn id(&self) -> Self::IdType; fn selector(&self) -> KeySelector where D: KeyDomain, { KeySelector::Id(self.id()) } } pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static { fn fallback(&self) -> Option { None } } #[derive(Debug, Clone)] pub enum KeySelector where K: ApiKey, D: KeyDomain, { Key(String), Id(K::IdType), UserId(i32), Has(Vec), OneOf(Vec), } impl KeySelector where K: ApiKey, D: KeyDomain, { pub(crate) fn fallback(&self) -> Option { match self { Self::Key(_) | Self::UserId(_) | Self::Id(_) => None, Self::Has(domains) => { let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect(); if fallbacks.is_empty() { None } else { Some(Self::Has(fallbacks)) } } Self::OneOf(domains) => { let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect(); if fallbacks.is_empty() { None } else { Some(Self::OneOf(fallbacks)) } } } } } impl From<&str> for KeySelector where K: ApiKey, D: KeyDomain, { fn from(value: &str) -> Self { Self::Key(value.to_owned()) } } impl From for KeySelector where K: ApiKey, D: KeyDomain, { fn from(value: D) -> Self { Self::Has(vec![value]) } } impl From<&[D]> for KeySelector where K: ApiKey, D: KeyDomain, { fn from(value: &[D]) -> Self { Self::Has(value.to_vec()) } } impl From> for KeySelector where K: ApiKey, D: KeyDomain, { fn from(value: Vec) -> Self { Self::Has(value) } } pub trait IntoSelector: Send where K: ApiKey, D: KeyDomain, { fn into_selector(self) -> KeySelector; } impl IntoSelector for T where K: ApiKey, D: KeyDomain, T: Into> + Send, { fn into_selector(self) -> KeySelector { self.into() } } pub trait KeyPoolError: From + From + From + From> + Send { } impl KeyPoolError for T where T: From + From + From + From> + Send { } pub trait KeyPoolStorage: Send + Sync { type Key: ApiKey; type Domain: KeyDomain; type Error: KeyPoolError; fn acquire_key( &self, selector: S, ) -> impl Future> + Send where S: IntoSelector; fn acquire_many_keys( &self, selector: S, number: i64, ) -> impl Future, Self::Error>> + Send where S: IntoSelector; fn store_key( &self, user_id: i32, key: String, domains: Vec, ) -> impl Future> + Send; fn read_key( &self, selector: S, ) -> impl Future, Self::Error>> + Send where S: IntoSelector; fn read_keys( &self, selector: S, ) -> impl Future, Self::Error>> + Send where S: IntoSelector; fn remove_key( &self, selector: S, ) -> impl Future> + Send where S: IntoSelector; fn add_domain_to_key( &self, selector: S, domain: Self::Domain, ) -> impl Future> + Send where S: IntoSelector; fn remove_domain_from_key( &self, selector: S, domain: Self::Domain, ) -> impl Future> + Send where S: IntoSelector; fn set_domains_for_key( &self, selector: S, domains: Vec, ) -> impl Future> + Send where S: IntoSelector; fn timeout_key( &self, selector: S, duration: Duration, ) -> impl Future> + Send where S: IntoSelector; } #[derive(Default)] pub struct PoolOptions where S: KeyPoolStorage, { comment: Option, #[allow(clippy::type_complexity)] error_hooks: HashMap< u16, Box< dyn for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result> + Send + Sync, >, >, } pub struct PoolBuilder where S: KeyPoolStorage, { client: reqwest::Client, storage: S, options: crate::PoolOptions, } impl PoolBuilder where S: KeyPoolStorage, { pub fn new(storage: S) -> Self { Self { client: reqwest::Client::builder() .brotli(true) .http2_keep_alive_timeout(Duration::from_secs(60)) .http2_keep_alive_interval(Duration::from_secs(5)) .https_only(true) .build() .unwrap(), storage, options: PoolOptions { comment: None, error_hooks: Default::default(), }, } } pub fn comment(mut self, c: impl ToString) -> Self { self.options.comment = Some(c.to_string()); self } pub fn error_hook(mut self, code: u16, handler: F) -> Self where F: for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result> + Send + Sync + 'static, { self.options.error_hooks.insert(code, Box::new(handler)); self } pub fn use_default_hooks(self) -> Self { self.error_hook(2, |storage, key| { async move { storage.remove_key(KeySelector::Id(key.id())).await?; Ok(true) } .boxed() }) .error_hook(5, |storage, key| { async move { storage .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60)) .await?; Ok(true) } .boxed() }) .error_hook(10, |storage, key| { async move { storage.remove_key(KeySelector::Id(key.id())).await?; Ok(true) } .boxed() }) .error_hook(13, |storage, key| { async move { storage .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600)) .await?; Ok(true) } .boxed() }) .error_hook(18, |storage, key| { async move { storage .timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600)) .await?; Ok(true) } .boxed() }) } pub fn build(self) -> KeyPool { KeyPool { inner: Arc::new(KeyPoolInner { client: self.client, storage: self.storage, options: self.options, }), } } } pub struct KeyPoolInner where S: KeyPoolStorage, { pub client: reqwest::Client, pub storage: S, pub options: PoolOptions, } impl KeyPoolInner where S: KeyPoolStorage, { async fn execute_with_key( &self, key: &S::Key, request: &ApiRequest, ) -> Result { let mut headers = HeaderMap::with_capacity(1); headers.insert( AUTHORIZATION, HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(), ); let resp = self .client .get(request.url()) .headers(headers) .send() .await?; let status = resp.status(); let bytes = resp.bytes().await?; if let Some(err) = decode_error(&bytes)? { if let Some(handler) = self.options.error_hooks.get(&err.code()) { let retry = (*handler)(&self.storage, key).await?; if retry { return Ok(RequestResult::Retry); } } Err(err.into()) } else { Ok(RequestResult::Response(ApiResponse { body: Some(bytes), status, })) } } async fn execute_request( &self, selector: KeySelector, request: ApiRequest, ) -> Result { loop { let key = self.storage.acquire_key(selector.clone()).await?; match self.execute_with_key(&key, &request).await { Ok(RequestResult::Response(resp)) => return Ok(resp), Ok(RequestResult::Retry) => (), Err(why) => return Err(why), } } } async fn execute_bulk_requests>( &self, selector: KeySelector, requests: T, ) -> impl Stream)> + use<'_, D, S, T> { let requests: Vec<_> = requests.into_iter().collect(); let keys: Vec<_> = match self .storage .acquire_many_keys(selector.clone(), requests.len() as i64) .await { Ok(keys) => keys.into_iter().map(Ok).collect(), Err(why) => { let why = Arc::new(why); std::iter::repeat_n(why, requests.len()) .map(|e| Err(S::Error::from(e))) .collect() } }; StreamExt::map( futures::stream::iter(std::iter::zip(requests, keys)), move |((discriminant, request), mut maybe_key)| { let selector = selector.clone(); async move { loop { let key = match maybe_key { Ok(key) => key, Err(why) => return (discriminant, Err(why)), }; match self.execute_with_key(&key, &request).await { Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)), Ok(RequestResult::Retry) => (), Err(why) => return (discriminant, Err(why)), } maybe_key = self.storage.acquire_key(selector.clone()).await; } } }, ) .buffer_unordered(25) } } pub struct KeyPool where S: KeyPoolStorage, { inner: Arc>, } impl Deref for KeyPool where S: KeyPoolStorage, { type Target = KeyPoolInner; fn deref(&self) -> &Self::Target { &self.inner } } enum RequestResult { Response(ApiResponse), Retry, } impl KeyPool where S: KeyPoolStorage + Send + Sync + 'static, { pub fn torn_api(&self, selector: I) -> KeyPoolExecutor where I: IntoSelector, { KeyPoolExecutor::new(self, selector.into_selector()) } pub fn throttled_torn_api( &self, selector: I, distance: Duration, ) -> ThrottledKeyPoolExecutor where I: IntoSelector, { ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance) } } fn decode_error(buf: &[u8]) -> Result, serde_json::Error> { if buf.starts_with(br#"{"error":{"#) { #[derive(Deserialize)] struct ErrorBody<'a> { code: u16, error: &'a str, } #[derive(Deserialize)] struct ErrorContainer<'a> { #[serde(borrow)] error: ErrorBody<'a>, } let error: ErrorContainer = serde_json::from_slice(buf)?; Ok(Some(crate::ApiError::new( error.error.code, error.error.error, ))) } else { Ok(None) } } pub struct KeyPoolExecutor<'p, S> where S: KeyPoolStorage, { pool: &'p KeyPoolInner, selector: KeySelector, } impl<'p, S> KeyPoolExecutor<'p, S> where S: KeyPoolStorage, { pub fn new(pool: &'p KeyPool, selector: KeySelector) -> Self { Self { pool: &pool.inner, selector, } } } impl Executor for KeyPoolExecutor<'_, S> where S: KeyPoolStorage + 'static, { type Error = S::Error; async fn execute(self, request: R) -> (R::Discriminant, Result) where R: torn_api::request::IntoRequest, { let (d, request) = request.into_request(); (d, self.pool.execute_request(self.selector, request).await) } } impl BulkExecutor for KeyPoolExecutor<'_, S> where S: KeyPoolStorage + 'static, { type Error = S::Error; fn execute( self, requests: impl IntoIterator, ) -> impl futures::Stream)> + Unpin where R: torn_api::request::IntoRequest, { let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect(); self.pool .execute_bulk_requests(self.selector.clone(), requests) .into_stream() .flatten() .boxed() } } pub struct ThrottledKeyPoolExecutor<'p, S> where S: KeyPoolStorage, { pool: &'p KeyPoolInner, selector: KeySelector, distance: Duration, } impl Clone for ThrottledKeyPoolExecutor<'_, S> where S: KeyPoolStorage, { fn clone(&self) -> Self { Self { pool: self.pool, selector: self.selector.clone(), distance: self.distance, } } } impl ThrottledKeyPoolExecutor<'_, S> where S: KeyPoolStorage, { async fn execute_request(self, request: ApiRequest) -> Result { self.pool.execute_request(self.selector, request).await } } impl<'p, S> ThrottledKeyPoolExecutor<'p, S> where S: KeyPoolStorage, { pub fn new( pool: &'p KeyPool, selector: KeySelector, distance: Duration, ) -> Self { Self { pool: &pool.inner, selector, distance, } } } impl BulkExecutor for ThrottledKeyPoolExecutor<'_, S> where S: KeyPoolStorage + 'static, { type Error = S::Error; fn execute( self, requests: impl IntoIterator, ) -> impl futures::Stream)> + Unpin where R: torn_api::request::IntoRequest, { let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect(); StreamExt::map( futures::stream::iter(requests).throttle(self.distance), move |(d, request)| { let this = self.clone(); async move { let result = this.execute_request(request).await; (d, result) } }, ) .buffer_unordered(25) .boxed() } } #[cfg(test)] #[cfg(feature = "postgres")] mod test { use torn_api::executor::{BulkExecutorExt, ExecutorExt}; use crate::postgres; use super::*; #[sqlx::test] fn name(pool: sqlx::PgPool) { let (storage, _) = postgres::test::setup(pool).await; let pool = PoolBuilder::new(storage) .use_default_hooks() .comment("test_runner") .build(); pool.torn_api(postgres::test::Domain::All) .faction() .basic(|b| b) .await .unwrap(); } #[sqlx::test] fn bulk(pool: sqlx::PgPool) { let (storage, _) = postgres::test::setup(pool).await; let pool = PoolBuilder::new(storage) .use_default_hooks() .comment("test_runner") .build(); let responses = pool .torn_api(postgres::test::Domain::All) .faction_bulk() .basic_for_id(vec![19.into(), 89.into()], |b| b); let mut responses: Vec<_> = StreamExt::collect(responses).await; let (_id1, basic1) = responses.pop().unwrap(); basic1.unwrap(); let (_id2, basic2) = responses.pop().unwrap(); basic2.unwrap(); } #[sqlx::test] fn bulk_trottled(pool: sqlx::PgPool) { let (storage, _) = postgres::test::setup(pool).await; let pool = PoolBuilder::new(storage) .use_default_hooks() .comment("test_runner") .build(); let responses = pool .throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500)) .faction_bulk() .basic_for_id(vec![19.into(), 89.into()], |b| b); let mut responses: Vec<_> = StreamExt::collect(responses).await; let (_id1, basic1) = responses.pop().unwrap(); basic1.unwrap(); let (_id2, basic2) = responses.pop().unwrap(); basic2.unwrap(); } }