simplified traits

This commit is contained in:
TotallyNot 2022-09-05 00:08:06 +02:00
parent 54345fef19
commit 27d7b4e9d9
9 changed files with 357 additions and 210 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-api" name = "torn-api"
version = "0.2.1" version = "0.3.0"
edition = "2021" edition = "2021"
[features] [features]

22
torn-api/src/awc.rs Normal file
View file

@ -0,0 +1,22 @@
use async_trait::async_trait;
use thiserror::Error;
use crate::ApiClient;
#[derive(Error, Debug)]
pub enum AwcApiClientError {
#[error(transparent)]
Client(#[from] awc::error::SendRequestError),
#[error(transparent)]
Payload(#[from] awc::error::JsonPayloadError),
}
#[async_trait(?Send)]
impl ApiClient for awc::Client {
type Error = AwcApiClientError;
async fn request(&self, url: String) -> Result<serde_json::Value, Self::Error> {
self.get(url).send().await?.json().await.map_err(Into::into)
}
}

View file

@ -39,10 +39,7 @@ pub struct Basic {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::{ use crate::tests::{async_test, setup, Client, ClientTrait};
prelude::*,
tests::{async_test, setup, Client},
};
#[async_test] #[async_test]
async fn faction() { async fn faction() {
@ -50,9 +47,7 @@ mod tests {
let response = Client::default() let response = Client::default()
.torn_api(key) .torn_api(key)
.faction() .faction(|b| b.selections(&[Selection::Basic]))
.selections(&[Selection::Basic])
.send()
.await .await
.unwrap(); .unwrap();

View file

@ -3,6 +3,12 @@
pub mod faction; pub mod faction;
pub mod user; pub mod user;
#[cfg(feature = "awc")]
pub mod awc;
#[cfg(feature = "reqwest")]
pub mod reqwest;
mod de_util; mod de_util;
use async_trait::async_trait; use async_trait::async_trait;
@ -10,33 +16,21 @@ use chrono::{DateTime, Utc};
use serde::de::{DeserializeOwned, Error as DeError}; use serde::de::{DeserializeOwned, Error as DeError};
use thiserror::Error; use thiserror::Error;
#[derive(Error, Debug)]
pub enum ClientError {
#[error("api returned error '{reason}', code = '{code}'")]
Api { code: u8, reason: String },
#[cfg(feature = "reqwest")]
#[error("api request failed with network error")]
Reqwest(#[from] reqwest::Error),
#[cfg(feature = "awc")]
#[error("api request failed with network error")]
AwcSend(#[from] awc::error::SendRequestError),
#[cfg(feature = "awc")]
#[error("api request failed to read payload")]
AwcPayload(#[from] awc::error::JsonPayloadError),
#[error("api response couldn't be deserialized")]
Deserialize(#[from] serde_json::Error),
}
pub struct ApiResponse { pub struct ApiResponse {
value: serde_json::Value, value: serde_json::Value,
} }
#[derive(Error, Debug)]
pub enum ResponseError {
#[error("API: {reason}")]
Api { code: u8, reason: String },
#[error(transparent)]
Parsing(#[from] serde_json::Error),
}
impl ApiResponse { impl ApiResponse {
fn from_value(mut value: serde_json::Value) -> Result<Self, ClientError> { pub fn from_value(mut value: serde_json::Value) -> Result<Self, ResponseError> {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct ApiErrorDto { struct ApiErrorDto {
code: u8, code: u8,
@ -46,7 +40,7 @@ impl ApiResponse {
match value.get_mut("error") { match value.get_mut("error") {
Some(error) => { Some(error) => {
let dto: ApiErrorDto = serde_json::from_value(error.take())?; let dto: ApiErrorDto = serde_json::from_value(error.take())?;
Err(ClientError::Api { Err(ResponseError::Api {
code: dto.code, code: dto.code,
reason: dto.reason, reason: dto.reason,
}) })
@ -88,111 +82,199 @@ pub trait ApiCategoryResponse: Send + Sync {
fn from_response(response: ApiResponse) -> Self; fn from_response(response: ApiResponse) -> Self;
} }
#[cfg(feature = "awc")]
#[async_trait(?Send)]
pub trait ApiClient {
async fn request(&self, url: String) -> Result<ApiResponse, ClientError>;
}
#[cfg(not(feature = "awc"))]
#[async_trait] #[async_trait]
pub trait ApiClient: Send + Sync { pub trait ThreadSafeApiClient: Send + Sync {
async fn request(&self, url: String) -> Result<ApiResponse, ClientError>; type Error: std::error::Error + Sync + Send;
}
pub trait DirectApiClient: ApiClient { async fn request(&self, url: String) -> Result<serde_json::Value, Self::Error>;
fn torn_api(&self, key: String) -> DirectExecutor<Self>
fn torn_api<S>(&self, key: S) -> ThreadSafeApiProvider<Self, DirectExecutor<Self>>
where where
Self: Sized, Self: Sized,
S: ToString,
{ {
DirectExecutor::from_client(self, key) ThreadSafeApiProvider::new(self, DirectExecutor::new(key.to_string()))
} }
} }
pub trait BackedApiClient: ApiClient {}
#[cfg(feature = "reqwest")]
#[cfg_attr(feature = "awc", async_trait(?Send))]
#[cfg_attr(not(feature = "awc"), async_trait)]
impl crate::ApiClient for reqwest::Client {
async fn request(&self, url: String) -> Result<ApiResponse, crate::ClientError> {
let value: serde_json::Value = self.get(url).send().await?.json().await?;
Ok(ApiResponse::from_value(value)?)
}
}
#[cfg(feature = "reqwest")]
impl crate::DirectApiClient for reqwest::Client {}
#[cfg(feature = "awc")]
#[async_trait(?Send)] #[async_trait(?Send)]
impl crate::ApiClient for awc::Client { pub trait ApiClient {
async fn request(&self, url: String) -> Result<ApiResponse, crate::ClientError> { type Error: std::error::Error;
let value: serde_json::Value = self.get(url).send().await?.json().await?;
Ok(ApiResponse::from_value(value)?) async fn request(&self, url: String) -> Result<serde_json::Value, Self::Error>;
fn torn_api<S>(&self, key: S) -> ApiProvider<Self, DirectExecutor<Self>>
where
Self: Sized,
S: ToString,
{
ApiProvider::new(self, DirectExecutor::new(key.to_string()))
} }
} }
#[cfg(feature = "awc")] #[async_trait(?Send)]
impl crate::DirectApiClient for awc::Client {} pub trait RequestExecutor<C>
where
C: ApiClient,
{
type Error: std::error::Error;
#[cfg_attr(feature = "awc", async_trait(?Send))] async fn execute<A>(&self, client: &C, request: ApiRequest<A>) -> Result<A, Self::Error>
#[cfg_attr(not(feature = "awc"), async_trait)]
pub trait ApiRequestExecutor<'client> {
type Err: std::error::Error;
async fn excute<A>(&self, request: ApiRequest<A>) -> Result<A, Self::Err>
where where
A: ApiCategoryResponse; A: ApiCategoryResponse;
}
#[must_use] #[async_trait]
fn user<'executor>( pub trait ThreadSafeRequestExecutor<C>
&'executor self, where
) -> ApiRequestBuilder<'client, 'executor, Self, user::Response> { C: ThreadSafeApiClient,
ApiRequestBuilder::new(self) {
type Error: std::error::Error + Send + Sync;
async fn execute<A>(&self, client: &C, request: ApiRequest<A>) -> Result<A, Self::Error>
where
A: ApiCategoryResponse;
}
pub struct ApiProvider<'a, C, E>
where
C: ApiClient,
E: RequestExecutor<C>,
{
client: &'a C,
executor: E,
}
impl<'a, C, E> ApiProvider<'a, C, E>
where
C: ApiClient,
E: RequestExecutor<C>,
{
pub fn new(client: &'a C, executor: E) -> ApiProvider<'a, C, E> {
Self { client, executor }
} }
#[must_use] pub async fn user<F>(&self, build: F) -> Result<user::Response, E::Error>
fn faction<'executor>( where
&'executor self, F: FnOnce(ApiRequestBuilder<user::Response>) -> ApiRequestBuilder<user::Response>,
) -> ApiRequestBuilder<'client, 'executor, Self, faction::Response> { {
ApiRequestBuilder::new(self) let mut builder = ApiRequestBuilder::<user::Response>::new();
builder = build(builder);
self.executor.execute(self.client, builder.request).await
}
pub async fn faction<F>(&self, build: F) -> Result<faction::Response, E::Error>
where
F: FnOnce(ApiRequestBuilder<faction::Response>) -> ApiRequestBuilder<faction::Response>,
{
let mut builder = ApiRequestBuilder::<faction::Response>::new();
builder = build(builder);
self.executor.execute(self.client, builder.request).await
} }
} }
pub struct DirectExecutor<'client, C> pub struct ThreadSafeApiProvider<'a, C, E>
where where
C: ApiClient, C: ThreadSafeApiClient,
E: ThreadSafeRequestExecutor<C>,
{ {
client: &'client C, client: &'a C,
executor: E,
}
impl<'a, C, E> ThreadSafeApiProvider<'a, C, E>
where
C: ThreadSafeApiClient,
E: ThreadSafeRequestExecutor<C>,
{
pub fn new(client: &'a C, executor: E) -> ThreadSafeApiProvider<'a, C, E> {
Self { client, executor }
}
pub async fn user<F>(&self, build: F) -> Result<user::Response, E::Error>
where
F: FnOnce(ApiRequestBuilder<user::Response>) -> ApiRequestBuilder<user::Response>,
{
let mut builder = ApiRequestBuilder::<user::Response>::new();
builder = build(builder);
self.executor.execute(self.client, builder.request).await
}
pub async fn faction<F>(&self, build: F) -> Result<faction::Response, E::Error>
where
F: FnOnce(ApiRequestBuilder<faction::Response>) -> ApiRequestBuilder<faction::Response>,
{
let mut builder = ApiRequestBuilder::<faction::Response>::new();
builder = build(builder);
self.executor.execute(self.client, builder.request).await
}
}
pub struct DirectExecutor<C> {
key: String, key: String,
_marker: std::marker::PhantomData<C>,
} }
impl<'client, C> DirectExecutor<'client, C> impl<C> DirectExecutor<C> {
where fn new(key: String) -> Self {
C: ApiClient, Self {
{ key,
#[allow(dead_code)] _marker: std::marker::PhantomData,
pub(crate) fn from_client(client: &'client C, key: String) -> Self { }
Self { client, key }
} }
} }
#[cfg_attr(feature = "awc", async_trait(?Send))] #[derive(Error, Debug)]
#[cfg_attr(not(feature = "awc"), async_trait)] pub enum ApiClientError<C>
impl<'client, C> ApiRequestExecutor<'client> for DirectExecutor<'client, C> where
C: std::error::Error,
{
#[error(transparent)]
Client(C),
#[error(transparent)]
Response(#[from] ResponseError),
}
#[async_trait(?Send)]
impl<C> RequestExecutor<C> for DirectExecutor<C>
where where
C: ApiClient, C: ApiClient,
{ {
type Err = ClientError; type Error = ApiClientError<C::Error>;
async fn excute<A>(&self, request: ApiRequest<A>) -> Result<A, Self::Err> async fn execute<A>(&self, client: &C, request: ApiRequest<A>) -> Result<A, Self::Error>
where where
A: ApiCategoryResponse, A: ApiCategoryResponse,
{ {
let url = request.url(&self.key); let url = request.url(&self.key);
self.client.request(url).await.map(A::from_response) let value = client.request(url).await.map_err(ApiClientError::Client)?;
Ok(A::from_response(ApiResponse::from_value(value)?))
}
}
#[async_trait]
impl<C> ThreadSafeRequestExecutor<C> for DirectExecutor<C>
where
C: ThreadSafeApiClient,
{
type Error = ApiClientError<C::Error>;
async fn execute<A>(&self, client: &C, request: ApiRequest<A>) -> Result<A, Self::Error>
where
A: ApiCategoryResponse,
{
let url = request.url(&self.key);
let value = client.request(url).await.map_err(ApiClientError::Client)?;
Ok(A::from_response(ApiResponse::from_value(value)?))
} }
} }
@ -263,26 +345,20 @@ where
} }
} }
pub struct ApiRequestBuilder<'client, 'executor, E, A> pub struct ApiRequestBuilder<A>
where where
E: ApiRequestExecutor<'client> + ?Sized,
A: ApiCategoryResponse, A: ApiCategoryResponse,
{ {
executor: &'executor E,
request: ApiRequest<A>, request: ApiRequest<A>,
_phantom: std::marker::PhantomData<&'client E>,
} }
impl<'client, 'executor, E, A> ApiRequestBuilder<'client, 'executor, E, A> impl<A> ApiRequestBuilder<A>
where where
E: ApiRequestExecutor<'client> + ?Sized,
A: ApiCategoryResponse, A: ApiCategoryResponse,
{ {
pub(crate) fn new(executor: &'executor E) -> Self { pub(crate) fn new() -> Self {
Self { Self {
executor,
request: ApiRequest::default(), request: ApiRequest::default(),
_phantom: std::marker::PhantomData::default(),
} }
} }
@ -317,49 +393,23 @@ where
self.request.comment = Some(comment); self.request.comment = Some(comment);
self self
} }
/// Executes the api request.
///
/// # Examples
///
/// ```no_run
/// use torn_api::{prelude::*, ClientError};
/// use reqwest::Client;
/// # async {
///
/// let key = "XXXXXXXXX".to_owned();
/// let response = Client::new()
/// .torn_api(key)
/// .user()
/// .send()
/// .await;
///
/// // invalid key
/// assert!(matches!(response, Err(ClientError::Api { code: 2, .. })));
/// # };
/// ```
///
/// # Errors
///
/// Will return an `Err` if the API returns an API error, the request fails due to a network
/// error, or if the response body doesn't contain valid json.
pub async fn send(self) -> Result<A, <E as ApiRequestExecutor<'client>>::Err> {
self.executor.excute(self.request).await
}
} }
pub mod prelude { pub mod prelude {}
pub use super::{ApiClient, ApiRequestExecutor, DirectApiClient};
}
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use std::sync::Once; use std::sync::Once;
#[cfg(all(not(feature = "reqwest"), feature = "awc"))] #[cfg(all(not(feature = "reqwest"), feature = "awc"))]
pub use awc::Client; pub use ::awc::Client;
#[cfg(feature = "reqwest")] #[cfg(feature = "reqwest")]
pub use reqwest::Client; pub use ::reqwest::Client;
#[cfg(all(not(feature = "reqwest"), feature = "awc"))]
pub use crate::ApiClient as ClientTrait;
#[cfg(feature = "reqwest")]
pub use crate::ThreadSafeApiClient as ClientTrait;
#[cfg(all(not(feature = "reqwest"), feature = "awc"))] #[cfg(all(not(feature = "reqwest"), feature = "awc"))]
pub use actix_rt::test as async_test; pub use actix_rt::test as async_test;
@ -387,12 +437,7 @@ pub(crate) mod tests {
async fn reqwest() { async fn reqwest() {
let key = setup(); let key = setup();
reqwest::Client::default() Client::default().torn_api(key).user(|b| b).await.unwrap();
.torn_api(key)
.user()
.send()
.await
.unwrap();
} }
#[cfg(feature = "awc")] #[cfg(feature = "awc")]
@ -400,11 +445,6 @@ pub(crate) mod tests {
async fn awc() { async fn awc() {
let key = setup(); let key = setup();
awc::Client::default() Client::default().torn_api(key).user(|b| b).await.unwrap();
.torn_api(key)
.user()
.send()
.await
.unwrap();
} }
} }

12
torn-api/src/reqwest.rs Normal file
View file

@ -0,0 +1,12 @@
use async_trait::async_trait;
use crate::ThreadSafeApiClient;
#[async_trait]
impl ThreadSafeApiClient for reqwest::Client {
type Error = reqwest::Error;
async fn request(&self, url: String) -> Result<serde_json::Value, Self::Error> {
self.get(url).send().await?.json().await
}
}

View file

@ -150,10 +150,7 @@ pub struct PersonalStats {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::{ use crate::tests::{async_test, setup, Client, ClientTrait};
prelude::*,
tests::{async_test, setup, Client},
};
#[async_test] #[async_test]
async fn user() { async fn user() {
@ -161,14 +158,14 @@ mod tests {
let response = Client::default() let response = Client::default()
.torn_api(key) .torn_api(key)
.user() .user(|b| {
.selections(&[ b.selections(&[
Selection::Basic, Selection::Basic,
Selection::Discord, Selection::Discord,
Selection::Profile, Selection::Profile,
Selection::PersonalStats, Selection::PersonalStats,
]) ])
.send() })
.await .await
.unwrap(); .unwrap();
@ -184,10 +181,7 @@ mod tests {
let response = Client::default() let response = Client::default()
.torn_api(key) .torn_api(key)
.user() .user(|b| b.id(28).selections(&[Selection::Profile]))
.id(28)
.selections(&[Selection::Profile])
.send()
.await .await
.unwrap(); .unwrap();

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.1.3" version = "0.2.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View file

@ -6,18 +6,25 @@ pub mod postgres;
use async_trait::async_trait; use async_trait::async_trait;
use thiserror::Error; use thiserror::Error;
use torn_api::prelude::*; use torn_api::{
ApiCategoryResponse, ApiClient, ApiProvider, ApiRequest, ApiResponse, RequestExecutor,
ResponseError, ThreadSafeApiClient, ThreadSafeApiProvider, ThreadSafeRequestExecutor,
};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum KeyPoolError<S> pub enum KeyPoolError<S, C>
where where
S: Sync + Send + std::error::Error, S: std::error::Error,
C: std::error::Error,
{ {
#[error("Key pool storage driver error: {0:?}")] #[error("Key pool storage driver error: {0:?}")]
Storage(#[source] S), Storage(#[source] S),
#[error(transparent)] #[error(transparent)]
Client(#[from] torn_api::ClientError), Client(#[from] C),
#[error(transparent)]
Response(ResponseError),
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@ -34,50 +41,47 @@ pub trait ApiKey: Sync + Send {
#[async_trait] #[async_trait]
pub trait KeyPoolStorage { pub trait KeyPoolStorage {
type Key: ApiKey; type Key: ApiKey;
type Err: Sync + Send + std::error::Error; type Error: std::error::Error + Sync + Send;
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Err>; async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error>;
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Err>; async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct KeyPoolExecutor<'client, C, S> pub struct KeyPoolExecutor<'a, C, S>
where where
C: ApiClient,
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
client: &'client C, storage: &'a S,
storage: &'client S,
domain: KeyDomain, domain: KeyDomain,
_marker: std::marker::PhantomData<C>,
} }
impl<'client, C, S> KeyPoolExecutor<'client, C, S> impl<'a, C, S> KeyPoolExecutor<'a, C, S>
where where
C: ApiClient,
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
pub fn new(client: &'client C, storage: &'client S, domain: KeyDomain) -> Self { pub fn new(storage: &'a S, domain: KeyDomain) -> Self {
Self { Self {
client,
storage, storage,
domain, domain,
_marker: std::marker::PhantomData,
} }
} }
} }
#[cfg_attr(feature = "awc", async_trait(?Send))] #[async_trait(?Send)]
#[cfg_attr(not(feature = "awc"), async_trait)] impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
impl<'client, C, S> ApiRequestExecutor<'client> for KeyPoolExecutor<'client, C, S>
where where
C: ApiClient, C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static, S: KeyPoolStorage + 'static,
{ {
type Err = KeyPoolError<S::Err>; type Error = KeyPoolError<S::Error, C::Error>;
async fn excute<A>(&self, request: torn_api::ApiRequest<A>) -> Result<A, Self::Err> async fn execute<A>(&self, client: &C, request: ApiRequest<A>) -> Result<A, Self::Error>
where where
A: torn_api::ApiCategoryResponse, A: ApiCategoryResponse,
{ {
loop { loop {
let key = self let key = self
@ -86,20 +90,60 @@ where
.await .await
.map_err(KeyPoolError::Storage)?; .map_err(KeyPoolError::Storage)?;
let url = request.url(key.value()); let url = request.url(key.value());
let res = self.client.request(url).await; let value = client.request(url).await?;
match res { match ApiResponse::from_value(value) {
Err(torn_api::ClientError::Api { code, .. }) => { Err(ResponseError::Api { code, reason }) => {
if !self if !self
.storage .storage
.flag_key(key, code) .flag_key(key, code)
.await .await
.map_err(KeyPoolError::Storage)? .map_err(KeyPoolError::Storage)?
{ {
panic!(); return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
} }
} }
_ => return res.map(A::from_response).map_err(KeyPoolError::Client), Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(A::from_response(res)),
};
}
}
}
#[async_trait]
impl<'client, C, S> ThreadSafeRequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ThreadSafeApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(&self, client: &C, request: ApiRequest<A>) -> Result<A, Self::Error>
where
A: ApiCategoryResponse,
{
loop {
let key = self
.storage
.acquire_key(self.domain)
.await
.map_err(KeyPoolError::Storage)?;
let url = request.url(key.value());
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(A::from_response(res)),
}; };
} }
} }
@ -118,29 +162,69 @@ where
impl<C, S> KeyPool<C, S> impl<C, S> KeyPool<C, S>
where where
C: ApiClient, C: ApiClient,
S: KeyPoolStorage, S: KeyPoolStorage + 'static,
{ {
pub fn new(client: C, storage: S) -> Self { pub fn new(client: C, storage: S) -> Self {
Self { client, storage } Self { client, storage }
} }
pub fn torn_api(&self, domain: KeyDomain) -> KeyPoolExecutor<C, S> { pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
KeyPoolExecutor::new(&self.client, &self.storage, domain) ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain))
} }
} }
pub trait KeyPoolClient: ApiClient { #[derive(Clone, Debug)]
fn with_pool<'a, S>(&'a self, domain: KeyDomain, storage: &'a S) -> KeyPoolExecutor<Self, S> pub struct ThreadSafeKeyPool<C, S>
where
C: ThreadSafeApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
client: C,
storage: S,
}
impl<C, S> ThreadSafeKeyPool<C, S>
where
C: ThreadSafeApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn new(client: C, storage: S) -> Self {
Self { client, storage }
}
pub fn torn_api(&self, domain: KeyDomain) -> ThreadSafeApiProvider<C, KeyPoolExecutor<C, S>> {
ThreadSafeApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain))
}
}
pub trait WithStorage {
fn with_storage<'a, S>(
&'a self,
storage: &'a S,
domain: KeyDomain,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where where
Self: Sized, Self: ApiClient + Sized,
S: KeyPoolStorage, S: KeyPoolStorage + 'static,
{ {
KeyPoolExecutor::new(self, storage, domain) ApiProvider::new(self, KeyPoolExecutor::new(storage, domain))
}
fn with_storage_sync<'a, S>(
&'a self,
storage: &'a S,
domain: KeyDomain,
) -> ThreadSafeApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ThreadSafeApiClient + Sized,
S: KeyPoolStorage + Send + Sync + 'static,
{
ThreadSafeApiProvider::new(self, KeyPoolExecutor::new(storage, domain))
} }
} }
#[cfg(feature = "reqwest")] #[cfg(feature = "reqwest")]
impl KeyPoolClient for reqwest::Client {} impl WithStorage for reqwest::Client {}
#[cfg(feature = "awc")] #[cfg(feature = "awc")]
impl KeyPoolClient for awc::Client {} impl WithStorage for awc::Client {}

View file

@ -67,9 +67,9 @@ impl PgKeyPoolStorage {
impl KeyPoolStorage for PgKeyPoolStorage { impl KeyPoolStorage for PgKeyPoolStorage {
type Key = PgKey; type Key = PgKey;
type Err = PgStorageError; type Error = PgStorageError;
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Err> { async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error> {
let predicate = match domain { let predicate = match domain {
KeyDomain::Public => "".to_owned(), KeyDomain::Public => "".to_owned(),
KeyDomain::User(id) => format!("where and user_id={} and user", id), KeyDomain::User(id) => format!("where and user_id={} and user", id),
@ -117,7 +117,7 @@ impl KeyPoolStorage for PgKeyPoolStorage {
key.ok_or(PgStorageError::Unavailable(domain)) key.ok_or(PgStorageError::Unavailable(domain))
} }
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Err> { async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
// TODO: put keys in cooldown when appropriate // TODO: put keys in cooldown when appropriate
match code { match code {
2 | 10 | 13 => { 2 | 10 | 13 => {