diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index 4453ee1..8057755 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,14 +1,13 @@ [package] name = "torn-key-pool" -version = "0.5.0" +version = "0.5.1" edition = "2021" +authors = ["Pyrit [2111649]"] license = "MIT" repository = "https://github.com/TotallyNot/torn-api.rs.git" homepage = "https://github.com/TotallyNot/torn-api.rs.git" description = "A generalised API key pool for torn-api" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [features] default = [ "postgres", "tokio-runtime" ] postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc", "dep:serde" ] diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index d99c0cf..8e7ed66 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -33,7 +33,11 @@ pub trait ApiKey: Sync + Send { fn value(&self) -> &str; } -pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {} +pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync { + fn fallback(&self) -> Option { + None + } +} impl KeyDomain for T where T: Clone + std::fmt::Debug + Send + Sync {} @@ -55,12 +59,15 @@ pub trait KeyPoolStorage { async fn store_key( &self, + user_id: i32, key: String, domains: Vec, ) -> Result; async fn read_key(&self, key: String) -> Result; + async fn read_user_keys(&self, user_id: i32) -> Result, Self::Error>; + async fn remove_key(&self, key: String) -> Result; async fn add_domain_to_key( diff --git a/torn-key-pool/src/local.rs b/torn-key-pool/src/local.rs index 5a78b48..1d73568 100644 --- a/torn-key-pool/src/local.rs +++ b/torn-key-pool/src/local.rs @@ -128,7 +128,7 @@ where S: KeyPoolStorage, { client: C, - storage: S, + pub storage: S, comment: Option, } diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index f0868c6..f92085d 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -42,6 +42,7 @@ where D: PgKeyDomain, { pub id: i32, + pub user_id: i32, pub key: String, pub uses: i16, pub domains: sqlx::types::Json>, @@ -82,13 +83,14 @@ where sqlx::query(indoc! {r#" CREATE TABLE IF NOT EXISTS api_keys ( id serial primary key, + user_id int4 not null, key char(16) not null, uses int2 not null default 0, domains jsonb not null default '{}'::jsonb, last_used timestamptz not null default now(), flag int2, cooldown timestamptz, - constraint "uq:api_keys.key" UNIQUE(key) + constraint "uq:api_keys.key+user_id" UNIQUE(user_id, key) )"# }) .execute(&self.pool) @@ -100,6 +102,12 @@ where .execute(&self.pool) .await?; + sqlx::query(indoc! {r#" + CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id) + "#}) + .execute(&self.pool) + .await?; + Ok(()) } } @@ -143,10 +151,14 @@ where select id, 0::int2 as uses - from api_keys where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 + from api_keys where last_used < date_trunc('minute', now()) + and (cooldown is null or now() >= cooldown) + and domains @> $1 union ( select id, uses from api_keys - where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 + where last_used >= date_trunc('minute', now()) + and (cooldown is null or now() >= cooldown) + and domains @> $1 order by uses asc ) limit 1 @@ -160,6 +172,7 @@ where api_keys.id=key.id and key.uses < $2 returning api_keys.id, + api_keys.user_id, api_keys.key, api_keys.uses, api_keys.domains @@ -170,17 +183,23 @@ where .fetch_optional(&mut tx) .await?; - tx.commit().await?; + tx.commit().await?; - Result::, sqlx::Error>::Ok( - key - ) + Result::, sqlx::Error>::Ok(key) } .await; match attempt { Ok(Some(result)) => return Ok(result), - Ok(None) => return Err(PgStorageError::Unavailable(domain)), + Ok(None) => { + return self + .acquire_key( + domain + .fallback() + .ok_or_else(|| PgStorageError::Unavailable(domain))?, + ) + .await + } Err(error) => { if let Some(db_error) = error.as_database_error() { let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref(); @@ -213,17 +232,23 @@ where let mut keys: Vec = sqlx::query_as(&indoc::formatdoc!( r#"select id, + user_id, key, 0::int2 as uses, domains - from api_keys where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 + from api_keys where last_used < date_trunc('minute', now()) + and (cooldown is null or now() >= cooldown) + and domains @> $1 union select id, + user_id, key, uses, domains - from api_keys where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 + from api_keys where last_used >= date_trunc('minute', now()) + and (cooldown is null or now() >= cooldown) + and domains @> $1 order by uses limit $2 "#, )) @@ -285,7 +310,16 @@ where match attempt { Ok(Some(result)) => return Ok(result), - Ok(None) => return Err(Self::Error::Unavailable(domain)), + Ok(None) => { + return self + .acquire_many_keys( + domain + .fallback() + .ok_or_else(|| Self::Error::Unavailable(domain))?, + number, + ) + .await + } Err(error) => { if let Some(db_error) = error.as_database_error() { let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref(); @@ -303,7 +337,6 @@ where } async fn flag_key(&self, key: Self::Key, code: u8) -> Result { - // TODO: put keys in cooldown when appropriate match code { 2 | 10 | 13 => { // invalid key, owner fedded or owner inactive @@ -350,21 +383,29 @@ where } } - async fn store_key(&self, key: String, domains: Vec) -> Result { - sqlx::query_as("insert into api_keys(key, domains) values ($1, $2) returning *") - .bind(&key) - .bind(sqlx::types::Json(domains)) - .fetch_one(&self.pool) - .await - .map_err(|why| { - if let Some(error) = why.as_database_error() { - let pg_error: &sqlx::postgres::PgDatabaseError = error.downcast_ref(); - if pg_error.code() == "23505" { - return PgStorageError::DuplicateKey(key); - } + async fn store_key( + &self, + user_id: i32, + key: String, + domains: Vec, + ) -> Result { + sqlx::query_as( + "insert into api_keys(user_id, key, domains) values ($1, $2, $3) returning *", + ) + .bind(user_id) + .bind(&key) + .bind(sqlx::types::Json(domains)) + .fetch_one(&self.pool) + .await + .map_err(|why| { + if let Some(error) = why.as_database_error() { + let pg_error: &sqlx::postgres::PgDatabaseError = error.downcast_ref(); + if pg_error.code() == "23505" { + return PgStorageError::DuplicateKey(key); } - PgStorageError::Pg(why) - }) + } + PgStorageError::Pg(why) + }) } async fn read_key(&self, key: String) -> Result { @@ -375,6 +416,14 @@ where .ok_or_else(|| PgStorageError::KeyNotFound(key)) } + async fn read_user_keys(&self, user_id: i32) -> Result, Self::Error> { + sqlx::query_as("select * from api_keys where user_id=$1") + .bind(user_id) + .fetch_all(&self.pool) + .await + .map_err(Into::into) + } + async fn remove_key(&self, key: String) -> Result { sqlx::query_as("delete from api_keys where key=$1 returning *") .bind(&key) @@ -475,7 +524,7 @@ pub(crate) mod test { storage.initialise().await.unwrap(); storage - .store_key(std::env::var("APIKEY").unwrap(), vec![Domain::All]) + .store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All]) .await .unwrap(); @@ -495,7 +544,7 @@ pub(crate) mod test { async fn test_store_duplicate() { let storage = setup().await; match storage - .store_key(std::env::var("APIKEY").unwrap(), vec![]) + .store_key(1, std::env::var("APIKEY").unwrap(), vec![]) .await .unwrap_err() { @@ -545,12 +594,20 @@ pub(crate) mod test { async fn test_store_key() { let storage = setup().await; let key = storage - .store_key("ABCDABCDABCDABCD".to_owned(), vec![]) + .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![]) .await .unwrap(); assert_eq!(key.value(), "ABCDABCDABCDABCD"); } + #[test] + async fn test_read_user_keys() { + let storage = setup().await; + + let keys = storage.read_user_keys(1).await.unwrap(); + assert_eq!(keys.len(), 1); + } + #[test] async fn acquire_one() { let storage = setup().await; diff --git a/torn-key-pool/src/send.rs b/torn-key-pool/src/send.rs index 2409443..aa89a6a 100644 --- a/torn-key-pool/src/send.rs +++ b/torn-key-pool/src/send.rs @@ -128,7 +128,7 @@ where S: KeyPoolStorage, { client: C, - storage: S, + pub storage: S, comment: Option, }