feat(key-pool): allow setting schema for key table

This commit is contained in:
TotallyNot 2025-04-27 11:02:33 +02:00
parent c5651efbb0
commit 40b784cf57
Signed by: pyrite
GPG key ID: 7F1BA9170CD35D15
3 changed files with 105 additions and 66 deletions

2
Cargo.lock generated
View file

@ -2303,7 +2303,7 @@ dependencies = [
[[package]]
name = "torn-key-pool"
version = "1.0.0"
version = "1.0.1"
dependencies = [
"chrono",
"futures",

View file

@ -1,6 +1,6 @@
[package]
name = "torn-key-pool"
version = "1.0.0"
version = "1.0.1"
edition = "2021"
authors = ["Pyrit [2111649]"]
license-file = { workspace = true }

View file

@ -1,5 +1,5 @@
use futures::future::BoxFuture;
use indoc::indoc;
use indoc::formatdoc;
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
use thiserror::Error;
@ -93,6 +93,7 @@ where
{
pool: PgPool,
limit: i16,
schema: Option<String>,
_phantom: std::marker::PhantomData<D>,
}
@ -117,62 +118,91 @@ impl<D> PgKeyPoolStorage<D>
where
D: PgKeyDomain,
{
pub fn new(pool: PgPool, limit: i16) -> Self {
pub fn new(pool: PgPool, limit: i16, schema: Option<String>) -> Self {
Self {
pool,
limit,
schema,
_phantom: Default::default(),
}
}
fn table_name(&self) -> String {
match self.schema.as_ref() {
Some(schema) => format!("{schema}.api_keys"),
None => "api_keys".to_owned(),
}
}
fn unique_array_fn(&self) -> String {
match self.schema.as_ref() {
Some(schema) => format!("{schema}.__unique_jsonb_array"),
None => "__unique_jsonb_array".to_owned(),
}
}
fn filter_array_fn(&self) -> String {
match self.schema.as_ref() {
Some(schema) => format!("{schema}.__filter_jsonb_array"),
None => "__filter_jsonb_array".to_owned(),
}
}
pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
sqlx::query(indoc! {r#"
CREATE TABLE IF NOT EXISTS api_keys (
if let Some(schema) = self.schema.as_ref() {
sqlx::query(&format!("create schema if not exists {}", schema))
.execute(&self.pool)
.await?;
}
sqlx::query(&formatdoc! {r#"
CREATE TABLE IF NOT EXISTS {} (
id serial primary key,
user_id int4 not null,
key char(16) not null,
uses int2 not null default 0,
domains jsonb not null default '{}'::jsonb,
domains jsonb not null default '{{}}'::jsonb,
last_used timestamptz not null default now(),
flag int2,
cooldown timestamptz,
constraint "uq:api_keys.key" UNIQUE(key)
)"#
)"#,
self.table_name()
})
.execute(&self.pool)
.await?;
sqlx::query(indoc! {r#"
CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops)
"#})
sqlx::query(&formatdoc! {r#"
CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON {} USING GIN(domains jsonb_path_ops)
"#, self.table_name()})
.execute(&self.pool)
.await?;
sqlx::query(indoc! {r#"
CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id)
"#})
sqlx::query(&formatdoc! {r#"
CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON {} USING BTREE(user_id)
"#, self.table_name()})
.execute(&self.pool)
.await?;
sqlx::query(indoc! {r#"
create or replace function __unique_jsonb_array(jsonb) returns jsonb
sqlx::query(&formatdoc! {r#"
create or replace function {}(jsonb) returns jsonb
AS $$
select jsonb_agg(d::jsonb) from (
select distinct jsonb_array_elements_text($1) as d
) t
$$ language sql;
"#})
"#, self.unique_array_fn()})
.execute(&self.pool)
.await?;
sqlx::query(indoc! {r#"
create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb
sqlx::query(&formatdoc! {r#"
create or replace function {}(jsonb, jsonb) returns jsonb
AS $$
select jsonb_agg(d::jsonb) from (
select distinct jsonb_array_elements_text($1) as d
) t where d<>$2::text
$$ language sql;
"#})
"#, self.filter_array_fn()})
.execute(&self.pool)
.await?;
@ -209,54 +239,52 @@ where
.execute(&mut *tx)
.await?;
let mut qb = QueryBuilder::new(indoc::indoc! {
let mut qb = QueryBuilder::new(&formatdoc! {
r#"
with key as (
select
id,
0::int2 as uses
from api_keys where last_used < date_trunc('minute', now())
from {} where last_used < date_trunc('minute', now())
and (cooldown is null or now() >= cooldown)
and "#
and "#,
self.table_name()
});
build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! {
qb.push(formatdoc! {
"
\n union (
select id, uses from api_keys
select id, uses from {}
where last_used >= date_trunc('minute', now())
and (cooldown is null or now() >= cooldown)
and "
and ",
self.table_name()
});
build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! {
qb.push(formatdoc! {
"
\n order by uses asc limit 1
)
order by uses asc limit 1
)
update api_keys set
update {} as keys set
uses = key.uses + 1,
cooldown = null,
flag = null,
last_used = now()
from key where
api_keys.id=key.id and key.uses < "
keys.id=key.id and key.uses < ",
self.table_name()
});
qb.push_bind(self.limit);
qb.push(indoc::indoc! { "
\nreturning
api_keys.id,
api_keys.user_id,
api_keys.key,
api_keys.uses,
api_keys.domains"
\nreturning keys.id, keys.user_id, keys.key, keys.uses, keys.domains"
});
let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
@ -321,19 +349,20 @@ where
.execute(&mut *tx)
.await?;
let mut qb = QueryBuilder::new(indoc::indoc! {
let mut qb = QueryBuilder::new(&formatdoc! {
r#"select
id,
user_id,
key,
0::int2 as uses,
domains
from api_keys where last_used < date_trunc('minute', now())
from {} where last_used < date_trunc('minute', now())
and (cooldown is null or now() >= cooldown)
and "#
and "#,
self.table_name()
});
build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! {
qb.push(formatdoc! {
"
\nunion
select
@ -342,9 +371,10 @@ where
key,
uses,
domains
from api_keys where last_used >= date_trunc('minute', now())
from {} where last_used >= date_trunc('minute', now())
and (cooldown is null or now() >= cooldown)
and "
and ",
self.table_name()
});
build_predicate(&mut qb, &selector);
qb.push("\norder by uses limit ");
@ -383,15 +413,15 @@ where
result.extend_from_slice(slice);
}
sqlx::query(indoc! {r#"
update api_keys set
sqlx::query(&formatdoc! {r#"
update {} keys set
uses = tmp.uses,
cooldown = null,
flag = null,
last_used = now()
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
where api_keys.id = tmp.id
"#})
where keys.id = tmp.id
"#, self.table_name()})
.bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
.bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
.execute(&mut *tx)
@ -452,7 +482,10 @@ where
{
let selector = selector.into_selector();
let mut qb = QueryBuilder::new("update api_keys set cooldown=now() + ");
let mut qb = QueryBuilder::new(format!(
"update {} set cooldown=now() + ",
self.table_name()
));
qb.push_bind(duration);
qb.push(" where ");
build_predicate(&mut qb, &selector);
@ -468,11 +501,13 @@ where
key: String,
domains: Vec<D>,
) -> Result<Self::Key, Self::Error> {
sqlx::query_as(
"insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \
constraint \"uq:api_keys.key\" do update set domains = \
__unique_jsonb_array(excluded.domains || api_keys.domains) returning *",
)
sqlx::query_as(&dbg!(formatdoc!(
"insert into {} as api_keys(user_id, key, domains) values ($1, $2, $3)
on conflict(key) do update
set domains = {}(excluded.domains || api_keys.domains) returning *",
self.table_name(),
self.unique_array_fn()
)))
.bind(user_id)
.bind(&key)
.bind(sqlx::types::Json(domains))
@ -487,7 +522,7 @@ where
{
let selector = selector.into_selector();
let mut qb = QueryBuilder::new("select * from api_keys where ");
let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
build_predicate(&mut qb, &selector);
qb.build_query_as()
@ -502,7 +537,7 @@ where
{
let selector = selector.into_selector();
let mut qb = QueryBuilder::new("select * from api_keys where ");
let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
build_predicate(&mut qb, &selector);
qb.build_query_as()
@ -517,7 +552,7 @@ where
{
let selector = selector.into_selector();
let mut qb = QueryBuilder::new("delete from api_keys where ");
let mut qb = QueryBuilder::new(format!("delete from {} where ", self.table_name()));
build_predicate(&mut qb, &selector);
qb.push(" returning *");
@ -533,9 +568,11 @@ where
{
let selector = selector.into_selector();
let mut qb = QueryBuilder::new(
"update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(",
);
let mut qb = QueryBuilder::new(format!(
"update {} set domains = {}(domains || jsonb_build_array(",
self.table_name(),
self.unique_array_fn()
));
qb.push_bind(sqlx::types::Json(domain));
qb.push(")) where ");
build_predicate(&mut qb, &selector);
@ -557,9 +594,11 @@ where
{
let selector = selector.into_selector();
let mut qb = QueryBuilder::new(
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, ",
);
let mut qb = QueryBuilder::new(format!(
"update {} set domains = coalesce({}(domains, ",
self.table_name(),
self.filter_array_fn()
));
qb.push_bind(sqlx::types::Json(domain));
qb.push("), '[]'::jsonb) where ");
build_predicate(&mut qb, &selector);
@ -626,7 +665,7 @@ pub(crate) mod test {
.await
.unwrap();
let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
let storage = PgKeyPoolStorage::new(pool.clone(), 1000, Some("test".to_owned()));
storage.initialise().await.unwrap();
let key = storage
@ -823,7 +862,7 @@ pub(crate) mod test {
set.join_next().await.unwrap().unwrap();
}
let uses: i16 = sqlx::query("select uses from api_keys")
let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
.fetch_one(&storage.pool)
.await
.unwrap()
@ -831,7 +870,7 @@ pub(crate) mod test {
assert_eq!(uses, 100);
sqlx::query("update api_keys set uses=0")
sqlx::query(&format!("update {} set uses=0", storage.table_name()))
.execute(&storage.pool)
.await
.unwrap();
@ -871,7 +910,7 @@ pub(crate) mod test {
assert_eq!(key.uses, 2);
}
sqlx::query("update api_keys set uses=0")
sqlx::query(&format!("update {} set uses=0", storage.table_name()))
.execute(&storage.pool)
.await
.unwrap();
@ -896,7 +935,7 @@ pub(crate) mod test {
set.join_next().await.unwrap().unwrap();
}
let uses: i16 = sqlx::query("select uses from api_keys")
let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
.fetch_one(&storage.pool)
.await
.unwrap()
@ -904,7 +943,7 @@ pub(crate) mod test {
assert_eq!(uses, 500);
sqlx::query("update api_keys set uses=0")
sqlx::query(&format!("update {} set uses=0", storage.table_name()))
.execute(&storage.pool)
.await
.unwrap();