feat(key-pool): allow setting schema for key table

This commit is contained in:
TotallyNot 2025-04-27 11:02:33 +02:00
parent c5651efbb0
commit 40b784cf57
Signed by: pyrite
GPG key ID: 7F1BA9170CD35D15
3 changed files with 105 additions and 66 deletions

2
Cargo.lock generated
View file

@ -2303,7 +2303,7 @@ dependencies = [
[[package]] [[package]]
name = "torn-key-pool" name = "torn-key-pool"
version = "1.0.0" version = "1.0.1"
dependencies = [ dependencies = [
"chrono", "chrono",
"futures", "futures",

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "1.0.0" version = "1.0.1"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license-file = { workspace = true } license-file = { workspace = true }

View file

@ -1,5 +1,5 @@
use futures::future::BoxFuture; use futures::future::BoxFuture;
use indoc::indoc; use indoc::formatdoc;
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder}; use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
use thiserror::Error; use thiserror::Error;
@ -93,6 +93,7 @@ where
{ {
pool: PgPool, pool: PgPool,
limit: i16, limit: i16,
schema: Option<String>,
_phantom: std::marker::PhantomData<D>, _phantom: std::marker::PhantomData<D>,
} }
@ -117,62 +118,91 @@ impl<D> PgKeyPoolStorage<D>
where where
D: PgKeyDomain, D: PgKeyDomain,
{ {
pub fn new(pool: PgPool, limit: i16) -> Self { pub fn new(pool: PgPool, limit: i16, schema: Option<String>) -> Self {
Self { Self {
pool, pool,
limit, limit,
schema,
_phantom: Default::default(), _phantom: Default::default(),
} }
} }
fn table_name(&self) -> String {
match self.schema.as_ref() {
Some(schema) => format!("{schema}.api_keys"),
None => "api_keys".to_owned(),
}
}
fn unique_array_fn(&self) -> String {
match self.schema.as_ref() {
Some(schema) => format!("{schema}.__unique_jsonb_array"),
None => "__unique_jsonb_array".to_owned(),
}
}
fn filter_array_fn(&self) -> String {
match self.schema.as_ref() {
Some(schema) => format!("{schema}.__filter_jsonb_array"),
None => "__filter_jsonb_array".to_owned(),
}
}
pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> { pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
sqlx::query(indoc! {r#" if let Some(schema) = self.schema.as_ref() {
CREATE TABLE IF NOT EXISTS api_keys ( sqlx::query(&format!("create schema if not exists {}", schema))
.execute(&self.pool)
.await?;
}
sqlx::query(&formatdoc! {r#"
CREATE TABLE IF NOT EXISTS {} (
id serial primary key, id serial primary key,
user_id int4 not null, user_id int4 not null,
key char(16) not null, key char(16) not null,
uses int2 not null default 0, uses int2 not null default 0,
domains jsonb not null default '{}'::jsonb, domains jsonb not null default '{{}}'::jsonb,
last_used timestamptz not null default now(), last_used timestamptz not null default now(),
flag int2, flag int2,
cooldown timestamptz, cooldown timestamptz,
constraint "uq:api_keys.key" UNIQUE(key) constraint "uq:api_keys.key" UNIQUE(key)
)"# )"#,
self.table_name()
}) })
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
sqlx::query(indoc! {r#" sqlx::query(&formatdoc! {r#"
CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops) CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON {} USING GIN(domains jsonb_path_ops)
"#}) "#, self.table_name()})
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
sqlx::query(indoc! {r#" sqlx::query(&formatdoc! {r#"
CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id) CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON {} USING BTREE(user_id)
"#}) "#, self.table_name()})
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
sqlx::query(indoc! {r#" sqlx::query(&formatdoc! {r#"
create or replace function __unique_jsonb_array(jsonb) returns jsonb create or replace function {}(jsonb) returns jsonb
AS $$ AS $$
select jsonb_agg(d::jsonb) from ( select jsonb_agg(d::jsonb) from (
select distinct jsonb_array_elements_text($1) as d select distinct jsonb_array_elements_text($1) as d
) t ) t
$$ language sql; $$ language sql;
"#}) "#, self.unique_array_fn()})
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
sqlx::query(indoc! {r#" sqlx::query(&formatdoc! {r#"
create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb create or replace function {}(jsonb, jsonb) returns jsonb
AS $$ AS $$
select jsonb_agg(d::jsonb) from ( select jsonb_agg(d::jsonb) from (
select distinct jsonb_array_elements_text($1) as d select distinct jsonb_array_elements_text($1) as d
) t where d<>$2::text ) t where d<>$2::text
$$ language sql; $$ language sql;
"#}) "#, self.filter_array_fn()})
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
@ -209,54 +239,52 @@ where
.execute(&mut *tx) .execute(&mut *tx)
.await?; .await?;
let mut qb = QueryBuilder::new(indoc::indoc! { let mut qb = QueryBuilder::new(&formatdoc! {
r#" r#"
with key as ( with key as (
select select
id, id,
0::int2 as uses 0::int2 as uses
from api_keys where last_used < date_trunc('minute', now()) from {} where last_used < date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and "# and "#,
self.table_name()
}); });
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! { qb.push(formatdoc! {
" "
\n union ( \n union (
select id, uses from api_keys select id, uses from {}
where last_used >= date_trunc('minute', now()) where last_used >= date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and " and ",
self.table_name()
}); });
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! { qb.push(formatdoc! {
" "
\n order by uses asc limit 1 \n order by uses asc limit 1
) )
order by uses asc limit 1 order by uses asc limit 1
) )
update api_keys set update {} as keys set
uses = key.uses + 1, uses = key.uses + 1,
cooldown = null, cooldown = null,
flag = null, flag = null,
last_used = now() last_used = now()
from key where from key where
api_keys.id=key.id and key.uses < " keys.id=key.id and key.uses < ",
self.table_name()
}); });
qb.push_bind(self.limit); qb.push_bind(self.limit);
qb.push(indoc::indoc! { " qb.push(indoc::indoc! { "
\nreturning \nreturning keys.id, keys.user_id, keys.key, keys.uses, keys.domains"
api_keys.id,
api_keys.user_id,
api_keys.key,
api_keys.uses,
api_keys.domains"
}); });
let key = qb.build_query_as().fetch_optional(&mut *tx).await?; let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
@ -321,19 +349,20 @@ where
.execute(&mut *tx) .execute(&mut *tx)
.await?; .await?;
let mut qb = QueryBuilder::new(indoc::indoc! { let mut qb = QueryBuilder::new(&formatdoc! {
r#"select r#"select
id, id,
user_id, user_id,
key, key,
0::int2 as uses, 0::int2 as uses,
domains domains
from api_keys where last_used < date_trunc('minute', now()) from {} where last_used < date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and "# and "#,
self.table_name()
}); });
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! { qb.push(formatdoc! {
" "
\nunion \nunion
select select
@ -342,9 +371,10 @@ where
key, key,
uses, uses,
domains domains
from api_keys where last_used >= date_trunc('minute', now()) from {} where last_used >= date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and " and ",
self.table_name()
}); });
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.push("\norder by uses limit "); qb.push("\norder by uses limit ");
@ -383,15 +413,15 @@ where
result.extend_from_slice(slice); result.extend_from_slice(slice);
} }
sqlx::query(indoc! {r#" sqlx::query(&formatdoc! {r#"
update api_keys set update {} keys set
uses = tmp.uses, uses = tmp.uses,
cooldown = null, cooldown = null,
flag = null, flag = null,
last_used = now() last_used = now()
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
where api_keys.id = tmp.id where keys.id = tmp.id
"#}) "#, self.table_name()})
.bind(keys.iter().map(|k| k.id).collect::<Vec<_>>()) .bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
.bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>()) .bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
.execute(&mut *tx) .execute(&mut *tx)
@ -452,7 +482,10 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new("update api_keys set cooldown=now() + "); let mut qb = QueryBuilder::new(format!(
"update {} set cooldown=now() + ",
self.table_name()
));
qb.push_bind(duration); qb.push_bind(duration);
qb.push(" where "); qb.push(" where ");
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
@ -468,11 +501,13 @@ where
key: String, key: String,
domains: Vec<D>, domains: Vec<D>,
) -> Result<Self::Key, Self::Error> { ) -> Result<Self::Key, Self::Error> {
sqlx::query_as( sqlx::query_as(&dbg!(formatdoc!(
"insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \ "insert into {} as api_keys(user_id, key, domains) values ($1, $2, $3)
constraint \"uq:api_keys.key\" do update set domains = \ on conflict(key) do update
__unique_jsonb_array(excluded.domains || api_keys.domains) returning *", set domains = {}(excluded.domains || api_keys.domains) returning *",
) self.table_name(),
self.unique_array_fn()
)))
.bind(user_id) .bind(user_id)
.bind(&key) .bind(&key)
.bind(sqlx::types::Json(domains)) .bind(sqlx::types::Json(domains))
@ -487,7 +522,7 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new("select * from api_keys where "); let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.build_query_as() qb.build_query_as()
@ -502,7 +537,7 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new("select * from api_keys where "); let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.build_query_as() qb.build_query_as()
@ -517,7 +552,7 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new("delete from api_keys where "); let mut qb = QueryBuilder::new(format!("delete from {} where ", self.table_name()));
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.push(" returning *"); qb.push(" returning *");
@ -533,9 +568,11 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new( let mut qb = QueryBuilder::new(format!(
"update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(", "update {} set domains = {}(domains || jsonb_build_array(",
); self.table_name(),
self.unique_array_fn()
));
qb.push_bind(sqlx::types::Json(domain)); qb.push_bind(sqlx::types::Json(domain));
qb.push(")) where "); qb.push(")) where ");
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
@ -557,9 +594,11 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new( let mut qb = QueryBuilder::new(format!(
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, ", "update {} set domains = coalesce({}(domains, ",
); self.table_name(),
self.filter_array_fn()
));
qb.push_bind(sqlx::types::Json(domain)); qb.push_bind(sqlx::types::Json(domain));
qb.push("), '[]'::jsonb) where "); qb.push("), '[]'::jsonb) where ");
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
@ -626,7 +665,7 @@ pub(crate) mod test {
.await .await
.unwrap(); .unwrap();
let storage = PgKeyPoolStorage::new(pool.clone(), 1000); let storage = PgKeyPoolStorage::new(pool.clone(), 1000, Some("test".to_owned()));
storage.initialise().await.unwrap(); storage.initialise().await.unwrap();
let key = storage let key = storage
@ -823,7 +862,7 @@ pub(crate) mod test {
set.join_next().await.unwrap().unwrap(); set.join_next().await.unwrap().unwrap();
} }
let uses: i16 = sqlx::query("select uses from api_keys") let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
.fetch_one(&storage.pool) .fetch_one(&storage.pool)
.await .await
.unwrap() .unwrap()
@ -831,7 +870,7 @@ pub(crate) mod test {
assert_eq!(uses, 100); assert_eq!(uses, 100);
sqlx::query("update api_keys set uses=0") sqlx::query(&format!("update {} set uses=0", storage.table_name()))
.execute(&storage.pool) .execute(&storage.pool)
.await .await
.unwrap(); .unwrap();
@ -871,7 +910,7 @@ pub(crate) mod test {
assert_eq!(key.uses, 2); assert_eq!(key.uses, 2);
} }
sqlx::query("update api_keys set uses=0") sqlx::query(&format!("update {} set uses=0", storage.table_name()))
.execute(&storage.pool) .execute(&storage.pool)
.await .await
.unwrap(); .unwrap();
@ -896,7 +935,7 @@ pub(crate) mod test {
set.join_next().await.unwrap().unwrap(); set.join_next().await.unwrap().unwrap();
} }
let uses: i16 = sqlx::query("select uses from api_keys") let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
.fetch_one(&storage.pool) .fetch_one(&storage.pool)
.await .await
.unwrap() .unwrap()
@ -904,7 +943,7 @@ pub(crate) mod test {
assert_eq!(uses, 500); assert_eq!(uses, 500);
sqlx::query("update api_keys set uses=0") sqlx::query(&format!("update {} set uses=0", storage.table_name()))
.execute(&storage.pool) .execute(&storage.pool)
.await .await
.unwrap(); .unwrap();