feat(key-pool): updated key pool to use v2 api

This commit is contained in:
TotallyNot 2025-04-25 17:21:50 +02:00
parent 5ae490c756
commit 4b52c37774
Signed by: pyrite
GPG key ID: 7F1BA9170CD35D15
8 changed files with 1728 additions and 870 deletions

1267
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[workspace]
resolver = "2"
members = ["torn-api", "torn-api-codegen"]
members = ["torn-api", "torn-api-codegen", "torn-key-pool"]
[workspace.package]
license-file = "./LICENSE"

View file

@ -3,7 +3,7 @@
"info": {
"title": "Torn API",
"description": "\n * The development of Torn's API v2 is still ongoing.\n * If selections remain unaltered, they will default to the API v1 version.\n * Unlike API v1, API v2 accepts both selections and IDs as path and query parameters.\n * If any discrepancies or errors are found, please submit a [bug report](https://www.torn.com/forums.php#/p=forums&f=19&b=0&a=0) on the Torn Forums.\n * In case you're using bots to check for changes on openapi.json file, make sure to specificy a custom user-agent header - CloudFlare sometimes prevents requests from default user-agents.",
"version": "1.3.1"
"version": "1.3.2"
},
"servers": [
{
@ -8773,7 +8773,8 @@
"attacksdamaging",
"attacksrunaway",
"highestterritories",
"territoryrespect"
"territoryrespect",
"membersamount"
]
},
"FactionBranchStateEnum": {
@ -13864,25 +13865,30 @@
"type": "object"
},
"RacingRaceDetailsResponse": {
"allOf": [
{
"required": [
"results"
],
"properties": {
"results": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RacerDetails"
}
"properties": {
"race": {
"allOf": [
{
"required": [
"results"
],
"properties": {
"results": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RacerDetails"
}
}
},
"type": "object"
},
{
"$ref": "#/components/schemas/Race"
}
},
"type": "object"
},
{
"$ref": "#/components/schemas/Race"
]
}
]
},
"type": "object"
},
"RacingSelectionName": {
"type": "string",
@ -15352,7 +15358,14 @@
"searchforcash",
"shoplifting",
"stats",
"stocks"
"stocks",
"chainreport",
"rackets",
"rankedwarreport",
"rankedwars",
"territorynames",
"territorywarreport",
"territorywars"
]
},
"TornLookupResponse": {
@ -19062,10 +19075,7 @@
"format": "int64"
},
"rewards": {
"type": "array",
"items": {
"$ref": "#/components/schemas/UserCrimeUniquesReward"
}
"$ref": "#/components/schemas/UserCrimeUniquesReward"
}
},
"type": "object"
@ -19139,6 +19149,9 @@
},
{
"$ref": "#/components/schemas/UserCrimeDetailsScamming"
},
{
"type": "null"
}
]
}
@ -19181,7 +19194,14 @@
"$ref": "#/components/schemas/RaceCarId"
},
"name": {
"type": "string"
"oneOf": [
{
"type": "string"
},
{
"type": "null"
}
]
},
"worth": {
"type": "integer",
@ -19421,10 +19441,7 @@
],
"properties": {
"hof": {
"type": "array",
"items": {
"$ref": "#/components/schemas/UserHofStats"
}
"$ref": "#/components/schemas/UserHofStats"
}
},
"type": "object"

View file

