refactored and expanded postgres keypool

This commit is contained in:
TotallyNot 2023-01-22 19:29:11 +01:00
parent 0fd74e7006
commit 9ae436c694
5 changed files with 560 additions and 142 deletions

View file

@ -1,28 +1,29 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.4.2" version = "0.5.0"
edition = "2021" edition = "2021"
license = "MIT" license = "MIT"
repository = "https://github.com/TotallyNot/torn-api.rs.git" repository = "https://github.com/TotallyNot/torn-api.rs.git"
homepage = "https://github.com/TotallyNot/torn-api.rs.git" homepage = "https://github.com/TotallyNot/torn-api.rs.git"
description = "A generalizes API key pool for torn-api" description = "A generalised API key pool for torn-api"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features] [features]
default = [ "postgres", "tokio-runtime" ] default = [ "postgres", "tokio-runtime" ]
postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc" ] postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc", "dep:serde" ]
reqwest = [ "dep:reqwest", "torn-api/reqwest" ] reqwest = [ "dep:reqwest", "torn-api/reqwest" ]
awc = [ "dep:awc", "torn-api/awc" ] awc = [ "dep:awc", "torn-api/awc" ]
tokio-runtime = [ "dep:tokio", "dep:rand" ] tokio-runtime = [ "dep:tokio", "dep:rand" ]
actix-runtime = [ "dep:actix-rt", "dep:rand" ] actix-runtime = [ "dep:actix-rt", "dep:rand" ]
[dependencies] [dependencies]
torn-api = { path = "../torn-api", default-features = false, version = "0.5" } torn-api = { path = "../torn-api", default-features = false, version = "0.5.5" }
async-trait = "0.1" async-trait = "0.1"
thiserror = "1" thiserror = "1"
sqlx = { version = "0.6", features = [ "postgres", "chrono" ], optional = true } sqlx = { version = "0.6", features = [ "postgres", "chrono", "json" ], optional = true }
serde = { version = "1.0", optional = true }
chrono = { version = "0.4", optional = true } chrono = { version = "0.4", optional = true }
indoc = { version = "1", optional = true } indoc = { version = "1", optional = true }
tokio = { version = "1", optional = true, default-features = false, features = ["time"] } tokio = { version = "1", optional = true, default-features = false, features = ["time"] }
@ -37,7 +38,7 @@ awc = { version = "3", default-features = false, optional = true }
torn-api = { path = "../torn-api", features = [ "reqwest" ] } torn-api = { path = "../torn-api", features = [ "reqwest" ] }
sqlx = { version = "0.6", features = [ "runtime-tokio-rustls" ] } sqlx = { version = "0.6", features = [ "runtime-tokio-rustls" ] }
dotenv = "0.15.0" dotenv = "0.15.0"
tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] } tokio = { version = "1.24.2", features = ["test-util", "rt", "macros"] }
tokio-test = "0.4.2" tokio-test = "0.4.2"
reqwest = { version = "0.11", default-features = true } reqwest = { version = "0.11", default-features = true }
awc = { version = "3", features = [ "rustls" ] } awc = { version = "3", features = [ "rustls" ] }

View file

@ -29,31 +29,57 @@ where
Response(ResponseError), Response(ResponseError),
} }
#[derive(Debug, Clone, Copy)]
pub enum KeyDomain {
Public,
User(i32),
Faction(i32),
}
pub trait ApiKey: Sync + Send { pub trait ApiKey: Sync + Send {
fn value(&self) -> &str; fn value(&self) -> &str;
} }
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {}
impl<T> KeyDomain for T where T: Clone + std::fmt::Debug + Send + Sync {}
#[async_trait] #[async_trait]
pub trait KeyPoolStorage { pub trait KeyPoolStorage {
type Key: ApiKey; type Key: ApiKey;
type Domain: KeyDomain;
type Error: std::error::Error + Sync + Send; type Error: std::error::Error + Sync + Send;
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error>; async fn acquire_key(&self, domain: Self::Domain) -> Result<Self::Key, Self::Error>;
async fn acquire_many_keys( async fn acquire_many_keys(
&self, &self,
domain: KeyDomain, domain: Self::Domain,
number: i64, number: i64,
) -> Result<Vec<Self::Key>, Self::Error>; ) -> Result<Vec<Self::Key>, Self::Error>;
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>; async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
async fn store_key(
&self,
key: String,
domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>;
async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error>;
async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error>;
async fn add_domain_to_key(
&self,
key: String,
domain: Self::Domain,
) -> Result<Self::Key, Self::Error>;
async fn remove_domain_from_key(
&self,
key: String,
domain: Self::Domain,
) -> Result<Self::Key, Self::Error>;
async fn set_domains_for_key(
&self,
key: String,
domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -62,7 +88,8 @@ where
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
storage: &'a S, storage: &'a S,
domain: KeyDomain, comment: Option<&'a str>,
domain: S::Domain,
_marker: std::marker::PhantomData<C>, _marker: std::marker::PhantomData<C>,
} }
@ -70,52 +97,15 @@ impl<'a, C, S> KeyPoolExecutor<'a, C, S>
where where
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
pub fn new(storage: &'a S, domain: KeyDomain) -> Self { pub fn new(storage: &'a S, domain: S::Domain, comment: Option<&'a str>) -> Self {
Self { Self {
storage, storage,
domain, domain,
comment,
_marker: std::marker::PhantomData, _marker: std::marker::PhantomData,
} }
} }
} }
#[cfg(all(test, feature = "postgres"))] #[cfg(all(test, feature = "postgres"))]
mod test { mod test {}
use std::sync::Once;
use tokio::test;
use super::*;
static INIT: Once = Once::new();
pub(crate) async fn setup() -> postgres::PgKeyPoolStorage {
INIT.call_once(|| {
dotenv::dotenv().ok();
});
let pool = sqlx::PgPool::connect(&std::env::var("DATABASE_URL").unwrap())
.await
.unwrap();
sqlx::query("update api_keys set uses=0")
.execute(&pool)
.await
.unwrap();
postgres::PgKeyPoolStorage::new(pool, 50)
}
#[test]
async fn key_pool_bulk() {
let storage = setup().await;
if let Err(e) = storage.initialise().await {
panic!("Initialising key storage failed: {:?}", e);
}
let pool = send::KeyPool::new(reqwest::Client::default(), storage);
pool.torn_api(KeyDomain::Public).users([1], |b| b).await;
}
}

View file

