use new IntoSelector
trait to identify keys
This commit is contained in:
parent
1ac79e3b4f
commit
6da09226a6
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "torn-key-pool"
|
||||
version = "0.5.7"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
authors = ["Pyrit [2111649]"]
|
||||
license = "MIT"
|
||||
|
|
|
@ -30,7 +30,7 @@ where
|
|||
}
|
||||
|
||||
pub trait ApiKey: Sync + Send {
|
||||
type IdType: PartialEq + Eq + std::hash::Hash;
|
||||
type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync;
|
||||
|
||||
fn value(&self) -> &str;
|
||||
|
||||
|
@ -44,12 +44,65 @@ pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KeySelector<K>
|
||||
pub enum KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
Key(String),
|
||||
Id(K::IdType),
|
||||
UserId(i32),
|
||||
Has(D),
|
||||
OneOf(Vec<D>),
|
||||
}
|
||||
|
||||
impl<K, D> KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
pub(crate) fn fallback(&self) -> Option<Self> {
|
||||
match self {
|
||||
Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
|
||||
Self::Has(domain) => domain.fallback().map(Self::Has),
|
||||
Self::OneOf(domains) => {
|
||||
let fallbacks: Vec<_> = domains.into_iter().filter_map(|d| d.fallback()).collect();
|
||||
if fallbacks.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Self::OneOf(fallbacks))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IntoSelector<K, D>: Send + Sync
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn into_selector(self) -> KeySelector<K, D>;
|
||||
}
|
||||
|
||||
impl<K, D> IntoSelector<K, D> for D
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn into_selector(self) -> KeySelector<K, D> {
|
||||
KeySelector::Has(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, D> IntoSelector<K, D> for KeySelector<K, D>
|
||||
where
|
||||
K: ApiKey,
|
||||
D: KeyDomain,
|
||||
{
|
||||
fn into_selector(self) -> KeySelector<K, D> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
@ -58,13 +111,17 @@ pub trait KeyPoolStorage {
|
|||
type Domain: KeyDomain;
|
||||
type Error: std::error::Error + Sync + Send;
|
||||
|
||||
async fn acquire_key(&self, domain: Self::Domain) -> Result<Self::Key, Self::Error>;
|
||||
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn acquire_many_keys(
|
||||
async fn acquire_many_keys<S>(
|
||||
&self,
|
||||
domain: Self::Domain,
|
||||
selector: S,
|
||||
number: i64,
|
||||
) -> Result<Vec<Self::Key>, Self::Error>;
|
||||
) -> Result<Vec<Self::Key>, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
|
||||
|
||||
|
@ -75,34 +132,41 @@ pub trait KeyPoolStorage {
|
|||
domains: Vec<Self::Domain>,
|
||||
) -> Result<Self::Key, Self::Error>;
|
||||
|
||||
async fn read_key(&self, key: KeySelector<Self::Key>)
|
||||
-> Result<Option<Self::Key>, Self::Error>;
|
||||
async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error>;
|
||||
async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn remove_key(&self, key: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error>;
|
||||
async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn query_key(&self, domain: Self::Domain) -> Result<Option<Self::Key>, Self::Error>;
|
||||
|
||||
async fn query_all(&self, domain: Self::Domain) -> Result<Vec<Self::Key>, Self::Error>;
|
||||
|
||||
async fn add_domain_to_key(
|
||||
async fn add_domain_to_key<S>(
|
||||
&self,
|
||||
key: KeySelector<Self::Key>,
|
||||
selector: S,
|
||||
domain: Self::Domain,
|
||||
) -> Result<Self::Key, Self::Error>;
|
||||
) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn remove_domain_from_key(
|
||||
async fn remove_domain_from_key<S>(
|
||||
&self,
|
||||
key: KeySelector<Self::Key>,
|
||||
selector: S,
|
||||
domain: Self::Domain,
|
||||
) -> Result<Self::Key, Self::Error>;
|
||||
) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
|
||||
async fn set_domains_for_key(
|
||||
async fn set_domains_for_key<S>(
|
||||
&self,
|
||||
key: KeySelector<Self::Key>,
|
||||
selector: S,
|
||||
domains: Vec<Self::Domain>,
|
||||
) -> Result<Self::Key, Self::Error>;
|
||||
) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use async_trait::async_trait;
|
||||
use indoc::indoc;
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{ApiKey, KeyDomain, KeyPoolStorage, KeySelector};
|
||||
use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector};
|
||||
|
||||
pub trait PgKeyDomain:
|
||||
KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
|
||||
|
@ -24,10 +24,10 @@ where
|
|||
Pg(#[from] sqlx::Error),
|
||||
|
||||
#[error("No key avalaible for domain {0:?}")]
|
||||
Unavailable(D),
|
||||
Unavailable(KeySelector<PgKey<D>, D>),
|
||||
|
||||
#[error("Key not found: '{0:?}'")]
|
||||
KeyNotFound(KeySelector<PgKey<D>>),
|
||||
KeyNotFound(KeySelector<PgKey<D>, D>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow)]
|
||||
|
@ -42,6 +42,41 @@ where
|
|||
pub domains: sqlx::types::Json<Vec<D>>,
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn build_predicate<'b, D>(
|
||||
builder: &mut QueryBuilder<'b, Postgres>,
|
||||
selector: &'b KeySelector<PgKey<D>, D>,
|
||||
) where
|
||||
D: PgKeyDomain,
|
||||
{
|
||||
match selector {
|
||||
KeySelector::Id(id) => builder.push("id=").push_bind(id),
|
||||
KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
|
||||
KeySelector::Key(key) => builder.push("key=").push_bind(key),
|
||||
KeySelector::Has(domain) => builder
|
||||
.push("domains @> ")
|
||||
.push_bind(sqlx::types::Json(vec![domain])),
|
||||
KeySelector::OneOf(domains) => {
|
||||
if domains.is_empty() {
|
||||
builder.push("false");
|
||||
return;
|
||||
}
|
||||
|
||||
for (idx, domain) in domains.iter().enumerate() {
|
||||
if idx == 0 {
|
||||
builder.push("(");
|
||||
} else {
|
||||
builder.push(" or ");
|
||||
}
|
||||
builder
|
||||
.push("domains @> ")
|
||||
.push_bind(sqlx::types::Json(vec![domain]));
|
||||
}
|
||||
builder.push(")")
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow)]
|
||||
pub struct PgKeyPoolStorage<D>
|
||||
where
|
||||
|
@ -160,7 +195,11 @@ where
|
|||
|
||||
type Error = PgStorageError<D>;
|
||||
|
||||
async fn acquire_key(&self, domain: D) -> Result<Self::Key, Self::Error> {
|
||||
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>,
|
||||
{
|
||||
let selector = selector.into_selector();
|
||||
loop {
|
||||
let attempt = async {
|
||||
let mut tx = self.pool.begin().await?;
|
||||
|
@ -169,8 +208,7 @@ where
|
|||
.execute(&mut tx)
|
||||
.await?;
|
||||
|
||||
// TODO: improve query
|
||||
let key = sqlx::query_as(&indoc::formatdoc!(
|
||||
let mut qb = QueryBuilder::new(indoc::indoc! {
|
||||
r#"
|
||||
with key as (
|
||||
select
|
||||
|
@ -178,13 +216,25 @@ where
|
|||
0::int2 as uses
|
||||
from api_keys where last_used < date_trunc('minute', now())
|
||||
and (cooldown is null or now() >= cooldown)
|
||||
and domains @> $1
|
||||
union (
|
||||
and "#
|
||||
});
|
||||
|
||||
build_predicate(&mut qb, &selector);
|
||||
|
||||
qb.push(indoc::indoc! {
|
||||
"
|
||||
\n union (
|
||||
select id, uses from api_keys
|
||||
where last_used >= date_trunc('minute', now())
|
||||
and (cooldown is null or now() >= cooldown)
|
||||
and domains @> $1
|
||||
order by uses asc limit 1
|
||||
and "
|
||||
});
|
||||
|
||||
build_predicate(&mut qb, &selector);
|
||||
|
||||
qb.push(indoc::indoc! {
|
||||
"
|
||||
\n order by uses asc limit 1
|
||||
)
|
||||
order by uses asc limit 1
|
||||
)
|
||||
|
@ -194,19 +244,21 @@ where
|
|||
flag = null,
|
||||
last_used = now()
|
||||
from key where
|
||||
api_keys.id=key.id and key.uses < $2
|
||||
returning
|
||||
api_keys.id=key.id and key.uses < "
|
||||
});
|
||||
|
||||
qb.push_bind(self.limit);
|
||||
|
||||
qb.push(indoc::indoc! { "
|
||||
\nreturning
|
||||
api_keys.id,
|
||||
api_keys.user_id,
|
||||
api_keys.key,
|
||||
api_keys.uses,
|
||||
api_keys.domains
|
||||
"#,
|
||||
))
|
||||
.bind(sqlx::types::Json(vec![&domain]))
|
||||
.bind(self.limit)
|
||||
.fetch_optional(&mut tx)
|
||||
.await?;
|
||||
api_keys.domains"
|
||||
});
|
||||
|
||||
let key = qb.build_query_as().fetch_optional(&mut tx).await?;
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
|
@ -219,9 +271,9 @@ where
|
|||
Ok(None) => {
|
||||
return self
|
||||
.acquire_key(
|
||||
domain
|
||||
selector
|
||||
.fallback()
|
||||
.ok_or_else(|| PgStorageError::Unavailable(domain))?,
|
||||
.ok_or_else(|| PgStorageError::Unavailable(selector))?,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
@ -241,11 +293,15 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
async fn acquire_many_keys(
|
||||
async fn acquire_many_keys<S>(
|
||||
&self,
|
||||
domain: D,
|
||||
selector: S,
|
||||
number: i64,
|
||||
) -> Result<Vec<Self::Key>, Self::Error> {
|
||||
) -> Result<Vec<Self::Key>, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>,
|
||||
{
|
||||
let selector = selector.into_selector();
|
||||
loop {
|
||||
let attempt = async {
|
||||
let mut tx = self.pool.begin().await?;
|
||||
|
@ -254,7 +310,7 @@ where
|
|||
.execute(&mut tx)
|
||||
.await?;
|
||||
|
||||
let mut keys: Vec<Self::Key> = sqlx::query_as(&indoc::formatdoc!(
|
||||
let mut qb = QueryBuilder::new(indoc::indoc! {
|
||||
r#"select
|
||||
id,
|
||||
user_id,
|
||||
|
@ -263,8 +319,12 @@ where
|
|||
domains
|
||||
from api_keys where last_used < date_trunc('minute', now())
|
||||
and (cooldown is null or now() >= cooldown)
|
||||
and domains @> $1
|
||||
union
|
||||
and "#
|
||||
});
|
||||
build_predicate(&mut qb, &selector);
|
||||
qb.push(indoc::indoc! {
|
||||
"
|
||||
\nunion
|
||||
select
|
||||
id,
|
||||
user_id,
|
||||
|
@ -273,14 +333,13 @@ where
|
|||
domains
|
||||
from api_keys where last_used >= date_trunc('minute', now())
|
||||
and (cooldown is null or now() >= cooldown)
|
||||
and domains @> $1
|
||||
order by uses limit $2
|
||||
"#,
|
||||
))
|
||||
.bind(sqlx::types::Json(vec![&domain]))
|
||||
.bind(number)
|
||||
.fetch_all(&mut tx)
|
||||
.await?;
|
||||
and "
|
||||
});
|
||||
build_predicate(&mut qb, &selector);
|
||||
qb.push("\norder by uses limit ");
|
||||
qb.push_bind(self.limit);
|
||||
|
||||
let mut keys: Vec<Self::Key> = qb.build_query_as().fetch_all(&mut tx).await?;
|
||||
|
||||
if keys.is_empty() {
|
||||
tx.commit().await?;
|
||||
|
@ -338,9 +397,9 @@ where
|
|||
Ok(None) => {
|
||||
return self
|
||||
.acquire_many_keys(
|
||||
domain
|
||||
selector
|
||||
.fallback()
|
||||
.ok_or_else(|| Self::Error::Unavailable(domain))?,
|
||||
.ok_or_else(|| Self::Error::Unavailable(selector))?,
|
||||
number,
|
||||
)
|
||||
.await
|
||||
|
@ -433,143 +492,116 @@ where
|
|||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn read_key(
|
||||
&self,
|
||||
selector: KeySelector<Self::Key>,
|
||||
) -> Result<Option<Self::Key>, Self::Error> {
|
||||
match &selector {
|
||||
KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1")
|
||||
.bind(key)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(Into::into),
|
||||
KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1")
|
||||
.bind(id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(Into::into),
|
||||
}
|
||||
}
|
||||
async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>,
|
||||
{
|
||||
let selector = selector.into_selector();
|
||||
|
||||
async fn query_key(&self, domain: D) -> Result<Option<Self::Key>, Self::Error> {
|
||||
sqlx::query_as("select * from api_keys where domains @> $1 limit 1")
|
||||
.bind(sqlx::types::Json(vec![domain]))
|
||||
let mut qb = QueryBuilder::new("select * from api_keys where ");
|
||||
build_predicate(&mut qb, &selector);
|
||||
|
||||
qb.build_query_as()
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn query_all(&self, domain: D) -> Result<Vec<Self::Key>, Self::Error> {
|
||||
sqlx::query_as("select * from api_keys where domains @> $1")
|
||||
.bind(sqlx::types::Json(vec![domain]))
|
||||
async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>,
|
||||
{
|
||||
let selector = selector.into_selector();
|
||||
|
||||
let mut qb = QueryBuilder::new("select * from api_keys where ");
|
||||
build_predicate(&mut qb, &selector);
|
||||
|
||||
qb.build_query_as()
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error> {
|
||||
sqlx::query_as("select * from api_keys where user_id=$1")
|
||||
.bind(user_id)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>,
|
||||
{
|
||||
let selector = selector.into_selector();
|
||||
|
||||
async fn remove_key(&self, selector: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error> {
|
||||
match &selector {
|
||||
KeySelector::Key(key) => {
|
||||
sqlx::query_as("delete from api_keys where key=$1 returning *")
|
||||
.bind(key)
|
||||
let mut qb = QueryBuilder::new("delete from api_keys where ");
|
||||
build_predicate(&mut qb, &selector);
|
||||
qb.push(" returning *");
|
||||
|
||||
qb.build_query_as()
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||
}
|
||||
KeySelector::Id(id) => sqlx::query_as("delete from api_keys where id=$1 returning *")
|
||||
.bind(id)
|
||||
|
||||
async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>,
|
||||
{
|
||||
let selector = selector.into_selector();
|
||||
|
||||
let mut qb = QueryBuilder::new(
|
||||
"update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(",
|
||||
);
|
||||
qb.push_bind(sqlx::types::Json(domain));
|
||||
qb.push(")) where ");
|
||||
build_predicate(&mut qb, &selector);
|
||||
qb.push(" returning *");
|
||||
|
||||
qb.build_query_as()
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||
}
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||
}
|
||||
|
||||
async fn add_domain_to_key(
|
||||
async fn remove_domain_from_key<S>(
|
||||
&self,
|
||||
selector: KeySelector<Self::Key>,
|
||||
selector: S,
|
||||
domain: D,
|
||||
) -> Result<Self::Key, Self::Error> {
|
||||
match &selector {
|
||||
KeySelector::Key(key) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
||||
"update api_keys set domains = __unique_jsonb_array(domains || \
|
||||
jsonb_build_array($1)) where key=$2 returning *",
|
||||
)
|
||||
.bind(sqlx::types::Json(domain))
|
||||
.bind(key)
|
||||
) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>,
|
||||
{
|
||||
let selector = selector.into_selector();
|
||||
|
||||
let mut qb = QueryBuilder::new(
|
||||
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, ",
|
||||
);
|
||||
qb.push_bind(sqlx::types::Json(domain));
|
||||
qb.push("), '[]'::jsonb) where ");
|
||||
build_predicate(&mut qb, &selector);
|
||||
qb.push(" returning *");
|
||||
|
||||
qb.build_query_as()
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||
KeySelector::Id(id) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
||||
"update api_keys set domains = __unique_jsonb_array(domains || \
|
||||
jsonb_build_array($1)) where id=$2 returning *",
|
||||
)
|
||||
.bind(sqlx::types::Json(domain))
|
||||
.bind(id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||
}
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||
}
|
||||
|
||||
async fn remove_domain_from_key(
|
||||
async fn set_domains_for_key<S>(
|
||||
&self,
|
||||
selector: KeySelector<Self::Key>,
|
||||
domain: D,
|
||||
) -> Result<Self::Key, Self::Error> {
|
||||
match &selector {
|
||||
KeySelector::Key(key) => sqlx::query_as(
|
||||
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \
|
||||
'[]'::jsonb) where key=$2 returning *",
|
||||
)
|
||||
.bind(sqlx::types::Json(domain))
|
||||
.bind(key)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||
KeySelector::Id(id) => sqlx::query_as(
|
||||
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \
|
||||
'[]'::jsonb) where id=$2 returning *",
|
||||
)
|
||||
.bind(sqlx::types::Json(domain))
|
||||
.bind(id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_domains_for_key(
|
||||
&self,
|
||||
selector: KeySelector<Self::Key>,
|
||||
selector: S,
|
||||
domains: Vec<D>,
|
||||
) -> Result<Self::Key, Self::Error> {
|
||||
match &selector {
|
||||
KeySelector::Key(key) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
||||
"update api_keys set domains = $1 where key=$2 returning *",
|
||||
)
|
||||
.bind(sqlx::types::Json(domains))
|
||||
.bind(key)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||
) -> Result<Self::Key, Self::Error>
|
||||
where
|
||||
S: IntoSelector<Self::Key, Self::Domain>,
|
||||
{
|
||||
let selector = selector.into_selector();
|
||||
|
||||
KeySelector::Id(id) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
||||
"update api_keys set domains = $1 where id=$2 returning *",
|
||||
)
|
||||
.bind(sqlx::types::Json(domains))
|
||||
.bind(id)
|
||||
let mut qb = QueryBuilder::new("update api_keys set domains = ");
|
||||
qb.push_bind(sqlx::types::Json(domains));
|
||||
qb.push(" where ");
|
||||
build_predicate(&mut qb, &selector);
|
||||
qb.push(" returning *");
|
||||
|
||||
qb.build_query_as()
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||
}
|
||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -752,7 +784,7 @@ pub(crate) mod test {
|
|||
async fn test_read_user_keys() {
|
||||
let (storage, _) = setup().await;
|
||||
|
||||
let keys = storage.read_user_keys(1).await.unwrap();
|
||||
let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
|
||||
assert_eq!(keys.len(), 1);
|
||||
}
|
||||
|
||||
|
@ -777,7 +809,7 @@ pub(crate) mod test {
|
|||
_ = storage.acquire_key(Domain::All).await.unwrap();
|
||||
}
|
||||
|
||||
let keys = storage.read_user_keys(1).await.unwrap();
|
||||
let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
|
||||
assert_eq!(keys.len(), 2);
|
||||
for key in keys {
|
||||
assert_eq!(key.uses, 5);
|
||||
|
@ -791,7 +823,7 @@ pub(crate) mod test {
|
|||
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||
|
||||
match storage.acquire_key(Domain::All).await.unwrap_err() {
|
||||
PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All),
|
||||
PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))),
|
||||
why => panic!("Expected domain unavailable error but found '{why}'"),
|
||||
}
|
||||
}
|
||||
|
@ -803,7 +835,7 @@ pub(crate) mod test {
|
|||
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||
|
||||
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
|
||||
PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All),
|
||||
PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))),
|
||||
why => panic!("Expected domain unavailable error but found '{why}'"),
|
||||
}
|
||||
}
|
||||
|
@ -877,7 +909,7 @@ pub(crate) mod test {
|
|||
set.join_next().await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
let keys = storage.read_user_keys(1).await.unwrap();
|
||||
let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
|
||||
|
||||
assert_eq!(keys.len(), 25);
|
||||
|
||||
|
@ -952,7 +984,7 @@ pub(crate) mod test {
|
|||
async fn query_key() {
|
||||
let (storage, _) = setup().await;
|
||||
|
||||
let key = storage.query_key(Domain::All).await.unwrap();
|
||||
let key = storage.read_key(Domain::All).await.unwrap();
|
||||
assert!(key.is_some());
|
||||
}
|
||||
|
||||
|
@ -960,7 +992,7 @@ pub(crate) mod test {
|
|||
async fn query_nonexistent_key() {
|
||||
let (storage, _) = setup().await;
|
||||
|
||||
let key = storage.query_key(Domain::Guild { id: 0 }).await.unwrap();
|
||||
let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
|
||||
assert!(key.is_none());
|
||||
}
|
||||
|
||||
|
@ -968,7 +1000,38 @@ pub(crate) mod test {
|
|||
async fn query_all() {
|
||||
let (storage, _) = setup().await;
|
||||
|
||||
let keys = storage.query_all(Domain::All).await.unwrap();
|
||||
let keys = storage.read_keys(Domain::All).await.unwrap();
|
||||
assert!(keys.len() == 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn query_by_id() {
|
||||
let (storage, _) = setup().await;
|
||||
let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
|
||||
|
||||
assert!(key.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn query_by_key() {
|
||||
let (storage, key) = setup().await;
|
||||
let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
|
||||
|
||||
assert!(key.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn query_by_set() {
|
||||
let (storage, _key) = setup().await;
|
||||
let key = storage
|
||||
.read_key(KeySelector::OneOf(vec![
|
||||
Domain::All,
|
||||
Domain::Guild { id: 0 },
|
||||
Domain::Faction { id: 0 },
|
||||
]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(key.is_some());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue