From 7837a649500831ddaa39cb59a59c625bb1810ff6 Mon Sep 17 00:00:00 2001 From: TotallyNot <44345987+TotallyNot@users.noreply.github.com> Date: Mon, 19 Sep 2022 03:22:38 +0200 Subject: [PATCH] fix key selection logic --- torn-key-pool/Cargo.toml | 2 +- torn-key-pool/src/lib.rs | 3 +- torn-key-pool/src/postgres.rs | 194 +++++++++++++++++++++------------- 3 files changed, 123 insertions(+), 76 deletions(-) diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index 6129dff..00af3f0 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "0.4.0" +version = "0.4.1" edition = "2021" license = "MIT" repository = "https://github.com/TotallyNot/torn-api.rs.git" diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index 9c81e13..624937c 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -81,9 +81,8 @@ where #[cfg(all(test, feature = "postgres"))] mod test { - use std::sync::{Arc, Once}; + use std::sync::Once; - use sqlx::Row; use tokio::test; use super::*; diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index e644c5d..e850f6b 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use chrono::{DateTime, Utc}; use indoc::indoc; use sqlx::{FromRow, PgPool}; use thiserror::Error; @@ -15,16 +14,29 @@ pub enum PgStorageError { Unavailable(KeyDomain), } -#[derive(Debug, Clone, FromRow)] +#[derive(Debug, Clone, FromRow, Eq)] pub struct PgKey { pub id: i32, - pub user_id: i32, - pub faction_id: Option, pub key: String, pub uses: i16, - pub user: bool, - pub faction: bool, - pub last_used: DateTime, +} + +impl Ord for PgKey { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + other.uses.cmp(&self.uses) + } +} + +impl PartialOrd for PgKey { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for PgKey { + fn eq(&self, other: &Self) -> bool { + self.uses == other.uses + } } impl ApiKey for PgKey { @@ -86,8 +98,8 @@ impl KeyPoolStorage for PgKeyPoolStorage { async fn acquire_key(&self, domain: KeyDomain) -> Result { let predicate = match domain { KeyDomain::Public => "".to_owned(), - KeyDomain::User(id) => format!("where and user_id={} and user", id), - KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id), + KeyDomain::User(id) => format!(" and user_id={} and user", id), + KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id), }; loop { @@ -98,16 +110,17 @@ impl KeyPoolStorage for PgKeyPoolStorage { .execute(&mut tx) .await?; - let key: Option = sqlx::query_as(&indoc::formatdoc!(r#" + let key: Option = sqlx::query_as(&indoc::formatdoc!( + r#" with key as ( select id, - case - when extract(minute from last_used)=extract(minute from now()) then uses - else 0::smallint - end as uses - from api_keys {} - order by last_used asc limit 1 + 0::int2 as uses + from api_keys where last_used < date_trunc('minute', now()){predicate} + union ( + select id, uses from api_keys where last_used >= date_trunc('minute', now()){predicate} order by uses asc + ) + limit 1 ) update api_keys set uses = key.uses + 1, @@ -116,15 +129,9 @@ impl KeyPoolStorage for PgKeyPoolStorage { api_keys.id=key.id and key.uses < $1 returning api_keys.id, - api_keys.user_id, - api_keys.faction_id, api_keys.key, - api_keys.uses, - api_keys.user, - api_keys.faction, - api_keys.last_used + api_keys.uses "#, - predicate )) .bind(self.limit) .fetch_optional(&mut tx) @@ -163,61 +170,102 @@ impl KeyPoolStorage for PgKeyPoolStorage { ) -> Result, Self::Error> { let predicate = match domain { KeyDomain::Public => "".to_owned(), - KeyDomain::User(id) => format!("where and user_id={} and user", id), - KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id), + KeyDomain::User(id) => format!(" and user_id={} and user", id), + KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id), }; - let mut tx = self.pool.begin().await?; + loop { + let attempt = async { + let mut tx = self.pool.begin().await?; - let mut keys: Vec = sqlx::query_as(&indoc::formatdoc!( - r#" - select - id, - user_id, - faction_id, - key, - case - when extract(minute from last_used)=extract(minute from now()) then uses - else 0::smallint - end as uses, - "user", - faction, - last_used - from api_keys {} order by last_used limit $1 for update - "#, - predicate - )) - .bind(number) - .fetch_all(&mut tx) - .await?; + sqlx::query("set transaction isolation level serializable") + .execute(&mut tx) + .await?; - let mut result = Vec::with_capacity(number as usize); - 'outer: for _ in 0..(((number as usize) / keys.len()) + 1) { - for key in &mut keys { - if key.uses == self.limit || result.len() == (number as usize) { - break 'outer; - } else { - key.uses += 1; - result.push(key.clone()); + let mut keys: Vec = sqlx::query_as(&indoc::formatdoc!( + r#"select + id, + key, + 0::int2 as uses + from api_keys where last_used < date_trunc('minute', now()){predicate} + union + select + id, + key, + uses + from api_keys where last_used >= date_trunc('minute', now()){predicate} + order by uses limit $1 + "#, + )) + .bind(number) + .fetch_all(&mut tx) + .await?; + + if keys.is_empty() { + tx.commit().await?; + return Ok(Err(PgStorageError::Unavailable(domain))); + } + + keys.sort_unstable(); + + let mut result = Vec::with_capacity(number as usize); + let (max, rest) = keys.split_last_mut().unwrap(); + for key in rest { + let available = max.uses - key.uses; + let using = std::cmp::min(available, (number as i16) - (result.len() as i16)); + key.uses += using; + result.extend(std::iter::repeat(key.clone()).take(using as usize)); + + if result.len() == number as usize { + break; + } + } + + while result.len() < (number as usize) { + if keys[0].uses == self.limit { + break; + } + + let take = std::cmp::min(keys.len(), (number as usize) - result.len()); + let slice = &mut keys[0..take]; + slice.iter_mut().for_each(|k| k.uses += 1); + result.extend_from_slice(slice); + } + + sqlx::query(indoc! {r#" + update api_keys set + uses = tmp.uses, + last_used = now() + from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp + where api_keys.id = tmp.id + "#}) + .bind(keys.iter().map(|k| k.id).collect::>()) + .bind(keys.iter().map(|k| k.uses).collect::>()) + .execute(&mut tx) + .await?; + + tx.commit().await?; + + Result::, Self::Error>, sqlx::Error>::Ok(Ok(result)) + } + .await; + + match attempt { + Ok(result) => return result, + Err(error) => { + if let Some(db_error) = error.as_database_error() { + let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref(); + if pg_error.code() == "40001" { + random_sleep().await; + } else { + return Err(error.into()); + } + } else { + return Err(error.into()); + } } } } - - sqlx::query(indoc! {r#" - update api_keys set - uses = tmp.uses, - last_used = now() - from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp - where api_keys.id = tmp.id - "#}) - .bind(keys.iter().map(|k| k.id).collect::>()) - .bind(keys.iter().map(|k| k.uses).collect::>()) - .execute(&mut tx) - .await?; - - tx.commit().await?; - - Ok(result) } async fn flag_key(&self, key: Self::Key, code: u8) -> Result { @@ -284,7 +332,7 @@ mod test { #[test] async fn test_concurrent() { let storage = Arc::new(setup().await); - let before: i16 = sqlx::query("select uses from api_keys") + let before: i64 = sqlx::query("select sum(uses) as uses from api_keys") .fetch_one(&storage.pool) .await .unwrap() @@ -297,7 +345,7 @@ mod test { assert_eq!(keys.len(), 30); - let after: i16 = sqlx::query("select uses from api_keys") + let after: i64 = sqlx::query("select sum(uses) as uses from api_keys") .fetch_one(&storage.pool) .await .unwrap()