fix key selection logic
This commit is contained in:
parent
3e6bfa8c34
commit
7837a64950
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "torn-key-pool"
|
name = "torn-key-pool"
|
||||||
version = "0.4.0"
|
version = "0.4.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
repository = "https://github.com/TotallyNot/torn-api.rs.git"
|
repository = "https://github.com/TotallyNot/torn-api.rs.git"
|
||||||
|
|
|
@ -81,9 +81,8 @@ where
|
||||||
|
|
||||||
#[cfg(all(test, feature = "postgres"))]
|
#[cfg(all(test, feature = "postgres"))]
|
||||||
mod test {
|
mod test {
|
||||||
use std::sync::{Arc, Once};
|
use std::sync::Once;
|
||||||
|
|
||||||
use sqlx::Row;
|
|
||||||
use tokio::test;
|
use tokio::test;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::{DateTime, Utc};
|
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use sqlx::{FromRow, PgPool};
|
use sqlx::{FromRow, PgPool};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -15,16 +14,29 @@ pub enum PgStorageError {
|
||||||
Unavailable(KeyDomain),
|
Unavailable(KeyDomain),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, FromRow)]
|
#[derive(Debug, Clone, FromRow, Eq)]
|
||||||
pub struct PgKey {
|
pub struct PgKey {
|
||||||
pub id: i32,
|
pub id: i32,
|
||||||
pub user_id: i32,
|
|
||||||
pub faction_id: Option<i32>,
|
|
||||||
pub key: String,
|
pub key: String,
|
||||||
pub uses: i16,
|
pub uses: i16,
|
||||||
pub user: bool,
|
}
|
||||||
pub faction: bool,
|
|
||||||
pub last_used: DateTime<Utc>,
|
impl Ord for PgKey {
|
||||||
|
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||||
|
other.uses.cmp(&self.uses)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd for PgKey {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||||
|
Some(self.cmp(other))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for PgKey {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.uses == other.uses
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ApiKey for PgKey {
|
impl ApiKey for PgKey {
|
||||||
|
@ -86,8 +98,8 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error> {
|
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Error> {
|
||||||
let predicate = match domain {
|
let predicate = match domain {
|
||||||
KeyDomain::Public => "".to_owned(),
|
KeyDomain::Public => "".to_owned(),
|
||||||
KeyDomain::User(id) => format!("where and user_id={} and user", id),
|
KeyDomain::User(id) => format!(" and user_id={} and user", id),
|
||||||
KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id),
|
KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id),
|
||||||
};
|
};
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
@ -98,16 +110,17 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
.execute(&mut tx)
|
.execute(&mut tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let key: Option<PgKey> = sqlx::query_as(&indoc::formatdoc!(r#"
|
let key: Option<PgKey> = sqlx::query_as(&indoc::formatdoc!(
|
||||||
|
r#"
|
||||||
with key as (
|
with key as (
|
||||||
select
|
select
|
||||||
id,
|
id,
|
||||||
case
|
0::int2 as uses
|
||||||
when extract(minute from last_used)=extract(minute from now()) then uses
|
from api_keys where last_used < date_trunc('minute', now()){predicate}
|
||||||
else 0::smallint
|
union (
|
||||||
end as uses
|
select id, uses from api_keys where last_used >= date_trunc('minute', now()){predicate} order by uses asc
|
||||||
from api_keys {}
|
)
|
||||||
order by last_used asc limit 1
|
limit 1
|
||||||
)
|
)
|
||||||
update api_keys set
|
update api_keys set
|
||||||
uses = key.uses + 1,
|
uses = key.uses + 1,
|
||||||
|
@ -116,15 +129,9 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
api_keys.id=key.id and key.uses < $1
|
api_keys.id=key.id and key.uses < $1
|
||||||
returning
|
returning
|
||||||
api_keys.id,
|
api_keys.id,
|
||||||
api_keys.user_id,
|
|
||||||
api_keys.faction_id,
|
|
||||||
api_keys.key,
|
api_keys.key,
|
||||||
api_keys.uses,
|
api_keys.uses
|
||||||
api_keys.user,
|
|
||||||
api_keys.faction,
|
|
||||||
api_keys.last_used
|
|
||||||
"#,
|
"#,
|
||||||
predicate
|
|
||||||
))
|
))
|
||||||
.bind(self.limit)
|
.bind(self.limit)
|
||||||
.fetch_optional(&mut tx)
|
.fetch_optional(&mut tx)
|
||||||
|
@ -163,44 +170,66 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
) -> Result<Vec<Self::Key>, Self::Error> {
|
) -> Result<Vec<Self::Key>, Self::Error> {
|
||||||
let predicate = match domain {
|
let predicate = match domain {
|
||||||
KeyDomain::Public => "".to_owned(),
|
KeyDomain::Public => "".to_owned(),
|
||||||
KeyDomain::User(id) => format!("where and user_id={} and user", id),
|
KeyDomain::User(id) => format!(" and user_id={} and user", id),
|
||||||
KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id),
|
KeyDomain::Faction(id) => format!(" and faction_id={} and faction", id),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let attempt = async {
|
||||||
let mut tx = self.pool.begin().await?;
|
let mut tx = self.pool.begin().await?;
|
||||||
|
|
||||||
|
sqlx::query("set transaction isolation level serializable")
|
||||||
|
.execute(&mut tx)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let mut keys: Vec<PgKey> = sqlx::query_as(&indoc::formatdoc!(
|
let mut keys: Vec<PgKey> = sqlx::query_as(&indoc::formatdoc!(
|
||||||
r#"
|
r#"select
|
||||||
|
id,
|
||||||
|
key,
|
||||||
|
0::int2 as uses
|
||||||
|
from api_keys where last_used < date_trunc('minute', now()){predicate}
|
||||||
|
union
|
||||||
select
|
select
|
||||||
id,
|
id,
|
||||||
user_id,
|
|
||||||
faction_id,
|
|
||||||
key,
|
key,
|
||||||
case
|
uses
|
||||||
when extract(minute from last_used)=extract(minute from now()) then uses
|
from api_keys where last_used >= date_trunc('minute', now()){predicate}
|
||||||
else 0::smallint
|
order by uses limit $1
|
||||||
end as uses,
|
|
||||||
"user",
|
|
||||||
faction,
|
|
||||||
last_used
|
|
||||||
from api_keys {} order by last_used limit $1 for update
|
|
||||||
"#,
|
"#,
|
||||||
predicate
|
|
||||||
))
|
))
|
||||||
.bind(number)
|
.bind(number)
|
||||||
.fetch_all(&mut tx)
|
.fetch_all(&mut tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
if keys.is_empty() {
|
||||||
|
tx.commit().await?;
|
||||||
|
return Ok(Err(PgStorageError::Unavailable(domain)));
|
||||||
|
}
|
||||||
|
|
||||||
|
keys.sort_unstable();
|
||||||
|
|
||||||
let mut result = Vec::with_capacity(number as usize);
|
let mut result = Vec::with_capacity(number as usize);
|
||||||
'outer: for _ in 0..(((number as usize) / keys.len()) + 1) {
|
let (max, rest) = keys.split_last_mut().unwrap();
|
||||||
for key in &mut keys {
|
for key in rest {
|
||||||
if key.uses == self.limit || result.len() == (number as usize) {
|
let available = max.uses - key.uses;
|
||||||
break 'outer;
|
let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
|
||||||
} else {
|
key.uses += using;
|
||||||
key.uses += 1;
|
result.extend(std::iter::repeat(key.clone()).take(using as usize));
|
||||||
result.push(key.clone());
|
|
||||||
|
if result.len() == number as usize {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
while result.len() < (number as usize) {
|
||||||
|
if keys[0].uses == self.limit {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let take = std::cmp::min(keys.len(), (number as usize) - result.len());
|
||||||
|
let slice = &mut keys[0..take];
|
||||||
|
slice.iter_mut().for_each(|k| k.uses += 1);
|
||||||
|
result.extend_from_slice(slice);
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlx::query(indoc! {r#"
|
sqlx::query(indoc! {r#"
|
||||||
|
@ -217,7 +246,26 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
|
|
||||||
tx.commit().await?;
|
tx.commit().await?;
|
||||||
|
|
||||||
Ok(result)
|
Result::<Result<Vec<Self::Key>, Self::Error>, sqlx::Error>::Ok(Ok(result))
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match attempt {
|
||||||
|
Ok(result) => return result,
|
||||||
|
Err(error) => {
|
||||||
|
if let Some(db_error) = error.as_database_error() {
|
||||||
|
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
|
||||||
|
if pg_error.code() == "40001" {
|
||||||
|
random_sleep().await;
|
||||||
|
} else {
|
||||||
|
return Err(error.into());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(error.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
|
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
|
||||||
|
@ -284,7 +332,7 @@ mod test {
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_concurrent() {
|
async fn test_concurrent() {
|
||||||
let storage = Arc::new(setup().await);
|
let storage = Arc::new(setup().await);
|
||||||
let before: i16 = sqlx::query("select uses from api_keys")
|
let before: i64 = sqlx::query("select sum(uses) as uses from api_keys")
|
||||||
.fetch_one(&storage.pool)
|
.fetch_one(&storage.pool)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -297,7 +345,7 @@ mod test {
|
||||||
|
|
||||||
assert_eq!(keys.len(), 30);
|
assert_eq!(keys.len(), 30);
|
||||||
|
|
||||||
let after: i16 = sqlx::query("select uses from api_keys")
|
let after: i64 = sqlx::query("select sum(uses) as uses from api_keys")
|
||||||
.fetch_one(&storage.pool)
|
.fetch_one(&storage.pool)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|
Loading…
Reference in a new issue