major refactoring

This commit is contained in:
TotallyNot 2024-04-04 15:59:10 +02:00
parent 01bbe37876
commit 75fc19d0f7
10 changed files with 404 additions and 222 deletions

View file

@ -1,6 +1,6 @@
[package]
name = "torn-api-macros"
version = "0.2.0"
version = "0.3.0"
edition = "2021"
authors = ["Pyrit [2111649]"]
license = "MIT"

View file

@ -147,15 +147,15 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream {
#(#accessors)*
}
impl crate::ApiCategoryResponse for Response {
type Selection = #name;
fn from_response(response: crate::ApiResponse) -> Self {
Self(response)
impl From<crate::ApiResponse> for Response {
fn from(value: crate::ApiResponse) -> Self {
Self(value)
}
}
impl crate::ApiSelection for #name {
type Response = Response;
fn raw_value(self) -> &'static str {
match self {
#(#raw_values,)*

View file

@ -1,6 +1,6 @@
[package]
name = "torn-api"
version = "0.6.7"
version = "0.7.0"
edition = "2021"
rust-version = "1.75.0"
authors = ["Pyrit [2111649]"]
@ -39,7 +39,7 @@ reqwest = { version = "0.11", default-features = false, features = [ "json" ], o
awc = { version = "3", default-features = false, optional = true }
rust_decimal = { version = "1", default-features = false, optional = true, features = [ "serde" ] }
torn-api-macros = { path = "../torn-api-macros", version = "0.2" }
torn-api-macros = { path = "../torn-api-macros", version = "0.3" }
[dev-dependencies]
actix-rt = { version = "2.7.0" }

View file

@ -111,18 +111,14 @@ impl ApiResponse {
}
}
pub trait ApiSelection: Send + Sync {
pub trait ApiSelection: Send + Sync + 'static {
type Response: From<ApiResponse> + Send + Sync;
fn raw_value(self) -> &'static str;
fn category() -> &'static str;
}
pub trait ApiCategoryResponse: Send + Sync {
type Selection: ApiSelection;
fn from_response(response: ApiResponse) -> Self;
}
pub struct DirectExecutor<C> {
key: String,
_marker: std::marker::PhantomData<C>,

View file

@ -2,9 +2,7 @@ use std::collections::HashMap;
use async_trait::async_trait;
use crate::{
ApiCategoryResponse, ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor,
};
use crate::{ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor};
pub struct ApiProvider<'a, C, E>
where
@ -39,7 +37,6 @@ where
self.executor
.execute(self.client, builder.request, builder.id)
.await
.map(crate::user::Response::from_response)
}
#[cfg(feature = "user")]
@ -61,9 +58,6 @@ where
self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids))
.await
.into_iter()
.map(|(k, v)| (k, v.map(crate::user::Response::from_response)))
.collect()
}
#[cfg(feature = "faction")]
@ -79,7 +73,6 @@ where
self.executor
.execute(self.client, builder.request, builder.id)
.await
.map(crate::faction::Response::from_response)
}
#[cfg(feature = "faction")]
@ -101,9 +94,6 @@ where
self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids))
.await
.into_iter()
.map(|(k, v)| (k, v.map(crate::faction::Response::from_response)))
.collect()
}
#[cfg(feature = "market")]
@ -119,7 +109,6 @@ where
self.executor
.execute(self.client, builder.request, builder.id)
.await
.map(crate::market::Response::from_response)
}
#[cfg(feature = "market")]
@ -141,9 +130,6 @@ where
self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids))
.await
.into_iter()
.map(|(k, v)| (k, v.map(crate::market::Response::from_response)))
.collect()
}
#[cfg(feature = "torn")]
@ -159,7 +145,6 @@ where
self.executor
.execute(self.client, builder.request, builder.id)
.await
.map(crate::torn::Response::from_response)
}
#[cfg(feature = "torn")]
@ -181,9 +166,6 @@ where
self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids))
.await
.into_iter()
.map(|(k, v)| (k, v.map(crate::torn::Response::from_response)))
.collect()
}
#[cfg(feature = "key")]
@ -199,7 +181,6 @@ where
self.executor
.execute(self.client, builder.request, builder.id)
.await
.map(crate::key::Response::from_response)
}
}
@ -215,7 +196,7 @@ where
client: &C,
request: ApiRequest<A>,
id: Option<String>,
) -> Result<ApiResponse, Self::Error>
) -> Result<A::Response, Self::Error>
where
A: ApiSelection;
@ -224,7 +205,7 @@ where
client: &C,
request: ApiRequest<A>,
ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>>
) -> HashMap<I, Result<A::Response, Self::Error>>
where
A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq;
@ -242,7 +223,7 @@ where
client: &C,
request: ApiRequest<A>,
id: Option<String>,
) -> Result<ApiResponse, Self::Error>
) -> Result<A::Response, Self::Error>
where
A: ApiSelection,
{
@ -250,7 +231,7 @@ where
let value = client.request(url).await.map_err(ApiClientError::Client)?;
Ok(ApiResponse::from_value(value)?)
Ok(ApiResponse::from_value(value)?.into())
}
async fn execute_many<A, I>(
@ -258,7 +239,7 @@ where
client: &C,
request: ApiRequest<A>,
ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>>
) -> HashMap<I, Result<A::Response, Self::Error>>
where
A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq,
@ -272,7 +253,11 @@ where
(
i,
value.and_then(|v| ApiResponse::from_value(v).map_err(Into::into)),
value.and_then(|v| {
ApiResponse::from_value(v)
.map(Into::into)
.map_err(Into::into)
}),
)
}))
.await;

View file

@ -2,9 +2,7 @@ use std::collections::HashMap;
use async_trait::async_trait;
use crate::{
ApiCategoryResponse, ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor,
};
use crate::{ApiClientError, ApiRequest, ApiResponse, ApiSelection, DirectExecutor};
pub struct ApiProvider<'a, C, E>
where
@ -37,7 +35,6 @@ where
self.executor
.execute(self.client, builder.request, builder.id)
.await
.map(crate::user::Response::from_response)
}
#[cfg(feature = "user")]
@ -59,9 +56,6 @@ where
self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids))
.await
.into_iter()
.map(|(k, v)| (k, v.map(crate::user::Response::from_response)))
.collect()
}
#[cfg(feature = "faction")]
@ -77,7 +71,6 @@ where
self.executor
.execute(self.client, builder.request, builder.id)
.await
.map(crate::faction::Response::from_response)
}
#[cfg(feature = "faction")]
@ -99,9 +92,6 @@ where
self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids))
.await
.into_iter()
.map(|(k, v)| (k, v.map(crate::faction::Response::from_response)))
.collect()
}
#[cfg(feature = "market")]
@ -117,7 +107,6 @@ where
self.executor
.execute(self.client, builder.request, builder.id)
.await
.map(crate::market::Response::from_response)
}
#[cfg(feature = "market")]
@ -139,9 +128,6 @@ where
self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids))
.await
.into_iter()
.map(|(k, v)| (k, v.map(crate::market::Response::from_response)))
.collect()
}
#[cfg(feature = "torn")]
@ -157,7 +143,6 @@ where
self.executor
.execute(self.client, builder.request, builder.id)
.await
.map(crate::torn::Response::from_response)
}
#[cfg(feature = "torn")]
@ -179,9 +164,6 @@ where
self.executor
.execute_many(self.client, builder.request, Vec::from_iter(ids))
.await
.into_iter()
.map(|(k, v)| (k, v.map(crate::torn::Response::from_response)))
.collect()
}
#[cfg(feature = "key")]
@ -197,7 +179,6 @@ where
self.executor
.execute(self.client, builder.request, builder.id)
.await
.map(crate::key::Response::from_response)
}
}
@ -213,7 +194,7 @@ where
client: &C,
request: ApiRequest<A>,
id: Option<String>,
) -> Result<ApiResponse, Self::Error>
) -> Result<A::Response, Self::Error>
where
A: ApiSelection;
@ -222,7 +203,7 @@ where
client: &C,
request: ApiRequest<A>,
ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>>
) -> HashMap<I, Result<A::Response, Self::Error>>
where
A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync;
@ -240,7 +221,7 @@ where
client: &C,
request: ApiRequest<A>,
id: Option<String>,
) -> Result<ApiResponse, Self::Error>
) -> Result<A::Response, Self::Error>
where
A: ApiSelection,
{
@ -248,7 +229,7 @@ where
let value = client.request(url).await.map_err(ApiClientError::Client)?;
Ok(ApiResponse::from_value(value)?)
Ok(ApiResponse::from_value(value)?.into())
}
async fn execute_many<A, I>(
@ -256,7 +237,7 @@ where
client: &C,
request: ApiRequest<A>,
ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>>
) -> HashMap<I, Result<A::Response, Self::Error>>
where
A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
@ -270,7 +251,11 @@ where
(
i,
value.and_then(|v| ApiResponse::from_value(v).map_err(Into::into)),
value.and_then(|v| {
ApiResponse::from_value(v)
.map(Into::into)
.map_err(Into::into)
}),
)
}))
.await;

