simplified traits

This commit is contained in:
TotallyNot 2022-09-05 00:08:06 +02:00
parent df40047ec3
commit 758ab39a1d
9 changed files with 357 additions and 210 deletions

View file

@ -6,18 +6,25 @@ pub mod postgres;
use async_trait::async_trait;
use thiserror::Error;
use torn_api::prelude::*;
use torn_api::{
ApiCategoryResponse, ApiClient, ApiProvider, ApiRequest, ApiResponse, RequestExecutor,
ResponseError, ThreadSafeApiClient, ThreadSafeApiProvider, ThreadSafeRequestExecutor,
};
#[derive(Debug, Error)]
pub enum KeyPoolError<S>
pub enum KeyPoolError<S, C>
where
S: Sync + Send + std::error::Error,
S: std::error::Error,
C: std::error::Error,
{
#[error("Key pool storage driver error: {0:?}")]
Storage(#[source] S),
#[error(transparent)]
Client(#[from] torn_api::ClientError),
Client(#[from] C),
#[error(transparent)]
Response(ResponseError),
}
#[derive(Debug, Clone, Copy)]
@ -34,50 +41,47 @@ pub trait ApiKey: Sync + Send {
#[async_trait]
pub trait KeyPoolStorage {
type Key: ApiKey;
type Err: Sync + Send + std::error::Error;
type Error: std::error::Error + Sync + Send;
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Err>;
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error>;
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Err>;
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
}
#[derive(Debug, Clone)]
pub struct KeyPoolExecutor<'client, C, S>
pub struct KeyPoolExecutor<'a, C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: &'client C,
storage: &'client S,
storage: &'a S,
domain: KeyDomain,
_marker: std::marker::PhantomData<C>,
}
impl<'client, C, S> KeyPoolExecutor<'client, C, S>
impl<'a, C, S> KeyPoolExecutor<'a, C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
pub fn new(client: &'client C, storage: &'client S, domain: KeyDomain) -> Self {
pub fn new(storage: &'a S, domain: KeyDomain) -> Self {
Self {
client,
storage,
domain,
_marker: std::marker::PhantomData,
}
}
}
#[cfg_attr(feature = "awc", async_trait(?Send))]
#[cfg_attr(not(feature = "awc"), async_trait)]
impl<'client, C, S> ApiRequestExecutor<'client> for KeyPoolExecutor<'client, C, S>
#[async_trait(?Send)]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
S: KeyPoolStorage + 'static,
{
type Err = KeyPoolError<S::Err>;
type Error = KeyPoolError<S::Error, C::Error>;
async fn excute<A>(&self, request: torn_api::ApiRequest<A>) -> Result<A, Self::Err>
async fn execute<A>(&self, client: &C, request: ApiRequest<A>) -> Result<A, Self::Error>
where
A: torn_api::ApiCategoryResponse,
A: ApiCategoryResponse,
{
loop {
let key = self
@ -86,20 +90,60 @@ where
.await
.map_err(KeyPoolError::Storage)?;
let url = request.url(key.value());
let res = self.client.request(url).await;
let value = client.request(url).await?;
match res {
Err(torn_api::ClientError::Api { code, .. }) => {
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(KeyPoolError::Storage)?
{
panic!();
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
_ => return res.map(A::from_response).map_err(KeyPoolError::Client),
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(A::from_response(res)),
};
}
}
}
#[async_trait]
impl<'client, C, S> ThreadSafeRequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ThreadSafeApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(&self, client: &C, request: ApiRequest<A>) -> Result<A, Self::Error>
where
A: ApiCategoryResponse,
{
loop {
let key = self
.storage
.acquire_key(self.domain)
.await
.map_err(KeyPoolError::Storage)?;
let url = request.url(key.value());
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(A::from_response(res)),
};
}
}
@ -118,29 +162,69 @@ where
impl<C, S> KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
S: KeyPoolStorage + 'static,
{
pub fn new(client: C, storage: S) -> Self {
Self { client, storage }
}
pub fn torn_api(&self, domain: KeyDomain) -> KeyPoolExecutor<C, S> {
KeyPoolExecutor::new(&self.client, &self.storage, domain)
pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain))
}
}
pub trait KeyPoolClient: ApiClient {
fn with_pool<'a, S>(&'a self, domain: KeyDomain, storage: &'a S) -> KeyPoolExecutor<Self, S>
#[derive(Clone, Debug)]
pub struct ThreadSafeKeyPool<C, S>
where
C: ThreadSafeApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
client: C,
storage: S,
}
impl<C, S> ThreadSafeKeyPool<C, S>
where
C: ThreadSafeApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn new(client: C, storage: S) -> Self {
Self { client, storage }
}
pub fn torn_api(&self, domain: KeyDomain) -> ThreadSafeApiProvider<C, KeyPoolExecutor<C, S>> {
ThreadSafeApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain))
}
}
pub trait WithStorage {
fn with_storage<'a, S>(
&'a self,
storage: &'a S,
domain: KeyDomain,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: Sized,
S: KeyPoolStorage,
Self: ApiClient + Sized,
S: KeyPoolStorage + 'static,
{
KeyPoolExecutor::new(self, storage, domain)
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain))
}
fn with_storage_sync<'a, S>(
&'a self,
storage: &'a S,
domain: KeyDomain,
) -> ThreadSafeApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ThreadSafeApiClient + Sized,
S: KeyPoolStorage + Send + Sync + 'static,
{
ThreadSafeApiProvider::new(self, KeyPoolExecutor::new(storage, domain))
}
}
#[cfg(feature = "reqwest")]
impl KeyPoolClient for reqwest::Client {}
impl WithStorage for reqwest::Client {}
#[cfg(feature = "awc")]
impl KeyPoolClient for awc::Client {}
impl WithStorage for awc::Client {}

View file

@ -67,9 +67,9 @@ impl PgKeyPoolStorage {
impl KeyPoolStorage for PgKeyPoolStorage {
type Key = PgKey;
type Err = PgStorageError;
type Error = PgStorageError;
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Err> {
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error> {
let predicate = match domain {
KeyDomain::Public => "".to_owned(),
KeyDomain::User(id) => format!("where and user_id={} and user", id),
@ -117,7 +117,7 @@ impl KeyPoolStorage for PgKeyPoolStorage {
key.ok_or(PgStorageError::Unavailable(domain))
}
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Err> {
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
// TODO: put keys in cooldown when appropriate
match code {
2 | 10 | 13 => {