@ -1,43 +1,45 @@
[package]
name = "torn-key-pool"
version = "0.9.0"
version = "1.0.0"
edition = "2021"
authors = ["Pyrit [2111649]"]
license = "MIT"
repository = "https://github.com/TotallyNot/torn-api.rs.git"
homepage = "https://github.com/TotallyNot/torn-api.rs.git"
license-file = { workspace = true }
repository = { workspace = true }
homepage = { workspace = true }
description = "A generalised API key pool for torn-api"
[features]
default = [ "postgres", "tokio-runtime" ]
postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc", "dep:serde" ]
reqwest = [ "dep:reqwest", "torn-api/reqwest" ]
awc = [ "dep:awc", "torn-api/awc" ]
tokio-runtime = [ "dep:tokio", "dep:rand" ]
actix-runtime = [ "dep:actix-rt", "dep:rand" ]
default = ["postgres", "tokio-runtime"]
postgres = ["dep:sqlx", "dep:chrono", "dep:indoc"]
tokio-runtime = ["dep:tokio", "dep:rand"]
[dependencies]
torn-api = { path = "../torn-api", default-features = false, version = "0.7" }
async-trait = "0.1"
torn-api = { path = "../torn-api", default-features = false, version = "1.0.1" }
thiserror = "2"
sqlx = { version = "0.8", features = [ "postgres", "chrono", "json", "derive" ], optional = true, default-features = false }
serde = { version = "1.0", optional = true }
sqlx = { version = "0.8", features = [
"postgres",
"chrono",
"json",
"derive",
], optional = true, default-features = false }
serde = { workspace = true }
serde_json = { workspace = true }
chrono = { version = "0.4", optional = true }
indoc = { version = "2", optional = true }
tokio = { version = "1", optional = true, default-features = false, features = ["time"] }
actix-rt = { version = "2", optional = true, default-features = false }
rand = { version = "0.8", optional = true }
tokio = { version = "1", optional = true, default-features = false, features = [
"time",
] }
rand = { version = "0.9", optional = true }
futures = "0.3"
reqwest = { version = "0.12", default-features = false, features = [ "json" ], optional = true }
awc = { version = "3", default-features = false, optional = true }
reqwest = { version = "0.12", default-features = false, features = [
"brotli",
"http2",
"rustls-tls-webpki-roots",
] }
[dev-dependencies]
torn-api = { path = "../torn-api", features = [ "reqwest" ] }
sqlx = { version = "0.8", features = [ "runtime-tokio-rustls" ] }
dotenvy = "0.15"
torn-api = { path = "../torn-api" }
sqlx = { version = "0.8", features = ["runtime-tokio-rustls"] }
tokio = { version = "1.42", features = ["rt"] }
tokio-test = "0.4"
reqwest = { version = "0.12", default-features = true }
awc = { version = "3", features = [ "rustls" ] }

View file

@ -3,48 +3,23 @@
#[cfg(feature = "postgres")]
pub mod postgres;
// pub mod local;
pub mod send;
use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
use std::sync::Arc;
use futures::{future::BoxFuture, FutureExt};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde::Deserialize;
use torn_api::{
executor::Executor,
request::{ApiRequest, ApiResponse},
ApiError,
};
use async_trait::async_trait;
use thiserror::Error;
pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
use torn_api::ResponseError;
impl<T> ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
#[derive(Debug, Error)]
pub enum KeyPoolError<S, C>
where
S: std::error::Error + Clone,
C: std::error::Error,
{
#[error("Key pool storage driver error: {0:?}")]
Storage(#[source] S),
#[error(transparent)]
Client(#[from] C),
#[error(transparent)]
Response(ResponseError),
}
impl<S, C> KeyPoolError<S, C>
where
S: std::error::Error + Clone,
C: std::error::Error,
{
#[inline(always)]
pub fn api_code(&self) -> Option<u8> {
match self {
Self::Response(why) => why.api_code(),
_ => None,
}
}
}
pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone + 'static {
type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone;
pub trait ApiKey: Send + Sync + Clone + 'static {
type IdType: ApiKeyId;
fn value(&self) -> &str;
@ -105,7 +80,7 @@ where
}
}
pub trait IntoSelector<K, D>: Send + Sync
pub trait IntoSelector<K, D>: Send
where
K: ApiKey,
D: KeyDomain,
@ -133,114 +108,347 @@ where
}
}
pub enum KeyAction<D>
where
D: KeyDomain,
{
Delete,
RemoveDomain(D),
Timeout(chrono::Duration),
}
#[async_trait]
pub trait KeyPoolStorage {
pub trait KeyPoolStorage: Send + Sync {
type Key: ApiKey;
type Domain: KeyDomain;
type Error: std::error::Error + Sync + Send + Clone;
type Error: From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + Send;
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
fn acquire_key<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn acquire_many_keys<S>(
fn acquire_many_keys<S>(
&self,
selector: S,
number: i64,
) -> Result<Vec<Self::Key>, Self::Error>
) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
async fn store_key(
fn store_key(
&self,
user_id: i32,
key: String,
domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>;
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send;
async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
fn read_key<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Option<Self::Key>, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
fn read_keys<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
fn remove_key<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn add_domain_to_key<S>(
fn add_domain_to_key<S>(
&self,
selector: S,
domain: Self::Domain,
) -> Result<Self::Key, Self::Error>
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn remove_domain_from_key<S>(
fn remove_domain_from_key<S>(
&self,
selector: S,
domain: Self::Domain,
) -> Result<Self::Key, Self::Error>
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn set_domains_for_key<S>(
fn set_domains_for_key<S>(
&self,
selector: S,
domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
fn timeout_key<S>(
&self,
selector: S,
duration: Duration,
) -> impl Future<Output = Result<(), Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
}
#[derive(Debug, Default)]
pub struct PoolOptions {
#[derive(Default)]
pub struct PoolOptions<S>
where
S: KeyPoolStorage,
{
comment: Option<String>,
hooks_before: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
hooks_after: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
#[allow(clippy::type_complexity)]
error_hooks: HashMap<
u16,
Box<
dyn for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
+ Send
+ Sync,
>,
>,
}
#[derive(Debug, Clone)]
pub struct KeyPoolExecutor<'a, C, S>
pub struct KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
storage: &'a S,
options: Arc<PoolOptions>,
pool: &'p KeyPool<S>,
selector: KeySelector<S::Key, S::Domain>,
_marker: std::marker::PhantomData<C>,
}
impl<'a, C, S> KeyPoolExecutor<'a, C, S>
impl<'p, S> KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pub fn new(
storage: &'a S,
selector: KeySelector<S::Key, S::Domain>,
options: Arc<PoolOptions>,
) -> Self {
Self {
storage,
selector,
options,
_marker: std::marker::PhantomData,
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
Self { pool, selector }
}
async fn execute_request<D>(&self, request: ApiRequest<D>) -> Result<ApiResponse<D>, S::Error>
where
D: Send,
{
let key = self.pool.storage.acquire_key(self.selector.clone()).await?;
let mut headers = HeaderMap::with_capacity(1);
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
);
let resp = self
.pool
.client
.get(request.url())
.headers(headers)
.send()
.await?;
let status = resp.status();
let bytes = resp.bytes().await?;
if let Some(err) = decode_error(&bytes)? {
if let Some(handler) = self.pool.options.error_hooks.get(&err.code()) {
let retry = (*handler)(&self.pool.storage, &key).await?;
if retry {
return Box::pin(self.execute_request(request)).await;
}
}
Err(err.into())
} else {
Ok(ApiResponse {
discriminant: request.disriminant,
body: Some(bytes),
status,
})
}
}
}
#[cfg(all(test, feature = "postgres"))]
mod test {}
pub struct PoolBuilder<S>
where
S: KeyPoolStorage,
{
client: reqwest::Client,
storage: S,
options: crate::PoolOptions<S>,
}
impl<S> PoolBuilder<S>
where
S: KeyPoolStorage,
{
pub fn new(storage: S) -> Self {
Self {
client: reqwest::Client::builder()
.brotli(true)
.http2_keep_alive_timeout(Duration::from_secs(60))
.http2_keep_alive_interval(Duration::from_secs(5))
.https_only(true)
.build()
.unwrap(),
storage,
options: PoolOptions {
comment: None,
error_hooks: Default::default(),
},
}
}
pub fn comment(mut self, c: impl ToString) -> Self {
self.options.comment = Some(c.to_string());
self
}
pub fn error_hook<F>(mut self, code: u16, handler: F) -> Self
where
F: for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
+ Send
+ Sync
+ 'static,
{
self.options.error_hooks.insert(code, Box::new(handler));
self
}
pub fn use_default_hooks(self) -> Self {
self.error_hook(2, |storage, key| {
async move {
storage.remove_key(KeySelector::Id(key.id())).await?;
Ok(true)
}
.boxed()
})
.error_hook(5, |storage, key| {
async move {
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
.await?;
Ok(true)
}
.boxed()
})
.error_hook(10, |storage, key| {
async move {
storage.remove_key(KeySelector::Id(key.id())).await?;
Ok(true)
}
.boxed()
})
.error_hook(13, |storage, key| {
async move {
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
.await?;
Ok(true)
}
.boxed()
})
.error_hook(18, |storage, key| {
async move {
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
.await?;
Ok(true)
}
.boxed()
})
}
pub fn build(self) -> KeyPool<S> {
KeyPool {
client: self.client,
storage: self.storage,
options: Arc::new(self.options),
}
}
}
pub struct KeyPool<S>
where
S: KeyPoolStorage,
{
pub client: reqwest::Client,
pub storage: S,
pub options: Arc<PoolOptions<S>>,
}
impl<S> KeyPool<S>
where
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<S>
where
I: IntoSelector<S::Key, S::Domain>,
{
KeyPoolExecutor::new(self, selector.into_selector())
}
}
fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
if buf.starts_with(br#"{"error":{"#) {
#[derive(Deserialize)]
struct ErrorBody<'a> {
code: u16,
error: &'a str,
}
#[derive(Deserialize)]
struct ErrorContainer<'a> {
#[serde(borrow)]
error: ErrorBody<'a>,
}
let error: ErrorContainer = serde_json::from_slice(buf)?;
Ok(Some(crate::ApiError::new(
error.error.code,
error.error.error,
)))
} else {
Ok(None)
}
}
impl<S> Executor for KeyPoolExecutor<'_, S>
where
S: KeyPoolStorage,
{
type Error = S::Error;
async fn execute<R>(
&self,
request: R,
) -> Result<torn_api::request::ApiResponse<R::Discriminant>, Self::Error>
where
R: torn_api::request::IntoRequest,
{
let request = request.into_request();
self.execute_request(request).await
}
}
#[cfg(test)]
mod test {
use torn_api::executor::ExecutorExt;
use crate::postgres;
use super::*;
#[sqlx::test]
fn name(pool: sqlx::PgPool) {
let (storage, _) = postgres::test::setup(pool).await;
let pool = PoolBuilder::new(storage)
.use_default_hooks()
.comment("test_runner")
.build();
pool.torn_api(postgres::test::Domain::All)
.faction()
.basic(|b| b)
.await
.unwrap();
}
}

View file

@ -1,206 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use torn_api::{
local::{ApiClient, ApiProvider, RequestExecutor},
ApiRequest, ApiResponse, ApiSelection, ResponseError,
};
use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, IntoSelector};
#[async_trait(?Send)]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ApiClient,
S: KeyPoolStorage + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(
&self,
client: &C,
mut request: ApiRequest<A>,
id: Option<String>,
) -> Result<ApiResponse, Self::Error>
where
A: ApiSelection,
{
request.comment = self.comment.map(ToOwned::to_owned);
loop {
let key = self
.storage
.acquire_key(self.selector.clone())
.await
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
let url = request.url(key.value(), id.as_deref());
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(Arc::new)
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(res),
};
}
}
async fn execute_many<A, I>(
&self,
client: &C,
mut request: ApiRequest<A>,
ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>>
where
A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq,
{
let keys = match self
.storage
.acquire_many_keys(self.selector.clone(), ids.len() as i64)
.await
{
Ok(keys) => keys,
Err(why) => {
let shared = Arc::new(why);
return ids
.into_iter()
.map(|i| (i, Err(Self::Error::Storage(shared.clone()))))
.collect();
}
};
request.comment = self.comment.map(ToOwned::to_owned);
let request_ref = &request;
let tuples =
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
let id_string = id.to_string();
loop {
let url = request_ref.url(key.value(), Some(&id_string));
let value = match client.request(url).await {
Ok(v) => v,
Err(why) => return (id, Err(Self::Error::Client(why))),
};
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
match self.storage.flag_key(key, code).await {
Ok(false) => {
return (
id,
Err(KeyPoolError::Response(ResponseError::Api {
code,
reason,
})),
)
}
Ok(true) => (),
Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))),
}
}
Err(parsing_error) => {
return (id, Err(KeyPoolError::Response(parsing_error)))
}
Ok(res) => return (id, Ok(res)),
};
key = match self.storage.acquire_key(self.selector.clone()).await {
Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
};
}
}))
.await;
HashMap::from_iter(tuples)
}
}
#[derive(Clone, Debug)]
pub struct KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: C,
pub storage: S,
comment: Option<String>,
}
impl<C, S> KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage + 'static,
{
pub fn new(client: C, storage: S, comment: Option<String>) -> Self {
Self {
client,
storage,
comment,
}
}
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>> where I: IntoSelector<S::Key, S::Domain> {
ApiProvider::new(
&self.client,
KeyPoolExecutor::new(&self.storage, selector.into_selector(), self.comment.as_deref()),
)
}
}
pub trait WithStorage {
fn with_storage<'a, S, I>(
&'a self,
storage: &'a S,
selector: I
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ApiClient + Sized,
S: KeyPoolStorage + 'static,
I: IntoSelector<S::Key, S::Domain>
{
ApiProvider::new(self, KeyPoolExecutor::new(storage, selector.into_selector(), None))
}
}
#[cfg(feature = "awc")]
impl WithStorage for awc::Client {}
#[cfg(all(test, feature = "postgres", feature = "awc"))]
mod test {
use tokio::test;
use super::*;
use crate::postgres::test::{setup, Domain};
#[test]
async fn test_pool_request() {
let storage = setup().await;
let pool = KeyPool::new(awc::Client::default(), storage);
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.profile().unwrap();
}
#[test]
async fn test_with_storage_request() {
let storage = setup().await;
let response = awc::Client::new()
.with_storage(&storage, Domain::All)
.user(|b| b)
.await
.unwrap();
_ = response.profile().unwrap();
}
}

