refactored and expanded postgres keypool
This commit is contained in:
parent
78a5ea37b9
commit
0993f56489
|
@ -1,28 +1,29 @@
|
||||||
[package]
|
[package]
|
||||||
name = "torn-key-pool"
|
name = "torn-key-pool"
|
||||||
version = "0.4.2"
|
version = "0.5.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
repository = "https://github.com/TotallyNot/torn-api.rs.git"
|
repository = "https://github.com/TotallyNot/torn-api.rs.git"
|
||||||
homepage = "https://github.com/TotallyNot/torn-api.rs.git"
|
homepage = "https://github.com/TotallyNot/torn-api.rs.git"
|
||||||
description = "A generalizes API key pool for torn-api"
|
description = "A generalised API key pool for torn-api"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = [ "postgres", "tokio-runtime" ]
|
default = [ "postgres", "tokio-runtime" ]
|
||||||
postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc" ]
|
postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc", "dep:serde" ]
|
||||||
reqwest = [ "dep:reqwest", "torn-api/reqwest" ]
|
reqwest = [ "dep:reqwest", "torn-api/reqwest" ]
|
||||||
awc = [ "dep:awc", "torn-api/awc" ]
|
awc = [ "dep:awc", "torn-api/awc" ]
|
||||||
tokio-runtime = [ "dep:tokio", "dep:rand" ]
|
tokio-runtime = [ "dep:tokio", "dep:rand" ]
|
||||||
actix-runtime = [ "dep:actix-rt", "dep:rand" ]
|
actix-runtime = [ "dep:actix-rt", "dep:rand" ]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
torn-api = { path = "../torn-api", default-features = false, version = "0.5" }
|
torn-api = { path = "../torn-api", default-features = false, version = "0.5.5" }
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
|
|
||||||
sqlx = { version = "0.6", features = [ "postgres", "chrono" ], optional = true }
|
sqlx = { version = "0.6", features = [ "postgres", "chrono", "json" ], optional = true }
|
||||||
|
serde = { version = "1.0", optional = true }
|
||||||
chrono = { version = "0.4", optional = true }
|
chrono = { version = "0.4", optional = true }
|
||||||
indoc = { version = "1", optional = true }
|
indoc = { version = "1", optional = true }
|
||||||
tokio = { version = "1", optional = true, default-features = false, features = ["time"] }
|
tokio = { version = "1", optional = true, default-features = false, features = ["time"] }
|
||||||
|
@ -37,7 +38,7 @@ awc = { version = "3", default-features = false, optional = true }
|
||||||
torn-api = { path = "../torn-api", features = [ "reqwest" ] }
|
torn-api = { path = "../torn-api", features = [ "reqwest" ] }
|
||||||
sqlx = { version = "0.6", features = [ "runtime-tokio-rustls" ] }
|
sqlx = { version = "0.6", features = [ "runtime-tokio-rustls" ] }
|
||||||
dotenv = "0.15.0"
|
dotenv = "0.15.0"
|
||||||
tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] }
|
tokio = { version = "1.24.2", features = ["test-util", "rt", "macros"] }
|
||||||
tokio-test = "0.4.2"
|
tokio-test = "0.4.2"
|
||||||
reqwest = { version = "0.11", default-features = true }
|
reqwest = { version = "0.11", default-features = true }
|
||||||
awc = { version = "3", features = [ "rustls" ] }
|
awc = { version = "3", features = [ "rustls" ] }
|
||||||
|
|
|
@ -29,31 +29,57 @@ where
|
||||||
Response(ResponseError),
|
Response(ResponseError),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub enum KeyDomain {
|
|
||||||
Public,
|
|
||||||
User(i32),
|
|
||||||
Faction(i32),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait ApiKey: Sync + Send {
|
pub trait ApiKey: Sync + Send {
|
||||||
fn value(&self) -> &str;
|
fn value(&self) -> &str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {}
|
||||||
|
|
||||||
|
impl<T> KeyDomain for T where T: Clone + std::fmt::Debug + Send + Sync {}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait KeyPoolStorage {
|
pub trait KeyPoolStorage {
|
||||||
type Key: ApiKey;
|
type Key: ApiKey;
|
||||||
|
type Domain: KeyDomain;
|
||||||
type Error: std::error::Error + Sync + Send;
|
type Error: std::error::Error + Sync + Send;
|
||||||
|
|
||||||
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error>;
|
async fn acquire_key(&self, domain: Self::Domain) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
async fn acquire_many_keys(
|
async fn acquire_many_keys(
|
||||||
&self,
|
&self,
|
||||||
domain: KeyDomain,
|
domain: Self::Domain,
|
||||||
number: i64,
|
number: i64,
|
||||||
) -> Result<Vec<Self::Key>, Self::Error>;
|
) -> Result<Vec<Self::Key>, Self::Error>;
|
||||||
|
|
||||||
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
|
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
|
||||||
|
|
||||||
|
async fn store_key(
|
||||||
|
&self,
|
||||||
|
key: String,
|
||||||
|
domains: Vec<Self::Domain>,
|
||||||
|
) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
|
async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
|
async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
|
async fn add_domain_to_key(
|
||||||
|
&self,
|
||||||
|
key: String,
|
||||||
|
domain: Self::Domain,
|
||||||
|
) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
|
async fn remove_domain_from_key(
|
||||||
|
&self,
|
||||||
|
key: String,
|
||||||
|
domain: Self::Domain,
|
||||||
|
) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
|
async fn set_domains_for_key(
|
||||||
|
&self,
|
||||||
|
key: String,
|
||||||
|
domains: Vec<Self::Domain>,
|
||||||
|
) -> Result<Self::Key, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -62,7 +88,8 @@ where
|
||||||
S: KeyPoolStorage,
|
S: KeyPoolStorage,
|
||||||
{
|
{
|
||||||
storage: &'a S,
|
storage: &'a S,
|
||||||
domain: KeyDomain,
|
comment: Option<&'a str>,
|
||||||
|
domain: S::Domain,
|
||||||
_marker: std::marker::PhantomData<C>,
|
_marker: std::marker::PhantomData<C>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,52 +97,15 @@ impl<'a, C, S> KeyPoolExecutor<'a, C, S>
|
||||||
where
|
where
|
||||||
S: KeyPoolStorage,
|
S: KeyPoolStorage,
|
||||||
{
|
{
|
||||||
pub fn new(storage: &'a S, domain: KeyDomain) -> Self {
|
pub fn new(storage: &'a S, domain: S::Domain, comment: Option<&'a str>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
storage,
|
storage,
|
||||||
domain,
|
domain,
|
||||||
|
comment,
|
||||||
_marker: std::marker::PhantomData,
|
_marker: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(test, feature = "postgres"))]
|
#[cfg(all(test, feature = "postgres"))]
|
||||||
mod test {
|
mod test {}
|
||||||
use std::sync::Once;
|
|
||||||
|
|
||||||
use tokio::test;
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
static INIT: Once = Once::new();
|
|
||||||
|
|
||||||
pub(crate) async fn setup() -> postgres::PgKeyPoolStorage {
|
|
||||||
INIT.call_once(|| {
|
|
||||||
dotenv::dotenv().ok();
|
|
||||||
});
|
|
||||||
|
|
||||||
let pool = sqlx::PgPool::connect(&std::env::var("DATABASE_URL").unwrap())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
sqlx::query("update api_keys set uses=0")
|
|
||||||
.execute(&pool)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
postgres::PgKeyPoolStorage::new(pool, 50)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
async fn key_pool_bulk() {
|
|
||||||
let storage = setup().await;
|
|
||||||
|
|
||||||
if let Err(e) = storage.initialise().await {
|
|
||||||
panic!("Initialising key storage failed: {:?}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
let pool = send::KeyPool::new(reqwest::Client::default(), storage);
|
|
||||||
|
|
||||||
pool.torn_api(KeyDomain::Public).users([1], |b| b).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ use torn_api::{
|
||||||
ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError,
|
ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{ApiKey, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
|
use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
|
||||||
|
|
||||||
#[async_trait(?Send)]
|
#[async_trait(?Send)]
|
||||||
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
||||||
|
@ -20,16 +20,17 @@ where
|
||||||
async fn execute<A>(
|
async fn execute<A>(
|
||||||
&self,
|
&self,
|
||||||
client: &C,
|
client: &C,
|
||||||
request: ApiRequest<A>,
|
mut request: ApiRequest<A>,
|
||||||
id: Option<i64>,
|
id: Option<i64>,
|
||||||
) -> Result<A, Self::Error>
|
) -> Result<A, Self::Error>
|
||||||
where
|
where
|
||||||
A: ApiCategoryResponse,
|
A: ApiCategoryResponse,
|
||||||
{
|
{
|
||||||
|
request.comment = self.comment.map(ToOwned::to_owned);
|
||||||
loop {
|
loop {
|
||||||
let key = self
|
let key = self
|
||||||
.storage
|
.storage
|
||||||
.acquire_key(self.domain)
|
.acquire_key(self.domain.clone())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
|
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
|
||||||
let url = request.url(key.value(), id);
|
let url = request.url(key.value(), id);
|
||||||
|
@ -56,7 +57,7 @@ where
|
||||||
async fn execute_many<A>(
|
async fn execute_many<A>(
|
||||||
&self,
|
&self,
|
||||||
client: &C,
|
client: &C,
|
||||||
request: ApiRequest<A>,
|
mut request: ApiRequest<A>,
|
||||||
ids: Vec<i64>,
|
ids: Vec<i64>,
|
||||||
) -> HashMap<i64, Result<A, Self::Error>>
|
) -> HashMap<i64, Result<A, Self::Error>>
|
||||||
where
|
where
|
||||||
|
@ -64,7 +65,7 @@ where
|
||||||
{
|
{
|
||||||
let keys = match self
|
let keys = match self
|
||||||
.storage
|
.storage
|
||||||
.acquire_many_keys(self.domain, ids.len() as i64)
|
.acquire_many_keys(self.domain.clone(), ids.len() as i64)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(keys) => keys,
|
Ok(keys) => keys,
|
||||||
|
@ -77,6 +78,7 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
request.comment = self.comment.map(ToOwned::to_owned);
|
||||||
let request_ref = &request;
|
let request_ref = &request;
|
||||||
|
|
||||||
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
|
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
|
||||||
|
@ -107,7 +109,7 @@ where
|
||||||
Ok(res) => return (id, Ok(A::from_response(res))),
|
Ok(res) => return (id, Ok(A::from_response(res))),
|
||||||
};
|
};
|
||||||
|
|
||||||
key = match self.storage.acquire_key(self.domain).await {
|
key = match self.storage.acquire_key(self.domain.clone()).await {
|
||||||
Ok(k) => k,
|
Ok(k) => k,
|
||||||
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
|
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
|
||||||
};
|
};
|
||||||
|
@ -127,6 +129,7 @@ where
|
||||||
{
|
{
|
||||||
client: C,
|
client: C,
|
||||||
storage: S,
|
storage: S,
|
||||||
|
comment: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C, S> KeyPool<C, S>
|
impl<C, S> KeyPool<C, S>
|
||||||
|
@ -134,12 +137,19 @@ where
|
||||||
C: ApiClient,
|
C: ApiClient,
|
||||||
S: KeyPoolStorage + 'static,
|
S: KeyPoolStorage + 'static,
|
||||||
{
|
{
|
||||||
pub fn new(client: C, storage: S) -> Self {
|
pub fn new(client: C, storage: S, comment: Option<String>) -> Self {
|
||||||
Self { client, storage }
|
Self {
|
||||||
|
client,
|
||||||
|
storage,
|
||||||
|
comment,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
|
pub fn torn_api(&self, domain: S::Domain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
|
||||||
ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain))
|
ApiProvider::new(
|
||||||
|
&self.client,
|
||||||
|
KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,15 +157,44 @@ pub trait WithStorage {
|
||||||
fn with_storage<'a, S>(
|
fn with_storage<'a, S>(
|
||||||
&'a self,
|
&'a self,
|
||||||
storage: &'a S,
|
storage: &'a S,
|
||||||
domain: KeyDomain,
|
domain: S::Domain,
|
||||||
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
|
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
|
||||||
where
|
where
|
||||||
Self: ApiClient + Sized,
|
Self: ApiClient + Sized,
|
||||||
S: KeyPoolStorage + 'static,
|
S: KeyPoolStorage + 'static,
|
||||||
{
|
{
|
||||||
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain))
|
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "awc")]
|
#[cfg(feature = "awc")]
|
||||||
impl WithStorage for awc::Client {}
|
impl WithStorage for awc::Client {}
|
||||||
|
|
||||||
|
#[cfg(all(test, feature = "postgres", feature = "awc"))]
|
||||||
|
mod test {
|
||||||
|
use tokio::test;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::postgres::test::{setup, Domain};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_pool_request() {
|
||||||
|
let storage = setup().await;
|
||||||
|
let pool = KeyPool::new(awc::Client::default(), storage);
|
||||||
|
|
||||||
|
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
|
||||||
|
_ = response.profile().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_with_storage_request() {
|
||||||
|
let storage = setup().await;
|
||||||
|
|
||||||
|
let response = awc::Client::new()
|
||||||
|
.with_storage(&storage, Domain::All)
|
||||||
|
.user(|b| b)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
_ = response.profile().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -5,51 +5,98 @@ use thiserror::Error;
|
||||||
|
|
||||||
use crate::{ApiKey, KeyDomain, KeyPoolStorage};
|
use crate::{ApiKey, KeyDomain, KeyPoolStorage};
|
||||||
|
|
||||||
|
pub trait PgKeyDomain:
|
||||||
|
KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> PgKeyDomain for T where
|
||||||
|
T: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum PgStorageError {
|
pub enum PgStorageError<D>
|
||||||
|
where
|
||||||
|
D: std::fmt::Debug,
|
||||||
|
{
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Pg(#[from] sqlx::Error),
|
Pg(#[from] sqlx::Error),
|
||||||
|
|
||||||
#[error("No key avalaible for domain {0:?}")]
|
#[error("No key avalaible for domain {0:?}")]
|
||||||
Unavailable(KeyDomain),
|
Unavailable(D),
|
||||||
|
|
||||||
|
#[error("Duplicate key '{0}'")]
|
||||||
|
DuplicateKey(String),
|
||||||
|
|
||||||
|
#[error("Duplicate domain '{0:?}'")]
|
||||||
|
DuplicateDomain(D),
|
||||||
|
|
||||||
|
#[error("Key not found: '{0}'")]
|
||||||
|
KeyNotFound(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, FromRow)]
|
#[derive(Debug, Clone, FromRow)]
|
||||||
pub struct PgKey {
|
pub struct PgKey<D>
|
||||||
|
where
|
||||||
|
D: PgKeyDomain,
|
||||||
|
{
|
||||||
pub id: i32,
|
pub id: i32,
|
||||||
pub key: String,
|
pub key: String,
|
||||||
pub uses: i16,
|
pub uses: i16,
|
||||||
|
pub domains: sqlx::types::Json<Vec<D>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, FromRow)]
|
#[derive(Debug, Clone, FromRow)]
|
||||||
pub struct PgKeyPoolStorage {
|
pub struct PgKeyPoolStorage<D>
|
||||||
|
where
|
||||||
|
D: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
|
||||||
|
{
|
||||||
pool: PgPool,
|
pool: PgPool,
|
||||||
limit: i16,
|
limit: i16,
|
||||||
|
_phantom: std::marker::PhantomData<D>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ApiKey for PgKey {
|
impl<D> ApiKey for PgKey<D>
|
||||||
|
where
|
||||||
|
D: PgKeyDomain,
|
||||||
|
{
|
||||||
fn value(&self) -> &str {
|
fn value(&self) -> &str {
|
||||||
&self.key
|
&self.key
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PgKeyPoolStorage {
|
impl<D> PgKeyPoolStorage<D>
|
||||||
|
where
|
||||||
|
D: PgKeyDomain,
|
||||||
|
{
|
||||||
pub fn new(pool: PgPool, limit: i16) -> Self {
|
pub fn new(pool: PgPool, limit: i16) -> Self {
|
||||||
Self { pool, limit }
|
Self {
|
||||||
|
pool,
|
||||||
|
limit,
|
||||||
|
_phantom: Default::default(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn initialise(&self) -> Result<(), PgStorageError> {
|
pub async fn initialise(&self) -> Result<(), PgStorageError<D>> {
|
||||||
sqlx::query(indoc! {r#"
|
sqlx::query(indoc! {r#"
|
||||||
CREATE TABLE IF NOT EXISTS api_keys (
|
CREATE TABLE IF NOT EXISTS api_keys (
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
user_id int4 not null,
|
|
||||||
faction_id int4,
|
|
||||||
key char(16) not null,
|
key char(16) not null,
|
||||||
uses int2 not null default 0,
|
uses int2 not null default 0,
|
||||||
"user" bool not null,
|
domains jsonb not null default '{}'::jsonb,
|
||||||
faction bool not null,
|
last_used timestamptz not null default now(),
|
||||||
last_used timestamptz not null default now()
|
flag int2,
|
||||||
)"#})
|
cooldown timestamptz,
|
||||||
|
constraint "uq:api_keys.key" UNIQUE(key)
|
||||||
|
)"#
|
||||||
|
})
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(indoc! {r#"
|
||||||
|
CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops)
|
||||||
|
"#})
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -72,63 +119,68 @@ async fn random_sleep() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl KeyPoolStorage for PgKeyPoolStorage {
|
impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
|
||||||
type Key = PgKey;
|
where
|
||||||
|
D: PgKeyDomain,
|
||||||
|
{
|
||||||
|
type Key = PgKey<D>;
|
||||||
|
type Domain = D;
|
||||||
|
|
||||||
type Error = PgStorageError;
|
type Error = PgStorageError<D>;
|
||||||
|
|
||||||
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error> {
|
|
||||||
let predicate = match domain {
|
|
||||||
KeyDomain::Public => "".to_owned(),
|
|
||||||
KeyDomain::User(id) => format!(" and user_id={} and user", id),
|
|
||||||
KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id),
|
|
||||||
};
|
|
||||||
|
|
||||||
|
async fn acquire_key(&self, domain: D) -> Result<Self::Key, Self::Error> {
|
||||||
loop {
|
loop {
|
||||||
let attempt = async {
|
let attempt = async {
|
||||||
let mut tx = self.pool.begin().await?;
|
let mut tx = self.pool.begin().await?;
|
||||||
|
|
||||||
sqlx::query("set transaction isolation level serializable")
|
sqlx::query("set transaction isolation level repeatable read")
|
||||||
.execute(&mut tx)
|
.execute(&mut tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let key: Option<PgKey> = sqlx::query_as(&indoc::formatdoc!(
|
let key = sqlx::query_as(&indoc::formatdoc!(
|
||||||
r#"
|
r#"
|
||||||
with key as (
|
with key as (
|
||||||
select
|
select
|
||||||
id,
|
id,
|
||||||
0::int2 as uses
|
0::int2 as uses
|
||||||
from api_keys where last_used < date_trunc('minute', now()){predicate}
|
from api_keys where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1
|
||||||
union (
|
union (
|
||||||
select id, uses from api_keys where last_used >= date_trunc('minute', now()){predicate} order by uses asc
|
select id, uses from api_keys
|
||||||
|
where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1
|
||||||
|
order by uses asc
|
||||||
)
|
)
|
||||||
limit 1
|
limit 1
|
||||||
)
|
)
|
||||||
update api_keys set
|
update api_keys set
|
||||||
uses = key.uses + 1,
|
uses = key.uses + 1,
|
||||||
|
cooldown = null,
|
||||||
|
flag = null,
|
||||||
last_used = now()
|
last_used = now()
|
||||||
from key where
|
from key where
|
||||||
api_keys.id=key.id and key.uses < $1
|
api_keys.id=key.id and key.uses < $2
|
||||||
returning
|
returning
|
||||||
api_keys.id,
|
api_keys.id,
|
||||||
api_keys.key,
|
api_keys.key,
|
||||||
api_keys.uses
|
api_keys.uses,
|
||||||
|
api_keys.domains
|
||||||
"#,
|
"#,
|
||||||
))
|
))
|
||||||
|
.bind(sqlx::types::Json(vec![&domain]))
|
||||||
.bind(self.limit)
|
.bind(self.limit)
|
||||||
.fetch_optional(&mut tx)
|
.fetch_optional(&mut tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
tx.commit().await?;
|
tx.commit().await?;
|
||||||
|
|
||||||
Result::<Result<Self::Key, Self::Error>, sqlx::Error>::Ok(
|
Result::<Option<Self::Key>, sqlx::Error>::Ok(
|
||||||
key.ok_or(PgStorageError::Unavailable(domain)),
|
key
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
match attempt {
|
match attempt {
|
||||||
Ok(result) => return result,
|
Ok(Some(result)) => return Ok(result),
|
||||||
|
Ok(None) => return Err(PgStorageError::Unavailable(domain)),
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
if let Some(db_error) = error.as_database_error() {
|
if let Some(db_error) = error.as_database_error() {
|
||||||
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
|
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
|
||||||
|
@ -147,45 +199,42 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
|
|
||||||
async fn acquire_many_keys(
|
async fn acquire_many_keys(
|
||||||
&self,
|
&self,
|
||||||
domain: KeyDomain,
|
domain: D,
|
||||||
number: i64,
|
number: i64,
|
||||||
) -> Result<Vec<Self::Key>, Self::Error> {
|
) -> Result<Vec<Self::Key>, Self::Error> {
|
||||||
let predicate = match domain {
|
|
||||||
KeyDomain::Public => "".to_owned(),
|
|
||||||
KeyDomain::User(id) => format!(" and user_id={} and user", id),
|
|
||||||
KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id),
|
|
||||||
};
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let attempt = async {
|
let attempt = async {
|
||||||
let mut tx = self.pool.begin().await?;
|
let mut tx = self.pool.begin().await?;
|
||||||
|
|
||||||
sqlx::query("set transaction isolation level serializable")
|
sqlx::query("set transaction isolation level repeatable read")
|
||||||
.execute(&mut tx)
|
.execute(&mut tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut keys: Vec<PgKey> = sqlx::query_as(&indoc::formatdoc!(
|
let mut keys: Vec<Self::Key> = sqlx::query_as(&indoc::formatdoc!(
|
||||||
r#"select
|
r#"select
|
||||||
id,
|
id,
|
||||||
key,
|
key,
|
||||||
0::int2 as uses
|
0::int2 as uses,
|
||||||
from api_keys where last_used < date_trunc('minute', now()){predicate}
|
domains
|
||||||
|
from api_keys where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1
|
||||||
union
|
union
|
||||||
select
|
select
|
||||||
id,
|
id,
|
||||||
key,
|
key,
|
||||||
uses
|
uses,
|
||||||
from api_keys where last_used >= date_trunc('minute', now()){predicate}
|
domains
|
||||||
order by uses limit $1
|
from api_keys where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1
|
||||||
|
order by uses limit $2
|
||||||
"#,
|
"#,
|
||||||
))
|
))
|
||||||
|
.bind(sqlx::types::Json(vec![&domain]))
|
||||||
.bind(number)
|
.bind(number)
|
||||||
.fetch_all(&mut tx)
|
.fetch_all(&mut tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if keys.is_empty() {
|
if keys.is_empty() {
|
||||||
tx.commit().await?;
|
tx.commit().await?;
|
||||||
return Ok(Err(PgStorageError::Unavailable(domain)));
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses));
|
keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses));
|
||||||
|
@ -217,6 +266,8 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
sqlx::query(indoc! {r#"
|
sqlx::query(indoc! {r#"
|
||||||
update api_keys set
|
update api_keys set
|
||||||
uses = tmp.uses,
|
uses = tmp.uses,
|
||||||
|
cooldown = null,
|
||||||
|
flag = null,
|
||||||
last_used = now()
|
last_used = now()
|
||||||
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
|
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
|
||||||
where api_keys.id = tmp.id
|
where api_keys.id = tmp.id
|
||||||
|
@ -228,12 +279,13 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
|
|
||||||
tx.commit().await?;
|
tx.commit().await?;
|
||||||
|
|
||||||
Result::<Result<Vec<Self::Key>, Self::Error>, sqlx::Error>::Ok(Ok(result))
|
Result::<Option<Vec<Self::Key>>, sqlx::Error>::Ok(Some(result))
|
||||||
}
|
}
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
match attempt {
|
match attempt {
|
||||||
Ok(result) => return result,
|
Ok(Some(result)) => return Ok(result),
|
||||||
|
Ok(None) => return Err(Self::Error::Unavailable(domain)),
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
if let Some(db_error) = error.as_database_error() {
|
if let Some(db_error) = error.as_database_error() {
|
||||||
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
|
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
|
||||||
|
@ -254,7 +306,41 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
// TODO: put keys in cooldown when appropriate
|
// TODO: put keys in cooldown when appropriate
|
||||||
match code {
|
match code {
|
||||||
2 | 10 | 13 => {
|
2 | 10 | 13 => {
|
||||||
sqlx::query("delete from api_keys where id=$1")
|
// invalid key, owner fedded or owner inactive
|
||||||
|
sqlx::query(
|
||||||
|
"update api_keys set cooldown='infinity'::timestamptz, flag=$1 where id=$2",
|
||||||
|
)
|
||||||
|
.bind(code as i16)
|
||||||
|
.bind(key.id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
5 => {
|
||||||
|
// too many requests
|
||||||
|
sqlx::query("update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', flag=5 where id=$1")
|
||||||
|
.bind(key.id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
8 => {
|
||||||
|
// IP block
|
||||||
|
sqlx::query("update api_keys set cooldown=now() + interval '5 min', flag=8")
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
9 => {
|
||||||
|
// API disabled
|
||||||
|
sqlx::query("update api_keys set cooldown=now() + interval '1 min', flag=9")
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
14 => {
|
||||||
|
// daily read limit reached
|
||||||
|
sqlx::query("update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', flag=14 where id=$1")
|
||||||
.bind(key.id)
|
.bind(key.id)
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -263,19 +349,115 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
_ => Ok(false),
|
_ => Ok(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn store_key(&self, key: String, domains: Vec<D>) -> Result<Self::Key, Self::Error> {
|
||||||
|
sqlx::query_as("insert into api_keys(key, domains) values ($1, $2) returning *")
|
||||||
|
.bind(&key)
|
||||||
|
.bind(sqlx::types::Json(domains))
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|why| {
|
||||||
|
if let Some(error) = why.as_database_error() {
|
||||||
|
let pg_error: &sqlx::postgres::PgDatabaseError = error.downcast_ref();
|
||||||
|
if pg_error.code() == "23505" {
|
||||||
|
return PgStorageError::DuplicateKey(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PgStorageError::Pg(why)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error> {
|
||||||
|
sqlx::query_as("select * from api_keys where key=$1")
|
||||||
|
.bind(&key)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error> {
|
||||||
|
sqlx::query_as("delete from api_keys where key=$1 returning *")
|
||||||
|
.bind(&key)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn add_domain_to_key(&self, key: String, domain: D) -> Result<Self::Key, Self::Error> {
|
||||||
|
let mut tx = self.pool.begin().await?;
|
||||||
|
match sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
||||||
|
"update api_keys set domains = domains || jsonb_build_array($1) where key=$2 returning *",
|
||||||
|
)
|
||||||
|
.bind(sqlx::types::Json(domain.clone()))
|
||||||
|
.bind(&key)
|
||||||
|
.fetch_optional(&mut tx)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
|
None => Err(PgStorageError::KeyNotFound(key)),
|
||||||
|
Some(key) => {
|
||||||
|
if key.domains.0.iter().filter(|d| **d == domain).count() > 1 {
|
||||||
|
tx.rollback().await?;
|
||||||
|
return Err(PgStorageError::DuplicateDomain(domain));
|
||||||
|
}
|
||||||
|
tx.commit().await?;
|
||||||
|
Ok(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn remove_domain_from_key(
|
||||||
|
&self,
|
||||||
|
key: String,
|
||||||
|
domain: D,
|
||||||
|
) -> Result<Self::Key, Self::Error> {
|
||||||
|
// FIX: potential race condition
|
||||||
|
let api_key = self.read_key(key.clone()).await?;
|
||||||
|
let domains = api_key
|
||||||
|
.domains
|
||||||
|
.0
|
||||||
|
.into_iter()
|
||||||
|
.filter(|d| *d != domain)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
self.set_domains_for_key(key, domains).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_domains_for_key(
|
||||||
|
&self,
|
||||||
|
key: String,
|
||||||
|
domains: Vec<D>,
|
||||||
|
) -> Result<Self::Key, Self::Error> {
|
||||||
|
sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
||||||
|
"update api_keys set domains = $1 where key=$2 returning *",
|
||||||
|
)
|
||||||
|
.bind(sqlx::types::Json(domains))
|
||||||
|
.bind(&key)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(key))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
pub(crate) mod test {
|
||||||
use std::sync::{Arc, Once};
|
use std::sync::{Arc, Once};
|
||||||
|
|
||||||
|
use sqlx::Row;
|
||||||
use tokio::test;
|
use tokio::test;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
static INIT: Once = Once::new();
|
static INIT: Once = Once::new();
|
||||||
|
|
||||||
pub(crate) async fn setup() -> PgKeyPoolStorage {
|
#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub(crate) enum Domain {
|
||||||
|
All,
|
||||||
|
User { id: i32 },
|
||||||
|
Faction { id: i32 },
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn setup() -> PgKeyPoolStorage<Domain> {
|
||||||
INIT.call_once(|| {
|
INIT.call_once(|| {
|
||||||
dotenv::dotenv().ok();
|
dotenv::dotenv().ok();
|
||||||
});
|
});
|
||||||
|
@ -284,12 +466,20 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
sqlx::query("update api_keys set uses=id")
|
sqlx::query("DROP TABLE IF EXISTS api_keys")
|
||||||
.execute(&pool)
|
.execute(&pool)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
PgKeyPoolStorage::new(pool, 50)
|
let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
|
||||||
|
storage.initialise().await.unwrap();
|
||||||
|
|
||||||
|
storage
|
||||||
|
.store_key(std::env::var("APIKEY").unwrap(), vec![Domain::All])
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
storage
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -301,24 +491,179 @@ mod test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_store_duplicate() {
|
||||||
|
let storage = setup().await;
|
||||||
|
match storage
|
||||||
|
.store_key(std::env::var("APIKEY").unwrap(), vec![])
|
||||||
|
.await
|
||||||
|
.unwrap_err()
|
||||||
|
{
|
||||||
|
PgStorageError::DuplicateKey(key) => {
|
||||||
|
assert_eq!(key, std::env::var("APIKEY").unwrap())
|
||||||
|
}
|
||||||
|
why => panic!("Expected duplicate key error but found '{why}'"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_add_domain() {
|
||||||
|
let storage = setup().await;
|
||||||
|
let key = storage
|
||||||
|
.add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::User { id: 12345 })
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_add_duplicate_domain() {
|
||||||
|
let storage = setup().await;
|
||||||
|
match storage
|
||||||
|
.add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::All)
|
||||||
|
.await
|
||||||
|
.unwrap_err()
|
||||||
|
{
|
||||||
|
PgStorageError::DuplicateDomain(d) => assert_eq!(d, Domain::All),
|
||||||
|
why => panic!("Expected duplicate domain error but found '{why}'"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_remove_domain() {
|
||||||
|
let storage = setup().await;
|
||||||
|
let key = storage
|
||||||
|
.remove_domain_from_key(std::env::var("APIKEY").unwrap(), Domain::All)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(key.domains.0.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_store_key() {
|
||||||
|
let storage = setup().await;
|
||||||
|
let key = storage
|
||||||
|
.store_key("ABCDABCDABCDABCD".to_owned(), vec![])
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(key.value(), "ABCDABCDABCDABCD");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn acquire_one() {
|
async fn acquire_one() {
|
||||||
let storage = setup().await;
|
let storage = setup().await;
|
||||||
|
|
||||||
if let Err(e) = storage.acquire_key(KeyDomain::Public).await {
|
if let Err(e) = storage.acquire_key(Domain::All).await {
|
||||||
panic!("Acquiring key failed: {:?}", e);
|
panic!("Acquiring key failed: {:?}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_flag_key_one() {
|
||||||
|
let storage = setup().await;
|
||||||
|
let key = storage
|
||||||
|
.read_key(std::env::var("APIKEY").unwrap())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||||
|
|
||||||
|
match storage.acquire_key(Domain::All).await.unwrap_err() {
|
||||||
|
PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All),
|
||||||
|
why => panic!("Expected domain unavailable error but found '{why}'"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_flag_key_many() {
|
||||||
|
let storage = setup().await;
|
||||||
|
let key = storage
|
||||||
|
.read_key(std::env::var("APIKEY").unwrap())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||||
|
|
||||||
|
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
|
||||||
|
PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All),
|
||||||
|
why => panic!("Expected domain unavailable error but found '{why}'"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn acquire_many() {
|
||||||
|
let storage = setup().await;
|
||||||
|
|
||||||
|
match storage.acquire_many_keys(Domain::All, 30).await {
|
||||||
|
Err(e) => panic!("Acquiring key failed: {:?}", e),
|
||||||
|
Ok(keys) => assert_eq!(keys.len(), 30),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_concurrent() {
|
async fn test_concurrent() {
|
||||||
let storage = Arc::new(setup().await);
|
let storage = Arc::new(setup().await);
|
||||||
|
|
||||||
let keys = storage
|
for _ in 0..10 {
|
||||||
.acquire_many_keys(KeyDomain::Public, 30)
|
let mut set = tokio::task::JoinSet::new();
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(keys.len(), 30);
|
for _ in 0..100 {
|
||||||
|
let storage = storage.clone();
|
||||||
|
set.spawn(async move {
|
||||||
|
storage.acquire_key(Domain::All).await.unwrap();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in 0..100 {
|
||||||
|
set.join_next().await.unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let uses: i16 = sqlx::query("select uses from api_keys")
|
||||||
|
.fetch_one(&storage.pool)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.get("uses");
|
||||||
|
|
||||||
|
assert_eq!(uses, 100);
|
||||||
|
|
||||||
|
sqlx::query("update api_keys set uses=0")
|
||||||
|
.execute(&storage.pool)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_concurrent_many() {
|
||||||
|
let storage = Arc::new(setup().await);
|
||||||
|
for _ in 0..10 {
|
||||||
|
let mut set = tokio::task::JoinSet::new();
|
||||||
|
|
||||||
|
for _ in 0..100 {
|
||||||
|
let storage = storage.clone();
|
||||||
|
set.spawn(async move {
|
||||||
|
storage.acquire_many_keys(Domain::All, 5).await.unwrap();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in 0..100 {
|
||||||
|
set.join_next().await.unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let uses: i16 = sqlx::query("select uses from api_keys")
|
||||||
|
.fetch_one(&storage.pool)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.get("uses");
|
||||||
|
|
||||||
|
assert_eq!(uses, 500);
|
||||||
|
|
||||||
|
sqlx::query("update api_keys set uses=0")
|
||||||
|
.execute(&storage.pool)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ use torn_api::{
|
||||||
ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError,
|
ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{ApiKey, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
|
use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
||||||
|
@ -20,16 +20,17 @@ where
|
||||||
async fn execute<A>(
|
async fn execute<A>(
|
||||||
&self,
|
&self,
|
||||||
client: &C,
|
client: &C,
|
||||||
request: ApiRequest<A>,
|
mut request: ApiRequest<A>,
|
||||||
id: Option<i64>,
|
id: Option<i64>,
|
||||||
) -> Result<A, Self::Error>
|
) -> Result<A, Self::Error>
|
||||||
where
|
where
|
||||||
A: ApiCategoryResponse,
|
A: ApiCategoryResponse,
|
||||||
{
|
{
|
||||||
|
request.comment = self.comment.map(ToOwned::to_owned);
|
||||||
loop {
|
loop {
|
||||||
let key = self
|
let key = self
|
||||||
.storage
|
.storage
|
||||||
.acquire_key(self.domain)
|
.acquire_key(self.domain.clone())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
|
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
|
||||||
let url = request.url(key.value(), id);
|
let url = request.url(key.value(), id);
|
||||||
|
@ -56,7 +57,7 @@ where
|
||||||
async fn execute_many<A>(
|
async fn execute_many<A>(
|
||||||
&self,
|
&self,
|
||||||
client: &C,
|
client: &C,
|
||||||
request: ApiRequest<A>,
|
mut request: ApiRequest<A>,
|
||||||
ids: Vec<i64>,
|
ids: Vec<i64>,
|
||||||
) -> HashMap<i64, Result<A, Self::Error>>
|
) -> HashMap<i64, Result<A, Self::Error>>
|
||||||
where
|
where
|
||||||
|
@ -64,7 +65,7 @@ where
|
||||||
{
|
{
|
||||||
let keys = match self
|
let keys = match self
|
||||||
.storage
|
.storage
|
||||||
.acquire_many_keys(self.domain, ids.len() as i64)
|
.acquire_many_keys(self.domain.clone(), ids.len() as i64)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(keys) => keys,
|
Ok(keys) => keys,
|
||||||
|
@ -77,6 +78,7 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
request.comment = self.comment.map(ToOwned::to_owned);
|
||||||
let request_ref = &request;
|
let request_ref = &request;
|
||||||
|
|
||||||
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
|
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
|
||||||
|
@ -107,7 +109,7 @@ where
|
||||||
Ok(res) => return (id, Ok(A::from_response(res))),
|
Ok(res) => return (id, Ok(A::from_response(res))),
|
||||||
};
|
};
|
||||||
|
|
||||||
key = match self.storage.acquire_key(self.domain).await {
|
key = match self.storage.acquire_key(self.domain.clone()).await {
|
||||||
Ok(k) => k,
|
Ok(k) => k,
|
||||||
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
|
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
|
||||||
};
|
};
|
||||||
|
@ -127,6 +129,7 @@ where
|
||||||
{
|
{
|
||||||
client: C,
|
client: C,
|
||||||
storage: S,
|
storage: S,
|
||||||
|
comment: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C, S> KeyPool<C, S>
|
impl<C, S> KeyPool<C, S>
|
||||||
|
@ -134,12 +137,19 @@ where
|
||||||
C: ApiClient,
|
C: ApiClient,
|
||||||
S: KeyPoolStorage + Send + Sync + 'static,
|
S: KeyPoolStorage + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
pub fn new(client: C, storage: S) -> Self {
|
pub fn new(client: C, storage: S, comment: Option<String>) -> Self {
|
||||||
Self { client, storage }
|
Self {
|
||||||
|
client,
|
||||||
|
storage,
|
||||||
|
comment,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
|
pub fn torn_api(&self, domain: S::Domain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
|
||||||
ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain))
|
ApiProvider::new(
|
||||||
|
&self.client,
|
||||||
|
KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,15 +157,48 @@ pub trait WithStorage {
|
||||||
fn with_storage<'a, S>(
|
fn with_storage<'a, S>(
|
||||||
&'a self,
|
&'a self,
|
||||||
storage: &'a S,
|
storage: &'a S,
|
||||||
domain: KeyDomain,
|
domain: S::Domain,
|
||||||
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
|
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
|
||||||
where
|
where
|
||||||
Self: ApiClient + Sized,
|
Self: ApiClient + Sized,
|
||||||
S: KeyPoolStorage + Send + Sync + 'static,
|
S: KeyPoolStorage + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain))
|
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "reqwest")]
|
#[cfg(feature = "reqwest")]
|
||||||
impl WithStorage for reqwest::Client {}
|
impl WithStorage for reqwest::Client {}
|
||||||
|
|
||||||
|
#[cfg(all(test, feature = "postgres", feature = "reqwest"))]
|
||||||
|
mod test {
|
||||||
|
use tokio::test;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::postgres::test::{setup, Domain};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_pool_request() {
|
||||||
|
let storage = setup().await;
|
||||||
|
let pool = KeyPool::new(
|
||||||
|
reqwest::Client::default(),
|
||||||
|
storage,
|
||||||
|
Some("api.rs".to_owned()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
|
||||||
|
_ = response.profile().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_with_storage_request() {
|
||||||
|
let storage = setup().await;
|
||||||
|
|
||||||
|
let response = reqwest::Client::new()
|
||||||
|
.with_storage(&storage, Domain::All)
|
||||||
|
.user(|b| b)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
_ = response.profile().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue