From 9ae436c694983e8e17106e3dbf039a5c76700971 Mon Sep 17 00:00:00 2001 From: TotallyNot <44345987+TotallyNot@users.noreply.github.com> Date: Sun, 22 Jan 2023 19:29:11 +0100 Subject: [PATCH] refactored and expanded postgres keypool --- torn-key-pool/Cargo.toml | 13 +- torn-key-pool/src/lib.rs | 90 +++---- torn-key-pool/src/local.rs | 63 ++++- torn-key-pool/src/postgres.rs | 469 +++++++++++++++++++++++++++++----- torn-key-pool/src/send.rs | 67 ++++- 5 files changed, 560 insertions(+), 142 deletions(-) diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index bb03fc9..4453ee1 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,28 +1,29 @@ [package] name = "torn-key-pool" -version = "0.4.2" +version = "0.5.0" edition = "2021" license = "MIT" repository = "https://github.com/TotallyNot/torn-api.rs.git" homepage = "https://github.com/TotallyNot/torn-api.rs.git" -description = "A generalizes API key pool for torn-api" +description = "A generalised API key pool for torn-api" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] default = [ "postgres", "tokio-runtime" ] -postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc" ] +postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc", "dep:serde" ] reqwest = [ "dep:reqwest", "torn-api/reqwest" ] awc = [ "dep:awc", "torn-api/awc" ] tokio-runtime = [ "dep:tokio", "dep:rand" ] actix-runtime = [ "dep:actix-rt", "dep:rand" ] [dependencies] -torn-api = { path = "../torn-api", default-features = false, version = "0.5" } +torn-api = { path = "../torn-api", default-features = false, version = "0.5.5" } async-trait = "0.1" thiserror = "1" -sqlx = { version = "0.6", features = [ "postgres", "chrono" ], optional = true } +sqlx = { version = "0.6", features = [ "postgres", "chrono", "json" ], optional = true } +serde = { version = "1.0", optional = true } chrono = { version = "0.4", optional = true } indoc = { version = "1", optional = true } tokio = { version = "1", optional = true, default-features = false, features = ["time"] } @@ -37,7 +38,7 @@ awc = { version = "3", default-features = false, optional = true } torn-api = { path = "../torn-api", features = [ "reqwest" ] } sqlx = { version = "0.6", features = [ "runtime-tokio-rustls" ] } dotenv = "0.15.0" -tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] } +tokio = { version = "1.24.2", features = ["test-util", "rt", "macros"] } tokio-test = "0.4.2" reqwest = { version = "0.11", default-features = true } awc = { version = "3", features = [ "rustls" ] } diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index 624937c..d99c0cf 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -29,31 +29,57 @@ where Response(ResponseError), } -#[derive(Debug, Clone, Copy)] -pub enum KeyDomain { - Public, - User(i32), - Faction(i32), -} - pub trait ApiKey: Sync + Send { fn value(&self) -> &str; } +pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {} + +impl KeyDomain for T where T: Clone + std::fmt::Debug + Send + Sync {} + #[async_trait] pub trait KeyPoolStorage { type Key: ApiKey; + type Domain: KeyDomain; type Error: std::error::Error + Sync + Send; - async fn acquire_key(&self, domain: KeyDomain) -> Result; + async fn acquire_key(&self, domain: Self::Domain) -> Result; async fn acquire_many_keys( &self, - domain: KeyDomain, + domain: Self::Domain, number: i64, ) -> Result, Self::Error>; async fn flag_key(&self, key: Self::Key, code: u8) -> Result; + + async fn store_key( + &self, + key: String, + domains: Vec, + ) -> Result; + + async fn read_key(&self, key: String) -> Result; + + async fn remove_key(&self, key: String) -> Result; + + async fn add_domain_to_key( + &self, + key: String, + domain: Self::Domain, + ) -> Result; + + async fn remove_domain_from_key( + &self, + key: String, + domain: Self::Domain, + ) -> Result; + + async fn set_domains_for_key( + &self, + key: String, + domains: Vec, + ) -> Result; } #[derive(Debug, Clone)] @@ -62,7 +88,8 @@ where S: KeyPoolStorage, { storage: &'a S, - domain: KeyDomain, + comment: Option<&'a str>, + domain: S::Domain, _marker: std::marker::PhantomData, } @@ -70,52 +97,15 @@ impl<'a, C, S> KeyPoolExecutor<'a, C, S> where S: KeyPoolStorage, { - pub fn new(storage: &'a S, domain: KeyDomain) -> Self { + pub fn new(storage: &'a S, domain: S::Domain, comment: Option<&'a str>) -> Self { Self { storage, domain, + comment, _marker: std::marker::PhantomData, } } } #[cfg(all(test, feature = "postgres"))] -mod test { - use std::sync::Once; - - use tokio::test; - - use super::*; - - static INIT: Once = Once::new(); - - pub(crate) async fn setup() -> postgres::PgKeyPoolStorage { - INIT.call_once(|| { - dotenv::dotenv().ok(); - }); - - let pool = sqlx::PgPool::connect(&std::env::var("DATABASE_URL").unwrap()) - .await - .unwrap(); - - sqlx::query("update api_keys set uses=0") - .execute(&pool) - .await - .unwrap(); - - postgres::PgKeyPoolStorage::new(pool, 50) - } - - #[test] - async fn key_pool_bulk() { - let storage = setup().await; - - if let Err(e) = storage.initialise().await { - panic!("Initialising key storage failed: {:?}", e); - } - - let pool = send::KeyPool::new(reqwest::Client::default(), storage); - - pool.torn_api(KeyDomain::Public).users([1], |b| b).await; - } -} +mod test {} diff --git a/torn-key-pool/src/local.rs b/torn-key-pool/src/local.rs index a6e0173..5a78b48 100644 --- a/torn-key-pool/src/local.rs +++ b/torn-key-pool/src/local.rs @@ -7,7 +7,7 @@ use torn_api::{ ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError, }; -use crate::{ApiKey, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; +use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; #[async_trait(?Send)] impl<'client, C, S> RequestExecutor for KeyPoolExecutor<'client, C, S> @@ -20,16 +20,17 @@ where async fn execute( &self, client: &C, - request: ApiRequest, + mut request: ApiRequest, id: Option, ) -> Result where A: ApiCategoryResponse, { + request.comment = self.comment.map(ToOwned::to_owned); loop { let key = self .storage - .acquire_key(self.domain) + .acquire_key(self.domain.clone()) .await .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; let url = request.url(key.value(), id); @@ -56,7 +57,7 @@ where async fn execute_many( &self, client: &C, - request: ApiRequest, + mut request: ApiRequest, ids: Vec, ) -> HashMap> where @@ -64,7 +65,7 @@ where { let keys = match self .storage - .acquire_many_keys(self.domain, ids.len() as i64) + .acquire_many_keys(self.domain.clone(), ids.len() as i64) .await { Ok(keys) => keys, @@ -77,6 +78,7 @@ where } }; + request.comment = self.comment.map(ToOwned::to_owned); let request_ref = &request; futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move { @@ -107,7 +109,7 @@ where Ok(res) => return (id, Ok(A::from_response(res))), }; - key = match self.storage.acquire_key(self.domain).await { + key = match self.storage.acquire_key(self.domain.clone()).await { Ok(k) => k, Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), }; @@ -127,6 +129,7 @@ where { client: C, storage: S, + comment: Option, } impl KeyPool @@ -134,12 +137,19 @@ where C: ApiClient, S: KeyPoolStorage + 'static, { - pub fn new(client: C, storage: S) -> Self { - Self { client, storage } + pub fn new(client: C, storage: S, comment: Option) -> Self { + Self { + client, + storage, + comment, + } } - pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider> { - ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain)) + pub fn torn_api(&self, domain: S::Domain) -> ApiProvider> { + ApiProvider::new( + &self.client, + KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()), + ) } } @@ -147,15 +157,44 @@ pub trait WithStorage { fn with_storage<'a, S>( &'a self, storage: &'a S, - domain: KeyDomain, + domain: S::Domain, ) -> ApiProvider> where Self: ApiClient + Sized, S: KeyPoolStorage + 'static, { - ApiProvider::new(self, KeyPoolExecutor::new(storage, domain)) + ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None)) } } #[cfg(feature = "awc")] impl WithStorage for awc::Client {} + +#[cfg(all(test, feature = "postgres", feature = "awc"))] +mod test { + use tokio::test; + + use super::*; + use crate::postgres::test::{setup, Domain}; + + #[test] + async fn test_pool_request() { + let storage = setup().await; + let pool = KeyPool::new(awc::Client::default(), storage); + + let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap(); + _ = response.profile().unwrap(); + } + + #[test] + async fn test_with_storage_request() { + let storage = setup().await; + + let response = awc::Client::new() + .with_storage(&storage, Domain::All) + .user(|b| b) + .await + .unwrap(); + _ = response.profile().unwrap(); + } +} diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index ec84bc3..f0868c6 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -5,51 +5,98 @@ use thiserror::Error; use crate::{ApiKey, KeyDomain, KeyPoolStorage}; +pub trait PgKeyDomain: + KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin +{ +} + +impl PgKeyDomain for T where + T: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin +{ +} + #[derive(Debug, Error)] -pub enum PgStorageError { +pub enum PgStorageError +where + D: std::fmt::Debug, +{ #[error(transparent)] Pg(#[from] sqlx::Error), #[error("No key avalaible for domain {0:?}")] - Unavailable(KeyDomain), + Unavailable(D), + + #[error("Duplicate key '{0}'")] + DuplicateKey(String), + + #[error("Duplicate domain '{0:?}'")] + DuplicateDomain(D), + + #[error("Key not found: '{0}'")] + KeyNotFound(String), } #[derive(Debug, Clone, FromRow)] -pub struct PgKey { +pub struct PgKey +where + D: PgKeyDomain, +{ pub id: i32, pub key: String, pub uses: i16, + pub domains: sqlx::types::Json>, } #[derive(Debug, Clone, FromRow)] -pub struct PgKeyPoolStorage { +pub struct PgKeyPoolStorage +where + D: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static, +{ pool: PgPool, limit: i16, + _phantom: std::marker::PhantomData, } -impl ApiKey for PgKey { +impl ApiKey for PgKey +where + D: PgKeyDomain, +{ fn value(&self) -> &str { &self.key } } -impl PgKeyPoolStorage { +impl PgKeyPoolStorage +where + D: PgKeyDomain, +{ pub fn new(pool: PgPool, limit: i16) -> Self { - Self { pool, limit } + Self { + pool, + limit, + _phantom: Default::default(), + } } - pub async fn initialise(&self) -> Result<(), PgStorageError> { + pub async fn initialise(&self) -> Result<(), PgStorageError> { sqlx::query(indoc! {r#" CREATE TABLE IF NOT EXISTS api_keys ( id serial primary key, - user_id int4 not null, - faction_id int4, key char(16) not null, uses int2 not null default 0, - "user" bool not null, - faction bool not null, - last_used timestamptz not null default now() - )"#}) + domains jsonb not null default '{}'::jsonb, + last_used timestamptz not null default now(), + flag int2, + cooldown timestamptz, + constraint "uq:api_keys.key" UNIQUE(key) + )"# + }) + .execute(&self.pool) + .await?; + + sqlx::query(indoc! {r#" + CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops) + "#}) .execute(&self.pool) .await?; @@ -72,63 +119,68 @@ async fn random_sleep() { } #[async_trait] -impl KeyPoolStorage for PgKeyPoolStorage { - type Key = PgKey; +impl KeyPoolStorage for PgKeyPoolStorage +where + D: PgKeyDomain, +{ + type Key = PgKey; + type Domain = D; - type Error = PgStorageError; - - async fn acquire_key(&self, domain: KeyDomain) -> Result { - let predicate = match domain { - KeyDomain::Public => "".to_owned(), - KeyDomain::User(id) => format!(" and user_id={} and user", id), - KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id), - }; + type Error = PgStorageError; + async fn acquire_key(&self, domain: D) -> Result { loop { let attempt = async { let mut tx = self.pool.begin().await?; - sqlx::query("set transaction isolation level serializable") + sqlx::query("set transaction isolation level repeatable read") .execute(&mut tx) .await?; - let key: Option = sqlx::query_as(&indoc::formatdoc!( + let key = sqlx::query_as(&indoc::formatdoc!( r#" with key as ( select id, 0::int2 as uses - from api_keys where last_used < date_trunc('minute', now()){predicate} + from api_keys where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 union ( - select id, uses from api_keys where last_used >= date_trunc('minute', now()){predicate} order by uses asc + select id, uses from api_keys + where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 + order by uses asc ) limit 1 ) update api_keys set uses = key.uses + 1, + cooldown = null, + flag = null, last_used = now() from key where - api_keys.id=key.id and key.uses < $1 + api_keys.id=key.id and key.uses < $2 returning api_keys.id, api_keys.key, - api_keys.uses + api_keys.uses, + api_keys.domains "#, )) + .bind(sqlx::types::Json(vec![&domain])) .bind(self.limit) .fetch_optional(&mut tx) .await?; - tx.commit().await?; + tx.commit().await?; - Result::, sqlx::Error>::Ok( - key.ok_or(PgStorageError::Unavailable(domain)), + Result::, sqlx::Error>::Ok( + key ) } .await; match attempt { - Ok(result) => return result, + Ok(Some(result)) => return Ok(result), + Ok(None) => return Err(PgStorageError::Unavailable(domain)), Err(error) => { if let Some(db_error) = error.as_database_error() { let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref(); @@ -147,45 +199,42 @@ impl KeyPoolStorage for PgKeyPoolStorage { async fn acquire_many_keys( &self, - domain: KeyDomain, + domain: D, number: i64, ) -> Result, Self::Error> { - let predicate = match domain { - KeyDomain::Public => "".to_owned(), - KeyDomain::User(id) => format!(" and user_id={} and user", id), - KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id), - }; - loop { let attempt = async { let mut tx = self.pool.begin().await?; - sqlx::query("set transaction isolation level serializable") + sqlx::query("set transaction isolation level repeatable read") .execute(&mut tx) .await?; - let mut keys: Vec = sqlx::query_as(&indoc::formatdoc!( + let mut keys: Vec = sqlx::query_as(&indoc::formatdoc!( r#"select id, key, - 0::int2 as uses - from api_keys where last_used < date_trunc('minute', now()){predicate} + 0::int2 as uses, + domains + from api_keys where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 union select id, key, - uses - from api_keys where last_used >= date_trunc('minute', now()){predicate} - order by uses limit $1 + uses, + domains + from api_keys where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 + order by uses limit $2 "#, )) + .bind(sqlx::types::Json(vec![&domain])) .bind(number) .fetch_all(&mut tx) .await?; if keys.is_empty() { tx.commit().await?; - return Ok(Err(PgStorageError::Unavailable(domain))); + return Ok(None); } keys.sort_unstable_by(|k1, k2| k1.uses.cmp(&k2.uses)); @@ -217,6 +266,8 @@ impl KeyPoolStorage for PgKeyPoolStorage { sqlx::query(indoc! {r#" update api_keys set uses = tmp.uses, + cooldown = null, + flag = null, last_used = now() from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp where api_keys.id = tmp.id @@ -228,12 +279,13 @@ impl KeyPoolStorage for PgKeyPoolStorage { tx.commit().await?; - Result::, Self::Error>, sqlx::Error>::Ok(Ok(result)) + Result::>, sqlx::Error>::Ok(Some(result)) } .await; match attempt { - Ok(result) => return result, + Ok(Some(result)) => return Ok(result), + Ok(None) => return Err(Self::Error::Unavailable(domain)), Err(error) => { if let Some(db_error) = error.as_database_error() { let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref(); @@ -254,7 +306,41 @@ impl KeyPoolStorage for PgKeyPoolStorage { // TODO: put keys in cooldown when appropriate match code { 2 | 10 | 13 => { - sqlx::query("delete from api_keys where id=$1") + // invalid key, owner fedded or owner inactive + sqlx::query( + "update api_keys set cooldown='infinity'::timestamptz, flag=$1 where id=$2", + ) + .bind(code as i16) + .bind(key.id) + .execute(&self.pool) + .await?; + Ok(true) + } + 5 => { + // too many requests + sqlx::query("update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', flag=5 where id=$1") + .bind(key.id) + .execute(&self.pool) + .await?; + Ok(true) + } + 8 => { + // IP block + sqlx::query("update api_keys set cooldown=now() + interval '5 min', flag=8") + .execute(&self.pool) + .await?; + Ok(false) + } + 9 => { + // API disabled + sqlx::query("update api_keys set cooldown=now() + interval '1 min', flag=9") + .execute(&self.pool) + .await?; + Ok(false) + } + 14 => { + // daily read limit reached + sqlx::query("update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', flag=14 where id=$1") .bind(key.id) .execute(&self.pool) .await?; @@ -263,19 +349,115 @@ impl KeyPoolStorage for PgKeyPoolStorage { _ => Ok(false), } } + + async fn store_key(&self, key: String, domains: Vec) -> Result { + sqlx::query_as("insert into api_keys(key, domains) values ($1, $2) returning *") + .bind(&key) + .bind(sqlx::types::Json(domains)) + .fetch_one(&self.pool) + .await + .map_err(|why| { + if let Some(error) = why.as_database_error() { + let pg_error: &sqlx::postgres::PgDatabaseError = error.downcast_ref(); + if pg_error.code() == "23505" { + return PgStorageError::DuplicateKey(key); + } + } + PgStorageError::Pg(why) + }) + } + + async fn read_key(&self, key: String) -> Result { + sqlx::query_as("select * from api_keys where key=$1") + .bind(&key) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(key)) + } + + async fn remove_key(&self, key: String) -> Result { + sqlx::query_as("delete from api_keys where key=$1 returning *") + .bind(&key) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(key)) + } + + async fn add_domain_to_key(&self, key: String, domain: D) -> Result { + let mut tx = self.pool.begin().await?; + match sqlx::query_as::>( + "update api_keys set domains = domains || jsonb_build_array($1) where key=$2 returning *", + ) + .bind(sqlx::types::Json(domain.clone())) + .bind(&key) + .fetch_optional(&mut tx) + .await? + { + None => Err(PgStorageError::KeyNotFound(key)), + Some(key) => { + if key.domains.0.iter().filter(|d| **d == domain).count() > 1 { + tx.rollback().await?; + return Err(PgStorageError::DuplicateDomain(domain)); + } + tx.commit().await?; + Ok(key) + } + } + } + + async fn remove_domain_from_key( + &self, + key: String, + domain: D, + ) -> Result { + // FIX: potential race condition + let api_key = self.read_key(key.clone()).await?; + let domains = api_key + .domains + .0 + .into_iter() + .filter(|d| *d != domain) + .collect(); + + self.set_domains_for_key(key, domains).await + } + + async fn set_domains_for_key( + &self, + key: String, + domains: Vec, + ) -> Result { + sqlx::query_as::>( + "update api_keys set domains = $1 where key=$2 returning *", + ) + .bind(sqlx::types::Json(domains)) + .bind(&key) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(key)) + } } #[cfg(test)] -mod test { +pub(crate) mod test { use std::sync::{Arc, Once}; + use sqlx::Row; use tokio::test; use super::*; static INIT: Once = Once::new(); - pub(crate) async fn setup() -> PgKeyPoolStorage { + #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] + #[serde(tag = "type", rename_all = "snake_case")] + pub(crate) enum Domain { + All, + User { id: i32 }, + Faction { id: i32 }, + } + + pub(crate) async fn setup() -> PgKeyPoolStorage { INIT.call_once(|| { dotenv::dotenv().ok(); }); @@ -284,12 +466,20 @@ mod test { .await .unwrap(); - sqlx::query("update api_keys set uses=id") + sqlx::query("DROP TABLE IF EXISTS api_keys") .execute(&pool) .await .unwrap(); - PgKeyPoolStorage::new(pool, 50) + let storage = PgKeyPoolStorage::new(pool.clone(), 1000); + storage.initialise().await.unwrap(); + + storage + .store_key(std::env::var("APIKEY").unwrap(), vec![Domain::All]) + .await + .unwrap(); + + storage } #[test] @@ -301,24 +491,179 @@ mod test { } } + #[test] + async fn test_store_duplicate() { + let storage = setup().await; + match storage + .store_key(std::env::var("APIKEY").unwrap(), vec![]) + .await + .unwrap_err() + { + PgStorageError::DuplicateKey(key) => { + assert_eq!(key, std::env::var("APIKEY").unwrap()) + } + why => panic!("Expected duplicate key error but found '{why}'"), + }; + } + + #[test] + async fn test_add_domain() { + let storage = setup().await; + let key = storage + .add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::User { id: 12345 }) + .await + .unwrap(); + + assert!(key.domains.0.contains(&Domain::User { id: 12345 })); + } + + #[test] + async fn test_add_duplicate_domain() { + let storage = setup().await; + match storage + .add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::All) + .await + .unwrap_err() + { + PgStorageError::DuplicateDomain(d) => assert_eq!(d, Domain::All), + why => panic!("Expected duplicate domain error but found '{why}'"), + }; + } + + #[test] + async fn test_remove_domain() { + let storage = setup().await; + let key = storage + .remove_domain_from_key(std::env::var("APIKEY").unwrap(), Domain::All) + .await + .unwrap(); + + assert!(key.domains.0.is_empty()); + } + + #[test] + async fn test_store_key() { + let storage = setup().await; + let key = storage + .store_key("ABCDABCDABCDABCD".to_owned(), vec![]) + .await + .unwrap(); + assert_eq!(key.value(), "ABCDABCDABCDABCD"); + } + #[test] async fn acquire_one() { let storage = setup().await; - if let Err(e) = storage.acquire_key(KeyDomain::Public).await { + if let Err(e) = storage.acquire_key(Domain::All).await { panic!("Acquiring key failed: {:?}", e); } } + #[test] + async fn test_flag_key_one() { + let storage = setup().await; + let key = storage + .read_key(std::env::var("APIKEY").unwrap()) + .await + .unwrap(); + + assert!(storage.flag_key(key, 2).await.unwrap()); + + match storage.acquire_key(Domain::All).await.unwrap_err() { + PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All), + why => panic!("Expected domain unavailable error but found '{why}'"), + } + } + + #[test] + async fn test_flag_key_many() { + let storage = setup().await; + let key = storage + .read_key(std::env::var("APIKEY").unwrap()) + .await + .unwrap(); + + assert!(storage.flag_key(key, 2).await.unwrap()); + + match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() { + PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All), + why => panic!("Expected domain unavailable error but found '{why}'"), + } + } + + #[test] + async fn acquire_many() { + let storage = setup().await; + + match storage.acquire_many_keys(Domain::All, 30).await { + Err(e) => panic!("Acquiring key failed: {:?}", e), + Ok(keys) => assert_eq!(keys.len(), 30), + } + } + #[test] async fn test_concurrent() { let storage = Arc::new(setup().await); - let keys = storage - .acquire_many_keys(KeyDomain::Public, 30) - .await - .unwrap(); + for _ in 0..10 { + let mut set = tokio::task::JoinSet::new(); - assert_eq!(keys.len(), 30); + for _ in 0..100 { + let storage = storage.clone(); + set.spawn(async move { + storage.acquire_key(Domain::All).await.unwrap(); + }); + } + + for _ in 0..100 { + set.join_next().await.unwrap().unwrap(); + } + + let uses: i16 = sqlx::query("select uses from api_keys") + .fetch_one(&storage.pool) + .await + .unwrap() + .get("uses"); + + assert_eq!(uses, 100); + + sqlx::query("update api_keys set uses=0") + .execute(&storage.pool) + .await + .unwrap(); + } + } + + #[test] + async fn test_concurrent_many() { + let storage = Arc::new(setup().await); + for _ in 0..10 { + let mut set = tokio::task::JoinSet::new(); + + for _ in 0..100 { + let storage = storage.clone(); + set.spawn(async move { + storage.acquire_many_keys(Domain::All, 5).await.unwrap(); + }); + } + + for _ in 0..100 { + set.join_next().await.unwrap().unwrap(); + } + + let uses: i16 = sqlx::query("select uses from api_keys") + .fetch_one(&storage.pool) + .await + .unwrap() + .get("uses"); + + assert_eq!(uses, 500); + + sqlx::query("update api_keys set uses=0") + .execute(&storage.pool) + .await + .unwrap(); + } } } diff --git a/torn-key-pool/src/send.rs b/torn-key-pool/src/send.rs index 9581f6b..2409443 100644 --- a/torn-key-pool/src/send.rs +++ b/torn-key-pool/src/send.rs @@ -7,7 +7,7 @@ use torn_api::{ ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError, }; -use crate::{ApiKey, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; +use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; #[async_trait] impl<'client, C, S> RequestExecutor for KeyPoolExecutor<'client, C, S> @@ -20,16 +20,17 @@ where async fn execute( &self, client: &C, - request: ApiRequest, + mut request: ApiRequest, id: Option, ) -> Result where A: ApiCategoryResponse, { + request.comment = self.comment.map(ToOwned::to_owned); loop { let key = self .storage - .acquire_key(self.domain) + .acquire_key(self.domain.clone()) .await .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; let url = request.url(key.value(), id); @@ -56,7 +57,7 @@ where async fn execute_many( &self, client: &C, - request: ApiRequest, + mut request: ApiRequest, ids: Vec, ) -> HashMap> where @@ -64,7 +65,7 @@ where { let keys = match self .storage - .acquire_many_keys(self.domain, ids.len() as i64) + .acquire_many_keys(self.domain.clone(), ids.len() as i64) .await { Ok(keys) => keys, @@ -77,6 +78,7 @@ where } }; + request.comment = self.comment.map(ToOwned::to_owned); let request_ref = &request; futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move { @@ -107,7 +109,7 @@ where Ok(res) => return (id, Ok(A::from_response(res))), }; - key = match self.storage.acquire_key(self.domain).await { + key = match self.storage.acquire_key(self.domain.clone()).await { Ok(k) => k, Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), }; @@ -127,6 +129,7 @@ where { client: C, storage: S, + comment: Option, } impl KeyPool @@ -134,12 +137,19 @@ where C: ApiClient, S: KeyPoolStorage + Send + Sync + 'static, { - pub fn new(client: C, storage: S) -> Self { - Self { client, storage } + pub fn new(client: C, storage: S, comment: Option) -> Self { + Self { + client, + storage, + comment, + } } - pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider> { - ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain)) + pub fn torn_api(&self, domain: S::Domain) -> ApiProvider> { + ApiProvider::new( + &self.client, + KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()), + ) } } @@ -147,15 +157,48 @@ pub trait WithStorage { fn with_storage<'a, S>( &'a self, storage: &'a S, - domain: KeyDomain, + domain: S::Domain, ) -> ApiProvider> where Self: ApiClient + Sized, S: KeyPoolStorage + Send + Sync + 'static, { - ApiProvider::new(self, KeyPoolExecutor::new(storage, domain)) + ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None)) } } #[cfg(feature = "reqwest")] impl WithStorage for reqwest::Client {} + +#[cfg(all(test, feature = "postgres", feature = "reqwest"))] +mod test { + use tokio::test; + + use super::*; + use crate::postgres::test::{setup, Domain}; + + #[test] + async fn test_pool_request() { + let storage = setup().await; + let pool = KeyPool::new( + reqwest::Client::default(), + storage, + Some("api.rs".to_owned()), + ); + + let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap(); + _ = response.profile().unwrap(); + } + + #[test] + async fn test_with_storage_request() { + let storage = setup().await; + + let response = reqwest::Client::new() + .with_storage(&storage, Domain::All) + .user(|b| b) + .await + .unwrap(); + _ = response.profile().unwrap(); + } +}