View file

@ -1,6 +1,4 @@
use std::sync::Arc;
use async_trait::async_trait;
use futures::future::BoxFuture;
use indoc::indoc;
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
use thiserror::Error;
@ -17,13 +15,22 @@ impl<T> PgKeyDomain for T where
{
}
#[derive(Debug, Error, Clone)]
pub enum PgStorageError<D>
#[derive(Debug, Error)]
pub enum PgKeyPoolError<D>
where
D: PgKeyDomain,
{
#[error(transparent)]
Pg(Arc<sqlx::Error>),
#[error("Databank: {0}")]
Pg(#[from] sqlx::Error),
#[error("Network: {0}")]
Network(#[from] reqwest::Error),
#[error("Parsing: {0}")]
Parsing(#[from] serde_json::Error),
#[error("Api: {0}")]
Api(#[from] torn_api::ApiError),
#[error("No key avalaible for domain {0:?}")]
Unavailable(KeySelector<PgKey<D>, D>),
@ -32,15 +39,6 @@ where
KeyNotFound(KeySelector<PgKey<D>, D>),
}
impl<D> From<sqlx::Error> for PgStorageError<D>
where
D: PgKeyDomain,
{
fn from(value: sqlx::Error) -> Self {
Self::Pg(Arc::new(value))
}
}
#[derive(Debug, Clone, FromRow)]
pub struct PgKey<D>
where
@ -127,7 +125,7 @@ where
}
}
pub async fn initialise(&self) -> Result<(), PgStorageError<D>> {
pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
sqlx::query(indoc! {r#"
CREATE TABLE IF NOT EXISTS api_keys (
id serial primary key,
@ -184,19 +182,11 @@ where
#[cfg(feature = "tokio-runtime")]
async fn random_sleep() {
use rand::{thread_rng, Rng};
let dur = tokio::time::Duration::from_millis(thread_rng().gen_range(1..50));
use rand::{rng, Rng};
let dur = tokio::time::Duration::from_millis(rng().random_range(1..50));
tokio::time::sleep(dur).await;
}
#[cfg(all(not(feature = "tokio-runtime"), feature = "actix-runtime"))]
async fn random_sleep() {
use rand::{thread_rng, Rng};
let dur = std::time::Duration::from_millis(thread_rng().gen_range(1..50));
actix_rt::time::sleep(dur).await;
}
#[async_trait]
impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
where
D: PgKeyDomain,
@ -204,7 +194,7 @@ where
type Key = PgKey<D>;
type Domain = D;
type Error = PgStorageError<D>;
type Error = PgKeyPoolError<D>;
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
where
@ -280,13 +270,23 @@ where
match attempt {
Ok(Some(result)) => return Ok(result),
Ok(None) => {
return self
.acquire_key(
selector
.fallback()
.ok_or_else(|| PgStorageError::Unavailable(selector))?,
)
.await
fn recurse<D>(
storage: &PgKeyPoolStorage<D>,
selector: KeySelector<PgKey<D>, D>,
) -> BoxFuture<Result<PgKey<D>, PgKeyPoolError<D>>>
where
D: PgKeyDomain,
{
Box::pin(storage.acquire_key(selector))
}
return recurse(
self,
selector
.fallback()
.ok_or_else(|| PgKeyPoolError::Unavailable(selector))?,
)
.await;
}
Err(error) => {
if let Some(db_error) = error.as_database_error() {
@ -365,7 +365,7 @@ where
let available = max.uses - key.uses;
let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
key.uses += using;
result.extend(std::iter::repeat(key.clone()).take(using as usize));
result.extend(std::iter::repeat_n(key.clone(), using as usize));
if result.len() == number as usize {
break;
@ -406,14 +406,25 @@ where
match attempt {
Ok(Some(result)) => return Ok(result),
Ok(None) => {
return self
.acquire_many_keys(
selector
.fallback()
.ok_or_else(|| Self::Error::Unavailable(selector))?,
number,
)
.await
fn recurse<D>(
storage: &PgKeyPoolStorage<D>,
selector: KeySelector<PgKey<D>, D>,
number: i64,
) -> BoxFuture<Result<Vec<PgKey<D>>, PgKeyPoolError<D>>>
where
D: PgKeyDomain,
{
Box::pin(storage.acquire_many_keys(selector, number))
}
return recurse(
self,
selector
.fallback()
.ok_or_else(|| Self::Error::Unavailable(selector))?,
number,
)
.await;
}
Err(error) => {
if let Some(db_error) = error.as_database_error() {
@ -431,57 +442,24 @@ where
}
}
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
match code {
2 | 10 | 13 => {
// invalid key, owner fedded or owner inactive
sqlx::query(
"update api_keys set cooldown='infinity'::timestamptz, flag=$1 where id=$2",
)
.bind(code as i16)
.bind(key.id)
.execute(&self.pool)
.await?;
Ok(true)
}
5 => {
// too many requests
sqlx::query(
"update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', \
flag=5 where id=$1",
)
.bind(key.id)
.execute(&self.pool)
.await?;
Ok(true)
}
8 => {
// IP block
sqlx::query("update api_keys set cooldown=now() + interval '5 min', flag=8")
.execute(&self.pool)
.await?;
Ok(false)
}
9 => {
// API disabled
sqlx::query("update api_keys set cooldown=now() + interval '1 min', flag=9")
.execute(&self.pool)
.await?;
Ok(false)
}
14 => {
// daily read limit reached
sqlx::query(
"update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', \
flag=14 where id=$1",
)
.bind(key.id)
.execute(&self.pool)
.await?;
Ok(true)
}
_ => Ok(false),
}
async fn timeout_key<S>(
&self,
selector: S,
duration: std::time::Duration,
) -> Result<(), Self::Error>
where
S: IntoSelector<Self::Key, Self::Domain>,
{
let selector = selector.into_selector();
let mut qb = QueryBuilder::new("update api_keys set cooldown=now() + ");
qb.push_bind(duration);
qb.push(" where ");
build_predicate(&mut qb, &selector);
qb.build().fetch_optional(&self.pool).await?;
Ok(())
}
async fn store_key(
@ -546,7 +524,7 @@ where
qb.build_query_as()
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
.ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
}
async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
@ -566,7 +544,7 @@ where
qb.build_query_as()
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
.ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
}
async fn remove_domain_from_key<S>(
@ -590,7 +568,7 @@ where
qb.build_query_as()
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
.ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
}
async fn set_domains_for_key<S>(
@ -612,13 +590,13 @@ where
qb.build_query_as()
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
.ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
}
}
#[cfg(test)]
pub(crate) mod test {
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
use sqlx::Row;
@ -652,7 +630,7 @@ pub(crate) mod test {
storage.initialise().await.unwrap();
let key = storage
.store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All])
.store_key(1, std::env::var("API_KEY").unwrap(), vec![Domain::All])
.await
.unwrap();
@ -816,34 +794,6 @@ pub(crate) mod test {
}
}
#[sqlx::test]
async fn test_flag_key_one(pool: PgPool) {
let (storage, key) = setup(pool).await;
assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_key(Domain::All).await.unwrap_err() {
PgStorageError::Unavailable(KeySelector::Has(domains)) => {
assert_eq!(domains, vec![Domain::All])
}
why => panic!("Expected domain unavailable error but found '{why}'"),
}
}
#[sqlx::test]
async fn test_flag_key_many(pool: PgPool) {
let (storage, key) = setup(pool).await;
assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
PgStorageError::Unavailable(KeySelector::Has(domains)) => {
assert_eq!(domains, vec![Domain::All])
}
why => panic!("Expected domain unavailable error but found '{why}'"),
}
}
#[sqlx::test]
async fn acquire_many(pool: PgPool) {
let (storage, _) = setup(pool).await;
@ -1025,6 +975,16 @@ pub(crate) mod test {
assert!(key.is_some());
}
#[sqlx::test]
async fn timeout(pool: PgPool) {
let (storage, key) = setup(pool).await;
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
.await
.unwrap();
}
#[sqlx::test]
async fn query_by_set(pool: PgPool) {
let (storage, _key) = setup(pool).await;

View file

@ -1,380 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use torn_api::{
send::{ApiClient, ApiProvider, RequestExecutor},
ApiRequest, ApiResponse, ApiSelection, ResponseError,
};
use crate::{
ApiKey, IntoSelector, KeyAction, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage,
KeySelector, PoolOptions,
};
#[async_trait]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(
&self,
client: &C,
mut request: ApiRequest<A>,
id: Option<String>,
) -> Result<A::Response, Self::Error>
where
A: ApiSelection,
{
if request.comment.is_none() {
request.comment = self.options.comment.clone();
}
if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) {
let concrete = hook
.downcast_ref::<BeforeHook<A, S::Key, S::Domain>>()
.unwrap();
(concrete.body)(&mut request, &self.selector);
}
loop {
let key = self
.storage
.acquire_key(self.selector.clone())
.await
.map_err(KeyPoolError::Storage)?;
let url = request.url(key.value(), id.as_deref());
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => {
let res = res.into();
if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) {
let concrete = hook
.downcast_ref::<AfterHook<A, S::Key, S::Domain>>()
.unwrap();
match (concrete.body)(&res, &self.selector) {
Err(KeyAction::Delete) => {
self.storage
.remove_key(key.selector())
.await
.map_err(KeyPoolError::Storage)?;
continue;
}
Err(KeyAction::RemoveDomain(domain)) => {
self.storage
.remove_domain_from_key(key.selector(), domain)
.await
.map_err(KeyPoolError::Storage)?;
continue;
}
_ => (),
};
}
return Ok(res);
}
};
}
}
async fn execute_many<A, I>(
&self,
client: &C,
mut request: ApiRequest<A>,
ids: Vec<I>,
) -> HashMap<I, Result<A::Response, Self::Error>>
where
A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
{
let keys = match self
.storage
.acquire_many_keys(self.selector.clone(), ids.len() as i64)
.await
{
Ok(keys) => keys,
Err(why) => {
return ids
.into_iter()
.map(|i| (i, Err(Self::Error::Storage(why.clone()))))
.collect();
}
};
if request.comment.is_none() {
request.comment = self.options.comment.clone();
}
let request_ref = &request;
let tuples =
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
let id_string = id.to_string();
loop {
let url = request_ref.url(key.value(), Some(&id_string));
let value = match client.request(url).await {
Ok(v) => v,
Err(why) => return (id, Err(Self::Error::Client(why))),
};
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
match self.storage.flag_key(key, code).await {
Ok(false) => {
return (
id,
Err(KeyPoolError::Response(ResponseError::Api {
code,
reason,
})),
)
}
Ok(true) => (),
Err(why) => return (id, Err(KeyPoolError::Storage(why))),
}
}
Err(parsing_error) => {
return (id, Err(KeyPoolError::Response(parsing_error)))
}
Ok(res) => return (id, Ok(res.into())),
};
key = match self.storage.acquire_key(self.selector.clone()).await {
Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(why))),
};
}
}))
.await;
HashMap::from_iter(tuples)
}
}
#[allow(clippy::type_complexity)]
pub struct BeforeHook<A, K, D>
where
A: ApiSelection,
K: ApiKey,
D: KeyDomain,
{
body: Box<dyn Fn(&mut ApiRequest<A>, &KeySelector<K, D>) + Send + Sync + 'static>,
}
#[allow(clippy::type_complexity)]
pub struct AfterHook<A, K, D>
where
A: ApiSelection,
K: ApiKey,
D: KeyDomain,
{
body: Box<
dyn Fn(&A::Response, &KeySelector<K, D>) -> Result<(), crate::KeyAction<D>>
+ Send
+ Sync
+ 'static,
>,
}
pub struct PoolBuilder<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: C,
storage: S,
options: crate::PoolOptions,
}
impl<C, S> PoolBuilder<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
pub fn new(client: C, storage: S) -> Self {
Self {
client,
storage,
options: Default::default(),
}
}
pub fn comment(mut self, c: impl ToString) -> Self {
self.options.comment = Some(c.to_string());
self
}
pub fn hook_before<A>(
mut self,
hook: impl Fn(&mut ApiRequest<A>, &KeySelector<S::Key, S::Domain>) + Send + Sync + 'static,
) -> Self
where
A: ApiSelection + 'static,
{
self.options.hooks_before.insert(
std::any::TypeId::of::<A>(),
Box::new(BeforeHook {
body: Box::new(hook),
}),
);
self
}
pub fn hook_after<A>(
mut self,
hook: impl Fn(&A::Response, &KeySelector<S::Key, S::Domain>) -> Result<(), KeyAction<S::Domain>>
+ Send
+ Sync
+ 'static,
) -> Self
where
A: ApiSelection + 'static,
{
self.options.hooks_after.insert(
std::any::TypeId::of::<A>(),
Box::new(AfterHook::<A, S::Key, S::Domain> {
body: Box::new(hook),
}),
);
self
}
pub fn build(self) -> KeyPool<C, S> {
KeyPool {
client: self.client,
storage: self.storage,
options: Arc::new(self.options),
}
}
}
#[derive(Clone, Debug)]
pub struct KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
pub client: C,
pub storage: S,
pub options: Arc<PoolOptions>,
}
impl<C, S> KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
where
I: IntoSelector<S::Key, S::Domain>,
{
ApiProvider::new(
&self.client,
KeyPoolExecutor::new(
&self.storage,
selector.into_selector(),
self.options.clone(),
),
)
}
}
pub trait WithStorage {
fn with_storage<'a, S, I>(
&'a self,
storage: &'a S,
selector: I,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ApiClient + Sized,
S: KeyPoolStorage + Send + Sync + 'static,
I: IntoSelector<S::Key, S::Domain>,
{
ApiProvider::new(
self,
KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()),
)
}
}
#[cfg(feature = "reqwest")]
impl WithStorage for reqwest::Client {}
#[cfg(all(test, feature = "postgres", feature = "reqwest"))]
mod test {
use sqlx::PgPool;
use super::*;
use crate::{
postgres::test::{setup, Domain},
KeySelector,
};
#[sqlx::test]
async fn test_pool_request(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.comment("api.rs")
.build();
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.profile().unwrap();
}
#[sqlx::test]
async fn test_with_storage_request(pool: PgPool) {
let (storage, _) = setup(pool).await;
let response = reqwest::Client::new()
.with_storage(&storage, Domain::All)
.user(|b| b)
.await
.unwrap();
_ = response.profile().unwrap();
}
#[sqlx::test]
async fn before_hook(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_before::<torn_api::user::UserSelection>(|req, _s| {
req.selections.push("crimes");
})
.build();
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.crimes().unwrap();
}
#[sqlx::test]
async fn after_hook(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_after::<torn_api::user::UserSelection>(|_res, _s| Err(KeyAction::Delete))
.build();
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_some());
let response = pool.torn_api(Domain::All).user(|b| b).await;
assert!(matches!(response, Err(KeyPoolError::Storage(_))));
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_none());
}
}