feat(key-pool): updated key pool to use v2 api
This commit is contained in:
parent
5ae490c756
commit
254ab9c509
1267
Cargo.lock
generated
1267
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,6 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["torn-api", "torn-api-codegen"]
|
||||
members = ["torn-api", "torn-api-codegen", "torn-key-pool"]
|
||||
|
||||
[workspace.package]
|
||||
license-file = "./LICENSE"
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
"info": {
|
||||
"title": "Torn API",
|
||||
"description": "\n * The development of Torn's API v2 is still ongoing.\n * If selections remain unaltered, they will default to the API v1 version.\n * Unlike API v1, API v2 accepts both selections and IDs as path and query parameters.\n * If any discrepancies or errors are found, please submit a [bug report](https://www.torn.com/forums.php#/p=forums&f=19&b=0&a=0) on the Torn Forums.\n * In case you're using bots to check for changes on openapi.json file, make sure to specificy a custom user-agent header - CloudFlare sometimes prevents requests from default user-agents.",
|
||||
"version": "1.3.1"
|
||||
"version": "1.3.2"
|
||||
},
|
||||
"servers": [
|
||||
{
|
||||
|
@ -8773,7 +8773,8 @@
|
|||
"attacksdamaging",
|
||||
"attacksrunaway",
|
||||
"highestterritories",
|
||||
"territoryrespect"
|
||||
"territoryrespect",
|
||||
"membersamount"
|
||||
]
|
||||
},
|
||||
"FactionBranchStateEnum": {
|
||||
|
@ -13864,6 +13865,8 @@
|
|||
"type": "object"
|
||||
},
|
||||
"RacingRaceDetailsResponse": {
|
||||
"properties": {
|
||||
"race": {
|
||||
"allOf": [
|
||||
{
|
||||
"required": [
|
||||
|
@ -13883,6 +13886,9 @@
|
|||
"$ref": "#/components/schemas/Race"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"type": "object"
|
||||
},
|
||||
"RacingSelectionName": {
|
||||
"type": "string",
|
||||
|
@ -15352,7 +15358,14 @@
|
|||
"searchforcash",
|
||||
"shoplifting",
|
||||
"stats",
|
||||
"stocks"
|
||||
"stocks",
|
||||
"chainreport",
|
||||
"rackets",
|
||||
"rankedwarreport",
|
||||
"rankedwars",
|
||||
"territorynames",
|
||||
"territorywarreport",
|
||||
"territorywars"
|
||||
]
|
||||
},
|
||||
"TornLookupResponse": {
|
||||
|
@ -19062,11 +19075,8 @@
|
|||
"format": "int64"
|
||||
},
|
||||
"rewards": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/UserCrimeUniquesReward"
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object"
|
||||
},
|
||||
|
@ -19139,6 +19149,9 @@
|
|||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/UserCrimeDetailsScamming"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -19181,8 +19194,15 @@
|
|||
"$ref": "#/components/schemas/RaceCarId"
|
||||
},
|
||||
"name": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
},
|
||||
"worth": {
|
||||
"type": "integer",
|
||||
"format": "int64"
|
||||
|
@ -19421,11 +19441,8 @@
|
|||
],
|
||||
"properties": {
|
||||
"hof": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/UserHofStats"
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object"
|
||||
},
|
||||
|
|
|
@ -3,41 +3,43 @@ name = "torn-key-pool"
|
|||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
authors = ["Pyrit [2111649]"]
|
||||
license = "MIT"
|
||||
repository = "https://github.com/TotallyNot/torn-api.rs.git"
|
||||
homepage = "https://github.com/TotallyNot/torn-api.rs.git"
|
||||
license-file = { workspace = true }
|
||||
repository = { workspace = true }
|
||||
homepage = { workspace = true }
|
||||
description = "A generalised API key pool for torn-api"
|
||||
|
||||
[features]
|
||||
default = ["postgres", "tokio-runtime"]
|
||||
postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc", "dep:serde" ]
|
||||
reqwest = [ "dep:reqwest", "torn-api/reqwest" ]
|
||||
awc = [ "dep:awc", "torn-api/awc" ]
|
||||
postgres = ["dep:sqlx", "dep:chrono", "dep:indoc"]
|
||||
tokio-runtime = ["dep:tokio", "dep:rand"]
|
||||
actix-runtime = [ "dep:actix-rt", "dep:rand" ]
|
||||
|
||||
[dependencies]
|
||||
torn-api = { path = "../torn-api", default-features = false, version = "0.7" }
|
||||
async-trait = "0.1"
|
||||
torn-api = { path = "../torn-api", default-features = false, version = "1.0.1" }
|
||||
thiserror = "2"
|
||||
|
||||
sqlx = { version = "0.8", features = [ "postgres", "chrono", "json", "derive" ], optional = true, default-features = false }
|
||||
serde = { version = "1.0", optional = true }
|
||||
sqlx = { version = "0.8", features = [
|
||||
"postgres",
|
||||
"chrono",
|
||||
"json",
|
||||
"derive",
|
||||
], optional = true, default-features = false }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
chrono = { version = "0.4", optional = true }
|
||||
indoc = { version = "2", optional = true }
|
||||
tokio = { version = "1", optional = true, default-features = false, features = ["time"] }
|
||||
actix-rt = { version = "2", optional = true, default-features = false }
|
||||
rand = { version = "0.8", optional = true }
|
||||
tokio = { version = "1", optional = true, default-features = false, features = [
|
||||
"time",
|
||||
] }
|
||||
rand = { version = "0.9", optional = true }
|
||||
futures = "0.3"
|
||||
|
||||
reqwest = { version = "0.12", default-features = false, features = [ "json" ], optional = true }
|
||||
awc = { version = "3", default-features = false, optional = true }
|
||||
reqwest = { version = "0.12", default-features = false, features = [
|
||||
"brotli",
|
||||
"http2",
|
||||
"rustls-tls-webpki-roots",
|
||||
] }
|
||||
|
||||
[dev-dependencies]
|
||||
torn-api = { path = "../torn-api", features = [ "reqwest" ] }
|
||||
torn-api = { path = "../torn-api" }
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio-rustls"] }
|
||||
dotenvy = "0.15"
|
||||
tokio = { version = "1.42", features = ["rt"] }
|
||||
tokio-test = "0.4"
|
||||
reqwest = { version = "0.12", default-features = true }
|
||||
awc = { version = "3", features = [ "rustls" ] }
|
||||
|
|
|
@ -3,48 +3,23 @@
|
|||
#[cfg(feature = "postgres")]
|
||||
pub mod postgres;
|
||||
|
||||
// pub mod local;
|
||||
pub mod send;
|
||||
use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
|
||||
|
||||
use std::sync::Arc;
|
||||
use futures::{future::BoxFuture, FutureExt};
|
||||
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
|
||||
use serde::Deserialize;
|
||||
use torn_api::{
|
||||
executor::Executor,
|
||||
request::{ApiRequest, ApiResponse},
|
||||
ApiError,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
|
||||
|
||||
use torn_api::ResponseError;
|
||||
impl<T> ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum KeyPoolError<S, C>
|
||||
where
|
||||
S: std::error::Error + Clone,
|
||||
C: std::error::Error,
|
||||
{
|
||||
#[error("Key pool storage driver error: {0:?}")]
|
||||
Storage(#[source] S),
|
||||
|
||||
#[error(transparent)]
|
||||
Client(#[from] C),
|
||||
|
||||
#[error(transparent)]
|
||||
Response(ResponseError),
|
||||
}
|
||||
|
||||
impl<S, C> KeyPoolError<S, C>
|
||||
where
|
||||
S: std::error::Error + Clone,
|
||||
C: std::error::Error,
|
||||
{
|
||||
#[inline(always)]
|
||||
pub fn api_code(&self) -> Option<u8> {
|
||||
match self {
|
||||
Self::Response(why) => why.api_code(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone + 'static {
|
||||
type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone;
|
||||
pub trait ApiKey: Send + Sync + Clone + 'static {
|
||||
type IdType: ApiKeyId;
|
||||
|
||||
fn value(&self) -> &str;
|
||||
|
||||
|
@ -105,7 +80,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub trait IntoSelector<K, D>: Send + Sync
|
||||
pub trait IntoSelector<K, D>: Send
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
|
@ -133,114 +108,347 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub enum KeyAction<D>
|
||||
where
|
||||
D: KeyDomain,
|
||||
{
|
||||
Delete,
|
||||
RemoveDomain(D),
|
||||
Timeout(chrono::Duration),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait KeyPoolStorage {
|
||||
pub trait KeyPoolStorage: Send + Sync {
|
||||
type Key: ApiKey;
|
||||
type Domain: KeyDomain;
|
||||
type Error: std::error::Error + Sync + Send + Clone;
|
||||
type Error: From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + Send;
|
||||
|
||||
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||
fn acquire_key<S>(
|
||||
&self,
|
||||
selector: S,
|
||||
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn acquire_many_keys<S>(
|
||||
fn acquire_many_keys<S>(
|
||||
&self,
|
||||
selector: S,
|
||||
number: i64,
|
||||
) -> Result<Vec<Self::Key>, Self::Error>
|
||||
) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
|
||||
|
||||
async fn store_key(
|
||||
fn store_key(
|
||||
&self,
|
||||
user_id: i32,
|
||||
key: String,
|
||||
domains: Vec<Self::Domain>,
|
||||
) -> Result<Self::Key, Self::Error>;
|
||||
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send;
|
||||
|
||||
async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
|
||||
fn read_key<S>(
|
||||
&self,
|
||||
selector: S,
|
||||
) -> impl Future<Output = Result<Option<Self::Key>, Self::Error>> + Send
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
|
||||
fn read_keys<S>(
|
||||
&self,
|
||||
selector: S,
|
||||
) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||
fn remove_key<S>(
|
||||
&self,
|
||||
selector: S,
|
||||
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn add_domain_to_key<S>(
|
||||
fn add_domain_to_key<S>(
|
||||
&self,
|
||||
selector: S,
|
||||
domain: Self::Domain,
|
||||
) -> Result<Self::Key, Self::Error>
|
||||
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn remove_domain_from_key<S>(
|
||||
fn remove_domain_from_key<S>(
|
||||
&self,
|
||||
selector: S,
|
||||
domain: Self::Domain,
|
||||
) -> Result<Self::Key, Self::Error>
|
||||
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn set_domains_for_key<S>(
|
||||
fn set_domains_for_key<S>(
|
||||
&self,
|
||||
selector: S,
|
||||
domains: Vec<Self::Domain>,
|
||||
) -> Result<Self::Key, Self::Error>
|
||||
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
fn timeout_key<S>(
|
||||
&self,
|
||||
selector: S,
|
||||
duration: Duration,
|
||||
) -> impl Future<Output = Result<(), Self::Error>> + Send
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PoolOptions {
|
||||
#[derive(Default)]
|
||||
pub struct PoolOptions<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
comment: Option<String>,
|
||||
hooks_before: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
|
||||
hooks_after: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
error_hooks: HashMap<
|
||||
u16,
|
||||
Box<
|
||||
dyn for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>,
|
||||
>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KeyPoolExecutor<'a, C, S>
|
||||
pub struct KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
storage: &'a S,
|
||||
options: Arc<PoolOptions>,
|
||||
pool: &'p KeyPool<S>,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
_marker: std::marker::PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<'a, C, S> KeyPoolExecutor<'a, C, S>
|
||||
impl<'p, S> KeyPoolExecutor<'p, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(
|
||||
storage: &'a S,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
options: Arc<PoolOptions>,
|
||||
) -> Self {
|
||||
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
|
||||
Self { pool, selector }
|
||||
}
|
||||
|
||||
async fn execute_request<D>(&self, request: ApiRequest<D>) -> Result<ApiResponse<D>, S::Error>
|
||||
where
|
||||
D: Send,
|
||||
{
|
||||
let key = self.pool.storage.acquire_key(self.selector.clone()).await?;
|
||||
|
||||
let mut headers = HeaderMap::with_capacity(1);
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.pool
|
||||
.client
|
||||
.get(request.url())
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
|
||||
let bytes = resp.bytes().await?;
|
||||
|
||||
if let Some(err) = decode_error(&bytes)? {
|
||||
if let Some(handler) = self.pool.options.error_hooks.get(&err.code()) {
|
||||
let retry = (*handler)(&self.pool.storage, &key).await?;
|
||||
|
||||
if retry {
|
||||
return Box::pin(self.execute_request(request)).await;
|
||||
}
|
||||
}
|
||||
Err(err.into())
|
||||
} else {
|
||||
Ok(ApiResponse {
|
||||
discriminant: request.disriminant,
|
||||
body: Some(bytes),
|
||||
status,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PoolBuilder<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
client: reqwest::Client,
|
||||
storage: S,
|
||||
options: crate::PoolOptions<S>,
|
||||
}
|
||||
|
||||
impl<S> PoolBuilder<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(storage: S) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::builder()
|
||||
.brotli(true)
|
||||
.http2_keep_alive_timeout(Duration::from_secs(60))
|
||||
.http2_keep_alive_interval(Duration::from_secs(5))
|
||||
.https_only(true)
|
||||
.build()
|
||||
.unwrap(),
|
||||
storage,
|
||||
selector,
|
||||
options,
|
||||
_marker: std::marker::PhantomData,
|
||||
options: PoolOptions {
|
||||
comment: None,
|
||||
error_hooks: Default::default(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn comment(mut self, c: impl ToString) -> Self {
|
||||
self.options.comment = Some(c.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn error_hook<F>(mut self, code: u16, handler: F) -> Self
|
||||
where
|
||||
F: for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static,
|
||||
{
|
||||
self.options.error_hooks.insert(code, Box::new(handler));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn use_default_hooks(self) -> Self {
|
||||
self.error_hook(2, |storage, key| {
|
||||
async move {
|
||||
storage.remove_key(KeySelector::Id(key.id())).await?;
|
||||
Ok(true)
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.error_hook(5, |storage, key| {
|
||||
async move {
|
||||
storage
|
||||
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
|
||||
.await?;
|
||||
Ok(true)
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.error_hook(10, |storage, key| {
|
||||
async move {
|
||||
storage.remove_key(KeySelector::Id(key.id())).await?;
|
||||
Ok(true)
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.error_hook(13, |storage, key| {
|
||||
async move {
|
||||
storage
|
||||
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
|
||||
.await?;
|
||||
Ok(true)
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.error_hook(18, |storage, key| {
|
||||
async move {
|
||||
storage
|
||||
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
|
||||
.await?;
|
||||
Ok(true)
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn build(self) -> KeyPool<S> {
|
||||
KeyPool {
|
||||
client: self.client,
|
||||
storage: self.storage,
|
||||
options: Arc::new(self.options),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "postgres"))]
|
||||
mod test {}
|
||||
pub struct KeyPool<S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub client: reqwest::Client,
|
||||
pub storage: S,
|
||||
pub options: Arc<PoolOptions<S>>,
|
||||
}
|
||||
|
||||
impl<S> KeyPool<S>
|
||||
where
|
||||
S: KeyPoolStorage + Send + Sync + 'static,
|
||||
{
|
||||
pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<S>
|
||||
where
|
||||
I: IntoSelector<S::Key, S::Domain>,
|
||||
{
|
||||
KeyPoolExecutor::new(self, selector.into_selector())
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
|
||||
if buf.starts_with(br#"{"error":{"#) {
|
||||
#[derive(Deserialize)]
|
||||
struct ErrorBody<'a> {
|
||||
code: u16,
|
||||
error: &'a str,
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
struct ErrorContainer<'a> {
|
||||
#[serde(borrow)]
|
||||
error: ErrorBody<'a>,
|
||||
}
|
||||
|
||||
let error: ErrorContainer = serde_json::from_slice(buf)?;
|
||||
Ok(Some(crate::ApiError::new(
|
||||
error.error.code,
|
||||
error.error.error,
|
||||
)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Executor for KeyPoolExecutor<'_, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
type Error = S::Error;
|
||||
|
||||
async fn execute<R>(
|
||||
&self,
|
||||
request: R,
|
||||
) -> Result<torn_api::request::ApiResponse<R::Discriminant>, Self::Error>
|
||||
where
|
||||
R: torn_api::request::IntoRequest,
|
||||
{
|
||||
let request = request.into_request();
|
||||
|
||||
self.execute_request(request).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use torn_api::executor::ExecutorExt;
|
||||
|
||||
use crate::postgres;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[sqlx::test]
|
||||
fn name(pool: sqlx::PgPool) {
|
||||
let (storage, _) = postgres::test::setup(pool).await;
|
||||
|
||||
let pool = PoolBuilder::new(storage)
|
||||
.use_default_hooks()
|
||||
.comment("test_runner")
|
||||
.build();
|
||||
|
||||
pool.torn_api(postgres::test::Domain::All)
|
||||
.faction()
|
||||
.basic(|b| b)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,206 +0,0 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use torn_api::{
|
||||
local::{ApiClient, ApiProvider, RequestExecutor},
|
||||
ApiRequest, ApiResponse, ApiSelection, ResponseError,
|
||||
};
|
||||
|
||||
use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, IntoSelector};
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
||||
where
|
||||
C: ApiClient,
|
||||
S: KeyPoolStorage + 'static,
|
||||
{
|
||||
type Error = KeyPoolError<S::Error, C::Error>;
|
||||
|
||||
async fn execute<A>(
|
||||
&self,
|
||||
client: &C,
|
||||
mut request: ApiRequest<A>,
|
||||
id: Option<String>,
|
||||
) -> Result<ApiResponse, Self::Error>
|
||||
where
|
||||
A: ApiSelection,
|
||||
{
|
||||
request.comment = self.comment.map(ToOwned::to_owned);
|
||||
loop {
|
||||
let key = self
|
||||
.storage
|
||||
.acquire_key(self.selector.clone())
|
||||
.await
|
||||
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
|
||||
let url = request.url(key.value(), id.as_deref());
|
||||
let value = client.request(url).await?;
|
||||
|
||||
match ApiResponse::from_value(value) {
|
||||
Err(ResponseError::Api { code, reason }) => {
|
||||
if !self
|
||||
.storage
|
||||
.flag_key(key, code)
|
||||
.await
|
||||
.map_err(Arc::new)
|
||||
.map_err(KeyPoolError::Storage)?
|
||||
{
|
||||
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
|
||||
}
|
||||
}
|
||||
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
|
||||
Ok(res) => return Ok(res),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_many<A, I>(
|
||||
&self,
|
||||
client: &C,
|
||||
mut request: ApiRequest<A>,
|
||||
ids: Vec<I>,
|
||||
) -> HashMap<I, Result<ApiResponse, Self::Error>>
|
||||
where
|
||||
A: ApiSelection,
|
||||
I: ToString + std::hash::Hash + std::cmp::Eq,
|
||||
{
|
||||
let keys = match self
|
||||
.storage
|
||||
.acquire_many_keys(self.selector.clone(), ids.len() as i64)
|
||||
.await
|
||||
{
|
||||
Ok(keys) => keys,
|
||||
Err(why) => {
|
||||
let shared = Arc::new(why);
|
||||
return ids
|
||||
.into_iter()
|
||||
.map(|i| (i, Err(Self::Error::Storage(shared.clone()))))
|
||||
.collect();
|
||||
}
|
||||
};
|
||||
|
||||
request.comment = self.comment.map(ToOwned::to_owned);
|
||||
let request_ref = &request;
|
||||
|
||||
let tuples =
|
||||
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
|
||||
let id_string = id.to_string();
|
||||
loop {
|
||||
let url = request_ref.url(key.value(), Some(&id_string));
|
||||
let value = match client.request(url).await {
|
||||
Ok(v) => v,
|
||||
Err(why) => return (id, Err(Self::Error::Client(why))),
|
||||
};
|
||||
|
||||
match ApiResponse::from_value(value) {
|
||||
Err(ResponseError::Api { code, reason }) => {
|
||||
match self.storage.flag_key(key, code).await {
|
||||
Ok(false) => {
|
||||
return (
|
||||
id,
|
||||
Err(KeyPoolError::Response(ResponseError::Api {
|
||||
code,
|
||||
reason,
|
||||
})),
|
||||
)
|
||||
}
|
||||
Ok(true) => (),
|
||||
Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))),
|
||||
}
|
||||
}
|
||||
Err(parsing_error) => {
|
||||
return (id, Err(KeyPoolError::Response(parsing_error)))
|
||||
}
|
||||
Ok(res) => return (id, Ok(res)),
|
||||
};
|
||||
|
||||
key = match self.storage.acquire_key(self.selector.clone()).await {
|
||||
Ok(k) => k,
|
||||
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
|
||||
};
|
||||
}
|
||||
}))
|
||||
.await;
|
||||
|
||||
HashMap::from_iter(tuples)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct KeyPool<C, S>
|
||||
where
|
||||
C: ApiClient,
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
client: C,
|
||||
pub storage: S,
|
||||
comment: Option<String>,
|
||||
}
|
||||
|
||||
impl<C, S> KeyPool<C, S>
|
||||
where
|
||||
C: ApiClient,
|
||||
S: KeyPoolStorage + 'static,
|
||||
{
|
||||
pub fn new(client: C, storage: S, comment: Option<String>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
storage,
|
||||
comment,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>> where I: IntoSelector<S::Key, S::Domain> {
|
||||
ApiProvider::new(
|
||||
&self.client,
|
||||
KeyPoolExecutor::new(&self.storage, selector.into_selector(), self.comment.as_deref()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithStorage {
|
||||
fn with_storage<'a, S, I>(
|
||||
&'a self,
|
||||
storage: &'a S,
|
||||
selector: I
|
||||
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
|
||||
where
|
||||
Self: ApiClient + Sized,
|
||||
S: KeyPoolStorage + 'static,
|
||||
I: IntoSelector<S::Key, S::Domain>
|
||||
{
|
||||
ApiProvider::new(self, KeyPoolExecutor::new(storage, selector.into_selector(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "awc")]
|
||||
impl WithStorage for awc::Client {}
|
||||
|
||||
#[cfg(all(test, feature = "postgres", feature = "awc"))]
|
||||
mod test {
|
||||
use tokio::test;
|
||||
|
||||
use super::*;
|
||||
use crate::postgres::test::{setup, Domain};
|
||||
|
||||
#[test]
|
||||
async fn test_pool_request() {
|
||||
let storage = setup().await;
|
||||
let pool = KeyPool::new(awc::Client::default(), storage);
|
||||
|
||||
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
|
||||
_ = response.profile().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_with_storage_request() {
|
||||
let storage = setup().await;
|
||||
|
||||
let response = awc::Client::new()
|
||||
.with_storage(&storage, Domain::All)
|
||||
.user(|b| b)
|
||||
.await
|
||||
.unwrap();
|
||||
_ = response.profile().unwrap();
|
||||
}
|
||||
}
|
|
@ -1,6 +1,4 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::future::BoxFuture;
|
||||
use indoc::indoc;
|
||||
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
|
||||
use thiserror::Error;
|
||||
|
@ -17,13 +15,22 @@ impl<T> PgKeyDomain for T where
|
|||
{
|
||||
}
|
||||
|
||||
#[derive(Debug, Error, Clone)]
|
||||
pub enum PgStorageError<D>
|
||||
#[derive(Debug, Error)]
|
||||
pub enum PgKeyPoolError<D>
|
||||
where
|
||||
D: PgKeyDomain,
|
||||
{
|
||||
#[error(transparent)]
|
||||
Pg(Arc<sqlx::Error>),
|
||||
#[error("Databank: {0}")]
|
||||
Pg(#[from] sqlx::Error),
|
||||
|
||||
#[error("Network: {0}")]
|
||||
Network(#[from] reqwest::Error),
|
||||
|
||||
#[error("Parsing: {0}")]
|
||||
Parsing(#[from] serde_json::Error),
|
||||
|
||||
#[error("Api: {0}")]
|
||||
Api(#[from] torn_api::ApiError),
|
||||
|
||||
#[error("No key avalaible for domain {0:?}")]
|
||||
Unavailable(KeySelector<PgKey<D>, D>),
|
||||
|
@ -32,15 +39,6 @@ where
|
|||
KeyNotFound(KeySelector<PgKey<D>, D>),
|
||||
}
|
||||
|
||||
impl<D> From<sqlx::Error> for PgStorageError<D>
|
||||
where
|
||||
D: PgKeyDomain,
|
||||
{
|
||||
fn from(value: sqlx::Error) -> Self {
|
||||
Self::Pg(Arc::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow)]
|
||||
pub struct PgKey<D>
|
||||
where
|
||||
|
@ -127,7 +125,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn initialise(&self) -> Result<(), PgStorageError<D>> {
|
||||
pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
|
||||
sqlx::query(indoc! {r#"
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
id serial primary key,
|
||||
|
@ -184,19 +182,11 @@ where
|
|||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
async fn random_sleep() {
|
||||
use rand::{thread_rng, Rng};
|
||||
let dur = tokio::time::Duration::from_millis(thread_rng().gen_range(1..50));
|
||||
use rand::{rng, Rng};
|
||||
let dur = tokio::time::Duration::from_millis(rng().random_range(1..50));
|
||||
tokio::time::sleep(dur).await;
|
||||
}
|
||||
|
||||
#[cfg(all(not(feature = "tokio-runtime"), feature = "actix-runtime"))]
|
||||
async fn random_sleep() {
|
||||
use rand::{thread_rng, Rng};
|
||||
let dur = std::time::Duration::from_millis(thread_rng().gen_range(1..50));
|
||||
actix_rt::time::sleep(dur).await;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
|
||||
where
|
||||
D: PgKeyDomain,
|
||||
|
@ -204,7 +194,7 @@ where
|
|||
type Key = PgKey<D>;
|
||||
type Domain = D;
|
||||
|
||||
type Error = PgStorageError<D>;
|
||||
type Error = PgKeyPoolError<D>;
|
||||
|
||||
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
|
@ -280,13 +270,23 @@ where
|
|||
match attempt {
|
||||
Ok(Some(result)) => return Ok(result),
|
||||
Ok(None) => {
|
||||
return self
|
||||
.acquire_key(
|
||||
fn recurse<D>(
|
||||
storage: &PgKeyPoolStorage<D>,
|
||||
selector: KeySelector<PgKey<D>, D>,
|
||||
) -> BoxFuture<Result<PgKey<D>, PgKeyPoolError<D>>>
|
||||
where
|
||||
D: PgKeyDomain,
|
||||
{
|
||||
Box::pin(storage.acquire_key(selector))
|
||||
}
|
||||
|
||||
return recurse(
|
||||
self,
|
||||
selector
|
||||
.fallback()
|
||||
.ok_or_else(|| PgStorageError::Unavailable(selector))?,
|
||||
.ok_or_else(|| PgKeyPoolError::Unavailable(selector))?,
|
||||
)
|
||||
.await
|
||||
.await;
|
||||
}
|
||||
Err(error) => {
|
||||
if let Some(db_error) = error.as_database_error() {
|
||||
|
@ -365,7 +365,7 @@ where
|
|||
let available = max.uses - key.uses;
|
||||
let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
|
||||
key.uses += using;
|
||||
result.extend(std::iter::repeat(key.clone()).take(using as usize));
|
||||
result.extend(std::iter::repeat_n(key.clone(), using as usize));
|
||||
|
||||
if result.len() == number as usize {
|
||||
break;
|
||||
|
@ -406,14 +406,25 @@ where
|
|||
match attempt {
|
||||
Ok(Some(result)) => return Ok(result),
|
||||
Ok(None) => {
|
||||
return self
|
||||
.acquire_many_keys(
|
||||
fn recurse<D>(
|
||||
storage: &PgKeyPoolStorage<D>,
|
||||
selector: KeySelector<PgKey<D>, D>,
|
||||
number: i64,
|
||||
) -> BoxFuture<Result<Vec<PgKey<D>>, PgKeyPoolError<D>>>
|
||||
where
|
||||
D: PgKeyDomain,
|
||||
{
|
||||
Box::pin(storage.acquire_many_keys(selector, number))
|
||||
}
|
||||
|
||||
return recurse(
|
||||
self,
|
||||
selector
|
||||
.fallback()
|
||||
.ok_or_else(|| Self::Error::Unavailable(selector))?,
|
||||
number,
|
||||
)
|
||||
.await
|
||||
.await;
|
||||
}
|
||||
Err(error) => {
|
||||
if let Some(db_error) = error.as_database_error() {
|
||||
|
@ -431,57 +442,24 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
|
||||
match code {
|
||||
2 | 10 | 13 => {
|
||||
// invalid key, owner fedded or owner inactive
|
||||
sqlx::query(
|
||||
"update api_keys set cooldown='infinity'::timestamptz, flag=$1 where id=$2",
|
||||
)
|
||||
.bind(code as i16)
|
||||
.bind(key.id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(true)
|
||||
}
|
||||
5 => {
|
||||
// too many requests
|
||||
sqlx::query(
|
||||
"update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', \
|
||||
flag=5 where id=$1",
|
||||
)
|
||||
.bind(key.id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(true)
|
||||
}
|
||||
8 => {
|
||||
// IP block
|
||||
sqlx::query("update api_keys set cooldown=now() + interval '5 min', flag=8")
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(false)
|
||||
}
|
||||
9 => {
|
||||
// API disabled
|
||||
sqlx::query("update api_keys set cooldown=now() + interval '1 min', flag=9")
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(false)
|
||||
}
|
||||
14 => {
|
||||
// daily read limit reached
|
||||
sqlx::query(
|
||||
"update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', \
|
||||
flag=14 where id=$1",
|
||||
)
|
||||
.bind(key.id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(true)
|
||||
}
|
||||
_ => Ok(false),
|
||||
}
|
||||
async fn timeout_key<S>(
|
||||
&self,
|
||||
selector: S,
|
||||
duration: std::time::Duration,
|
||||
) -> Result<(), Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>,
|
||||
{
|
||||
let selector = selector.into_selector();
|
||||
|
||||
let mut qb = QueryBuilder::new("update api_keys set cooldown=now() + ");
|
||||
qb.push_bind(duration);
|
||||
qb.push(" where ");
|
||||
build_predicate(&mut qb, &selector);
|
||||
|
||||
qb.build().fetch_optional(&self.pool).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn store_key(
|
||||
|
@ -546,7 +524,7 @@ where
|
|||
qb.build_query_as()
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||
.ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
|
||||
}
|
||||
|
||||
async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
|
||||
|
@ -566,7 +544,7 @@ where
|
|||
qb.build_query_as()
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||
.ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
|
||||
}
|
||||
|
||||
async fn remove_domain_from_key<S>(
|
||||
|
@ -590,7 +568,7 @@ where
|
|||
qb.build_query_as()
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||
.ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
|
||||
}
|
||||
|
||||
async fn set_domains_for_key<S>(
|
||||
|
@ -612,13 +590,13 @@ where
|
|||
qb.build_query_as()
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||
.ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test {
|
||||
use std::sync::Arc;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use sqlx::Row;
|
||||
|
||||
|
@ -652,7 +630,7 @@ pub(crate) mod test {
|
|||
storage.initialise().await.unwrap();
|
||||
|
||||
let key = storage
|
||||
.store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All])
|
||||
.store_key(1, std::env::var("API_KEY").unwrap(), vec![Domain::All])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -816,34 +794,6 @@ pub(crate) mod test {
|
|||
}
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn test_flag_key_one(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
|
||||
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||
|
||||
match storage.acquire_key(Domain::All).await.unwrap_err() {
|
||||
PgStorageError::Unavailable(KeySelector::Has(domains)) => {
|
||||
assert_eq!(domains, vec![Domain::All])
|
||||
}
|
||||
why => panic!("Expected domain unavailable error but found '{why}'"),
|
||||
}
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn test_flag_key_many(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
|
||||
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||
|
||||
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
|
||||
PgStorageError::Unavailable(KeySelector::Has(domains)) => {
|
||||
assert_eq!(domains, vec![Domain::All])
|
||||
}
|
||||
why => panic!("Expected domain unavailable error but found '{why}'"),
|
||||
}
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn acquire_many(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
@ -1025,6 +975,16 @@ pub(crate) mod test {
|
|||
assert!(key.is_some());
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn timeout(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
|
||||
storage
|
||||
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn query_by_set(pool: PgPool) {
|
||||
let (storage, _key) = setup(pool).await;
|
||||
|
|
|
@ -1,380 +0,0 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use torn_api::{
|
||||
send::{ApiClient, ApiProvider, RequestExecutor},
|
||||
ApiRequest, ApiResponse, ApiSelection, ResponseError,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
ApiKey, IntoSelector, KeyAction, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage,
|
||||
KeySelector, PoolOptions,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
||||
where
|
||||
C: ApiClient,
|
||||
S: KeyPoolStorage + Send + Sync + 'static,
|
||||
{
|
||||
type Error = KeyPoolError<S::Error, C::Error>;
|
||||
|
||||
async fn execute<A>(
|
||||
&self,
|
||||
client: &C,
|
||||
mut request: ApiRequest<A>,
|
||||
id: Option<String>,
|
||||
) -> Result<A::Response, Self::Error>
|
||||
where
|
||||
A: ApiSelection,
|
||||
{
|
||||
if request.comment.is_none() {
|
||||
request.comment = self.options.comment.clone();
|
||||
}
|
||||
if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) {
|
||||
let concrete = hook
|
||||
.downcast_ref::<BeforeHook<A, S::Key, S::Domain>>()
|
||||
.unwrap();
|
||||
|
||||
(concrete.body)(&mut request, &self.selector);
|
||||
}
|
||||
loop {
|
||||
let key = self
|
||||
.storage
|
||||
.acquire_key(self.selector.clone())
|
||||
.await
|
||||
.map_err(KeyPoolError::Storage)?;
|
||||
let url = request.url(key.value(), id.as_deref());
|
||||
let value = client.request(url).await?;
|
||||
|
||||
match ApiResponse::from_value(value) {
|
||||
Err(ResponseError::Api { code, reason }) => {
|
||||
if !self
|
||||
.storage
|
||||
.flag_key(key, code)
|
||||
.await
|
||||
.map_err(KeyPoolError::Storage)?
|
||||
{
|
||||
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
|
||||
}
|
||||
}
|
||||
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
|
||||
Ok(res) => {
|
||||
let res = res.into();
|
||||
if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) {
|
||||
let concrete = hook
|
||||
.downcast_ref::<AfterHook<A, S::Key, S::Domain>>()
|
||||
.unwrap();
|
||||
|
||||
match (concrete.body)(&res, &self.selector) {
|
||||
Err(KeyAction::Delete) => {
|
||||
self.storage
|
||||
.remove_key(key.selector())
|
||||
.await
|
||||
.map_err(KeyPoolError::Storage)?;
|
||||
continue;
|
||||
}
|
||||
Err(KeyAction::RemoveDomain(domain)) => {
|
||||
self.storage
|
||||
.remove_domain_from_key(key.selector(), domain)
|
||||
.await
|
||||
.map_err(KeyPoolError::Storage)?;
|
||||
continue;
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
return Ok(res);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_many<A, I>(
|
||||
&self,
|
||||
client: &C,
|
||||
mut request: ApiRequest<A>,
|
||||
ids: Vec<I>,
|
||||
) -> HashMap<I, Result<A::Response, Self::Error>>
|
||||
where
|
||||
A: ApiSelection,
|
||||
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
|
||||
{
|
||||
let keys = match self
|
||||
.storage
|
||||
.acquire_many_keys(self.selector.clone(), ids.len() as i64)
|
||||
.await
|
||||
{
|
||||
Ok(keys) => keys,
|
||||
Err(why) => {
|
||||
return ids
|
||||
.into_iter()
|
||||
.map(|i| (i, Err(Self::Error::Storage(why.clone()))))
|
||||
.collect();
|
||||
}
|
||||
};
|
||||
|
||||
if request.comment.is_none() {
|
||||
request.comment = self.options.comment.clone();
|
||||
}
|
||||
let request_ref = &request;
|
||||
|
||||
let tuples =
|
||||
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
|
||||
let id_string = id.to_string();
|
||||
loop {
|
||||
let url = request_ref.url(key.value(), Some(&id_string));
|
||||
let value = match client.request(url).await {
|
||||
Ok(v) => v,
|
||||
Err(why) => return (id, Err(Self::Error::Client(why))),
|
||||
};
|
||||
|
||||
match ApiResponse::from_value(value) {
|
||||
Err(ResponseError::Api { code, reason }) => {
|
||||
match self.storage.flag_key(key, code).await {
|
||||
Ok(false) => {
|
||||
return (
|
||||
id,
|
||||
Err(KeyPoolError::Response(ResponseError::Api {
|
||||
code,
|
||||
reason,
|
||||
})),
|
||||
)
|
||||
}
|
||||
Ok(true) => (),
|
||||
Err(why) => return (id, Err(KeyPoolError::Storage(why))),
|
||||
}
|
||||
}
|
||||
Err(parsing_error) => {
|
||||
return (id, Err(KeyPoolError::Response(parsing_error)))
|
||||
}
|
||||
Ok(res) => return (id, Ok(res.into())),
|
||||
};
|
||||
|
||||
key = match self.storage.acquire_key(self.selector.clone()).await {
|
||||
Ok(k) => k,
|
||||
Err(why) => return (id, Err(Self::Error::Storage(why))),
|
||||
};
|
||||
}
|
||||
}))
|
||||
.await;
|
||||
|
||||
HashMap::from_iter(tuples)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct BeforeHook<A, K, D>
|
||||
where
|
||||
A: ApiSelection,
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
body: Box<dyn Fn(&mut ApiRequest<A>, &KeySelector<K, D>) + Send + Sync + 'static>,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct AfterHook<A, K, D>
|
||||
where
|
||||
A: ApiSelection,
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
body: Box<
|
||||
dyn Fn(&A::Response, &KeySelector<K, D>) -> Result<(), crate::KeyAction<D>>
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static,
|
||||
>,
|
||||
}
|
||||
|
||||
pub struct PoolBuilder<C, S>
|
||||
where
|
||||
C: ApiClient,
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
client: C,
|
||||
storage: S,
|
||||
options: crate::PoolOptions,
|
||||
}
|
||||
|
||||
impl<C, S> PoolBuilder<C, S>
|
||||
where
|
||||
C: ApiClient,
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(client: C, storage: S) -> Self {
|
||||
Self {
|
||||
client,
|
||||
storage,
|
||||
options: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn comment(mut self, c: impl ToString) -> Self {
|
||||
self.options.comment = Some(c.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn hook_before<A>(
|
||||
mut self,
|
||||
hook: impl Fn(&mut ApiRequest<A>, &KeySelector<S::Key, S::Domain>) + Send + Sync + 'static,
|
||||
) -> Self
|
||||
where
|
||||
A: ApiSelection + 'static,
|
||||
{
|
||||
self.options.hooks_before.insert(
|
||||
std::any::TypeId::of::<A>(),
|
||||
Box::new(BeforeHook {
|
||||
body: Box::new(hook),
|
||||
}),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn hook_after<A>(
|
||||
mut self,
|
||||
hook: impl Fn(&A::Response, &KeySelector<S::Key, S::Domain>) -> Result<(), KeyAction<S::Domain>>
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static,
|
||||
) -> Self
|
||||
where
|
||||
A: ApiSelection + 'static,
|
||||
{
|
||||
self.options.hooks_after.insert(
|
||||
std::any::TypeId::of::<A>(),
|
||||
Box::new(AfterHook::<A, S::Key, S::Domain> {
|
||||
body: Box::new(hook),
|
||||
}),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> KeyPool<C, S> {
|
||||
KeyPool {
|
||||
client: self.client,
|
||||
storage: self.storage,
|
||||
options: Arc::new(self.options),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct KeyPool<C, S>
|
||||
where
|
||||
C: ApiClient,
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub client: C,
|
||||
pub storage: S,
|
||||
pub options: Arc<PoolOptions>,
|
||||
}
|
||||
|
||||
impl<C, S> KeyPool<C, S>
|
||||
where
|
||||
C: ApiClient,
|
||||
S: KeyPoolStorage + Send + Sync + 'static,
|
||||
{
|
||||
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
|
||||
where
|
||||
I: IntoSelector<S::Key, S::Domain>,
|
||||
{
|
||||
ApiProvider::new(
|
||||
&self.client,
|
||||
KeyPoolExecutor::new(
|
||||
&self.storage,
|
||||
selector.into_selector(),
|
||||
self.options.clone(),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithStorage {
|
||||
fn with_storage<'a, S, I>(
|
||||
&'a self,
|
||||
storage: &'a S,
|
||||
selector: I,
|
||||
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
|
||||
where
|
||||
Self: ApiClient + Sized,
|
||||
S: KeyPoolStorage + Send + Sync + 'static,
|
||||
I: IntoSelector<S::Key, S::Domain>,
|
||||
{
|
||||
ApiProvider::new(
|
||||
self,
|
||||
KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "reqwest")]
|
||||
impl WithStorage for reqwest::Client {}
|
||||
|
||||
#[cfg(all(test, feature = "postgres", feature = "reqwest"))]
|
||||
mod test {
|
||||
use sqlx::PgPool;
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
postgres::test::{setup, Domain},
|
||||
KeySelector,
|
||||
};
|
||||
|
||||
#[sqlx::test]
|
||||
async fn test_pool_request(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
|
||||
.comment("api.rs")
|
||||
.build();
|
||||
|
||||
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
|
||||
_ = response.profile().unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn test_with_storage_request(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let response = reqwest::Client::new()
|
||||
.with_storage(&storage, Domain::All)
|
||||
.user(|b| b)
|
||||
.await
|
||||
.unwrap();
|
||||
_ = response.profile().unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn before_hook(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
|
||||
.hook_before::<torn_api::user::UserSelection>(|req, _s| {
|
||||
req.selections.push("crimes");
|
||||
})
|
||||
.build();
|
||||
|
||||
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
|
||||
_ = response.crimes().unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn after_hook(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
|
||||
.hook_after::<torn_api::user::UserSelection>(|_res, _s| Err(KeyAction::Delete))
|
||||
.build();
|
||||
|
||||
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
|
||||
assert!(key.is_some());
|
||||
|
||||
let response = pool.torn_api(Domain::All).user(|b| b).await;
|
||||
assert!(matches!(response, Err(KeyPoolError::Storage(_))));
|
||||
|
||||
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
|
||||
assert!(key.is_none());
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue