resolve data races

This commit is contained in:
TotallyNot 2022-09-11 00:47:25 +02:00
parent d888530d24
commit 0115b6e615
2 changed files with 126 additions and 43 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.3.0" version = "0.3.1"
edition = "2021" edition = "2021"
license = "MIT" license = "MIT"
repository = "https://github.com/TotallyNot/torn-api.rs.git" repository = "https://github.com/TotallyNot/torn-api.rs.git"
@ -10,10 +10,12 @@ description = "A generalizes API key pool for torn-api"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features] [features]
default = [ "postgres" ] default = [ "postgres", "tokio-runtime" ]
postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc" ] postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc" ]
reqwest = [ "dep:reqwest", "torn-api/reqwest" ] reqwest = [ "dep:reqwest", "torn-api/reqwest" ]
awc = [ "dep:awc", "torn-api/awc" ] awc = [ "dep:awc", "torn-api/awc" ]
tokio-runtime = [ "dep:tokio", "dep:rand" ]
actix-runtime = [ "dep:actix-rt", "dep:rand" ]
[dependencies] [dependencies]
torn-api = { path = "../torn-api", default-features = false, version = "0.4" } torn-api = { path = "../torn-api", default-features = false, version = "0.4" }
@ -23,6 +25,9 @@ thiserror = "1"
sqlx = { version = "0.6", features = [ "postgres", "chrono" ], optional = true } sqlx = { version = "0.6", features = [ "postgres", "chrono" ], optional = true }
chrono = { version = "0.4", optional = true } chrono = { version = "0.4", optional = true }
indoc = { version = "1", optional = true } indoc = { version = "1", optional = true }
tokio = { version = "1", optional = true, default-features = false, features = ["time"] }
actix-rt = { version = "2", optional = true, default-features = false }
rand = { version = "0.8", optional = true }
reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true } reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true }
awc = { version = "3", default-features = false, optional = true } awc = { version = "3", default-features = false, optional = true }
@ -35,3 +40,4 @@ tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] }
tokio-test = "0.4.2" tokio-test = "0.4.2"
reqwest = { version = "0.11", default-features = true } reqwest = { version = "0.11", default-features = true }
awc = { version = "3", features = [ "rustls" ] } awc = { version = "3", features = [ "rustls" ] }
futures = "0.3.24"

View file

@ -63,6 +63,20 @@ impl PgKeyPoolStorage {
} }
} }
#[cfg(feature = "tokio-runtime")]
async fn random_sleep() {
use rand::{thread_rng, Rng};
let dur = tokio::time::Duration::from_millis(thread_rng().gen_range(1..50));
tokio::time::sleep(dur).await;
}
#[cfg(all(not(feature = "tokio-runtime"), feature = "actix-runtime"))]
async fn random_sleep() {
use rand::{thread_rng, Rng};
let dur = std::time::Duration::from_millis(thread_rng().gen_range(1..50));
actix_rt::time::sleep(dur).await;
}
#[async_trait] #[async_trait]
impl KeyPoolStorage for PgKeyPoolStorage { impl KeyPoolStorage for PgKeyPoolStorage {
type Key = PgKey; type Key = PgKey;
@ -75,46 +89,77 @@ impl KeyPoolStorage for PgKeyPoolStorage {
KeyDomain::User(id) => format!("where and user_id={} and user", id), KeyDomain::User(id) => format!("where and user_id={} and user", id),
KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id), KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id),
}; };
let key: Option<PgKey> = sqlx::query_as(&indoc::formatdoc!(
r#"
with key as (
select
id,
user_id,
faction_id,
key,
case
when extract(minute from last_used)=extract(minute from now()) then uses
else 0::smallint
end as uses,
user,
faction,
last_used
from api_keys {}
order by last_used asc limit 1 FOR UPDATE
)
update api_keys set
uses = key.uses + 1,
last_used = now()
from key where
api_keys.id=key.id and key.uses < $1
returning
api_keys.id,
api_keys.user_id,
api_keys.faction_id,
api_keys.key,
api_keys.uses,
api_keys.user,
api_keys.faction,
api_keys.last_used
"#,
predicate
))
.bind(self.limit)
.fetch_optional(&self.pool)
.await?;
key.ok_or(PgStorageError::Unavailable(domain)) loop {
let attempt = async {
let mut tx = self.pool.begin().await?;
sqlx::query("set transaction isolation level serializable")
.execute(&mut tx)
.await?;
let key: Option<PgKey> = sqlx::query_as(&indoc::formatdoc!(r#"
with key as (
select
id,
user_id,
faction_id,
key,
case
when extract(minute from last_used)=extract(minute from now()) then uses
else 0::smallint
end as uses,
user,
faction,
last_used
from api_keys {}
order by last_used asc limit 1 FOR UPDATE
)
update api_keys set
uses = key.uses + 1,
last_used = now()
from key where
api_keys.id=key.id and key.uses < $1
returning
api_keys.id,
api_keys.user_id,
api_keys.faction_id,
api_keys.key,
api_keys.uses,
api_keys.user,
api_keys.faction,
api_keys.last_used
"#,
predicate
))
.bind(self.limit)
.fetch_optional(&mut tx)
.await?;
tx.commit().await?;
Result::<Result<Self::Key, Self::Error>, sqlx::Error>::Ok(
key.ok_or(PgStorageError::Unavailable(domain)),
)
}
.await;
match attempt {
Ok(result) => return result,
Err(error) => {
if let Some(db_error) = error.as_database_error() {
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
if pg_error.code() == "40001" {
random_sleep().await;
} else {
return Err(error.into());
}
} else {
return Err(error.into());
}
}
}
}
} }
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> { async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
@ -155,8 +200,9 @@ where
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::sync::Once; use std::sync::{Arc, Once};
use sqlx::Row;
use tokio::test; use tokio::test;
use super::*; use super::*;
@ -172,7 +218,12 @@ mod test {
.await .await
.unwrap(); .unwrap();
PgKeyPoolStorage::new(pool, 3) sqlx::query("update api_keys set uses=0")
.execute(&pool)
.await
.unwrap();
PgKeyPoolStorage::new(pool, 50)
} }
#[test] #[test]
@ -192,4 +243,30 @@ mod test {
panic!("Acquiring key failed: {:?}", e); panic!("Acquiring key failed: {:?}", e);
} }
} }
#[test]
async fn test_concurrent() {
let storage = Arc::new(setup().await);
let before: i16 = sqlx::query("select uses from api_keys")
.fetch_one(&storage.pool)
.await
.unwrap()
.get("uses");
let futures = (0..30).into_iter().map(|_| {
let storage = storage.clone();
async move {
storage.acquire_key(KeyDomain::Public).await.unwrap();
}
});
futures::future::join_all(futures).await;
let after: i16 = sqlx::query("select uses from api_keys")
.fetch_one(&storage.pool)
.await
.unwrap()
.get("uses");
assert_eq!(after, before + 30);
}
} }