resolve data races
This commit is contained in:
parent
d888530d24
commit
0115b6e615
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "torn-key-pool"
|
name = "torn-key-pool"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
repository = "https://github.com/TotallyNot/torn-api.rs.git"
|
repository = "https://github.com/TotallyNot/torn-api.rs.git"
|
||||||
|
@ -10,10 +10,12 @@ description = "A generalizes API key pool for torn-api"
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = [ "postgres" ]
|
default = [ "postgres", "tokio-runtime" ]
|
||||||
postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc" ]
|
postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc" ]
|
||||||
reqwest = [ "dep:reqwest", "torn-api/reqwest" ]
|
reqwest = [ "dep:reqwest", "torn-api/reqwest" ]
|
||||||
awc = [ "dep:awc", "torn-api/awc" ]
|
awc = [ "dep:awc", "torn-api/awc" ]
|
||||||
|
tokio-runtime = [ "dep:tokio", "dep:rand" ]
|
||||||
|
actix-runtime = [ "dep:actix-rt", "dep:rand" ]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
torn-api = { path = "../torn-api", default-features = false, version = "0.4" }
|
torn-api = { path = "../torn-api", default-features = false, version = "0.4" }
|
||||||
|
@ -23,6 +25,9 @@ thiserror = "1"
|
||||||
sqlx = { version = "0.6", features = [ "postgres", "chrono" ], optional = true }
|
sqlx = { version = "0.6", features = [ "postgres", "chrono" ], optional = true }
|
||||||
chrono = { version = "0.4", optional = true }
|
chrono = { version = "0.4", optional = true }
|
||||||
indoc = { version = "1", optional = true }
|
indoc = { version = "1", optional = true }
|
||||||
|
tokio = { version = "1", optional = true, default-features = false, features = ["time"] }
|
||||||
|
actix-rt = { version = "2", optional = true, default-features = false }
|
||||||
|
rand = { version = "0.8", optional = true }
|
||||||
|
|
||||||
reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true }
|
reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true }
|
||||||
awc = { version = "3", default-features = false, optional = true }
|
awc = { version = "3", default-features = false, optional = true }
|
||||||
|
@ -35,3 +40,4 @@ tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] }
|
||||||
tokio-test = "0.4.2"
|
tokio-test = "0.4.2"
|
||||||
reqwest = { version = "0.11", default-features = true }
|
reqwest = { version = "0.11", default-features = true }
|
||||||
awc = { version = "3", features = [ "rustls" ] }
|
awc = { version = "3", features = [ "rustls" ] }
|
||||||
|
futures = "0.3.24"
|
||||||
|
|
|
@ -63,6 +63,20 @@ impl PgKeyPoolStorage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tokio-runtime")]
|
||||||
|
async fn random_sleep() {
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
let dur = tokio::time::Duration::from_millis(thread_rng().gen_range(1..50));
|
||||||
|
tokio::time::sleep(dur).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(all(not(feature = "tokio-runtime"), feature = "actix-runtime"))]
|
||||||
|
async fn random_sleep() {
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
let dur = std::time::Duration::from_millis(thread_rng().gen_range(1..50));
|
||||||
|
actix_rt::time::sleep(dur).await;
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl KeyPoolStorage for PgKeyPoolStorage {
|
impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
type Key = PgKey;
|
type Key = PgKey;
|
||||||
|
@ -75,46 +89,77 @@ impl KeyPoolStorage for PgKeyPoolStorage {
|
||||||
KeyDomain::User(id) => format!("where and user_id={} and user", id),
|
KeyDomain::User(id) => format!("where and user_id={} and user", id),
|
||||||
KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id),
|
KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id),
|
||||||
};
|
};
|
||||||
let key: Option<PgKey> = sqlx::query_as(&indoc::formatdoc!(
|
|
||||||
r#"
|
|
||||||
with key as (
|
|
||||||
select
|
|
||||||
id,
|
|
||||||
user_id,
|
|
||||||
faction_id,
|
|
||||||
key,
|
|
||||||
case
|
|
||||||
when extract(minute from last_used)=extract(minute from now()) then uses
|
|
||||||
else 0::smallint
|
|
||||||
end as uses,
|
|
||||||
user,
|
|
||||||
faction,
|
|
||||||
last_used
|
|
||||||
from api_keys {}
|
|
||||||
order by last_used asc limit 1 FOR UPDATE
|
|
||||||
)
|
|
||||||
update api_keys set
|
|
||||||
uses = key.uses + 1,
|
|
||||||
last_used = now()
|
|
||||||
from key where
|
|
||||||
api_keys.id=key.id and key.uses < $1
|
|
||||||
returning
|
|
||||||
api_keys.id,
|
|
||||||
api_keys.user_id,
|
|
||||||
api_keys.faction_id,
|
|
||||||
api_keys.key,
|
|
||||||
api_keys.uses,
|
|
||||||
api_keys.user,
|
|
||||||
api_keys.faction,
|
|
||||||
api_keys.last_used
|
|
||||||
"#,
|
|
||||||
predicate
|
|
||||||
))
|
|
||||||
.bind(self.limit)
|
|
||||||
.fetch_optional(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
key.ok_or(PgStorageError::Unavailable(domain))
|
loop {
|
||||||
|
let attempt = async {
|
||||||
|
let mut tx = self.pool.begin().await?;
|
||||||
|
|
||||||
|
sqlx::query("set transaction isolation level serializable")
|
||||||
|
.execute(&mut tx)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let key: Option<PgKey> = sqlx::query_as(&indoc::formatdoc!(r#"
|
||||||
|
with key as (
|
||||||
|
select
|
||||||
|
id,
|
||||||
|
user_id,
|
||||||
|
faction_id,
|
||||||
|
key,
|
||||||
|
case
|
||||||
|
when extract(minute from last_used)=extract(minute from now()) then uses
|
||||||
|
else 0::smallint
|
||||||
|
end as uses,
|
||||||
|
user,
|
||||||
|
faction,
|
||||||
|
last_used
|
||||||
|
from api_keys {}
|
||||||
|
order by last_used asc limit 1 FOR UPDATE
|
||||||
|
)
|
||||||
|
update api_keys set
|
||||||
|
uses = key.uses + 1,
|
||||||
|
last_used = now()
|
||||||
|
from key where
|
||||||
|
api_keys.id=key.id and key.uses < $1
|
||||||
|
returning
|
||||||
|
api_keys.id,
|
||||||
|
api_keys.user_id,
|
||||||
|
api_keys.faction_id,
|
||||||
|
api_keys.key,
|
||||||
|
api_keys.uses,
|
||||||
|
api_keys.user,
|
||||||
|
api_keys.faction,
|
||||||
|
api_keys.last_used
|
||||||
|
"#,
|
||||||
|
predicate
|
||||||
|
))
|
||||||
|
.bind(self.limit)
|
||||||
|
.fetch_optional(&mut tx)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
tx.commit().await?;
|
||||||
|
|
||||||
|
Result::<Result<Self::Key, Self::Error>, sqlx::Error>::Ok(
|
||||||
|
key.ok_or(PgStorageError::Unavailable(domain)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match attempt {
|
||||||
|
Ok(result) => return result,
|
||||||
|
Err(error) => {
|
||||||
|
if let Some(db_error) = error.as_database_error() {
|
||||||
|
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
|
||||||
|
if pg_error.code() == "40001" {
|
||||||
|
random_sleep().await;
|
||||||
|
} else {
|
||||||
|
return Err(error.into());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(error.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
|
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
|
||||||
|
@ -155,8 +200,9 @@ where
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use std::sync::Once;
|
use std::sync::{Arc, Once};
|
||||||
|
|
||||||
|
use sqlx::Row;
|
||||||
use tokio::test;
|
use tokio::test;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -172,7 +218,12 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
PgKeyPoolStorage::new(pool, 3)
|
sqlx::query("update api_keys set uses=0")
|
||||||
|
.execute(&pool)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
PgKeyPoolStorage::new(pool, 50)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -192,4 +243,30 @@ mod test {
|
||||||
panic!("Acquiring key failed: {:?}", e);
|
panic!("Acquiring key failed: {:?}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_concurrent() {
|
||||||
|
let storage = Arc::new(setup().await);
|
||||||
|
let before: i16 = sqlx::query("select uses from api_keys")
|
||||||
|
.fetch_one(&storage.pool)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.get("uses");
|
||||||
|
|
||||||
|
let futures = (0..30).into_iter().map(|_| {
|
||||||
|
let storage = storage.clone();
|
||||||
|
async move {
|
||||||
|
storage.acquire_key(KeyDomain::Public).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
futures::future::join_all(futures).await;
|
||||||
|
|
||||||
|
let after: i16 = sqlx::query("select uses from api_keys")
|
||||||
|
.fetch_one(&storage.pool)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.get("uses");
|
||||||
|
|
||||||
|
assert_eq!(after, before + 30);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue