diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index 4048028..462f966 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "0.3.0" +version = "0.3.1" edition = "2021" license = "MIT" repository = "https://github.com/TotallyNot/torn-api.rs.git" @@ -10,10 +10,12 @@ description = "A generalizes API key pool for torn-api" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = [ "postgres" ] +default = [ "postgres", "tokio-runtime" ] postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc" ] reqwest = [ "dep:reqwest", "torn-api/reqwest" ] awc = [ "dep:awc", "torn-api/awc" ] +tokio-runtime = [ "dep:tokio", "dep:rand" ] +actix-runtime = [ "dep:actix-rt", "dep:rand" ] [dependencies] torn-api = { path = "../torn-api", default-features = false, version = "0.4" } @@ -23,6 +25,9 @@ thiserror = "1" sqlx = { version = "0.6", features = [ "postgres", "chrono" ], optional = true } chrono = { version = "0.4", optional = true } indoc = { version = "1", optional = true } +tokio = { version = "1", optional = true, default-features = false, features = ["time"] } +actix-rt = { version = "2", optional = true, default-features = false } +rand = { version = "0.8", optional = true } reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true } awc = { version = "3", default-features = false, optional = true } @@ -35,3 +40,4 @@ tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] } tokio-test = "0.4.2" reqwest = { version = "0.11", default-features = true } awc = { version = "3", features = [ "rustls" ] } +futures = "0.3.24" diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 9d56909..36d80d5 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -63,6 +63,20 @@ impl PgKeyPoolStorage { } } +#[cfg(feature = "tokio-runtime")] +async fn random_sleep() { + use rand::{thread_rng, Rng}; + let dur = tokio::time::Duration::from_millis(thread_rng().gen_range(1..50)); + tokio::time::sleep(dur).await; +} + +#[cfg(all(not(feature = "tokio-runtime"), feature = "actix-runtime"))] +async fn random_sleep() { + use rand::{thread_rng, Rng}; + let dur = std::time::Duration::from_millis(thread_rng().gen_range(1..50)); + actix_rt::time::sleep(dur).await; +} + #[async_trait] impl KeyPoolStorage for PgKeyPoolStorage { type Key = PgKey; @@ -75,46 +89,77 @@ impl KeyPoolStorage for PgKeyPoolStorage { KeyDomain::User(id) => format!("where and user_id={} and user", id), KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id), }; - let key: Option = sqlx::query_as(&indoc::formatdoc!( - r#" - with key as ( - select - id, - user_id, - faction_id, - key, - case - when extract(minute from last_used)=extract(minute from now()) then uses - else 0::smallint - end as uses, - user, - faction, - last_used - from api_keys {} - order by last_used asc limit 1 FOR UPDATE - ) - update api_keys set - uses = key.uses + 1, - last_used = now() - from key where - api_keys.id=key.id and key.uses < $1 - returning - api_keys.id, - api_keys.user_id, - api_keys.faction_id, - api_keys.key, - api_keys.uses, - api_keys.user, - api_keys.faction, - api_keys.last_used - "#, - predicate - )) - .bind(self.limit) - .fetch_optional(&self.pool) - .await?; - key.ok_or(PgStorageError::Unavailable(domain)) + loop { + let attempt = async { + let mut tx = self.pool.begin().await?; + + sqlx::query("set transaction isolation level serializable") + .execute(&mut tx) + .await?; + + let key: Option = sqlx::query_as(&indoc::formatdoc!(r#" + with key as ( + select + id, + user_id, + faction_id, + key, + case + when extract(minute from last_used)=extract(minute from now()) then uses + else 0::smallint + end as uses, + user, + faction, + last_used + from api_keys {} + order by last_used asc limit 1 FOR UPDATE + ) + update api_keys set + uses = key.uses + 1, + last_used = now() + from key where + api_keys.id=key.id and key.uses < $1 + returning + api_keys.id, + api_keys.user_id, + api_keys.faction_id, + api_keys.key, + api_keys.uses, + api_keys.user, + api_keys.faction, + api_keys.last_used + "#, + predicate + )) + .bind(self.limit) + .fetch_optional(&mut tx) + .await?; + + tx.commit().await?; + + Result::, sqlx::Error>::Ok( + key.ok_or(PgStorageError::Unavailable(domain)), + ) + } + .await; + + match attempt { + Ok(result) => return result, + Err(error) => { + if let Some(db_error) = error.as_database_error() { + let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref(); + if pg_error.code() == "40001" { + random_sleep().await; + } else { + return Err(error.into()); + } + } else { + return Err(error.into()); + } + } + } + } } async fn flag_key(&self, key: Self::Key, code: u8) -> Result { @@ -155,8 +200,9 @@ where #[cfg(test)] mod test { - use std::sync::Once; + use std::sync::{Arc, Once}; + use sqlx::Row; use tokio::test; use super::*; @@ -172,7 +218,12 @@ mod test { .await .unwrap(); - PgKeyPoolStorage::new(pool, 3) + sqlx::query("update api_keys set uses=0") + .execute(&pool) + .await + .unwrap(); + + PgKeyPoolStorage::new(pool, 50) } #[test] @@ -192,4 +243,30 @@ mod test { panic!("Acquiring key failed: {:?}", e); } } + + #[test] + async fn test_concurrent() { + let storage = Arc::new(setup().await); + let before: i16 = sqlx::query("select uses from api_keys") + .fetch_one(&storage.pool) + .await + .unwrap() + .get("uses"); + + let futures = (0..30).into_iter().map(|_| { + let storage = storage.clone(); + async move { + storage.acquire_key(KeyDomain::Public).await.unwrap(); + } + }); + futures::future::join_all(futures).await; + + let after: i16 = sqlx::query("select uses from api_keys") + .fetch_one(&storage.pool) + .await + .unwrap() + .get("uses"); + + assert_eq!(after, before + 30); + } }