From 40b784cf57bae0969481e56d8dcfd48216adb90a Mon Sep 17 00:00:00 2001 From: TotallyNot <44345987+TotallyNot@users.noreply.github.com> Date: Sun, 27 Apr 2025 11:02:33 +0200 Subject: [PATCH] feat(key-pool): allow setting schema for key table --- Cargo.lock | 2 +- torn-key-pool/Cargo.toml | 2 +- torn-key-pool/src/postgres.rs | 167 +++++++++++++++++++++------------- 3 files changed, 105 insertions(+), 66 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f41bfb4..a3f54f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2303,7 +2303,7 @@ dependencies = [ [[package]] name = "torn-key-pool" -version = "1.0.0" +version = "1.0.1" dependencies = [ "chrono", "futures", diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index d31b7b7..88c3724 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "1.0.0" +version = "1.0.1" edition = "2021" authors = ["Pyrit [2111649]"] license-file = { workspace = true } diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 844e260..21651b0 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -1,5 +1,5 @@ use futures::future::BoxFuture; -use indoc::indoc; +use indoc::formatdoc; use sqlx::{FromRow, PgPool, Postgres, QueryBuilder}; use thiserror::Error; @@ -93,6 +93,7 @@ where { pool: PgPool, limit: i16, + schema: Option, _phantom: std::marker::PhantomData, } @@ -117,62 +118,91 @@ impl PgKeyPoolStorage where D: PgKeyDomain, { - pub fn new(pool: PgPool, limit: i16) -> Self { + pub fn new(pool: PgPool, limit: i16, schema: Option) -> Self { Self { pool, limit, + schema, _phantom: Default::default(), } } + fn table_name(&self) -> String { + match self.schema.as_ref() { + Some(schema) => format!("{schema}.api_keys"), + None => "api_keys".to_owned(), + } + } + + fn unique_array_fn(&self) -> String { + match self.schema.as_ref() { + Some(schema) => format!("{schema}.__unique_jsonb_array"), + None => "__unique_jsonb_array".to_owned(), + } + } + + fn filter_array_fn(&self) -> String { + match self.schema.as_ref() { + Some(schema) => format!("{schema}.__filter_jsonb_array"), + None => "__filter_jsonb_array".to_owned(), + } + } + pub async fn initialise(&self) -> Result<(), PgKeyPoolError> { - sqlx::query(indoc! {r#" - CREATE TABLE IF NOT EXISTS api_keys ( + if let Some(schema) = self.schema.as_ref() { + sqlx::query(&format!("create schema if not exists {}", schema)) + .execute(&self.pool) + .await?; + } + + sqlx::query(&formatdoc! {r#" + CREATE TABLE IF NOT EXISTS {} ( id serial primary key, user_id int4 not null, key char(16) not null, uses int2 not null default 0, - domains jsonb not null default '{}'::jsonb, + domains jsonb not null default '{{}}'::jsonb, last_used timestamptz not null default now(), flag int2, cooldown timestamptz, constraint "uq:api_keys.key" UNIQUE(key) - )"# + )"#, + self.table_name() }) .execute(&self.pool) .await?; - sqlx::query(indoc! {r#" - CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops) - "#}) + sqlx::query(&formatdoc! {r#" + CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON {} USING GIN(domains jsonb_path_ops) + "#, self.table_name()}) .execute(&self.pool) .await?; - sqlx::query(indoc! {r#" - CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id) - "#}) + sqlx::query(&formatdoc! {r#" + CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON {} USING BTREE(user_id) + "#, self.table_name()}) .execute(&self.pool) .await?; - sqlx::query(indoc! {r#" - create or replace function __unique_jsonb_array(jsonb) returns jsonb + sqlx::query(&formatdoc! {r#" + create or replace function {}(jsonb) returns jsonb AS $$ select jsonb_agg(d::jsonb) from ( select distinct jsonb_array_elements_text($1) as d ) t $$ language sql; - "#}) + "#, self.unique_array_fn()}) .execute(&self.pool) .await?; - sqlx::query(indoc! {r#" - create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb + sqlx::query(&formatdoc! {r#" + create or replace function {}(jsonb, jsonb) returns jsonb AS $$ select jsonb_agg(d::jsonb) from ( select distinct jsonb_array_elements_text($1) as d ) t where d<>$2::text $$ language sql; - "#}) + "#, self.filter_array_fn()}) .execute(&self.pool) .await?; @@ -209,54 +239,52 @@ where .execute(&mut *tx) .await?; - let mut qb = QueryBuilder::new(indoc::indoc! { + let mut qb = QueryBuilder::new(&formatdoc! { r#" with key as ( select id, 0::int2 as uses - from api_keys where last_used < date_trunc('minute', now()) + from {} where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) - and "# + and "#, + self.table_name() }); build_predicate(&mut qb, &selector); - qb.push(indoc::indoc! { + qb.push(formatdoc! { " \n union ( - select id, uses from api_keys + select id, uses from {} where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) - and " + and ", + self.table_name() }); build_predicate(&mut qb, &selector); - qb.push(indoc::indoc! { + qb.push(formatdoc! { " \n order by uses asc limit 1 ) order by uses asc limit 1 ) - update api_keys set + update {} as keys set uses = key.uses + 1, cooldown = null, flag = null, last_used = now() from key where - api_keys.id=key.id and key.uses < " + keys.id=key.id and key.uses < ", + self.table_name() }); qb.push_bind(self.limit); qb.push(indoc::indoc! { " - \nreturning - api_keys.id, - api_keys.user_id, - api_keys.key, - api_keys.uses, - api_keys.domains" + \nreturning keys.id, keys.user_id, keys.key, keys.uses, keys.domains" }); let key = qb.build_query_as().fetch_optional(&mut *tx).await?; @@ -321,19 +349,20 @@ where .execute(&mut *tx) .await?; - let mut qb = QueryBuilder::new(indoc::indoc! { + let mut qb = QueryBuilder::new(&formatdoc! { r#"select id, user_id, key, 0::int2 as uses, domains - from api_keys where last_used < date_trunc('minute', now()) + from {} where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) - and "# + and "#, + self.table_name() }); build_predicate(&mut qb, &selector); - qb.push(indoc::indoc! { + qb.push(formatdoc! { " \nunion select @@ -342,9 +371,10 @@ where key, uses, domains - from api_keys where last_used >= date_trunc('minute', now()) + from {} where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) - and " + and ", + self.table_name() }); build_predicate(&mut qb, &selector); qb.push("\norder by uses limit "); @@ -383,15 +413,15 @@ where result.extend_from_slice(slice); } - sqlx::query(indoc! {r#" - update api_keys set + sqlx::query(&formatdoc! {r#" + update {} keys set uses = tmp.uses, cooldown = null, flag = null, last_used = now() from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp - where api_keys.id = tmp.id - "#}) + where keys.id = tmp.id + "#, self.table_name()}) .bind(keys.iter().map(|k| k.id).collect::>()) .bind(keys.iter().map(|k| k.uses).collect::>()) .execute(&mut *tx) @@ -452,7 +482,10 @@ where { let selector = selector.into_selector(); - let mut qb = QueryBuilder::new("update api_keys set cooldown=now() + "); + let mut qb = QueryBuilder::new(format!( + "update {} set cooldown=now() + ", + self.table_name() + )); qb.push_bind(duration); qb.push(" where "); build_predicate(&mut qb, &selector); @@ -468,11 +501,13 @@ where key: String, domains: Vec, ) -> Result { - sqlx::query_as( - "insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \ - constraint \"uq:api_keys.key\" do update set domains = \ - __unique_jsonb_array(excluded.domains || api_keys.domains) returning *", - ) + sqlx::query_as(&dbg!(formatdoc!( + "insert into {} as api_keys(user_id, key, domains) values ($1, $2, $3) + on conflict(key) do update + set domains = {}(excluded.domains || api_keys.domains) returning *", + self.table_name(), + self.unique_array_fn() + ))) .bind(user_id) .bind(&key) .bind(sqlx::types::Json(domains)) @@ -487,7 +522,7 @@ where { let selector = selector.into_selector(); - let mut qb = QueryBuilder::new("select * from api_keys where "); + let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name())); build_predicate(&mut qb, &selector); qb.build_query_as() @@ -502,7 +537,7 @@ where { let selector = selector.into_selector(); - let mut qb = QueryBuilder::new("select * from api_keys where "); + let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name())); build_predicate(&mut qb, &selector); qb.build_query_as() @@ -517,7 +552,7 @@ where { let selector = selector.into_selector(); - let mut qb = QueryBuilder::new("delete from api_keys where "); + let mut qb = QueryBuilder::new(format!("delete from {} where ", self.table_name())); build_predicate(&mut qb, &selector); qb.push(" returning *"); @@ -533,9 +568,11 @@ where { let selector = selector.into_selector(); - let mut qb = QueryBuilder::new( - "update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(", - ); + let mut qb = QueryBuilder::new(format!( + "update {} set domains = {}(domains || jsonb_build_array(", + self.table_name(), + self.unique_array_fn() + )); qb.push_bind(sqlx::types::Json(domain)); qb.push(")) where "); build_predicate(&mut qb, &selector); @@ -557,9 +594,11 @@ where { let selector = selector.into_selector(); - let mut qb = QueryBuilder::new( - "update api_keys set domains = coalesce(__filter_jsonb_array(domains, ", - ); + let mut qb = QueryBuilder::new(format!( + "update {} set domains = coalesce({}(domains, ", + self.table_name(), + self.filter_array_fn() + )); qb.push_bind(sqlx::types::Json(domain)); qb.push("), '[]'::jsonb) where "); build_predicate(&mut qb, &selector); @@ -626,7 +665,7 @@ pub(crate) mod test { .await .unwrap(); - let storage = PgKeyPoolStorage::new(pool.clone(), 1000); + let storage = PgKeyPoolStorage::new(pool.clone(), 1000, Some("test".to_owned())); storage.initialise().await.unwrap(); let key = storage @@ -823,7 +862,7 @@ pub(crate) mod test { set.join_next().await.unwrap().unwrap(); } - let uses: i16 = sqlx::query("select uses from api_keys") + let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name())) .fetch_one(&storage.pool) .await .unwrap() @@ -831,7 +870,7 @@ pub(crate) mod test { assert_eq!(uses, 100); - sqlx::query("update api_keys set uses=0") + sqlx::query(&format!("update {} set uses=0", storage.table_name())) .execute(&storage.pool) .await .unwrap(); @@ -871,7 +910,7 @@ pub(crate) mod test { assert_eq!(key.uses, 2); } - sqlx::query("update api_keys set uses=0") + sqlx::query(&format!("update {} set uses=0", storage.table_name())) .execute(&storage.pool) .await .unwrap(); @@ -896,7 +935,7 @@ pub(crate) mod test { set.join_next().await.unwrap().unwrap(); } - let uses: i16 = sqlx::query("select uses from api_keys") + let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name())) .fetch_one(&storage.pool) .await .unwrap() @@ -904,7 +943,7 @@ pub(crate) mod test { assert_eq!(uses, 500); - sqlx::query("update api_keys set uses=0") + sqlx::query(&format!("update {} set uses=0", storage.table_name())) .execute(&storage.pool) .await .unwrap();