bulk updates

This commit is contained in:
TotallyNot 2022-09-18 23:48:36 +02:00
parent 0115b6e615
commit a7c640511c
16 changed files with 1091 additions and 443 deletions

View file

@ -1,6 +1,6 @@
[package]
name = "torn-key-pool"
version = "0.3.1"
version = "0.4.0"
edition = "2021"
license = "MIT"
repository = "https://github.com/TotallyNot/torn-api.rs.git"
@ -18,7 +18,7 @@ tokio-runtime = [ "dep:tokio", "dep:rand" ]
actix-runtime = [ "dep:actix-rt", "dep:rand" ]
[dependencies]
torn-api = { path = "../torn-api", default-features = false, version = "0.4" }
torn-api = { path = "../torn-api", default-features = false, version = "0.5" }
async-trait = "0.1"
thiserror = "1"
@ -28,6 +28,7 @@ indoc = { version = "1", optional = true }
tokio = { version = "1", optional = true, default-features = false, features = ["time"] }
actix-rt = { version = "2", optional = true, default-features = false }
rand = { version = "0.8", optional = true }
futures = "0.3"
reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true }
awc = { version = "3", default-features = false, optional = true }
@ -40,4 +41,3 @@ tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] }
tokio-test = "0.4.2"
reqwest = { version = "0.11", default-features = true }
awc = { version = "3", features = [ "rustls" ] }
futures = "0.3.24"

View file

@ -3,13 +3,15 @@
#[cfg(feature = "postgres")]
pub mod postgres;
pub mod local;
pub mod send;
use std::sync::Arc;
use async_trait::async_trait;
use thiserror::Error;
use torn_api::{
ApiCategoryResponse, ApiClient, ApiProvider, ApiRequest, ApiResponse, RequestExecutor,
ResponseError, ThreadSafeApiClient, ThreadSafeApiProvider, ThreadSafeRequestExecutor,
};
use torn_api::ResponseError;
#[derive(Debug, Error)]
pub enum KeyPoolError<S, C>
@ -18,7 +20,7 @@ where
C: std::error::Error,
{
#[error("Key pool storage driver error: {0:?}")]
Storage(#[source] S),
Storage(#[source] Arc<S>),
#[error(transparent)]
Client(#[from] C),
@ -45,6 +47,12 @@ pub trait KeyPoolStorage {
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error>;
async fn acquire_many_keys(
&self,
domain: KeyDomain,
number: i64,
) -> Result<Vec<Self::Key>, Self::Error>;
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
}
@ -70,161 +78,3 @@ where
}
}
}
#[async_trait(?Send)]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ApiClient,
S: KeyPoolStorage + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(&self, client: &C, request: ApiRequest<A>) -> Result<A, Self::Error>
where
A: ApiCategoryResponse,
{
loop {
let key = self
.storage
.acquire_key(self.domain)
.await
.map_err(KeyPoolError::Storage)?;
let url = request.url(key.value());
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(A::from_response(res)),
};
}
}
}
#[async_trait]
impl<'client, C, S> ThreadSafeRequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ThreadSafeApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(&self, client: &C, request: ApiRequest<A>) -> Result<A, Self::Error>
where
A: ApiCategoryResponse,
{
loop {
let key = self
.storage
.acquire_key(self.domain)
.await
.map_err(KeyPoolError::Storage)?;
let url = request.url(key.value());
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(A::from_response(res)),
};
}
}
}
#[derive(Clone, Debug)]
pub struct KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: C,
storage: S,
}
impl<C, S> KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage + 'static,
{
pub fn new(client: C, storage: S) -> Self {
Self { client, storage }
}
pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain))
}
}
#[derive(Clone, Debug)]
pub struct ThreadSafeKeyPool<C, S>
where
C: ThreadSafeApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
client: C,
storage: S,
}
impl<C, S> ThreadSafeKeyPool<C, S>
where
C: ThreadSafeApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn new(client: C, storage: S) -> Self {
Self { client, storage }
}
pub fn torn_api(&self, domain: KeyDomain) -> ThreadSafeApiProvider<C, KeyPoolExecutor<C, S>> {
ThreadSafeApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain))
}
}
pub trait WithStorage {
fn with_storage<'a, S>(
&'a self,
storage: &'a S,
domain: KeyDomain,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ApiClient + Sized,
S: KeyPoolStorage + 'static,
{
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain))
}
fn with_storage_sync<'a, S>(
&'a self,
storage: &'a S,
domain: KeyDomain,
) -> ThreadSafeApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ThreadSafeApiClient + Sized,
S: KeyPoolStorage + Send + Sync + 'static,
{
ThreadSafeApiProvider::new(self, KeyPoolExecutor::new(storage, domain))
}
}
#[cfg(feature = "reqwest")]
impl WithStorage for reqwest::Client {}
#[cfg(feature = "awc")]
impl WithStorage for awc::Client {}

161
torn-key-pool/src/local.rs Normal file
View file

@ -0,0 +1,161 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use torn_api::{
local::{ApiClient, ApiProvider, RequestExecutor},
ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError,
};
use crate::{ApiKey, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
#[async_trait(?Send)]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ApiClient,
S: KeyPoolStorage + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(
&self,
client: &C,
request: ApiRequest<A>,
id: Option<i64>,
) -> Result<A, Self::Error>
where
A: ApiCategoryResponse,
{
loop {
let key = self
.storage
.acquire_key(self.domain)
.await
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
let url = request.url(key.value(), id);
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(Arc::new)
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(A::from_response(res)),
};
}
}
async fn execute_many<A>(
&self,
client: &C,
request: ApiRequest<A>,
ids: Vec<i64>,
) -> HashMap<i64, Result<A, Self::Error>>
where
A: ApiCategoryResponse,
{
let keys = match self
.storage
.acquire_many_keys(self.domain, ids.len() as i64)
.await
{
Ok(keys) => keys,
Err(why) => {
let shared = Arc::new(why);
return ids
.into_iter()
.map(|i| (i, Err(Self::Error::Storage(shared.clone()))))
.collect();
}
};
let request_ref = &request;
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
loop {
let url = request_ref.url(key.value(), Some(id));
let value = match client.request(url).await {
Ok(v) => v,
Err(why) => return (id, Err(Self::Error::Client(why))),
};
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
match self.storage.flag_key(key, code).await {
Ok(false) => {
return (
id,
Err(KeyPoolError::Response(ResponseError::Api {
code,
reason,
})),
)
}
Ok(true) => (),
Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))),
}
}
Err(parsing_error) => return (id, Err(KeyPoolError::Response(parsing_error))),
Ok(res) => return (id, Ok(A::from_response(res))),
};
key = match self.storage.acquire_key(self.domain).await {
Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
};
}
}))
.await
.into_iter()
.collect()
}
}
#[derive(Clone, Debug)]
pub struct KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: C,
storage: S,
}
impl<C, S> KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage + 'static,
{
pub fn new(client: C, storage: S) -> Self {
Self { client, storage }
}
pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain))
}
}
pub trait WithStorage {
fn with_storage<'a, S>(
&'a self,
storage: &'a S,
domain: KeyDomain,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ApiClient + Sized,
S: KeyPoolStorage + 'static,
{
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain))
}
}
#[cfg(feature = "awc")]
impl WithStorage for awc::Client {}

View file

