feat(key-pool): allow setting schema for key table
This commit is contained in:
parent
c5651efbb0
commit
40b784cf57
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -2303,7 +2303,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "torn-key-pool"
|
name = "torn-key-pool"
|
||||||
version = "1.0.0"
|
version = "1.0.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"chrono",
|
"chrono",
|
||||||
"futures",
|
"futures",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "torn-key-pool"
|
name = "torn-key-pool"
|
||||||
version = "1.0.0"
|
version = "1.0.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Pyrit [2111649]"]
|
authors = ["Pyrit [2111649]"]
|
||||||
license-file = { workspace = true }
|
license-file = { workspace = true }
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use futures::future::BoxFuture;
|
use futures::future::BoxFuture;
|
||||||
use indoc::indoc;
|
use indoc::formatdoc;
|
||||||
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
|
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
|
@ -93,6 +93,7 @@ where
|
||||||
{
|
{
|
||||||
pool: PgPool,
|
pool: PgPool,
|
||||||
limit: i16,
|
limit: i16,
|
||||||
|
schema: Option<String>,
|
||||||
_phantom: std::marker::PhantomData<D>,
|
_phantom: std::marker::PhantomData<D>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,62 +118,91 @@ impl<D> PgKeyPoolStorage<D>
|
||||||
where
|
where
|
||||||
D: PgKeyDomain,
|
D: PgKeyDomain,
|
||||||
{
|
{
|
||||||
pub fn new(pool: PgPool, limit: i16) -> Self {
|
pub fn new(pool: PgPool, limit: i16, schema: Option<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
pool,
|
pool,
|
||||||
limit,
|
limit,
|
||||||
|
schema,
|
||||||
_phantom: Default::default(),
|
_phantom: Default::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn table_name(&self) -> String {
|
||||||
|
match self.schema.as_ref() {
|
||||||
|
Some(schema) => format!("{schema}.api_keys"),
|
||||||
|
None => "api_keys".to_owned(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unique_array_fn(&self) -> String {
|
||||||
|
match self.schema.as_ref() {
|
||||||
|
Some(schema) => format!("{schema}.__unique_jsonb_array"),
|
||||||
|
None => "__unique_jsonb_array".to_owned(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter_array_fn(&self) -> String {
|
||||||
|
match self.schema.as_ref() {
|
||||||
|
Some(schema) => format!("{schema}.__filter_jsonb_array"),
|
||||||
|
None => "__filter_jsonb_array".to_owned(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
|
pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
|
||||||
sqlx::query(indoc! {r#"
|
if let Some(schema) = self.schema.as_ref() {
|
||||||
CREATE TABLE IF NOT EXISTS api_keys (
|
sqlx::query(&format!("create schema if not exists {}", schema))
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlx::query(&formatdoc! {r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS {} (
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
user_id int4 not null,
|
user_id int4 not null,
|
||||||
key char(16) not null,
|
key char(16) not null,
|
||||||
uses int2 not null default 0,
|
uses int2 not null default 0,
|
||||||
domains jsonb not null default '{}'::jsonb,
|
domains jsonb not null default '{{}}'::jsonb,
|
||||||
last_used timestamptz not null default now(),
|
last_used timestamptz not null default now(),
|
||||||
flag int2,
|
flag int2,
|
||||||
cooldown timestamptz,
|
cooldown timestamptz,
|
||||||
constraint "uq:api_keys.key" UNIQUE(key)
|
constraint "uq:api_keys.key" UNIQUE(key)
|
||||||
)"#
|
)"#,
|
||||||
|
self.table_name()
|
||||||
})
|
})
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::query(indoc! {r#"
|
sqlx::query(&formatdoc! {r#"
|
||||||
CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops)
|
CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON {} USING GIN(domains jsonb_path_ops)
|
||||||
"#})
|
"#, self.table_name()})
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::query(indoc! {r#"
|
sqlx::query(&formatdoc! {r#"
|
||||||
CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id)
|
CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON {} USING BTREE(user_id)
|
||||||
"#})
|
"#, self.table_name()})
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::query(indoc! {r#"
|
sqlx::query(&formatdoc! {r#"
|
||||||
create or replace function __unique_jsonb_array(jsonb) returns jsonb
|
create or replace function {}(jsonb) returns jsonb
|
||||||
AS $$
|
AS $$
|
||||||
select jsonb_agg(d::jsonb) from (
|
select jsonb_agg(d::jsonb) from (
|
||||||
select distinct jsonb_array_elements_text($1) as d
|
select distinct jsonb_array_elements_text($1) as d
|
||||||
) t
|
) t
|
||||||
$$ language sql;
|
$$ language sql;
|
||||||
"#})
|
"#, self.unique_array_fn()})
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::query(indoc! {r#"
|
sqlx::query(&formatdoc! {r#"
|
||||||
create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb
|
create or replace function {}(jsonb, jsonb) returns jsonb
|
||||||
AS $$
|
AS $$
|
||||||
select jsonb_agg(d::jsonb) from (
|
select jsonb_agg(d::jsonb) from (
|
||||||
select distinct jsonb_array_elements_text($1) as d
|
select distinct jsonb_array_elements_text($1) as d
|
||||||
) t where d<>$2::text
|
) t where d<>$2::text
|
||||||
$$ language sql;
|
$$ language sql;
|
||||||
"#})
|
"#, self.filter_array_fn()})
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -209,54 +239,52 @@ where
|
||||||
.execute(&mut *tx)
|
.execute(&mut *tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut qb = QueryBuilder::new(indoc::indoc! {
|
let mut qb = QueryBuilder::new(&formatdoc! {
|
||||||
r#"
|
r#"
|
||||||
with key as (
|
with key as (
|
||||||
select
|
select
|
||||||
id,
|
id,
|
||||||
0::int2 as uses
|
0::int2 as uses
|
||||||
from api_keys where last_used < date_trunc('minute', now())
|
from {} where last_used < date_trunc('minute', now())
|
||||||
and (cooldown is null or now() >= cooldown)
|
and (cooldown is null or now() >= cooldown)
|
||||||
and "#
|
and "#,
|
||||||
|
self.table_name()
|
||||||
});
|
});
|
||||||
|
|
||||||
build_predicate(&mut qb, &selector);
|
build_predicate(&mut qb, &selector);
|
||||||
|
|
||||||
qb.push(indoc::indoc! {
|
qb.push(formatdoc! {
|
||||||
"
|
"
|
||||||
\n union (
|
\n union (
|
||||||
select id, uses from api_keys
|
select id, uses from {}
|
||||||
where last_used >= date_trunc('minute', now())
|
where last_used >= date_trunc('minute', now())
|
||||||
and (cooldown is null or now() >= cooldown)
|
and (cooldown is null or now() >= cooldown)
|
||||||
and "
|
and ",
|
||||||
|
self.table_name()
|
||||||
});
|
});
|
||||||
|
|
||||||
build_predicate(&mut qb, &selector);
|
build_predicate(&mut qb, &selector);
|
||||||
|
|
||||||
qb.push(indoc::indoc! {
|
qb.push(formatdoc! {
|
||||||
"
|
"
|
||||||
\n order by uses asc limit 1
|
\n order by uses asc limit 1
|
||||||
)
|
)
|
||||||
order by uses asc limit 1
|
order by uses asc limit 1
|
||||||
)
|
)
|
||||||
update api_keys set
|
update {} as keys set
|
||||||
uses = key.uses + 1,
|
uses = key.uses + 1,
|
||||||
cooldown = null,
|
cooldown = null,
|
||||||
flag = null,
|
flag = null,
|
||||||
last_used = now()
|
last_used = now()
|
||||||
from key where
|
from key where
|
||||||
api_keys.id=key.id and key.uses < "
|
keys.id=key.id and key.uses < ",
|
||||||
|
self.table_name()
|
||||||
});
|
});
|
||||||
|
|
||||||
qb.push_bind(self.limit);
|
qb.push_bind(self.limit);
|
||||||
|
|
||||||
qb.push(indoc::indoc! { "
|
qb.push(indoc::indoc! { "
|
||||||
\nreturning
|
\nreturning keys.id, keys.user_id, keys.key, keys.uses, keys.domains"
|
||||||
api_keys.id,
|
|
||||||
api_keys.user_id,
|
|
||||||
api_keys.key,
|
|
||||||
api_keys.uses,
|
|
||||||
api_keys.domains"
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
|
let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
|
||||||
|
@ -321,19 +349,20 @@ where
|
||||||
.execute(&mut *tx)
|
.execute(&mut *tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut qb = QueryBuilder::new(indoc::indoc! {
|
let mut qb = QueryBuilder::new(&formatdoc! {
|
||||||
r#"select
|
r#"select
|
||||||
id,
|
id,
|
||||||
user_id,
|
user_id,
|
||||||
key,
|
key,
|
||||||
0::int2 as uses,
|
0::int2 as uses,
|
||||||
domains
|
domains
|
||||||
from api_keys where last_used < date_trunc('minute', now())
|
from {} where last_used < date_trunc('minute', now())
|
||||||
and (cooldown is null or now() >= cooldown)
|
and (cooldown is null or now() >= cooldown)
|
||||||
and "#
|
and "#,
|
||||||
|
self.table_name()
|
||||||
});
|
});
|
||||||
build_predicate(&mut qb, &selector);
|
build_predicate(&mut qb, &selector);
|
||||||
qb.push(indoc::indoc! {
|
qb.push(formatdoc! {
|
||||||
"
|
"
|
||||||
\nunion
|
\nunion
|
||||||
select
|
select
|
||||||
|
@ -342,9 +371,10 @@ where
|
||||||
key,
|
key,
|
||||||
uses,
|
uses,
|
||||||
domains
|
domains
|
||||||
from api_keys where last_used >= date_trunc('minute', now())
|
from {} where last_used >= date_trunc('minute', now())
|
||||||
and (cooldown is null or now() >= cooldown)
|
and (cooldown is null or now() >= cooldown)
|
||||||
and "
|
and ",
|
||||||
|
self.table_name()
|
||||||
});
|
});
|
||||||
build_predicate(&mut qb, &selector);
|
build_predicate(&mut qb, &selector);
|
||||||
qb.push("\norder by uses limit ");
|
qb.push("\norder by uses limit ");
|
||||||
|
@ -383,15 +413,15 @@ where
|
||||||
result.extend_from_slice(slice);
|
result.extend_from_slice(slice);
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlx::query(indoc! {r#"
|
sqlx::query(&formatdoc! {r#"
|
||||||
update api_keys set
|
update {} keys set
|
||||||
uses = tmp.uses,
|
uses = tmp.uses,
|
||||||
cooldown = null,
|
cooldown = null,
|
||||||
flag = null,
|
flag = null,
|
||||||
last_used = now()
|
last_used = now()
|
||||||
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
|
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
|
||||||
where api_keys.id = tmp.id
|
where keys.id = tmp.id
|
||||||
"#})
|
"#, self.table_name()})
|
||||||
.bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
|
.bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
|
||||||
.bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
|
.bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
|
||||||
.execute(&mut *tx)
|
.execute(&mut *tx)
|
||||||
|
@ -452,7 +482,10 @@ where
|
||||||
{
|
{
|
||||||
let selector = selector.into_selector();
|
let selector = selector.into_selector();
|
||||||
|
|
||||||
let mut qb = QueryBuilder::new("update api_keys set cooldown=now() + ");
|
let mut qb = QueryBuilder::new(format!(
|
||||||
|
"update {} set cooldown=now() + ",
|
||||||
|
self.table_name()
|
||||||
|
));
|
||||||
qb.push_bind(duration);
|
qb.push_bind(duration);
|
||||||
qb.push(" where ");
|
qb.push(" where ");
|
||||||
build_predicate(&mut qb, &selector);
|
build_predicate(&mut qb, &selector);
|
||||||
|
@ -468,11 +501,13 @@ where
|
||||||
key: String,
|
key: String,
|
||||||
domains: Vec<D>,
|
domains: Vec<D>,
|
||||||
) -> Result<Self::Key, Self::Error> {
|
) -> Result<Self::Key, Self::Error> {
|
||||||
sqlx::query_as(
|
sqlx::query_as(&dbg!(formatdoc!(
|
||||||
"insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \
|
"insert into {} as api_keys(user_id, key, domains) values ($1, $2, $3)
|
||||||
constraint \"uq:api_keys.key\" do update set domains = \
|
on conflict(key) do update
|
||||||
__unique_jsonb_array(excluded.domains || api_keys.domains) returning *",
|
set domains = {}(excluded.domains || api_keys.domains) returning *",
|
||||||
)
|
self.table_name(),
|
||||||
|
self.unique_array_fn()
|
||||||
|
)))
|
||||||
.bind(user_id)
|
.bind(user_id)
|
||||||
.bind(&key)
|
.bind(&key)
|
||||||
.bind(sqlx::types::Json(domains))
|
.bind(sqlx::types::Json(domains))
|
||||||
|
@ -487,7 +522,7 @@ where
|
||||||
{
|
{
|
||||||
let selector = selector.into_selector();
|
let selector = selector.into_selector();
|
||||||
|
|
||||||
let mut qb = QueryBuilder::new("select * from api_keys where ");
|
let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
|
||||||
build_predicate(&mut qb, &selector);
|
build_predicate(&mut qb, &selector);
|
||||||
|
|
||||||
qb.build_query_as()
|
qb.build_query_as()
|
||||||
|
@ -502,7 +537,7 @@ where
|
||||||
{
|
{
|
||||||
let selector = selector.into_selector();
|
let selector = selector.into_selector();
|
||||||
|
|
||||||
let mut qb = QueryBuilder::new("select * from api_keys where ");
|
let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
|
||||||
build_predicate(&mut qb, &selector);
|
build_predicate(&mut qb, &selector);
|
||||||
|
|
||||||
qb.build_query_as()
|
qb.build_query_as()
|
||||||
|
@ -517,7 +552,7 @@ where
|
||||||
{
|
{
|
||||||
let selector = selector.into_selector();
|
let selector = selector.into_selector();
|
||||||
|
|
||||||
let mut qb = QueryBuilder::new("delete from api_keys where ");
|
let mut qb = QueryBuilder::new(format!("delete from {} where ", self.table_name()));
|
||||||
build_predicate(&mut qb, &selector);
|
build_predicate(&mut qb, &selector);
|
||||||
qb.push(" returning *");
|
qb.push(" returning *");
|
||||||
|
|
||||||
|
@ -533,9 +568,11 @@ where
|
||||||
{
|
{
|
||||||
let selector = selector.into_selector();
|
let selector = selector.into_selector();
|
||||||
|
|
||||||
let mut qb = QueryBuilder::new(
|
let mut qb = QueryBuilder::new(format!(
|
||||||
"update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(",
|
"update {} set domains = {}(domains || jsonb_build_array(",
|
||||||
);
|
self.table_name(),
|
||||||
|
self.unique_array_fn()
|
||||||
|
));
|
||||||
qb.push_bind(sqlx::types::Json(domain));
|
qb.push_bind(sqlx::types::Json(domain));
|
||||||
qb.push(")) where ");
|
qb.push(")) where ");
|
||||||
build_predicate(&mut qb, &selector);
|
build_predicate(&mut qb, &selector);
|
||||||
|
@ -557,9 +594,11 @@ where
|
||||||
{
|
{
|
||||||
let selector = selector.into_selector();
|
let selector = selector.into_selector();
|
||||||
|
|
||||||
let mut qb = QueryBuilder::new(
|
let mut qb = QueryBuilder::new(format!(
|
||||||
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, ",
|
"update {} set domains = coalesce({}(domains, ",
|
||||||
);
|
self.table_name(),
|
||||||
|
self.filter_array_fn()
|
||||||
|
));
|
||||||
qb.push_bind(sqlx::types::Json(domain));
|
qb.push_bind(sqlx::types::Json(domain));
|
||||||
qb.push("), '[]'::jsonb) where ");
|
qb.push("), '[]'::jsonb) where ");
|
||||||
build_predicate(&mut qb, &selector);
|
build_predicate(&mut qb, &selector);
|
||||||
|
@ -626,7 +665,7 @@ pub(crate) mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
|
let storage = PgKeyPoolStorage::new(pool.clone(), 1000, Some("test".to_owned()));
|
||||||
storage.initialise().await.unwrap();
|
storage.initialise().await.unwrap();
|
||||||
|
|
||||||
let key = storage
|
let key = storage
|
||||||
|
@ -823,7 +862,7 @@ pub(crate) mod test {
|
||||||
set.join_next().await.unwrap().unwrap();
|
set.join_next().await.unwrap().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let uses: i16 = sqlx::query("select uses from api_keys")
|
let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
|
||||||
.fetch_one(&storage.pool)
|
.fetch_one(&storage.pool)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -831,7 +870,7 @@ pub(crate) mod test {
|
||||||
|
|
||||||
assert_eq!(uses, 100);
|
assert_eq!(uses, 100);
|
||||||
|
|
||||||
sqlx::query("update api_keys set uses=0")
|
sqlx::query(&format!("update {} set uses=0", storage.table_name()))
|
||||||
.execute(&storage.pool)
|
.execute(&storage.pool)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -871,7 +910,7 @@ pub(crate) mod test {
|
||||||
assert_eq!(key.uses, 2);
|
assert_eq!(key.uses, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlx::query("update api_keys set uses=0")
|
sqlx::query(&format!("update {} set uses=0", storage.table_name()))
|
||||||
.execute(&storage.pool)
|
.execute(&storage.pool)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -896,7 +935,7 @@ pub(crate) mod test {
|
||||||
set.join_next().await.unwrap().unwrap();
|
set.join_next().await.unwrap().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let uses: i16 = sqlx::query("select uses from api_keys")
|
let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
|
||||||
.fetch_one(&storage.pool)
|
.fetch_one(&storage.pool)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -904,7 +943,7 @@ pub(crate) mod test {
|
||||||
|
|
||||||
assert_eq!(uses, 500);
|
assert_eq!(uses, 500);
|
||||||
|
|
||||||
sqlx::query("update api_keys set uses=0")
|
sqlx::query(&format!("update {} set uses=0", storage.table_name()))
|
||||||
.execute(&storage.pool)
|
.execute(&storage.pool)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
Loading…
Reference in a new issue