@ -7,7 +7,7 @@ use torn_api::{
ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError, ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError,
}; };
use crate::{ApiKey, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
#[async_trait(?Send)] #[async_trait(?Send)]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S> impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
@ -20,16 +20,17 @@ where
async fn execute<A>( async fn execute<A>(
&self, &self,
client: &C, client: &C,
request: ApiRequest<A>, mut request: ApiRequest<A>,
id: Option<i64>, id: Option<i64>,
) -> Result<A, Self::Error> ) -> Result<A, Self::Error>
where where
A: ApiCategoryResponse, A: ApiCategoryResponse,
{ {
request.comment = self.comment.map(ToOwned::to_owned);
loop { loop {
let key = self let key = self
.storage .storage
.acquire_key(self.domain) .acquire_key(self.domain.clone())
.await .await
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
let url = request.url(key.value(), id); let url = request.url(key.value(), id);
@ -56,7 +57,7 @@ where
async fn execute_many<A>( async fn execute_many<A>(
&self, &self,
client: &C, client: &C,
request: ApiRequest<A>, mut request: ApiRequest<A>,
ids: Vec<i64>, ids: Vec<i64>,
) -> HashMap<i64, Result<A, Self::Error>> ) -> HashMap<i64, Result<A, Self::Error>>
where where
@ -64,7 +65,7 @@ where
{ {
let keys = match self let keys = match self
.storage .storage
.acquire_many_keys(self.domain, ids.len() as i64) .acquire_many_keys(self.domain.clone(), ids.len() as i64)
.await .await
{ {
Ok(keys) => keys, Ok(keys) => keys,
@ -77,6 +78,7 @@ where
} }
}; };
request.comment = self.comment.map(ToOwned::to_owned);
let request_ref = &request; let request_ref = &request;
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move { futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
@ -107,7 +109,7 @@ where
Ok(res) => return (id, Ok(A::from_response(res))), Ok(res) => return (id, Ok(A::from_response(res))),
}; };
key = match self.storage.acquire_key(self.domain).await { key = match self.storage.acquire_key(self.domain.clone()).await {
Ok(k) => k, Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
}; };
@ -127,6 +129,7 @@ where
{ {
client: C, client: C,
storage: S, storage: S,
comment: Option<String>,
} }
impl<C, S> KeyPool<C, S> impl<C, S> KeyPool<C, S>
@ -134,12 +137,19 @@ where
C: ApiClient, C: ApiClient,
S: KeyPoolStorage + 'static, S: KeyPoolStorage + 'static,
{ {
pub fn new(client: C, storage: S) -> Self { pub fn new(client: C, storage: S, comment: Option<String>) -> Self {
Self { client, storage } Self {
client,
storage,
comment,
}
} }
pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider<C, KeyPoolExecutor<C, S>> { pub fn torn_api(&self, domain: S::Domain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain)) ApiProvider::new(
&self.client,
KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()),
)
} }
} }
@ -147,15 +157,44 @@ pub trait WithStorage {
fn with_storage<'a, S>( fn with_storage<'a, S>(
&'a self, &'a self,
storage: &'a S, storage: &'a S,
domain: KeyDomain, domain: S::Domain,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>> ) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where where
Self: ApiClient + Sized, Self: ApiClient + Sized,
S: KeyPoolStorage + 'static, S: KeyPoolStorage + 'static,
{ {
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain)) ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None))
} }
} }
#[cfg(feature = "awc")] #[cfg(feature = "awc")]
impl WithStorage for awc::Client {} impl WithStorage for awc::Client {}
#[cfg(all(test, feature = "postgres", feature = "awc"))]
mod test {
use tokio::test;
use super::*;
use crate::postgres::test::{setup, Domain};
#[test]
async fn test_pool_request() {
let storage = setup().await;
let pool = KeyPool::new(awc::Client::default(), storage);
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.profile().unwrap();
}
#[test]
async fn test_with_storage_request() {
let storage = setup().await;
let response = awc::Client::new()
.with_storage(&storage, Domain::All)
.user(|b| b)
.await
.unwrap();
_ = response.profile().unwrap();
}
}

View file

