major refactoring
This commit is contained in:
parent
01bbe37876
commit
75fc19d0f7
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "torn-api-macros"
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
authors = ["Pyrit [2111649]"]
|
||||
license = "MIT"
|
||||
|
|
|
@ -147,15 +147,15 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream {
|
|||
#(#accessors)*
|
||||
}
|
||||
|
||||
impl crate::ApiCategoryResponse for Response {
|
||||
type Selection = #name;
|
||||
|
||||
fn from_response(response: crate::ApiResponse) -> Self {
|
||||
Self(response)
|
||||
impl From<crate::ApiResponse> for Response {
|
||||
fn from(value: crate::ApiResponse) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::ApiSelection for #name {
|
||||
type Response = Response;
|
||||
|
||||
fn raw_value(self) -> &'static str {
|
||||
match self {
|
||||
#(#raw_values,)*
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "torn-api"
|
||||
version = "0.6.7"
|
||||
version = "0.7.0"
|
||||
edition = "2021"
|
||||
rust-version = "1.75.0"
|
||||
authors = ["Pyrit [2111649]"]
|
||||
|
@ -39,7 +39,7 @@ reqwest = { version = "0.11", default-features = false, features = [ "json" ], o
|
|||
awc = { version = "3", default-features = false, optional = true }
|
||||
rust_decimal = { version = "1", default-features = false, optional = true, features = [ "serde" ] }
|
||||
|
||||
torn-api-macros = { path = "../torn-api-macros", version = "0.2" }
|
||||
torn-api-macros = { path = "../torn-api-macros", version = "0.3" }
|
||||
|
||||
[dev-dependencies]
|
||||
actix-rt = { version = "2.7.0" }
|
||||
|
|
|
@ -111,18 +111,14 @@ impl ApiResponse {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait ApiSelection: Send + Sync {
|
||||
pub trait ApiSelection: Send + Sync + 'static {
|
||||
type Response: From<ApiResponse> + Send + Sync;
|
||||
|
||||
fn raw_value(self) -> &'static str;
|
||||
|
||||
fn category() -> &'static str;
|
||||
}
|
||||
|
||||
pub trait ApiCategoryResponse: Send + Sync {
|
||||
type Selection: ApiSelection;
|
||||
|
||||
fn from_response(response: ApiResponse) -> Self;
|
||||
}
|
||||
|
||||
pub struct DirectExecutor<C> {
|
||||
key: String,
|
||||
_marker: std::marker::PhantomData<C>,
|
||||
|
|
|
@ -2,9 +2,7 @@ use std::collections::HashMap;
|
|||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{
|
||||
ApiCategoryResponse, ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor,
|
||||
};
|
||||
use crate::{ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor};
|
||||
|
||||
pub struct ApiProvider<'a, C, E>
|
||||
where
|
||||
|
@ -39,7 +37,6 @@ where
|
|||
self.executor
|
||||
.execute(self.client, builder.request, builder.id)
|
||||
.await
|
||||
.map(crate::user::Response::from_response)
|
||||
}
|
||||
|
||||
#[cfg(feature = "user")]
|
||||
|
@ -61,9 +58,6 @@ where
|
|||
self.executor
|
||||
.execute_many(self.client, builder.request, Vec::from_iter(ids))
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(crate::user::Response::from_response)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "faction")]
|
||||
|
@ -79,7 +73,6 @@ where
|
|||
self.executor
|
||||
.execute(self.client, builder.request, builder.id)
|
||||
.await
|
||||
.map(crate::faction::Response::from_response)
|
||||
}
|
||||
|
||||
#[cfg(feature = "faction")]
|
||||
|
@ -101,9 +94,6 @@ where
|
|||
self.executor
|
||||
.execute_many(self.client, builder.request, Vec::from_iter(ids))
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(crate::faction::Response::from_response)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "market")]
|
||||
|
@ -119,7 +109,6 @@ where
|
|||
self.executor
|
||||
.execute(self.client, builder.request, builder.id)
|
||||
.await
|
||||
.map(crate::market::Response::from_response)
|
||||
}
|
||||
|
||||
#[cfg(feature = "market")]
|
||||
|
@ -141,9 +130,6 @@ where
|
|||
self.executor
|
||||
.execute_many(self.client, builder.request, Vec::from_iter(ids))
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(crate::market::Response::from_response)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "torn")]
|
||||
|
@ -159,7 +145,6 @@ where
|
|||
self.executor
|
||||
.execute(self.client, builder.request, builder.id)
|
||||
.await
|
||||
.map(crate::torn::Response::from_response)
|
||||
}
|
||||
|
||||
#[cfg(feature = "torn")]
|
||||
|
@ -181,9 +166,6 @@ where
|
|||
self.executor
|
||||
.execute_many(self.client, builder.request, Vec::from_iter(ids))
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(crate::torn::Response::from_response)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "key")]
|
||||
|
@ -199,7 +181,6 @@ where
|
|||
self.executor
|
||||
.execute(self.client, builder.request, builder.id)
|
||||
.await
|
||||
.map(crate::key::Response::from_response)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -215,7 +196,7 @@ where
|
|||
client: &C,
|
||||
request: ApiRequest<A>,
|
||||
id: Option<String>,
|
||||
) -> Result<ApiResponse, Self::Error>
|
||||
) -> Result<A::Response, Self::Error>
|
||||
where
|
||||
A: ApiSelection;
|
||||
|
||||
|
@ -224,7 +205,7 @@ where
|
|||
client: &C,
|
||||
request: ApiRequest<A>,
|
||||
ids: Vec<I>,
|
||||
) -> HashMap<I, Result<ApiResponse, Self::Error>>
|
||||
) -> HashMap<I, Result<A::Response, Self::Error>>
|
||||
where
|
||||
A: ApiSelection,
|
||||
I: ToString + std::hash::Hash + std::cmp::Eq;
|
||||
|
@ -242,7 +223,7 @@ where
|
|||
client: &C,
|
||||
request: ApiRequest<A>,
|
||||
id: Option<String>,
|
||||
) -> Result<ApiResponse, Self::Error>
|
||||
) -> Result<A::Response, Self::Error>
|
||||
where
|
||||
A: ApiSelection,
|
||||
{
|
||||
|
@ -250,7 +231,7 @@ where
|
|||
|
||||
let value = client.request(url).await.map_err(ApiClientError::Client)?;
|
||||
|
||||
Ok(ApiResponse::from_value(value)?)
|
||||
Ok(ApiResponse::from_value(value)?.into())
|
||||
}
|
||||
|
||||
async fn execute_many<A, I>(
|
||||
|
@ -258,7 +239,7 @@ where
|
|||
client: &C,
|
||||
request: ApiRequest<A>,
|
||||
ids: Vec<I>,
|
||||
) -> HashMap<I, Result<ApiResponse, Self::Error>>
|
||||
) -> HashMap<I, Result<A::Response, Self::Error>>
|
||||
where
|
||||
A: ApiSelection,
|
||||
I: ToString + std::hash::Hash + std::cmp::Eq,
|
||||
|
@ -272,7 +253,11 @@ where
|
|||
|
||||
(
|
||||
i,
|
||||
value.and_then(|v| ApiResponse::from_value(v).map_err(Into::into)),
|
||||
value.and_then(|v| {
|
||||
ApiResponse::from_value(v)
|
||||
.map(Into::into)
|
||||
.map_err(Into::into)
|
||||
}),
|
||||
)
|
||||
}))
|
||||
.await;
|
||||
|
|
|
@ -2,9 +2,7 @@ use std::collections::HashMap;
|
|||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{
|
||||
ApiCategoryResponse, ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor,
|
||||
};
|
||||
use crate::{ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor};
|
||||
|
||||
pub struct ApiProvider<'a, C, E>
|
||||
where
|
||||
|
@ -37,7 +35,6 @@ where
|
|||
self.executor
|
||||
.execute(self.client, builder.request, builder.id)
|
||||
.await
|
||||
.map(crate::user::Response::from_response)
|
||||
}
|
||||
|
||||
#[cfg(feature = "user")]
|
||||
|
@ -59,9 +56,6 @@ where
|
|||
self.executor
|
||||
.execute_many(self.client, builder.request, Vec::from_iter(ids))
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(crate::user::Response::from_response)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "faction")]
|
||||
|
@ -77,7 +71,6 @@ where
|
|||
self.executor
|
||||
.execute(self.client, builder.request, builder.id)
|
||||
.await
|
||||
.map(crate::faction::Response::from_response)
|
||||
}
|
||||
|
||||
#[cfg(feature = "faction")]
|
||||
|
@ -99,9 +92,6 @@ where
|
|||
self.executor
|
||||
.execute_many(self.client, builder.request, Vec::from_iter(ids))
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(crate::faction::Response::from_response)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "market")]
|
||||
|
@ -117,7 +107,6 @@ where
|
|||
self.executor
|
||||
.execute(self.client, builder.request, builder.id)
|
||||
.await
|
||||
.map(crate::market::Response::from_response)
|
||||
}
|
||||
|
||||
#[cfg(feature = "market")]
|
||||
|
@ -139,9 +128,6 @@ where
|
|||
self.executor
|
||||
.execute_many(self.client, builder.request, Vec::from_iter(ids))
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(crate::market::Response::from_response)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "torn")]
|
||||
|
@ -157,7 +143,6 @@ where
|
|||
self.executor
|
||||
.execute(self.client, builder.request, builder.id)
|
||||
.await
|
||||
.map(crate::torn::Response::from_response)
|
||||
}
|
||||
|
||||
#[cfg(feature = "torn")]
|
||||
|
@ -179,9 +164,6 @@ where
|
|||
self.executor
|
||||
.execute_many(self.client, builder.request, Vec::from_iter(ids))
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(crate::torn::Response::from_response)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "key")]
|
||||
|
@ -197,7 +179,6 @@ where
|
|||
self.executor
|
||||
.execute(self.client, builder.request, builder.id)
|
||||
.await
|
||||
.map(crate::key::Response::from_response)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -213,7 +194,7 @@ where
|
|||
client: &C,
|
||||
request: ApiRequest<A>,
|
||||
id: Option<String>,
|
||||
) -> Result<ApiResponse, Self::Error>
|
||||
) -> Result<A::Response, Self::Error>
|
||||
where
|
||||
A: ApiSelection;
|
||||
|
||||
|
@ -222,7 +203,7 @@ where
|
|||
client: &C,
|
||||
request: ApiRequest<A>,
|
||||
ids: Vec<I>,
|
||||
) -> HashMap<I, Result<ApiResponse, Self::Error>>
|
||||
) -> HashMap<I, Result<A::Response, Self::Error>>
|
||||
where
|
||||
A: ApiSelection,
|
||||
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync;
|
||||
|
@ -240,7 +221,7 @@ where
|
|||
client: &C,
|
||||
request: ApiRequest<A>,
|
||||
id: Option<String>,
|
||||
) -> Result<ApiResponse, Self::Error>
|
||||
) -> Result<A::Response, Self::Error>
|
||||
where
|
||||
A: ApiSelection,
|
||||
{
|
||||
|
@ -248,7 +229,7 @@ where
|
|||
|
||||
let value = client.request(url).await.map_err(ApiClientError::Client)?;
|
||||
|
||||
Ok(ApiResponse::from_value(value)?)
|
||||
Ok(ApiResponse::from_value(value)?.into())
|
||||
}
|
||||
|
||||
async fn execute_many<A, I>(
|
||||
|
@ -256,7 +237,7 @@ where
|
|||
client: &C,
|
||||
request: ApiRequest<A>,
|
||||
ids: Vec<I>,
|
||||
) -> HashMap<I, Result<ApiResponse, Self::Error>>
|
||||
) -> HashMap<I, Result<A::Response, Self::Error>>
|
||||
where
|
||||
A: ApiSelection,
|
||||
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
|
||||
|
@ -270,7 +251,11 @@ where
|
|||
|
||||
(
|
||||
i,
|
||||
value.and_then(|v| ApiResponse::from_value(v).map_err(Into::into)),
|
||||
value.and_then(|v| {
|
||||
ApiResponse::from_value(v)
|
||||
.map(Into::into)
|
||||
.map_err(Into::into)
|
||||
}),
|
||||
)
|
||||
}))
|
||||
.await;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "torn-key-pool"
|
||||
version = "0.7.0"
|
||||
version = "0.8.0"
|
||||
edition = "2021"
|
||||
authors = ["Pyrit [2111649]"]
|
||||
license = "MIT"
|
||||
|
@ -17,7 +17,7 @@ tokio-runtime = [ "dep:tokio", "dep:rand" ]
|
|||
actix-runtime = [ "dep:actix-rt", "dep:rand" ]
|
||||
|
||||
[dependencies]
|
||||
torn-api = { path = "../torn-api", default-features = false, version = "0.6" }
|
||||
torn-api = { path = "../torn-api", default-features = false, version = "0.7" }
|
||||
async-trait = "0.1"
|
||||
thiserror = "1"
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#[cfg(feature = "postgres")]
|
||||
pub mod postgres;
|
||||
|
||||
pub mod local;
|
||||
// pub mod local;
|
||||
pub mod send;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
@ -16,11 +16,11 @@ use torn_api::ResponseError;
|
|||
#[derive(Debug, Error)]
|
||||
pub enum KeyPoolError<S, C>
|
||||
where
|
||||
S: std::error::Error,
|
||||
S: std::error::Error + Clone,
|
||||
C: std::error::Error,
|
||||
{
|
||||
#[error("Key pool storage driver error: {0:?}")]
|
||||
Storage(#[source] Arc<S>),
|
||||
Storage(#[source] S),
|
||||
|
||||
#[error(transparent)]
|
||||
Client(#[from] C),
|
||||
|
@ -31,7 +31,7 @@ where
|
|||
|
||||
impl<S, C> KeyPoolError<S, C>
|
||||
where
|
||||
S: std::error::Error,
|
||||
S: std::error::Error + Clone,
|
||||
C: std::error::Error,
|
||||
{
|
||||
#[inline(always)]
|
||||
|
@ -49,9 +49,16 @@ pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone {
|
|||
fn value(&self) -> &str;
|
||||
|
||||
fn id(&self) -> Self::IdType;
|
||||
|
||||
fn selector<D>(&self) -> KeySelector<Self, D>
|
||||
where
|
||||
D: KeyDomain,
|
||||
{
|
||||
KeySelector::Id(self.id())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
|
||||
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
|
||||
fn fallback(&self) -> Option<Self> {
|
||||
None
|
||||
}
|
||||
|
@ -66,7 +73,7 @@ where
|
|||
Key(String),
|
||||
Id(K::IdType),
|
||||
UserId(i32),
|
||||
Has(D),
|
||||
Has(Vec<D>),
|
||||
OneOf(Vec<D>),
|
||||
}
|
||||
|
||||
|
@ -78,7 +85,14 @@ where
|
|||
pub(crate) fn fallback(&self) -> Option<Self> {
|
||||
match self {
|
||||
Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
|
||||
Self::Has(domain) => domain.fallback().map(Self::Has),
|
||||
Self::Has(domains) => {
|
||||
let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
|
||||
if fallbacks.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Self::Has(fallbacks))
|
||||
}
|
||||
}
|
||||
Self::OneOf(domains) => {
|
||||
let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
|
||||
if fallbacks.is_empty() {
|
||||
|
@ -105,7 +119,7 @@ where
|
|||
D: KeyDomain,
|
||||
{
|
||||
fn into_selector(self) -> KeySelector<K, D> {
|
||||
KeySelector::Has(self)
|
||||
KeySelector::Has(vec![self])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -119,11 +133,20 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub enum KeyAction<D>
|
||||
where
|
||||
D: KeyDomain,
|
||||
{
|
||||
Delete,
|
||||
RemoveDomain(D),
|
||||
Timeout(chrono::Duration),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait KeyPoolStorage {
|
||||
type Key: ApiKey;
|
||||
type Domain: KeyDomain;
|
||||
type Error: std::error::Error + Sync + Send;
|
||||
type Error: std::error::Error + Sync + Send + Clone;
|
||||
|
||||
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
|
@ -183,13 +206,20 @@ pub trait KeyPoolStorage {
|
|||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct PoolOptions {
|
||||
comment: Option<String>,
|
||||
hooks_before: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
|
||||
hooks_after: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KeyPoolExecutor<'a, C, S>
|
||||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
storage: &'a S,
|
||||
comment: Option<&'a str>,
|
||||
options: Arc<PoolOptions>,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
_marker: std::marker::PhantomData<C>,
|
||||
}
|
||||
|
@ -198,15 +228,15 @@ impl<'a, C, S> KeyPoolExecutor<'a, C, S>
|
|||
where
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(
|
||||
fn new(
|
||||
storage: &'a S,
|
||||
selector: KeySelector<S::Key, S::Domain>,
|
||||
comment: Option<&'a str>,
|
||||
options: Arc<PoolOptions>,
|
||||
) -> Self {
|
||||
Self {
|
||||
storage,
|
||||
selector,
|
||||
comment,
|
||||
options,
|
||||
_marker: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use indoc::indoc;
|
||||
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
|
||||
|
@ -15,13 +17,13 @@ impl<T> PgKeyDomain for T where
|
|||
{
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[derive(Debug, Error, Clone)]
|
||||
pub enum PgStorageError<D>
|
||||
where
|
||||
D: PgKeyDomain,
|
||||
{
|
||||
#[error(transparent)]
|
||||
Pg(#[from] sqlx::Error),
|
||||
Pg(Arc<sqlx::Error>),
|
||||
|
||||
#[error("No key avalaible for domain {0:?}")]
|
||||
Unavailable(KeySelector<PgKey<D>, D>),
|
||||
|
@ -30,6 +32,15 @@ where
|
|||
KeyNotFound(KeySelector<PgKey<D>, D>),
|
||||
}
|
||||
|
||||
impl<D> From<sqlx::Error> for PgStorageError<D>
|
||||
where
|
||||
D: PgKeyDomain,
|
||||
{
|
||||
fn from(value: sqlx::Error) -> Self {
|
||||
Self::Pg(Arc::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow)]
|
||||
pub struct PgKey<D>
|
||||
where
|
||||
|
@ -53,9 +64,9 @@ fn build_predicate<'b, D>(
|
|||
KeySelector::Id(id) => builder.push("id=").push_bind(id),
|
||||
KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
|
||||
KeySelector::Key(key) => builder.push("key=").push_bind(key),
|
||||
KeySelector::Has(domain) => builder
|
||||
KeySelector::Has(domains) => builder
|
||||
.push("domains @> ")
|
||||
.push_bind(sqlx::types::Json(vec![domain])),
|
||||
.push_bind(sqlx::types::Json(domains)),
|
||||
KeySelector::OneOf(domains) => {
|
||||
if domains.is_empty() {
|
||||
builder.push("false");
|
||||
|
@ -607,15 +618,12 @@ where
|
|||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test {
|
||||
use std::sync::{Arc, Once};
|
||||
use std::sync::Arc;
|
||||
|
||||
use sqlx::Row;
|
||||
use tokio::test;
|
||||
|
||||
use super::*;
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub(crate) enum Domain {
|
||||
|
@ -634,15 +642,7 @@ pub(crate) mod test {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn setup() -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
|
||||
INIT.call_once(|| {
|
||||
dotenv::dotenv().ok();
|
||||
});
|
||||
|
||||
let pool = PgPool::connect(&std::env::var("DATABASE_URL").unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
pub(crate) async fn setup(pool: PgPool) -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
|
||||
sqlx::query("DROP TABLE IF EXISTS api_keys")
|
||||
.execute(&pool)
|
||||
.await
|
||||
|
@ -659,18 +659,18 @@ pub(crate) mod test {
|
|||
(storage, key)
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_initialise() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_initialise(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
if let Err(e) = storage.initialise().await {
|
||||
panic!("Initialising key storage failed: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_store_duplicate_key() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_store_duplicate_key(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
let key = storage
|
||||
.store_key(1, key.key, vec![Domain::User { id: 1 }])
|
||||
.await
|
||||
|
@ -679,9 +679,9 @@ pub(crate) mod test {
|
|||
assert_eq!(key.domains.0.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_store_duplicate_key_duplicate_domain() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_store_duplicate_key_duplicate_domain(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
let key = storage
|
||||
.store_key(1, key.key, vec![Domain::All])
|
||||
.await
|
||||
|
@ -690,9 +690,9 @@ pub(crate) mod test {
|
|||
assert_eq!(key.domains.0.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_add_domain() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_add_domain(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
let key = storage
|
||||
.add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
|
||||
.await
|
||||
|
@ -701,9 +701,9 @@ pub(crate) mod test {
|
|||
assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_add_domain_id() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_add_domain_id(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
let key = storage
|
||||
.add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
|
||||
.await
|
||||
|
@ -712,9 +712,9 @@ pub(crate) mod test {
|
|||
assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_add_duplicate_domain() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_add_duplicate_domain(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
let key = storage
|
||||
.add_domain_to_key(KeySelector::Key(key.key), Domain::All)
|
||||
.await
|
||||
|
@ -729,9 +729,9 @@ pub(crate) mod test {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_remove_domain() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_remove_domain(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
storage
|
||||
.add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
|
||||
.await
|
||||
|
@ -744,9 +744,9 @@ pub(crate) mod test {
|
|||
assert_eq!(key.domains.0, vec![Domain::All]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_remove_domain_id() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_remove_domain_id(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
storage
|
||||
.add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
|
||||
.await
|
||||
|
@ -759,9 +759,9 @@ pub(crate) mod test {
|
|||
assert_eq!(key.domains.0, vec![Domain::All]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_remove_last_domain() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_remove_last_domain(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
let key = storage
|
||||
.remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
|
||||
.await
|
||||
|
@ -770,9 +770,9 @@ pub(crate) mod test {
|
|||
assert!(key.domains.0.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_store_key() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_store_key(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
let key = storage
|
||||
.store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
|
||||
.await
|
||||
|
@ -780,26 +780,26 @@ pub(crate) mod test {
|
|||
assert_eq!(key.value(), "ABCDABCDABCDABCD");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_read_user_keys() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_read_user_keys(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
|
||||
assert_eq!(keys.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn acquire_one() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn acquire_one(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
if let Err(e) = storage.acquire_key(Domain::All).await {
|
||||
panic!("Acquiring key failed: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn uses_spread() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn uses_spread(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
storage
|
||||
.store_key(1, "ABC".to_owned(), vec![Domain::All])
|
||||
.await
|
||||
|
@ -816,33 +816,37 @@ pub(crate) mod test {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_flag_key_one() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_flag_key_one(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
|
||||
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||
|
||||
match storage.acquire_key(Domain::All).await.unwrap_err() {
|
||||
PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))),
|
||||
PgStorageError::Unavailable(KeySelector::Has(domains)) => {
|
||||
assert_eq!(domains, vec![Domain::All])
|
||||
}
|
||||
why => panic!("Expected domain unavailable error but found '{why}'"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_flag_key_many() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_flag_key_many(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
|
||||
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||
|
||||
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
|
||||
PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))),
|
||||
PgStorageError::Unavailable(KeySelector::Has(domains)) => {
|
||||
assert_eq!(domains, vec![Domain::All])
|
||||
}
|
||||
why => panic!("Expected domain unavailable error but found '{why}'"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn acquire_many() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn acquire_many(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
match storage.acquire_many_keys(Domain::All, 30).await {
|
||||
Err(e) => panic!("Acquiring key failed: {:?}", e),
|
||||
|
@ -851,9 +855,9 @@ pub(crate) mod test {
|
|||
}
|
||||
|
||||
// HACK: this test is time sensitive and will fail if runs at the top of the minute
|
||||
#[test]
|
||||
async fn test_concurrent() {
|
||||
let storage = Arc::new(setup().await.0);
|
||||
#[sqlx::test]
|
||||
async fn test_concurrent(pool: PgPool) {
|
||||
let storage = Arc::new(setup(pool).await.0);
|
||||
|
||||
for _ in 0..10 {
|
||||
let mut set = tokio::task::JoinSet::new();
|
||||
|
@ -884,9 +888,9 @@ pub(crate) mod test {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_concurrent_spread() {
|
||||
let storage = Arc::new(setup().await.0);
|
||||
#[sqlx::test]
|
||||
async fn test_concurrent_spread(pool: PgPool) {
|
||||
let storage = Arc::new(setup(pool).await.0);
|
||||
|
||||
for i in 0..24 {
|
||||
storage
|
||||
|
@ -923,10 +927,11 @@ pub(crate) mod test {
|
|||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// HACK: this test is time sensitive and will fail if runs at the top of the minute
|
||||
#[test]
|
||||
async fn test_concurrent_many() {
|
||||
let storage = Arc::new(setup().await.0);
|
||||
#[sqlx::test]
|
||||
async fn test_concurrent_many(pool: PgPool) {
|
||||
let storage = Arc::new(setup(pool).await.0);
|
||||
for _ in 0..10 {
|
||||
let mut set = tokio::task::JoinSet::new();
|
||||
|
||||
|
@ -956,73 +961,73 @@ pub(crate) mod test {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn read_key() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn read_key(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
|
||||
let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
|
||||
assert!(key.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn read_key_id() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn read_key_id(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
|
||||
let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap();
|
||||
assert!(key.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn read_nonexistent_key() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn read_nonexistent_key(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let key = storage.read_key(KeySelector::Id(-1)).await.unwrap();
|
||||
assert!(key.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn query_key() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn query_key(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let key = storage.read_key(Domain::All).await.unwrap();
|
||||
assert!(key.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn query_nonexistent_key() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn query_nonexistent_key(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
|
||||
assert!(key.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn query_all() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn query_all(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let keys = storage.read_keys(Domain::All).await.unwrap();
|
||||
assert!(keys.len() == 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn query_by_id() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn query_by_id(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
|
||||
|
||||
assert!(key.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn query_by_key() {
|
||||
let (storage, key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn query_by_key(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
|
||||
|
||||
assert!(key.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn query_by_set() {
|
||||
let (storage, _key) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn query_by_set(pool: PgPool) {
|
||||
let (storage, _key) = setup(pool).await;
|
||||
let key = storage
|
||||
.read_key(KeySelector::OneOf(vec![
|
||||
Domain::All,
|
||||
|
@ -1034,4 +1039,45 @@ pub(crate) mod test {
|
|||
|
||||
assert!(key.is_some());
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn all_selector(pool: PgPool) {
|
||||
let (storage, key) = setup(pool).await;
|
||||
|
||||
storage
|
||||
.add_domain_to_key(key.selector(), Domain::Faction { id: 1 })
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let key = storage
|
||||
.read_key(KeySelector::Has(vec![
|
||||
Domain::Faction { id: 1 },
|
||||
Domain::All,
|
||||
]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(key.is_some());
|
||||
|
||||
let key = storage
|
||||
.read_key(KeySelector::Has(vec![
|
||||
Domain::All,
|
||||
Domain::Faction { id: 1 },
|
||||
]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(key.is_some());
|
||||
|
||||
let key = storage
|
||||
.read_key(KeySelector::Has(vec![
|
||||
Domain::All,
|
||||
Domain::Faction { id: 2 },
|
||||
Domain::Faction { id: 1 },
|
||||
]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(key.is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,9 @@ use torn_api::{
|
|||
ApiRequest, ApiResponse, ApiSelection, ResponseError,
|
||||
};
|
||||
|
||||
use crate::{ApiKey, IntoSelector, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
|
||||
use crate::{
|
||||
ApiKey, IntoSelector, KeyAction, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, PoolOptions,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
||||
|
@ -22,17 +24,22 @@ where
|
|||
client: &C,
|
||||
mut request: ApiRequest<A>,
|
||||
id: Option<String>,
|
||||
) -> Result<ApiResponse, Self::Error>
|
||||
) -> Result<A::Response, Self::Error>
|
||||
where
|
||||
A: ApiSelection,
|
||||
{
|
||||
request.comment = self.comment.map(ToOwned::to_owned);
|
||||
request.comment = self.options.comment.clone();
|
||||
if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) {
|
||||
let concrete = hook.downcast_ref::<BeforeHook<A>>().unwrap();
|
||||
|
||||
(concrete.body)(&mut request);
|
||||
}
|
||||
loop {
|
||||
let key = self
|
||||
.storage
|
||||
.acquire_key(self.selector.clone())
|
||||
.await
|
||||
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
|
||||
.map_err(KeyPoolError::Storage)?;
|
||||
let url = request.url(key.value(), id.as_deref());
|
||||
let value = client.request(url).await?;
|
||||
|
||||
|
@ -42,14 +49,37 @@ where
|
|||
.storage
|
||||
.flag_key(key, code)
|
||||
.await
|
||||
.map_err(Arc::new)
|
||||
.map_err(KeyPoolError::Storage)?
|
||||
{
|
||||
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
|
||||
}
|
||||
}
|
||||
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
|
||||
Ok(res) => return Ok(res),
|
||||
Ok(res) => {
|
||||
let res = res.into();
|
||||
if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) {
|
||||
let concrete = hook.downcast_ref::<AfterHook<A, S::Domain>>().unwrap();
|
||||
|
||||
match (concrete.body)(&res) {
|
||||
Err(KeyAction::Delete) => {
|
||||
self.storage
|
||||
.remove_key(key.selector())
|
||||
.await
|
||||
.map_err(KeyPoolError::Storage)?;
|
||||
continue;
|
||||
}
|
||||
Err(KeyAction::RemoveDomain(domain)) => {
|
||||
self.storage
|
||||
.remove_domain_from_key(key.selector(), domain)
|
||||
.await
|
||||
.map_err(KeyPoolError::Storage)?;
|
||||
continue;
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
return Ok(res);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -59,7 +89,7 @@ where
|
|||
client: &C,
|
||||
mut request: ApiRequest<A>,
|
||||
ids: Vec<I>,
|
||||
) -> HashMap<I, Result<ApiResponse, Self::Error>>
|
||||
) -> HashMap<I, Result<A::Response, Self::Error>>
|
||||
where
|
||||
A: ApiSelection,
|
||||
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
|
||||
|
@ -71,15 +101,14 @@ where
|
|||
{
|
||||
Ok(keys) => keys,
|
||||
Err(why) => {
|
||||
let shared = Arc::new(why);
|
||||
return ids
|
||||
.into_iter()
|
||||
.map(|i| (i, Err(Self::Error::Storage(shared.clone()))))
|
||||
.map(|i| (i, Err(Self::Error::Storage(why.clone()))))
|
||||
.collect();
|
||||
}
|
||||
};
|
||||
|
||||
request.comment = self.comment.map(ToOwned::to_owned);
|
||||
request.comment = self.options.comment.clone();
|
||||
let request_ref = &request;
|
||||
|
||||
let tuples =
|
||||
|
@ -105,18 +134,18 @@ where
|
|||
)
|
||||
}
|
||||
Ok(true) => (),
|
||||
Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))),
|
||||
Err(why) => return (id, Err(KeyPoolError::Storage(why))),
|
||||
}
|
||||
}
|
||||
Err(parsing_error) => {
|
||||
return (id, Err(KeyPoolError::Response(parsing_error)))
|
||||
}
|
||||
Ok(res) => return (id, Ok(res)),
|
||||
Ok(res) => return (id, Ok(res.into())),
|
||||
};
|
||||
|
||||
key = match self.storage.acquire_key(self.selector.clone()).await {
|
||||
Ok(k) => k,
|
||||
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
|
||||
Err(why) => return (id, Err(Self::Error::Storage(why))),
|
||||
};
|
||||
}
|
||||
}))
|
||||
|
@ -126,6 +155,92 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct BeforeHook<A>
|
||||
where
|
||||
A: ApiSelection,
|
||||
{
|
||||
body: Box<dyn Fn(&mut ApiRequest<A>) + Send + Sync + 'static>,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct AfterHook<A, D>
|
||||
where
|
||||
A: ApiSelection,
|
||||
D: crate::KeyDomain,
|
||||
{
|
||||
body: Box<dyn Fn(&A::Response) -> Result<(), crate::KeyAction<D>> + Send + Sync + 'static>,
|
||||
}
|
||||
|
||||
pub struct PoolBuilder<C, S>
|
||||
where
|
||||
C: ApiClient,
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
client: C,
|
||||
storage: S,
|
||||
options: crate::PoolOptions,
|
||||
}
|
||||
|
||||
impl<C, S> PoolBuilder<C, S>
|
||||
where
|
||||
C: ApiClient,
|
||||
S: KeyPoolStorage,
|
||||
{
|
||||
pub fn new(client: C, storage: S) -> Self {
|
||||
Self {
|
||||
client,
|
||||
storage,
|
||||
options: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn comment(mut self, c: impl ToString) -> Self {
|
||||
self.options.comment = Some(c.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn hook_before<A>(
|
||||
mut self,
|
||||
hook: impl Fn(&mut ApiRequest<A>) + Send + Sync + 'static,
|
||||
) -> Self
|
||||
where
|
||||
A: ApiSelection + 'static,
|
||||
{
|
||||
self.options.hooks_before.insert(
|
||||
std::any::TypeId::of::<A>(),
|
||||
Box::new(BeforeHook {
|
||||
body: Box::new(hook),
|
||||
}),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn hook_after<A>(
|
||||
mut self,
|
||||
hook: impl Fn(&A::Response) -> Result<(), KeyAction<S::Domain>> + Send + Sync + 'static,
|
||||
) -> Self
|
||||
where
|
||||
A: ApiSelection + 'static,
|
||||
{
|
||||
self.options.hooks_after.insert(
|
||||
std::any::TypeId::of::<A>(),
|
||||
Box::new(AfterHook::<A, S::Domain> {
|
||||
body: Box::new(hook),
|
||||
}),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> KeyPool<C, S> {
|
||||
KeyPool {
|
||||
client: self.client,
|
||||
storage: self.storage,
|
||||
options: Arc::new(self.options),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct KeyPool<C, S>
|
||||
where
|
||||
|
@ -134,7 +249,7 @@ where
|
|||
{
|
||||
client: C,
|
||||
pub storage: S,
|
||||
comment: Option<String>,
|
||||
options: Arc<PoolOptions>,
|
||||
}
|
||||
|
||||
impl<C, S> KeyPool<C, S>
|
||||
|
@ -142,14 +257,6 @@ where
|
|||
C: ApiClient,
|
||||
S: KeyPoolStorage + Send + Sync + 'static,
|
||||
{
|
||||
pub fn new(client: C, storage: S, comment: Option<String>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
storage,
|
||||
comment,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
|
||||
where
|
||||
I: IntoSelector<S::Key, S::Domain>,
|
||||
|
@ -159,7 +266,7 @@ where
|
|||
KeyPoolExecutor::new(
|
||||
&self.storage,
|
||||
selector.into_selector(),
|
||||
self.comment.as_deref(),
|
||||
self.options.clone(),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
@ -178,7 +285,7 @@ pub trait WithStorage {
|
|||
{
|
||||
ApiProvider::new(
|
||||
self,
|
||||
KeyPoolExecutor::new(storage, selector.into_selector(), None),
|
||||
KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -188,27 +295,28 @@ impl WithStorage for reqwest::Client {}
|
|||
|
||||
#[cfg(all(test, feature = "postgres", feature = "reqwest"))]
|
||||
mod test {
|
||||
use tokio::test;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use super::*;
|
||||
use crate::postgres::test::{setup, Domain};
|
||||
use crate::{
|
||||
postgres::test::{setup, Domain},
|
||||
KeySelector,
|
||||
};
|
||||
|
||||
#[test]
|
||||
async fn test_pool_request() {
|
||||
let (storage, _) = setup().await;
|
||||
let pool = KeyPool::new(
|
||||
reqwest::Client::default(),
|
||||
storage,
|
||||
Some("api.rs".to_owned()),
|
||||
);
|
||||
#[sqlx::test]
|
||||
async fn test_pool_request(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
|
||||
.comment("api.rs")
|
||||
.build();
|
||||
|
||||
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
|
||||
_ = response.profile().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn test_with_storage_request() {
|
||||
let (storage, _) = setup().await;
|
||||
#[sqlx::test]
|
||||
async fn test_with_storage_request(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let response = reqwest::Client::new()
|
||||
.with_storage(&storage, Domain::All)
|
||||
|
@ -217,4 +325,36 @@ mod test {
|
|||
.unwrap();
|
||||
_ = response.profile().unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn before_hook(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
|
||||
.hook_before::<torn_api::user::UserSelection>(|req| {
|
||||
req.selections.push("crimes");
|
||||
})
|
||||
.build();
|
||||
|
||||
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
|
||||
_ = response.crimes().unwrap();
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn after_hook(pool: PgPool) {
|
||||
let (storage, _) = setup(pool).await;
|
||||
|
||||
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
|
||||
.hook_after::<torn_api::user::UserSelection>(|_res| Err(KeyAction::Delete))
|
||||
.build();
|
||||
|
||||
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
|
||||
assert!(key.is_some());
|
||||
|
||||
let response = pool.torn_api(Domain::All).user(|b| b).await;
|
||||
assert!(matches!(response, Err(KeyPoolError::Storage(_))));
|
||||
|
||||
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
|
||||
assert!(key.is_none());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue