fix key selection logic

This commit is contained in:
TotallyNot 2022-09-19 03:22:38 +02:00
parent 3e6bfa8c34
commit 7837a64950
3 changed files with 123 additions and 76 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.4.0" version = "0.4.1"
edition = "2021" edition = "2021"
license = "MIT" license = "MIT"
repository = "https://github.com/TotallyNot/torn-api.rs.git" repository = "https://github.com/TotallyNot/torn-api.rs.git"

View file

@ -81,9 +81,8 @@ where
#[cfg(all(test, feature = "postgres"))] #[cfg(all(test, feature = "postgres"))]
mod test { mod test {
use std::sync::{Arc, Once}; use std::sync::Once;
use sqlx::Row;
use tokio::test; use tokio::test;
use super::*; use super::*;

View file

@ -1,5 +1,4 @@
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc};
use indoc::indoc; use indoc::indoc;
use sqlx::{FromRow, PgPool}; use sqlx::{FromRow, PgPool};
use thiserror::Error; use thiserror::Error;
@ -15,16 +14,29 @@ pub enum PgStorageError {
Unavailable(KeyDomain), Unavailable(KeyDomain),
} }
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow, Eq)]
pub struct PgKey { pub struct PgKey {
pub id: i32, pub id: i32,
pub user_id: i32,
pub faction_id: Option<i32>,
pub key: String, pub key: String,
pub uses: i16, pub uses: i16,
pub user: bool, }
pub faction: bool,
pub last_used: DateTime<Utc>, impl Ord for PgKey {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.uses.cmp(&self.uses)
}
}
impl PartialOrd for PgKey {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for PgKey {
fn eq(&self, other: &Self) -> bool {
self.uses == other.uses
}
} }
impl ApiKey for PgKey { impl ApiKey for PgKey {
@ -86,8 +98,8 @@ impl KeyPoolStorage for PgKeyPoolStorage {
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error> { async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error> {
let predicate = match domain { let predicate = match domain {
KeyDomain::Public => "".to_owned(), KeyDomain::Public => "".to_owned(),
KeyDomain::User(id) => format!("where and user_id={} and user", id), KeyDomain::User(id) => format!(" and user_id={} and user", id),
KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id), KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id),
}; };
loop { loop {
@ -98,16 +110,17 @@ impl KeyPoolStorage for PgKeyPoolStorage {
.execute(&mut tx) .execute(&mut tx)
.await?; .await?;
let key: Option<PgKey> = sqlx::query_as(&indoc::formatdoc!(r#" let key: Option<PgKey> = sqlx::query_as(&indoc::formatdoc!(
r#"
with key as ( with key as (
select select
id, id,
case 0::int2 as uses
when extract(minute from last_used)=extract(minute from now()) then uses from api_keys where last_used < date_trunc('minute', now()){predicate}
else 0::smallint union (
end as uses select id, uses from api_keys where last_used >= date_trunc('minute', now()){predicate} order by uses asc
from api_keys {} )
order by last_used asc limit 1 limit 1
) )
update api_keys set update api_keys set
uses = key.uses + 1, uses = key.uses + 1,
@ -116,15 +129,9 @@ impl KeyPoolStorage for PgKeyPoolStorage {
api_keys.id=key.id and key.uses < $1 api_keys.id=key.id and key.uses < $1
returning returning
api_keys.id, api_keys.id,
api_keys.user_id,
api_keys.faction_id,
api_keys.key, api_keys.key,
api_keys.uses, api_keys.uses
api_keys.user,
api_keys.faction,
api_keys.last_used
"#, "#,
predicate
)) ))
.bind(self.limit) .bind(self.limit)
.fetch_optional(&mut tx) .fetch_optional(&mut tx)
@ -163,61 +170,102 @@ impl KeyPoolStorage for PgKeyPoolStorage {
) -> Result<Vec<Self::Key>, Self::Error> { ) -> Result<Vec<Self::Key>, Self::Error> {
let predicate = match domain { let predicate = match domain {
KeyDomain::Public => "".to_owned(), KeyDomain::Public => "".to_owned(),
KeyDomain::User(id) => format!("where and user_id={} and user", id), KeyDomain::User(id) => format!(" and user_id={} and user", id),
KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id), KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id),
}; };
let mut tx = self.pool.begin().await?; loop {
let attempt = async {
let mut tx = self.pool.begin().await?;
let mut keys: Vec<PgKey> = sqlx::query_as(&indoc::formatdoc!( sqlx::query("set transaction isolation level serializable")
r#" .execute(&mut tx)
select .await?;
id,
user_id,
faction_id,
key,
case
when extract(minute from last_used)=extract(minute from now()) then uses
else 0::smallint
end as uses,
"user",
faction,
last_used
from api_keys {} order by last_used limit $1 for update
"#,
predicate
))
.bind(number)
.fetch_all(&mut tx)
.await?;
let mut result = Vec::with_capacity(number as usize); let mut keys: Vec<PgKey> = sqlx::query_as(&indoc::formatdoc!(
'outer: for _ in 0..(((number as usize) / keys.len()) + 1) { r#"select
for key in &mut keys { id,
if key.uses == self.limit || result.len() == (number as usize) { key,
break 'outer; 0::int2 as uses
} else { from api_keys where last_used < date_trunc('minute', now()){predicate}
key.uses += 1; union
result.push(key.clone()); select
id,
key,
uses
from api_keys where last_used >= date_trunc('minute', now()){predicate}
order by uses limit $1
"#,
))
.bind(number)
.fetch_all(&mut tx)
.await?;
if keys.is_empty() {
tx.commit().await?;
return Ok(Err(PgStorageError::Unavailable(domain)));
}
keys.sort_unstable();
let mut result = Vec::with_capacity(number as usize);
let (max, rest) = keys.split_last_mut().unwrap();
for key in rest {
let available = max.uses - key.uses;
let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
key.uses += using;
result.extend(std::iter::repeat(key.clone()).take(using as usize));
if result.len() == number as usize {
break;
}
}
while result.len() < (number as usize) {
if keys[0].uses == self.limit {
break;
}
let take = std::cmp::min(keys.len(), (number as usize) - result.len());
let slice = &mut keys[0..take];
slice.iter_mut().for_each(|k| k.uses += 1);
result.extend_from_slice(slice);
}
sqlx::query(indoc! {r#"
update api_keys set
uses = tmp.uses,
last_used = now()
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
where api_keys.id = tmp.id
"#})
.bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
.bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
.execute(&mut tx)
.await?;
tx.commit().await?;
Result::<Result<Vec<Self::Key>, Self::Error>, sqlx::Error>::Ok(Ok(result))
}
.await;
match attempt {
Ok(result) => return result,
Err(error) => {
if let Some(db_error) = error.as_database_error() {
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
if pg_error.code() == "40001" {
random_sleep().await;
} else {
return Err(error.into());
}
} else {
return Err(error.into());
}
} }
} }
} }
sqlx::query(indoc! {r#"
update api_keys set
uses = tmp.uses,
last_used = now()
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
where api_keys.id = tmp.id
"#})
.bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
.bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
.execute(&mut tx)
.await?;
tx.commit().await?;
Ok(result)
} }
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> { async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
@ -284,7 +332,7 @@ mod test {
#[test] #[test]
async fn test_concurrent() { async fn test_concurrent() {
let storage = Arc::new(setup().await); let storage = Arc::new(setup().await);
let before: i16 = sqlx::query("select uses from api_keys") let before: i64 = sqlx::query("select sum(uses) as uses from api_keys")
.fetch_one(&storage.pool) .fetch_one(&storage.pool)
.await .await
.unwrap() .unwrap()
@ -297,7 +345,7 @@ mod test {
assert_eq!(keys.len(), 30); assert_eq!(keys.len(), 30);
let after: i16 = sqlx::query("select uses from api_keys") let after: i64 = sqlx::query("select sum(uses) as uses from api_keys")
.fetch_one(&storage.pool) .fetch_one(&storage.pool)
.await .await
.unwrap() .unwrap()