View file

@ -1,6 +1,6 @@
[package]
name = "torn-key-pool"
version = "0.7.0"
version = "0.8.0"
edition = "2021"
authors = ["Pyrit [2111649]"]
license = "MIT"
@ -17,7 +17,7 @@ tokio-runtime = [ "dep:tokio", "dep:rand" ]
actix-runtime = [ "dep:actix-rt", "dep:rand" ]
[dependencies]
torn-api = { path = "../torn-api", default-features = false, version = "0.6" }
torn-api = { path = "../torn-api", default-features = false, version = "0.7" }
async-trait = "0.1"
thiserror = "1"

View file

@ -3,7 +3,7 @@
#[cfg(feature = "postgres")]
pub mod postgres;
pub mod local;
// pub mod local;
pub mod send;
use std::sync::Arc;
@ -16,11 +16,11 @@ use torn_api::ResponseError;
#[derive(Debug, Error)]
pub enum KeyPoolError<S, C>
where
S: std::error::Error,
S: std::error::Error + Clone,
C: std::error::Error,
{
#[error("Key pool storage driver error: {0:?}")]
Storage(#[source] Arc<S>),
Storage(#[source] S),
#[error(transparent)]
Client(#[from] C),
@ -31,7 +31,7 @@ where
impl<S, C> KeyPoolError<S, C>
where
S: std::error::Error,
S: std::error::Error + Clone,
C: std::error::Error,
{
#[inline(always)]
@ -49,9 +49,16 @@ pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone {
fn value(&self) -> &str;
fn id(&self) -> Self::IdType;
fn selector<D>(&self) -> KeySelector<Self, D>
where
D: KeyDomain,
{
KeySelector::Id(self.id())
}
}
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync + 'static {
fn fallback(&self) -> Option<Self> {
None
}
@ -66,7 +73,7 @@ where
Key(String),
Id(K::IdType),
UserId(i32),
Has(D),
Has(Vec<D>),
OneOf(Vec<D>),
}
@ -78,7 +85,14 @@ where
pub(crate) fn fallback(&self) -> Option<Self> {
match self {
Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
Self::Has(domain) => domain.fallback().map(Self::Has),
Self::Has(domains) => {
let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
if fallbacks.is_empty() {
None
} else {
Some(Self::Has(fallbacks))
}
}
Self::OneOf(domains) => {
let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
if fallbacks.is_empty() {
@ -105,7 +119,7 @@ where
D: KeyDomain,
{
fn into_selector(self) -> KeySelector<K, D> {
KeySelector::Has(self)
KeySelector::Has(vec![self])
}
}
@ -119,11 +133,20 @@ where
}
}
pub enum KeyAction<D>
where
D: KeyDomain,
{
Delete,
RemoveDomain(D),
Timeout(chrono::Duration),
}
#[async_trait]
pub trait KeyPoolStorage {
type Key: ApiKey;
type Domain: KeyDomain;
type Error: std::error::Error + Sync + Send;
type Error: std::error::Error + Sync + Send + Clone;
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
where
@ -183,13 +206,20 @@ pub trait KeyPoolStorage {
S: IntoSelector<Self::Key, Self::Domain>;
}
#[derive(Debug, Default)]
struct PoolOptions {
comment: Option<String>,
hooks_before: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
hooks_after: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
}
#[derive(Debug, Clone)]
pub struct KeyPoolExecutor<'a, C, S>
where
S: KeyPoolStorage,
{
storage: &'a S,
comment: Option<&'a str>,
options: Arc<PoolOptions>,
selector: KeySelector<S::Key, S::Domain>,
_marker: std::marker::PhantomData<C>,
}
@ -198,15 +228,15 @@ impl<'a, C, S> KeyPoolExecutor<'a, C, S>
where
S: KeyPoolStorage,
{
pub fn new(
fn new(
storage: &'a S,
selector: KeySelector<S::Key, S::Domain>,
comment: Option<&'a str>,
options: Arc<PoolOptions>,
) -> Self {
Self {
storage,
selector,
comment,
options,
_marker: std::marker::PhantomData,
}
}

View file

@ -1,3 +1,5 @@
use std::sync::Arc;
use async_trait::async_trait;
use indoc::indoc;
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
@ -15,13 +17,13 @@ impl<T> PgKeyDomain for T where
{
}
#[derive(Debug, Error)]
#[derive(Debug, Error, Clone)]
pub enum PgStorageError<D>
where
D: PgKeyDomain,
{
#[error(transparent)]
Pg(#[from] sqlx::Error),
Pg(Arc<sqlx::Error>),
#[error("No key avalaible for domain {0:?}")]
Unavailable(KeySelector<PgKey<D>, D>),
@ -30,6 +32,15 @@ where
KeyNotFound(KeySelector<PgKey<D>, D>),
}
impl<D> From<sqlx::Error> for PgStorageError<D>
where
D: PgKeyDomain,
{
fn from(value: sqlx::Error) -> Self {
Self::Pg(Arc::new(value))
}
}
#[derive(Debug, Clone, FromRow)]
pub struct PgKey<D>
where
@ -53,9 +64,9 @@ fn build_predicate<'b, D>(
KeySelector::Id(id) => builder.push("id=").push_bind(id),
KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
KeySelector::Key(key) => builder.push("key=").push_bind(key),
KeySelector::Has(domain) => builder
KeySelector::Has(domains) => builder
.push("domains @> ")
.push_bind(sqlx::types::Json(vec![domain])),
.push_bind(sqlx::types::Json(domains)),
KeySelector::OneOf(domains) => {
if domains.is_empty() {
builder.push("false");
@ -607,15 +618,12 @@ where
#[cfg(test)]
pub(crate) mod test {
use std::sync::{Arc, Once};
use std::sync::Arc;
use sqlx::Row;
use tokio::test;
use super::*;
static INIT: Once = Once::new();
#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum Domain {
@ -634,15 +642,7 @@ pub(crate) mod test {
}
}
pub(crate) async fn setup() -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
INIT.call_once(|| {
dotenv::dotenv().ok();
});
let pool = PgPool::connect(&std::env::var("DATABASE_URL").unwrap())
.await
.unwrap();
pub(crate) async fn setup(pool: PgPool) -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
sqlx::query("DROP TABLE IF EXISTS api_keys")
.execute(&pool)
.await
@ -659,18 +659,18 @@ pub(crate) mod test {
(storage, key)
}
#[test]
async fn test_initialise() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn test_initialise(pool: PgPool) {
let (storage, _) = setup(pool).await;
if let Err(e) = storage.initialise().await {
panic!("Initialising key storage failed: {:?}", e);
}
}
#[test]
async fn test_store_duplicate_key() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn test_store_duplicate_key(pool: PgPool) {
let (storage, key) = setup(pool).await;
let key = storage
.store_key(1, key.key, vec![Domain::User { id: 1 }])
.await
@ -679,9 +679,9 @@ pub(crate) mod test {
assert_eq!(key.domains.0.len(), 2);
}
#[test]
async fn test_store_duplicate_key_duplicate_domain() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn test_store_duplicate_key_duplicate_domain(pool: PgPool) {
let (storage, key) = setup(pool).await;
let key = storage
.store_key(1, key.key, vec![Domain::All])
.await
@ -690,9 +690,9 @@ pub(crate) mod test {
assert_eq!(key.domains.0.len(), 1);
}
#[test]
async fn test_add_domain() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn test_add_domain(pool: PgPool) {
let (storage, key) = setup(pool).await;
let key = storage
.add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
.await
@ -701,9 +701,9 @@ pub(crate) mod test {
assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
}
#[test]
async fn test_add_domain_id() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn test_add_domain_id(pool: PgPool) {
let (storage, key) = setup(pool).await;
let key = storage
.add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
.await
@ -712,9 +712,9 @@ pub(crate) mod test {
assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
}
#[test]
async fn test_add_duplicate_domain() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn test_add_duplicate_domain(pool: PgPool) {
let (storage, key) = setup(pool).await;
let key = storage
.add_domain_to_key(KeySelector::Key(key.key), Domain::All)
.await
@ -729,9 +729,9 @@ pub(crate) mod test {
);
}
#[test]
async fn test_remove_domain() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn test_remove_domain(pool: PgPool) {
let (storage, key) = setup(pool).await;
storage
.add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
.await
@ -744,9 +744,9 @@ pub(crate) mod test {
assert_eq!(key.domains.0, vec![Domain::All]);
}
#[test]
async fn test_remove_domain_id() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn test_remove_domain_id(pool: PgPool) {
let (storage, key) = setup(pool).await;
storage
.add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
.await
@ -759,9 +759,9 @@ pub(crate) mod test {
assert_eq!(key.domains.0, vec![Domain::All]);
}
#[test]
async fn test_remove_last_domain() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn test_remove_last_domain(pool: PgPool) {
let (storage, key) = setup(pool).await;
let key = storage
.remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
.await
@ -770,9 +770,9 @@ pub(crate) mod test {
assert!(key.domains.0.is_empty());
}
#[test]
async fn test_store_key() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn test_store_key(pool: PgPool) {
let (storage, _) = setup(pool).await;
let key = storage
.store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
.await
@ -780,26 +780,26 @@ pub(crate) mod test {
assert_eq!(key.value(), "ABCDABCDABCDABCD");
}
#[test]
async fn test_read_user_keys() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn test_read_user_keys(pool: PgPool) {
let (storage, _) = setup(pool).await;
let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
assert_eq!(keys.len(), 1);
}
#[test]
async fn acquire_one() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn acquire_one(pool: PgPool) {
let (storage, _) = setup(pool).await;
if let Err(e) = storage.acquire_key(Domain::All).await {
panic!("Acquiring key failed: {:?}", e);
}
}
#[test]
async fn uses_spread() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn uses_spread(pool: PgPool) {
let (storage, _) = setup(pool).await;
storage
.store_key(1, "ABC".to_owned(), vec![Domain::All])
.await
@ -816,33 +816,37 @@ pub(crate) mod test {
}
}
#[test]
async fn test_flag_key_one() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn test_flag_key_one(pool: PgPool) {
let (storage, key) = setup(pool).await;
assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_key(Domain::All).await.unwrap_err() {
PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))),
PgStorageError::Unavailable(KeySelector::Has(domains)) => {
assert_eq!(domains, vec![Domain::All])
}
why => panic!("Expected domain unavailable error but found '{why}'"),
}
}
#[test]
async fn test_flag_key_many() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn test_flag_key_many(pool: PgPool) {
let (storage, key) = setup(pool).await;
assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))),
PgStorageError::Unavailable(KeySelector::Has(domains)) => {
assert_eq!(domains, vec![Domain::All])
}
why => panic!("Expected domain unavailable error but found '{why}'"),
}
}
#[test]
async fn acquire_many() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn acquire_many(pool: PgPool) {
let (storage, _) = setup(pool).await;
match storage.acquire_many_keys(Domain::All, 30).await {
Err(e) => panic!("Acquiring key failed: {:?}", e),
@ -851,9 +855,9 @@ pub(crate) mod test {
}
// HACK: this test is time sensitive and will fail if runs at the top of the minute
#[test]
async fn test_concurrent() {
let storage = Arc::new(setup().await.0);
#[sqlx::test]
async fn test_concurrent(pool: PgPool) {
let storage = Arc::new(setup(pool).await.0);
for _ in 0..10 {
let mut set = tokio::task::JoinSet::new();
@ -884,9 +888,9 @@ pub(crate) mod test {
}
}
#[test]
async fn test_concurrent_spread() {
let storage = Arc::new(setup().await.0);
#[sqlx::test]
async fn test_concurrent_spread(pool: PgPool) {
let storage = Arc::new(setup(pool).await.0);
for i in 0..24 {
storage
@ -923,10 +927,11 @@ pub(crate) mod test {
.unwrap();
}
}
// HACK: this test is time sensitive and will fail if runs at the top of the minute
#[test]
async fn test_concurrent_many() {
let storage = Arc::new(setup().await.0);
#[sqlx::test]
async fn test_concurrent_many(pool: PgPool) {
let storage = Arc::new(setup(pool).await.0);
for _ in 0..10 {
let mut set = tokio::task::JoinSet::new();
@ -956,73 +961,73 @@ pub(crate) mod test {
}
}
#[test]
async fn read_key() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn read_key(pool: PgPool) {
let (storage, key) = setup(pool).await;
let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
assert!(key.is_some());
}
#[test]
async fn read_key_id() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn read_key_id(pool: PgPool) {
let (storage, key) = setup(pool).await;
let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap();
assert!(key.is_some());
}
#[test]
async fn read_nonexistent_key() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn read_nonexistent_key(pool: PgPool) {
let (storage, _) = setup(pool).await;
let key = storage.read_key(KeySelector::Id(-1)).await.unwrap();
assert!(key.is_none());
}
#[test]
async fn query_key() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn query_key(pool: PgPool) {
let (storage, _) = setup(pool).await;
let key = storage.read_key(Domain::All).await.unwrap();
assert!(key.is_some());
}
#[test]
async fn query_nonexistent_key() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn query_nonexistent_key(pool: PgPool) {
let (storage, _) = setup(pool).await;
let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
assert!(key.is_none());
}
#[test]
async fn query_all() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn query_all(pool: PgPool) {
let (storage, _) = setup(pool).await;
let keys = storage.read_keys(Domain::All).await.unwrap();
assert!(keys.len() == 1);
}
#[test]
async fn query_by_id() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn query_by_id(pool: PgPool) {
let (storage, _) = setup(pool).await;
let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_some());
}
#[test]
async fn query_by_key() {
let (storage, key) = setup().await;
#[sqlx::test]
async fn query_by_key(pool: PgPool) {
let (storage, key) = setup(pool).await;
let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
assert!(key.is_some());
}
#[test]
async fn query_by_set() {
let (storage, _key) = setup().await;
#[sqlx::test]
async fn query_by_set(pool: PgPool) {
let (storage, _key) = setup(pool).await;
let key = storage
.read_key(KeySelector::OneOf(vec![
Domain::All,
@ -1034,4 +1039,45 @@ pub(crate) mod test {
assert!(key.is_some());
}
#[sqlx::test]
async fn all_selector(pool: PgPool) {
let (storage, key) = setup(pool).await;
storage
.add_domain_to_key(key.selector(), Domain::Faction { id: 1 })
.await
.unwrap();
let key = storage
.read_key(KeySelector::Has(vec![
Domain::Faction { id: 1 },
Domain::All,
]))
.await
.unwrap();
assert!(key.is_some());
let key = storage
.read_key(KeySelector::Has(vec![
Domain::All,
Domain::Faction { id: 1 },
]))
.await
.unwrap();
assert!(key.is_some());
let key = storage
.read_key(KeySelector::Has(vec![
Domain::All,
Domain::Faction { id: 2 },
Domain::Faction { id: 1 },
]))
.await
.unwrap();
assert!(key.is_none());
}
}

View file

@ -7,7 +7,9 @@ use torn_api::{
ApiRequest, ApiResponse, ApiSelection, ResponseError,
};
use crate::{ApiKey, IntoSelector, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
use crate::{
ApiKey, IntoSelector, KeyAction, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, PoolOptions,
};
#[async_trait]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
@ -22,17 +24,22 @@ where
client: &C,
mut request: ApiRequest<A>,
id: Option<String>,
) -> Result<ApiResponse, Self::Error>
) -> Result<A::Response, Self::Error>
where
A: ApiSelection,
{
request.comment = self.comment.map(ToOwned::to_owned);
request.comment = self.options.comment.clone();
if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) {
let concrete = hook.downcast_ref::<BeforeHook<A>>().unwrap();
(concrete.body)(&mut request);
}
loop {
let key = self
.storage
.acquire_key(self.selector.clone())
.await
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
.map_err(KeyPoolError::Storage)?;
let url = request.url(key.value(), id.as_deref());
let value = client.request(url).await?;
@ -42,14 +49,37 @@ where
.storage
.flag_key(key, code)
.await
.map_err(Arc::new)
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(res),
Ok(res) => {
let res = res.into();
if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) {
let concrete = hook.downcast_ref::<AfterHook<A, S::Domain>>().unwrap();
match (concrete.body)(&res) {
Err(KeyAction::Delete) => {
self.storage
.remove_key(key.selector())
.await
.map_err(KeyPoolError::Storage)?;
continue;
}
Err(KeyAction::RemoveDomain(domain)) => {
self.storage
.remove_domain_from_key(key.selector(), domain)
.await
.map_err(KeyPoolError::Storage)?;
continue;
}
_ => (),
};
}
return Ok(res);
}
};
}
}
@ -59,7 +89,7 @@ where
client: &C,
mut request: ApiRequest<A>,
ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>>
) -> HashMap<I, Result<A::Response, Self::Error>>
where
A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
@ -71,15 +101,14 @@ where
{
Ok(keys) => keys,
Err(why) => {
let shared = Arc::new(why);
return ids
.into_iter()
.map(|i| (i, Err(Self::Error::Storage(shared.clone()))))
.map(|i| (i, Err(Self::Error::Storage(why.clone()))))
.collect();
}
};
request.comment = self.comment.map(ToOwned::to_owned);
request.comment = self.options.comment.clone();
let request_ref = &request;
let tuples =
@ -105,18 +134,18 @@ where
)
}
Ok(true) => (),
Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))),
Err(why) => return (id, Err(KeyPoolError::Storage(why))),
}
}
Err(parsing_error) => {
return (id, Err(KeyPoolError::Response(parsing_error)))
}
Ok(res) => return (id, Ok(res)),
Ok(res) => return (id, Ok(res.into())),
};
key = match self.storage.acquire_key(self.selector.clone()).await {
Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
Err(why) => return (id, Err(Self::Error::Storage(why))),
};
}
}))
@ -126,6 +155,92 @@ where
}
}
#[allow(clippy::type_complexity)]
pub struct BeforeHook<A>
where
A: ApiSelection,
{
body: Box<dyn Fn(&mut ApiRequest<A>) + Send + Sync + 'static>,
}
#[allow(clippy::type_complexity)]
pub struct AfterHook<A, D>
where
A: ApiSelection,
D: crate::KeyDomain,
{
body: Box<dyn Fn(&A::Response) -> Result<(), crate::KeyAction<D>> + Send + Sync + 'static>,
}
pub struct PoolBuilder<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: C,
storage: S,
options: crate::PoolOptions,
}
impl<C, S> PoolBuilder<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
pub fn new(client: C, storage: S) -> Self {
Self {
client,
storage,
options: Default::default(),
}
}
pub fn comment(mut self, c: impl ToString) -> Self {
self.options.comment = Some(c.to_string());
self
}
pub fn hook_before<A>(
mut self,
hook: impl Fn(&mut ApiRequest<A>) + Send + Sync + 'static,
) -> Self
where
A: ApiSelection + 'static,
{
self.options.hooks_before.insert(
std::any::TypeId::of::<A>(),
Box::new(BeforeHook {
body: Box::new(hook),
}),
);
self
}
pub fn hook_after<A>(
mut self,
hook: impl Fn(&A::Response) -> Result<(), KeyAction<S::Domain>> + Send + Sync + 'static,
) -> Self
where
A: ApiSelection + 'static,
{
self.options.hooks_after.insert(
std::any::TypeId::of::<A>(),
Box::new(AfterHook::<A, S::Domain> {
body: Box::new(hook),
}),
);
self
}
pub fn build(self) -> KeyPool<C, S> {
KeyPool {
client: self.client,
storage: self.storage,
options: Arc::new(self.options),
}
}
}
#[derive(Clone, Debug)]
pub struct KeyPool<C, S>
where
@ -134,7 +249,7 @@ where
{
client: C,
pub storage: S,
comment: Option<String>,
options: Arc<PoolOptions>,
}
impl<C, S> KeyPool<C, S>
@ -142,14 +257,6 @@ where
C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn new(client: C, storage: S, comment: Option<String>) -> Self {
Self {
client,
storage,
comment,
}
}
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
where
I: IntoSelector<S::Key, S::Domain>,
@ -159,7 +266,7 @@ where
KeyPoolExecutor::new(
&self.storage,
selector.into_selector(),
self.comment.as_deref(),
self.options.clone(),
),
)
}
@ -178,7 +285,7 @@ pub trait WithStorage {
{
ApiProvider::new(
self,
KeyPoolExecutor::new(storage, selector.into_selector(), None),
KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()),
)
}
}
@ -188,27 +295,28 @@ impl WithStorage for reqwest::Client {}
#[cfg(all(test, feature = "postgres", feature = "reqwest"))]
mod test {
use tokio::test;
use sqlx::PgPool;
use super::*;
use crate::postgres::test::{setup, Domain};
use crate::{
postgres::test::{setup, Domain},
KeySelector,
};
#[test]
async fn test_pool_request() {
let (storage, _) = setup().await;
let pool = KeyPool::new(
reqwest::Client::default(),
storage,
Some("api.rs".to_owned()),
);
#[sqlx::test]
async fn test_pool_request(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.comment("api.rs")
.build();
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.profile().unwrap();
}
#[test]
async fn test_with_storage_request() {
let (storage, _) = setup().await;
#[sqlx::test]
async fn test_with_storage_request(pool: PgPool) {
let (storage, _) = setup(pool).await;
let response = reqwest::Client::new()
.with_storage(&storage, Domain::All)
@ -217,4 +325,36 @@ mod test {
.unwrap();
_ = response.profile().unwrap();
}
#[sqlx::test]
async fn before_hook(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_before::<torn_api::user::UserSelection>(|req| {
req.selections.push("crimes");
})
.build();
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.crimes().unwrap();
}
#[sqlx::test]
async fn after_hook(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_after::<torn_api::user::UserSelection>(|_res| Err(KeyAction::Delete))
.build();
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_some());
let response = pool.torn_api(Domain::All).user(|b| b).await;
assert!(matches!(response, Err(KeyPoolError::Storage(_))));
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_none());
}
}