feat(key-pool): updated key pool to use v2 api

This commit is contained in:
TotallyNot 2025-04-25 17:21:50 +02:00
parent 5ae490c756
commit 4b52c37774
Signed by: pyrite
GPG key ID: 7F1BA9170CD35D15
8 changed files with 1728 additions and 870 deletions

View file

@ -3,48 +3,23 @@
#[cfg(feature = "postgres")]
pub mod postgres;
// pub mod local;
pub mod send;
use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
use std::sync::Arc;
use futures::{future::BoxFuture, FutureExt};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde::Deserialize;
use torn_api::{
executor::Executor,
request::{ApiRequest, ApiResponse},
ApiError,
};
use async_trait::async_trait;
use thiserror::Error;
pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
use torn_api::ResponseError;
impl<T> ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
#[derive(Debug, Error)]
pub enum KeyPoolError<S, C>
where
S: std::error::Error + Clone,
C: std::error::Error,
{
#[error("Key pool storage driver error: {0:?}")]
Storage(#[source] S),
#[error(transparent)]
Client(#[from] C),
#[error(transparent)]
Response(ResponseError),
}
impl<S, C> KeyPoolError<S, C>
where
S: std::error::Error + Clone,
C: std::error::Error,
{
#[inline(always)]
pub fn api_code(&self) -> Option<u8> {
match self {
Self::Response(why) => why.api_code(),
_ => None,
}
}
}
pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone + 'static {
type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone;
pub trait ApiKey: Send + Sync + Clone + 'static {
type IdType: ApiKeyId;
fn value(&self) -> &str;
@ -105,7 +80,7 @@ where
}
}
pub trait IntoSelector<K, D>: Send + Sync
pub trait IntoSelector<K, D>: Send
where
K: ApiKey,
D: KeyDomain,
@ -133,114 +108,347 @@ where
}
}
pub enum KeyAction<D>
where
D: KeyDomain,
{
Delete,
RemoveDomain(D),
Timeout(chrono::Duration),
}
#[async_trait]
pub trait KeyPoolStorage {
pub trait KeyPoolStorage: Send + Sync {
type Key: ApiKey;
type Domain: KeyDomain;
type Error: std::error::Error + Sync + Send + Clone;
type Error: From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + Send;
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
fn acquire_key<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn acquire_many_keys<S>(
fn acquire_many_keys<S>(
&self,
selector: S,
number: i64,
) -> Result<Vec<Self::Key>, Self::Error>
) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
async fn store_key(
fn store_key(
&self,
user_id: i32,
key: String,
domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>;
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send;
async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
fn read_key<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Option<Self::Key>, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
fn read_keys<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
fn remove_key<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn add_domain_to_key<S>(
fn add_domain_to_key<S>(
&self,
selector: S,
domain: Self::Domain,
) -> Result<Self::Key, Self::Error>
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn remove_domain_from_key<S>(
fn remove_domain_from_key<S>(
&self,
selector: S,
domain: Self::Domain,
) -> Result<Self::Key, Self::Error>
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn set_domains_for_key<S>(
fn set_domains_for_key<S>(
&self,
selector: S,
domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
fn timeout_key<S>(
&self,
selector: S,
duration: Duration,
) -> impl Future<Output = Result<(), Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
}
#[derive(Debug, Default)]
pub struct PoolOptions {
#[derive(Default)]
pub struct PoolOptions<S>
where
S: KeyPoolStorage,
{
comment: Option<String>,
hooks_before: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
hooks_after: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
#[allow(clippy::type_complexity)]
error_hooks: HashMap<
u16,
Box<
dyn for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
+ Send
+ Sync,
>,
>,
}
#[derive(Debug, Clone)]
pub struct KeyPoolExecutor<'a, C, S>
pub struct KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
storage: &'a S,
options: Arc<PoolOptions>,
pool: &'p KeyPool<S>,
selector: KeySelector<S::Key, S::Domain>,
_marker: std::marker::PhantomData<C>,
}
impl<'a, C, S> KeyPoolExecutor<'a, C, S>
impl<'p, S> KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pub fn new(
storage: &'a S,
selector: KeySelector<S::Key, S::Domain>,
options: Arc<PoolOptions>,
) -> Self {
Self {
storage,
selector,
options,
_marker: std::marker::PhantomData,
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
Self { pool, selector }
}
async fn execute_request<D>(&self, request: ApiRequest<D>) -> Result<ApiResponse<D>, S::Error>
where
D: Send,
{
let key = self.pool.storage.acquire_key(self.selector.clone()).await?;
let mut headers = HeaderMap::with_capacity(1);
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
);
let resp = self
.pool
.client
.get(request.url())
.headers(headers)
.send()
.await?;
let status = resp.status();
let bytes = resp.bytes().await?;
if let Some(err) = decode_error(&bytes)? {
if let Some(handler) = self.pool.options.error_hooks.get(&err.code()) {
let retry = (*handler)(&self.pool.storage, &key).await?;
if retry {
return Box::pin(self.execute_request(request)).await;
}
}
Err(err.into())
} else {
Ok(ApiResponse {
discriminant: request.disriminant,
body: Some(bytes),
status,
})
}
}
}
#[cfg(all(test, feature = "postgres"))]
mod test {}
pub struct PoolBuilder<S>
where
S: KeyPoolStorage,
{
client: reqwest::Client,
storage: S,
options: crate::PoolOptions<S>,
}
impl<S> PoolBuilder<S>
where
S: KeyPoolStorage,
{
pub fn new(storage: S) -> Self {
Self {
client: reqwest::Client::builder()
.brotli(true)
.http2_keep_alive_timeout(Duration::from_secs(60))
.http2_keep_alive_interval(Duration::from_secs(5))
.https_only(true)
.build()
.unwrap(),
storage,
options: PoolOptions {
comment: None,
error_hooks: Default::default(),
},
}
}
pub fn comment(mut self, c: impl ToString) -> Self {
self.options.comment = Some(c.to_string());
self
}
pub fn error_hook<F>(mut self, code: u16, handler: F) -> Self
where
F: for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
+ Send
+ Sync
+ 'static,
{
self.options.error_hooks.insert(code, Box::new(handler));
self
}
pub fn use_default_hooks(self) -> Self {
self.error_hook(2, |storage, key| {
async move {
storage.remove_key(KeySelector::Id(key.id())).await?;
Ok(true)
}
.boxed()
})
.error_hook(5, |storage, key| {
async move {
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
.await?;
Ok(true)
}
.boxed()
})
.error_hook(10, |storage, key| {
async move {
storage.remove_key(KeySelector::Id(key.id())).await?;
Ok(true)
}
.boxed()
})
.error_hook(13, |storage, key| {
async move {
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
.await?;
Ok(true)
}
.boxed()
})
.error_hook(18, |storage, key| {
async move {
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
.await?;
Ok(true)
}
.boxed()
})
}
pub fn build(self) -> KeyPool<S> {
KeyPool {
client: self.client,
storage: self.storage,
options: Arc::new(self.options),
}
}
}
pub struct KeyPool<S>
where
S: KeyPoolStorage,
{
pub client: reqwest::Client,
pub storage: S,
pub options: Arc<PoolOptions<S>>,
}
impl<S> KeyPool<S>
where
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<S>
where
I: IntoSelector<S::Key, S::Domain>,
{
KeyPoolExecutor::new(self, selector.into_selector())
}
}
fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
if buf.starts_with(br#"{"error":{"#) {
#[derive(Deserialize)]
struct ErrorBody<'a> {
code: u16,
error: &'a str,
}
#[derive(Deserialize)]
struct ErrorContainer<'a> {
#[serde(borrow)]
error: ErrorBody<'a>,
}
let error: ErrorContainer = serde_json::from_slice(buf)?;
Ok(Some(crate::ApiError::new(
error.error.code,
error.error.error,
)))
} else {
Ok(None)
}
}
impl<S> Executor for KeyPoolExecutor<'_, S>
where
S: KeyPoolStorage,
{
type Error = S::Error;
async fn execute<R>(
&self,
request: R,
) -> Result<torn_api::request::ApiResponse<R::Discriminant>, Self::Error>
where
R: torn_api::request::IntoRequest,
{
let request = request.into_request();
self.execute_request(request).await
}
}
#[cfg(test)]
mod test {
use torn_api::executor::ExecutorExt;
use crate::postgres;
use super::*;
#[sqlx::test]
fn name(pool: sqlx::PgPool) {
let (storage, _) = postgres::test::setup(pool).await;
let pool = PoolBuilder::new(storage)
.use_default_hooks()
.comment("test_runner")
.build();
pool.torn_api(postgres::test::Domain::All)
.faction()
.basic(|b| b)
.await
.unwrap();
}
}