@ -5,51 +5,98 @@ use thiserror::Error;
use crate::{ApiKey, KeyDomain, KeyPoolStorage}; use crate::{ApiKey, KeyDomain, KeyPoolStorage};
pub trait PgKeyDomain:
KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
{
}
impl<T> PgKeyDomain for T where
T: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
{
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum PgStorageError { pub enum PgStorageError<D>
where
D: std::fmt::Debug,
{
#[error(transparent)] #[error(transparent)]
Pg(#[from] sqlx::Error), Pg(#[from] sqlx::Error),
#[error("No key avalaible for domain {0:?}")] #[error("No key avalaible for domain {0:?}")]
Unavailable(KeyDomain), Unavailable(D),
#[error("Duplicate key '{0}'")]
DuplicateKey(String),
#[error("Duplicate domain '{0:?}'")]
DuplicateDomain(D),
#[error("Key not found: '{0}'")]
KeyNotFound(String),
} }
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow)]
pub struct PgKey { pub struct PgKey<D>
where
D: PgKeyDomain,
{
pub id: i32, pub id: i32,
pub key: String, pub key: String,
pub uses: i16, pub uses: i16,
pub domains: sqlx::types::Json<Vec<D>>,
} }
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow)]
pub struct PgKeyPoolStorage { pub struct PgKeyPoolStorage<D>
where
D: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
{
pool: PgPool, pool: PgPool,
limit: i16, limit: i16,
_phantom: std::marker::PhantomData<D>,
} }
impl ApiKey for PgKey { impl<D> ApiKey for PgKey<D>
where
D: PgKeyDomain,
{
fn value(&self) -> &str { fn value(&self) -> &str {
&self.key &self.key
} }
} }
impl PgKeyPoolStorage { impl<D> PgKeyPoolStorage<D>
where
D: PgKeyDomain,
{
pub fn new(pool: PgPool, limit: i16) -> Self { pub fn new(pool: PgPool, limit: i16) -> Self {
Self { pool, limit } Self {
pool,
limit,
_phantom: Default::default(),
}
} }
pub async fn initialise(&self) -> Result<(), PgStorageError> { pub async fn initialise(&self) -> Result<(), PgStorageError<D>> {
sqlx::query(indoc! {r#" sqlx::query(indoc! {r#"
CREATE TABLE IF NOT EXISTS api_keys ( CREATE TABLE IF NOT EXISTS api_keys (
id serial primary key, id serial primary key,
user_id int4 not null,
faction_id int4,
key char(16) not null, key char(16) not null,
uses int2 not null default 0, uses int2 not null default 0,
"user" bool not null, domains jsonb not null default '{}'::jsonb,
faction bool not null, last_used timestamptz not null default now(),
last_used timestamptz not null default now() flag int2,
)"#}) cooldown timestamptz,
constraint "uq:api_keys.key" UNIQUE(key)
)"#
})
.execute(&self.pool)
.await?;
sqlx::query(indoc! {r#"
CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops)
"#})
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
@ -72,63 +119,68 @@ async fn random_sleep() {
} }
#[async_trait] #[async_trait]
impl KeyPoolStorage for PgKeyPoolStorage { impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
type Key = PgKey; where
D: PgKeyDomain,
{
type Key = PgKey<D>;
type Domain = D;
type Error = PgStorageError; type Error = PgStorageError<D>;
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error> {
let predicate = match domain {
KeyDomain::Public => "".to_owned(),
KeyDomain::User(id) => format!(" and user_id={} and user", id),
KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id),
};
async fn acquire_key(&self, domain: D) -> Result<Self::Key, Self::Error> {
loop { loop {
let attempt = async { let attempt = async {
let mut tx = self.pool.begin().await?; let mut tx = self.pool.begin().await?;
sqlx::query("set transaction isolation level serializable") sqlx::query("set transaction isolation level repeatable read")
.execute(&mut tx) .execute(&mut tx)
.await?; .await?;
let key: Option<PgKey> = sqlx::query_as(&indoc::formatdoc!( let key = sqlx::query_as(&indoc::formatdoc!(
r#" r#"
with key as ( with key as (
select select
id, id,
0::int2 as uses 0::int2 as uses
from api_keys where last_used < date_trunc('minute', now()){predicate} from api_keys where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1
union ( union (
select id, uses from api_keys where last_used >= date_trunc('minute', now()){predicate} order by uses asc select id, uses from api_keys
where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1
order by uses asc
) )
limit 1 limit 1
) )
update api_keys set update api_keys set
uses = key.uses + 1, uses = key.uses + 1,
cooldown = null,
flag = null,
last_used = now() last_used = now()
from key where from key where
api_keys.id=key.id and key.uses < $1 api_keys.id=key.id and key.uses < $2
returning returning
api_keys.id, api_keys.id,
api_keys.key, api_keys.key,
api_keys.uses api_keys.uses,
api_keys.domains
"#, "#,
)) ))
.bind(sqlx::types::Json(vec![&domain]))
.bind(self.limit) .bind(self.limit)
.fetch_optional(&mut tx) .fetch_optional(&mut tx)
.await?; .await?;
tx.commit().await?; tx.commit().await?;
Result::<Result<Self::Key, Self::Error>, sqlx::Error>::Ok( Result::<Option<Self::Key>, sqlx::Error>::Ok(
key.ok_or(PgStorageError::Unavailable(domain)), key
) )
} }
.await; .await;
match attempt { match attempt {
Ok(result) => return result, Ok(Some(result)) => return Ok(result),
Ok(None) => return Err(PgStorageError::Unavailable(domain)),
Err(error) => { Err(error) => {
if let Some(db_error) = error.as_database_error() { if let Some(db_error) = error.as_database_error() {
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref(); let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
@ -147,45 +199,42 @@ impl KeyPoolStorage for PgKeyPoolStorage {
async fn acquire_many_keys( async fn acquire_many_keys(
&self, &self,
domain: KeyDomain, domain: D,
number: i64, number: i64,
) -> Result<Vec<Self::Key>, Self::Error> { ) -> Result<Vec<Self::Key>, Self::Error> {
let predicate = match domain {
KeyDomain::Public => "".to_owned(),
KeyDomain::User(id) => format!(" and user_id={} and user", id),
KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id),
};
loop { loop {
let attempt = async { let attempt = async {
let mut tx = self.pool.begin().await?; let mut tx = self.pool.begin().await?;
sqlx::query("set transaction isolation level serializable") sqlx::query("set transaction isolation level repeatable read")
.execute(&mut tx) .execute(&mut tx)
.await?; .await?;
let mut keys: Vec<PgKey> = sqlx::query_as(&indoc::formatdoc!( let mut keys: Vec<Self::Key> = sqlx::query_as(&indoc::formatdoc!(
r#"select r#"select
id, id,
key, key,
0::int2 as uses 0::int2 as uses,
from api_keys where last_used < date_trunc('minute', now()){predicate} domains
from api_keys where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1
union union
select select
id, id,
key, key,
uses uses,
from api_keys where last_used >= date_trunc('minute', now()){predicate} domains
order by uses limit $1 from api_keys where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1
order by uses limit $2
"#, "#,
)) ))
.bind(sqlx::types::Json(vec![&domain]))
.bind(number) .bind(number)
.fetch_all(&mut tx) .fetch_all(&mut tx)
.await?; .await?;
if keys.is_empty() { if keys.is_empty() {
tx.commit().await?; tx.commit().await?;
return Ok(Err(PgStorageError::Unavailable(domain))); return Ok(None);
} }
keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses)); keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses));
@ -217,6 +266,8 @@ impl KeyPoolStorage for PgKeyPoolStorage {
sqlx::query(indoc! {r#" sqlx::query(indoc! {r#"
update api_keys set update api_keys set
uses = tmp.uses, uses = tmp.uses,
cooldown = null,
flag = null,
last_used = now() last_used = now()
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
where api_keys.id = tmp.id where api_keys.id = tmp.id
@ -228,12 +279,13 @@ impl KeyPoolStorage for PgKeyPoolStorage {
tx.commit().await?; tx.commit().await?;
Result::<Result<Vec<Self::Key>, Self::Error>, sqlx::Error>::Ok(Ok(result)) Result::<Option<Vec<Self::Key>>, sqlx::Error>::Ok(Some(result))
} }
.await; .await;
match attempt { match attempt {
Ok(result) => return result, Ok(Some(result)) => return Ok(result),
Ok(None) => return Err(Self::Error::Unavailable(domain)),
Err(error) => { Err(error) => {
if let Some(db_error) = error.as_database_error() { if let Some(db_error) = error.as_database_error() {
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref(); let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
@ -254,7 +306,41 @@ impl KeyPoolStorage for PgKeyPoolStorage {
// TODO: put keys in cooldown when appropriate // TODO: put keys in cooldown when appropriate
match code { match code {
2 | 10 | 13 => { 2 | 10 | 13 => {
sqlx::query("delete from api_keys where id=$1") // invalid key, owner fedded or owner inactive
sqlx::query(
"update api_keys set cooldown='infinity'::timestamptz, flag=$1 where id=$2",
)
.bind(code as i16)
.bind(key.id)
.execute(&self.pool)
.await?;
Ok(true)
}
5 => {
// too many requests
sqlx::query("update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', flag=5 where id=$1")
.bind(key.id)
.execute(&self.pool)
.await?;
Ok(true)
}
8 => {
// IP block
sqlx::query("update api_keys set cooldown=now() + interval '5 min', flag=8")
.execute(&self.pool)
.await?;
Ok(false)
}
9 => {
// API disabled
sqlx::query("update api_keys set cooldown=now() + interval '1 min', flag=9")
.execute(&self.pool)
.await?;
Ok(false)
}
14 => {
// daily read limit reached
sqlx::query("update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', flag=14 where id=$1")
.bind(key.id) .bind(key.id)
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
@ -263,19 +349,115 @@ impl KeyPoolStorage for PgKeyPoolStorage {
_ => Ok(false), _ => Ok(false),
} }
} }
async fn store_key(&self, key: String, domains: Vec<D>) -> Result<Self::Key, Self::Error> {
sqlx::query_as("insert into api_keys(key, domains) values ($1, $2) returning *")
.bind(&key)
.bind(sqlx::types::Json(domains))
.fetch_one(&self.pool)
.await
.map_err(|why| {
if let Some(error) = why.as_database_error() {
let pg_error: &sqlx::postgres::PgDatabaseError = error.downcast_ref();
if pg_error.code() == "23505" {
return PgStorageError::DuplicateKey(key);
}
}
PgStorageError::Pg(why)
})
}
async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error> {
sqlx::query_as("select * from api_keys where key=$1")
.bind(&key)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(key))
}
async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error> {
sqlx::query_as("delete from api_keys where key=$1 returning *")
.bind(&key)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(key))
}
async fn add_domain_to_key(&self, key: String, domain: D) -> Result<Self::Key, Self::Error> {
let mut tx = self.pool.begin().await?;
match sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
"update api_keys set domains = domains || jsonb_build_array($1) where key=$2 returning *",
)
.bind(sqlx::types::Json(domain.clone()))
.bind(&key)
.fetch_optional(&mut tx)
.await?
{
None => Err(PgStorageError::KeyNotFound(key)),
Some(key) => {
if key.domains.0.iter().filter(|d| **d == domain).count() > 1 {
tx.rollback().await?;
return Err(PgStorageError::DuplicateDomain(domain));
}
tx.commit().await?;
Ok(key)
}
}
}
async fn remove_domain_from_key(
&self,
key: String,
domain: D,
) -> Result<Self::Key, Self::Error> {
// FIX: potential race condition
let api_key = self.read_key(key.clone()).await?;
let domains = api_key
.domains
.0
.into_iter()
.filter(|d| *d != domain)
.collect();
self.set_domains_for_key(key, domains).await
}
async fn set_domains_for_key(
&self,
key: String,
domains: Vec<D>,
) -> Result<Self::Key, Self::Error> {
sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
"update api_keys set domains = $1 where key=$2 returning *",
)
.bind(sqlx::types::Json(domains))
.bind(&key)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(key))
}
} }
#[cfg(test)] #[cfg(test)]
mod test { pub(crate) mod test {
use std::sync::{Arc, Once}; use std::sync::{Arc, Once};
use sqlx::Row;
use tokio::test; use tokio::test;
use super::*; use super::*;
static INIT: Once = Once::new(); static INIT: Once = Once::new();
pub(crate) async fn setup() -> PgKeyPoolStorage { #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum Domain {
All,
User { id: i32 },
Faction { id: i32 },
}
pub(crate) async fn setup() -> PgKeyPoolStorage<Domain> {
INIT.call_once(|| { INIT.call_once(|| {
dotenv::dotenv().ok(); dotenv::dotenv().ok();
}); });
@ -284,12 +466,20 @@ mod test {
.await .await
.unwrap(); .unwrap();
sqlx::query("update api_keys set uses=id") sqlx::query("DROP TABLE IF EXISTS api_keys")
.execute(&pool) .execute(&pool)
.await .await
.unwrap(); .unwrap();
PgKeyPoolStorage::new(pool, 50) let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
storage.initialise().await.unwrap();
storage
.store_key(std::env::var("APIKEY").unwrap(), vec![Domain::All])
.await
.unwrap();
storage
} }
#[test] #[test]
@ -301,24 +491,179 @@ mod test {
} }
} }
#[test]
async fn test_store_duplicate() {
let storage = setup().await;
match storage
.store_key(std::env::var("APIKEY").unwrap(), vec![])
.await
.unwrap_err()
{
PgStorageError::DuplicateKey(key) => {
assert_eq!(key, std::env::var("APIKEY").unwrap())
}
why => panic!("Expected duplicate key error but found '{why}'"),
};
}
#[test]
async fn test_add_domain() {
let storage = setup().await;
let key = storage
.add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::User { id: 12345 })
.await
.unwrap();
assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
}
#[test]
async fn test_add_duplicate_domain() {
let storage = setup().await;
match storage
.add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::All)
.await
.unwrap_err()
{
PgStorageError::DuplicateDomain(d) => assert_eq!(d, Domain::All),
why => panic!("Expected duplicate domain error but found '{why}'"),
};
}
#[test]
async fn test_remove_domain() {
let storage = setup().await;
let key = storage
.remove_domain_from_key(std::env::var("APIKEY").unwrap(), Domain::All)
.await
.unwrap();
assert!(key.domains.0.is_empty());
}
#[test]
async fn test_store_key() {
let storage = setup().await;
let key = storage
.store_key("ABCDABCDABCDABCD".to_owned(), vec![])
.await
.unwrap();
assert_eq!(key.value(), "ABCDABCDABCDABCD");
}
#[test] #[test]
async fn acquire_one() { async fn acquire_one() {
let storage = setup().await; let storage = setup().await;
if let Err(e) = storage.acquire_key(KeyDomain::Public).await { if let Err(e) = storage.acquire_key(Domain::All).await {
panic!("Acquiring key failed: {:?}", e); panic!("Acquiring key failed: {:?}", e);
} }
} }
#[test]
async fn test_flag_key_one() {
let storage = setup().await;
let key = storage
.read_key(std::env::var("APIKEY").unwrap())
.await
.unwrap();
assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_key(Domain::All).await.unwrap_err() {
PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All),
why => panic!("Expected domain unavailable error but found '{why}'"),
}
}
#[test]
async fn test_flag_key_many() {
let storage = setup().await;
let key = storage
.read_key(std::env::var("APIKEY").unwrap())
.await
.unwrap();
assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All),
why => panic!("Expected domain unavailable error but found '{why}'"),
}
}
#[test]
async fn acquire_many() {
let storage = setup().await;
match storage.acquire_many_keys(Domain::All, 30).await {
Err(e) => panic!("Acquiring key failed: {:?}", e),
Ok(keys) => assert_eq!(keys.len(), 30),
}
}
#[test] #[test]
async fn test_concurrent() { async fn test_concurrent() {
let storage = Arc::new(setup().await); let storage = Arc::new(setup().await);
let keys = storage for _ in 0..10 {
.acquire_many_keys(KeyDomain::Public, 30) let mut set = tokio::task::JoinSet::new();
.await
.unwrap();
assert_eq!(keys.len(), 30); for _ in 0..100 {
let storage = storage.clone();
set.spawn(async move {
storage.acquire_key(Domain::All).await.unwrap();
});
}
for _ in 0..100 {
set.join_next().await.unwrap().unwrap();
}
let uses: i16 = sqlx::query("select uses from api_keys")
.fetch_one(&storage.pool)
.await
.unwrap()
.get("uses");
assert_eq!(uses, 100);
sqlx::query("update api_keys set uses=0")
.execute(&storage.pool)
.await
.unwrap();
}
}
#[test]
async fn test_concurrent_many() {
let storage = Arc::new(setup().await);
for _ in 0..10 {
let mut set = tokio::task::JoinSet::new();
for _ in 0..100 {
let storage = storage.clone();
set.spawn(async move {
storage.acquire_many_keys(Domain::All, 5).await.unwrap();
});
}
for _ in 0..100 {
set.join_next().await.unwrap().unwrap();
}
let uses: i16 = sqlx::query("select uses from api_keys")
.fetch_one(&storage.pool)
.await
.unwrap()
.get("uses");
assert_eq!(uses, 500);
sqlx::query("update api_keys set uses=0")
.execute(&storage.pool)
.await
.unwrap();
}
} }
} }

View file

@ -7,7 +7,7 @@ use torn_api::{
ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError, ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError,
}; };
use crate::{ApiKey, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
#[async_trait] #[async_trait]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S> impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
@ -20,16 +20,17 @@ where
async fn execute<A>( async fn execute<A>(
&self, &self,
client: &C, client: &C,
request: ApiRequest<A>, mut request: ApiRequest<A>,
id: Option<i64>, id: Option<i64>,
) -> Result<A, Self::Error> ) -> Result<A, Self::Error>
where where
A: ApiCategoryResponse, A: ApiCategoryResponse,
{ {
request.comment = self.comment.map(ToOwned::to_owned);
loop { loop {
let key = self let key = self
.storage .storage
.acquire_key(self.domain) .acquire_key(self.domain.clone())
.await .await
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
let url = request.url(key.value(), id); let url = request.url(key.value(), id);
@ -56,7 +57,7 @@ where
async fn execute_many<A>( async fn execute_many<A>(
&self, &self,
client: &C, client: &C,
request: ApiRequest<A>, mut request: ApiRequest<A>,
ids: Vec<i64>, ids: Vec<i64>,
) -> HashMap<i64, Result<A, Self::Error>> ) -> HashMap<i64, Result<A, Self::Error>>
where where
@ -64,7 +65,7 @@ where
{ {
let keys = match self let keys = match self
.storage .storage
.acquire_many_keys(self.domain, ids.len() as i64) .acquire_many_keys(self.domain.clone(), ids.len() as i64)
.await .await
{ {
Ok(keys) => keys, Ok(keys) => keys,
@ -77,6 +78,7 @@ where
} }
}; };
request.comment = self.comment.map(ToOwned::to_owned);
let request_ref = &request; let request_ref = &request;
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move { futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
@ -107,7 +109,7 @@ where
Ok(res) => return (id, Ok(A::from_response(res))), Ok(res) => return (id, Ok(A::from_response(res))),
}; };
key = match self.storage.acquire_key(self.domain).await { key = match self.storage.acquire_key(self.domain.clone()).await {
Ok(k) => k, Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
}; };
@ -127,6 +129,7 @@ where
{ {
client: C, client: C,
storage: S, storage: S,
comment: Option<String>,
} }
impl<C, S> KeyPool<C, S> impl<C, S> KeyPool<C, S>
@ -134,12 +137,19 @@ where
C: ApiClient, C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static, S: KeyPoolStorage + Send + Sync + 'static,
{ {
pub fn new(client: C, storage: S) -> Self { pub fn new(client: C, storage: S, comment: Option<String>) -> Self {
Self { client, storage } Self {
client,
storage,
comment,
}
} }
pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider<C, KeyPoolExecutor<C, S>> { pub fn torn_api(&self, domain: S::Domain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain)) ApiProvider::new(
&self.client,
KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()),
)
} }
} }
@ -147,15 +157,48 @@ pub trait WithStorage {
fn with_storage<'a, S>( fn with_storage<'a, S>(
&'a self, &'a self,
storage: &'a S, storage: &'a S,
domain: KeyDomain, domain: S::Domain,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>> ) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where where
Self: ApiClient + Sized, Self: ApiClient + Sized,
S: KeyPoolStorage + Send + Sync + 'static, S: KeyPoolStorage + Send + Sync + 'static,
{ {
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain)) ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None))
} }
} }
#[cfg(feature = "reqwest")] #[cfg(feature = "reqwest")]
impl WithStorage for reqwest::Client {} impl WithStorage for reqwest::Client {}
#[cfg(all(test, feature = "postgres", feature = "reqwest"))]
mod test {
use tokio::test;
use super::*;
use crate::postgres::test::{setup, Domain};
#[test]
async fn test_pool_request() {
let storage = setup().await;
let pool = KeyPool::new(
reqwest::Client::default(),
storage,
Some("api.rs".to_owned()),
);
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.profile().unwrap();
}
#[test]
async fn test_with_storage_request() {
let storage = setup().await;
let response = reqwest::Client::new()
.with_storage(&storage, Domain::All)
.user(|b| b)
.await
.unwrap();
_ = response.profile().unwrap();
}
}