@ -4,7 +4,7 @@ use indoc::indoc;
use sqlx::{FromRow, PgPool};
use thiserror::Error;
use crate::{ApiKey, KeyDomain, KeyPool, KeyPoolStorage};
use crate::{ApiKey, KeyDomain, KeyPoolStorage};
#[derive(Debug, Error)]
pub enum PgStorageError {
@ -102,18 +102,12 @@ impl KeyPoolStorage for PgKeyPoolStorage {
with key as (
select
id,
user_id,
faction_id,
key,
case
when extract(minute from last_used)=extract(minute from now()) then uses
else 0::smallint
end as uses,
user,
faction,
last_used
end as uses
from api_keys {}
order by last_used asc limit 1 FOR UPDATE
order by last_used asc limit 1
)
update api_keys set
uses = key.uses + 1,
@ -162,6 +156,70 @@ impl KeyPoolStorage for PgKeyPoolStorage {
}
}
async fn acquire_many_keys(
&self,
domain: KeyDomain,
number: i64,
) -> Result<Vec<Self::Key>, Self::Error> {
let predicate = match domain {
KeyDomain::Public => "".to_owned(),
KeyDomain::User(id) => format!("where and user_id={} and user", id),
KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id),
};
let mut tx = self.pool.begin().await?;
let mut keys: Vec<PgKey> = sqlx::query_as(&indoc::formatdoc!(
r#"
select
id,
user_id,
faction_id,
key,
case
when extract(minute from last_used)=extract(minute from now()) then uses
else 0::smallint
end as uses,
"user",
faction,
last_used
from api_keys {} order by last_used limit $1 for update
"#,
predicate
))
.bind(number)
.fetch_all(&mut tx)
.await?;
let mut result = Vec::with_capacity(number as usize);
'outer: for _ in 0..(((number as usize) / keys.len()) + 1) {
for key in &mut keys {
if key.uses == self.limit || result.len() == (number as usize) {
break 'outer;
} else {
key.uses += 1;
result.push(key.clone());
}
}
}
sqlx::query(indoc! {r#"
update api_keys set
uses = tmp.uses,
last_used = now()
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
where api_keys.id = tmp.id
"#})
.bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
.bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(result)
}
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
// TODO: put keys in cooldown when appropriate
match code {
@ -177,27 +235,6 @@ impl KeyPoolStorage for PgKeyPoolStorage {
}
}
pub type PgKeyPool<A> = KeyPool<A, PgKeyPoolStorage>;
impl<A> PgKeyPool<A>
where
A: torn_api::ApiClient,
{
pub async fn connect(
client: A,
database_url: &str,
limit: i16,
) -> Result<Self, PgStorageError> {
let db_pool = PgPool::connect(database_url).await?;
let storage = PgKeyPoolStorage::new(db_pool, limit);
storage.initialise().await?;
let key_pool = Self::new(client, storage);
Ok(key_pool)
}
}
#[cfg(test)]
mod test {
use std::sync::{Arc, Once};
@ -253,13 +290,12 @@ mod test {
.unwrap()
.get("uses");
let futures = (0..30).into_iter().map(|_| {
let storage = storage.clone();
async move {
storage.acquire_key(KeyDomain::Public).await.unwrap();
}
});
futures::future::join_all(futures).await;
let keys = storage
.acquire_many_keys(KeyDomain::Public, 30)
.await
.unwrap();
assert_eq!(keys.len(), 30);
let after: i16 = sqlx::query("select uses from api_keys")
.fetch_one(&storage.pool)

161
torn-key-pool/src/send.rs Normal file
View file

@ -0,0 +1,161 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use torn_api::{
send::{ApiClient, ApiProvider, RequestExecutor},
ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError,
};
use crate::{ApiKey, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
#[async_trait]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(
&self,
client: &C,
request: ApiRequest<A>,
id: Option<i64>,
) -> Result<A, Self::Error>
where
A: ApiCategoryResponse,
{
loop {
let key = self
.storage
.acquire_key(self.domain)
.await
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
let url = request.url(key.value(), id);
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(Arc::new)
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(A::from_response(res)),
};
}
}
async fn execute_many<A>(
&self,
client: &C,
request: ApiRequest<A>,
ids: Vec<i64>,
) -> HashMap<i64, Result<A, Self::Error>>
where
A: ApiCategoryResponse,
{
let keys = match self
.storage
.acquire_many_keys(self.domain, ids.len() as i64)
.await
{
Ok(keys) => keys,
Err(why) => {
let shared = Arc::new(why);
return ids
.into_iter()
.map(|i| (i, Err(Self::Error::Storage(shared.clone()))))
.collect();
}
};
let request_ref = &request;
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
loop {
let url = request_ref.url(key.value(), Some(id));
let value = match client.request(url).await {
Ok(v) => v,
Err(why) => return (id, Err(Self::Error::Client(why))),
};
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
match self.storage.flag_key(key, code).await {
Ok(false) => {
return (
id,
Err(KeyPoolError::Response(ResponseError::Api {
code,
reason,
})),
)
}
Ok(true) => (),
Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))),
}
}
Err(parsing_error) => return (id, Err(KeyPoolError::Response(parsing_error))),
Ok(res) => return (id, Ok(A::from_response(res))),
};
key = match self.storage.acquire_key(self.domain).await {
Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
};
}
}))
.await
.into_iter()
.collect()
}
}
#[derive(Clone, Debug)]
pub struct KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: C,
storage: S,
}
impl<C, S> KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn new(client: C, storage: S) -> Self {
Self { client, storage }
}
pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain))
}
}
pub trait WithStorage {
fn with_storage<'a, S>(
&'a self,
storage: &'a S,
domain: KeyDomain,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ApiClient + Sized,
S: KeyPoolStorage + Send + Sync + 'static,
{
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain))
}
}
#[cfg(feature = "reqwest")]
impl WithStorage for reqwest::Client {}