changed key storage interface
This commit is contained in:
parent
cff93d4c3d
commit
91bfb08652
|
@ -1 +1,2 @@
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
format_strings = true
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "torn-key-pool"
|
name = "torn-key-pool"
|
||||||
version = "0.5.2"
|
version = "0.5.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Pyrit [2111649]"]
|
authors = ["Pyrit [2111649]"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|
|
@ -30,7 +30,11 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait ApiKey: Sync + Send {
|
pub trait ApiKey: Sync + Send {
|
||||||
|
type IdType: PartialEq + Eq + std::hash::Hash;
|
||||||
|
|
||||||
fn value(&self) -> &str;
|
fn value(&self) -> &str;
|
||||||
|
|
||||||
|
fn id(&self) -> Self::IdType;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
|
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
|
||||||
|
@ -39,6 +43,15 @@ pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum KeySelector<K>
|
||||||
|
where
|
||||||
|
K: ApiKey,
|
||||||
|
{
|
||||||
|
Key(String),
|
||||||
|
Id(K::IdType),
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait KeyPoolStorage {
|
pub trait KeyPoolStorage {
|
||||||
type Key: ApiKey;
|
type Key: ApiKey;
|
||||||
|
@ -62,27 +75,27 @@ pub trait KeyPoolStorage {
|
||||||
domains: Vec<Self::Domain>,
|
domains: Vec<Self::Domain>,
|
||||||
) -> Result<Self::Key, Self::Error>;
|
) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error>;
|
async fn read_key(&self, key: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error>;
|
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error>;
|
||||||
|
|
||||||
async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error>;
|
async fn remove_key(&self, key: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
async fn add_domain_to_key(
|
async fn add_domain_to_key(
|
||||||
&self,
|
&self,
|
||||||
key: String,
|
key: KeySelector<Self::Key>,
|
||||||
domain: Self::Domain,
|
domain: Self::Domain,
|
||||||
) -> Result<Self::Key, Self::Error>;
|
) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
async fn remove_domain_from_key(
|
async fn remove_domain_from_key(
|
||||||
&self,
|
&self,
|
||||||
key: String,
|
key: KeySelector<Self::Key>,
|
||||||
domain: Self::Domain,
|
domain: Self::Domain,
|
||||||
) -> Result<Self::Key, Self::Error>;
|
) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
async fn set_domains_for_key(
|
async fn set_domains_for_key(
|
||||||
&self,
|
&self,
|
||||||
key: String,
|
key: KeySelector<Self::Key>,
|
||||||
domains: Vec<Self::Domain>,
|
domains: Vec<Self::Domain>,
|
||||||
) -> Result<Self::Key, Self::Error>;
|
) -> Result<Self::Key, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ use indoc::indoc;
|
||||||
use sqlx::{FromRow, PgPool};
|
use sqlx::{FromRow, PgPool};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::{ApiKey, KeyDomain, KeyPoolStorage};
|
use crate::{ApiKey, KeyDomain, KeyPoolStorage, KeySelector};
|
||||||
|
|
||||||
pub trait PgKeyDomain:
|
pub trait PgKeyDomain:
|
||||||
KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
|
KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
|
||||||
|
@ -18,7 +18,7 @@ impl<T> PgKeyDomain for T where
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum PgStorageError<D>
|
pub enum PgStorageError<D>
|
||||||
where
|
where
|
||||||
D: std::fmt::Debug,
|
D: PgKeyDomain,
|
||||||
{
|
{
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Pg(#[from] sqlx::Error),
|
Pg(#[from] sqlx::Error),
|
||||||
|
@ -26,14 +26,8 @@ where
|
||||||
#[error("No key avalaible for domain {0:?}")]
|
#[error("No key avalaible for domain {0:?}")]
|
||||||
Unavailable(D),
|
Unavailable(D),
|
||||||
|
|
||||||
#[error("Duplicate key '{0}'")]
|
#[error("Key not found: '{0:?}'")]
|
||||||
DuplicateKey(String),
|
KeyNotFound(KeySelector<PgKey<D>>),
|
||||||
|
|
||||||
#[error("Duplicate domain '{0:?}'")]
|
|
||||||
DuplicateDomain(D),
|
|
||||||
|
|
||||||
#[error("Key not found: '{0}'")]
|
|
||||||
KeyNotFound(String),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, FromRow)]
|
#[derive(Debug, Clone, FromRow)]
|
||||||
|
@ -62,9 +56,17 @@ impl<D> ApiKey for PgKey<D>
|
||||||
where
|
where
|
||||||
D: PgKeyDomain,
|
D: PgKeyDomain,
|
||||||
{
|
{
|
||||||
|
type IdType = i32;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
fn value(&self) -> &str {
|
fn value(&self) -> &str {
|
||||||
&self.key
|
&self.key
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn id(&self) -> Self::IdType {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D> PgKeyPoolStorage<D>
|
impl<D> PgKeyPoolStorage<D>
|
||||||
|
@ -90,7 +92,7 @@ where
|
||||||
last_used timestamptz not null default now(),
|
last_used timestamptz not null default now(),
|
||||||
flag int2,
|
flag int2,
|
||||||
cooldown timestamptz,
|
cooldown timestamptz,
|
||||||
constraint "uq:api_keys.key+user_id" UNIQUE(user_id, key)
|
constraint "uq:api_keys.key" UNIQUE(key)
|
||||||
)"#
|
)"#
|
||||||
})
|
})
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
|
@ -108,6 +110,28 @@ where
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(indoc! {r#"
|
||||||
|
create or replace function __unique_jsonb_array(jsonb) returns jsonb
|
||||||
|
AS $$
|
||||||
|
select jsonb_agg(d::jsonb) from (
|
||||||
|
select distinct jsonb_array_elements_text($1) as d
|
||||||
|
) t
|
||||||
|
$$ language sql;
|
||||||
|
"#})
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(indoc! {r#"
|
||||||
|
create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb
|
||||||
|
AS $$
|
||||||
|
select jsonb_agg(d::jsonb) from (
|
||||||
|
select distinct jsonb_array_elements_text($1) as d
|
||||||
|
) t where d<>$2::text
|
||||||
|
$$ language sql;
|
||||||
|
"#})
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -351,10 +375,13 @@ where
|
||||||
}
|
}
|
||||||
5 => {
|
5 => {
|
||||||
// too many requests
|
// too many requests
|
||||||
sqlx::query("update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', flag=5 where id=$1")
|
sqlx::query(
|
||||||
.bind(key.id)
|
"update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', \
|
||||||
.execute(&self.pool)
|
flag=5 where id=$1",
|
||||||
.await?;
|
)
|
||||||
|
.bind(key.id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
Ok(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
8 => {
|
8 => {
|
||||||
|
@ -373,10 +400,13 @@ where
|
||||||
}
|
}
|
||||||
14 => {
|
14 => {
|
||||||
// daily read limit reached
|
// daily read limit reached
|
||||||
sqlx::query("update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', flag=14 where id=$1")
|
sqlx::query(
|
||||||
.bind(key.id)
|
"update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', \
|
||||||
.execute(&self.pool)
|
flag=14 where id=$1",
|
||||||
.await?;
|
)
|
||||||
|
.bind(key.id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
Ok(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
_ => Ok(false),
|
_ => Ok(false),
|
||||||
|
@ -390,30 +420,31 @@ where
|
||||||
domains: Vec<D>,
|
domains: Vec<D>,
|
||||||
) -> Result<Self::Key, Self::Error> {
|
) -> Result<Self::Key, Self::Error> {
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"insert into api_keys(user_id, key, domains) values ($1, $2, $3) returning *",
|
"insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \
|
||||||
|
constraint \"uq:api_keys.key\" do update set domains = \
|
||||||
|
__unique_jsonb_array(excluded.domains || api_keys.domains) returning *",
|
||||||
)
|
)
|
||||||
.bind(user_id)
|
.bind(user_id)
|
||||||
.bind(&key)
|
.bind(&key)
|
||||||
.bind(sqlx::types::Json(domains))
|
.bind(sqlx::types::Json(domains))
|
||||||
.fetch_one(&self.pool)
|
.fetch_one(&self.pool)
|
||||||
.await
|
.await
|
||||||
.map_err(|why| {
|
.map_err(Into::into)
|
||||||
if let Some(error) = why.as_database_error() {
|
|
||||||
let pg_error: &sqlx::postgres::PgDatabaseError = error.downcast_ref();
|
|
||||||
if pg_error.code() == "23505" {
|
|
||||||
return PgStorageError::DuplicateKey(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
PgStorageError::Pg(why)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error> {
|
async fn read_key(&self, selector: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error> {
|
||||||
sqlx::query_as("select * from api_keys where key=$1")
|
match &selector {
|
||||||
.bind(&key)
|
KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1")
|
||||||
.fetch_optional(&self.pool)
|
.bind(key)
|
||||||
.await?
|
.fetch_optional(&self.pool)
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(key))
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||||
|
KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1")
|
||||||
|
.bind(id)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error> {
|
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error> {
|
||||||
|
@ -424,66 +455,101 @@ where
|
||||||
.map_err(Into::into)
|
.map_err(Into::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error> {
|
async fn remove_key(&self, selector: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error> {
|
||||||
sqlx::query_as("delete from api_keys where key=$1 returning *")
|
match &selector {
|
||||||
.bind(&key)
|
KeySelector::Key(key) => {
|
||||||
.fetch_optional(&self.pool)
|
sqlx::query_as("delete from api_keys where key=$1 returning *")
|
||||||
.await?
|
.bind(key)
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(key))
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||||
|
}
|
||||||
|
KeySelector::Id(id) => sqlx::query_as("delete from api_keys where id=$1 returning *")
|
||||||
|
.bind(id)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add_domain_to_key(&self, key: String, domain: D) -> Result<Self::Key, Self::Error> {
|
async fn add_domain_to_key(
|
||||||
let mut tx = self.pool.begin().await?;
|
&self,
|
||||||
match sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
selector: KeySelector<Self::Key>,
|
||||||
"update api_keys set domains = domains || jsonb_build_array($1) where key=$2 returning *",
|
domain: D,
|
||||||
)
|
) -> Result<Self::Key, Self::Error> {
|
||||||
.bind(sqlx::types::Json(domain.clone()))
|
match &selector {
|
||||||
.bind(&key)
|
KeySelector::Key(key) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
||||||
.fetch_optional(&mut tx)
|
"update api_keys set domains = __unique_jsonb_array(domains || \
|
||||||
.await?
|
jsonb_build_array($1)) where key=$2 returning *",
|
||||||
{
|
)
|
||||||
None => Err(PgStorageError::KeyNotFound(key)),
|
.bind(sqlx::types::Json(domain))
|
||||||
Some(key) => {
|
.bind(key)
|
||||||
if key.domains.0.iter().filter(|d| **d == domain).count() > 1 {
|
.fetch_optional(&self.pool)
|
||||||
tx.rollback().await?;
|
.await?
|
||||||
return Err(PgStorageError::DuplicateDomain(domain));
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||||
}
|
KeySelector::Id(id) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
||||||
tx.commit().await?;
|
"update api_keys set domains = __unique_jsonb_array(domains || \
|
||||||
Ok(key)
|
jsonb_build_array($1)) where id=$2 returning *",
|
||||||
}
|
)
|
||||||
|
.bind(sqlx::types::Json(domain))
|
||||||
|
.bind(id)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn remove_domain_from_key(
|
async fn remove_domain_from_key(
|
||||||
&self,
|
&self,
|
||||||
key: String,
|
selector: KeySelector<Self::Key>,
|
||||||
domain: D,
|
domain: D,
|
||||||
) -> Result<Self::Key, Self::Error> {
|
) -> Result<Self::Key, Self::Error> {
|
||||||
// FIX: potential race condition
|
match &selector {
|
||||||
let api_key = self.read_key(key.clone()).await?;
|
KeySelector::Key(key) => sqlx::query_as(
|
||||||
let domains = api_key
|
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \
|
||||||
.domains
|
'[]'::jsonb) where key=$2 returning *",
|
||||||
.0
|
)
|
||||||
.into_iter()
|
.bind(sqlx::types::Json(domain))
|
||||||
.filter(|d| *d != domain)
|
.bind(key)
|
||||||
.collect();
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
self.set_domains_for_key(key, domains).await
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||||
|
KeySelector::Id(id) => sqlx::query_as(
|
||||||
|
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \
|
||||||
|
'[]'::jsonb) where id=$2 returning *",
|
||||||
|
)
|
||||||
|
.bind(sqlx::types::Json(domain))
|
||||||
|
.bind(id)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn set_domains_for_key(
|
async fn set_domains_for_key(
|
||||||
&self,
|
&self,
|
||||||
key: String,
|
selector: KeySelector<Self::Key>,
|
||||||
domains: Vec<D>,
|
domains: Vec<D>,
|
||||||
) -> Result<Self::Key, Self::Error> {
|
) -> Result<Self::Key, Self::Error> {
|
||||||
sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
match &selector {
|
||||||
"update api_keys set domains = $1 where key=$2 returning *",
|
KeySelector::Key(key) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
||||||
)
|
"update api_keys set domains = $1 where key=$2 returning *",
|
||||||
.bind(sqlx::types::Json(domains))
|
)
|
||||||
.bind(&key)
|
.bind(sqlx::types::Json(domains))
|
||||||
.fetch_optional(&self.pool)
|
.bind(key)
|
||||||
.await?
|
.fetch_optional(&self.pool)
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(key))
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||||
|
|
||||||
|
KeySelector::Id(id) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
||||||
|
"update api_keys set domains = $1 where id=$2 returning *",
|
||||||
|
)
|
||||||
|
.bind(sqlx::types::Json(domains))
|
||||||
|
.bind(id)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -516,7 +582,7 @@ pub(crate) mod test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn setup() -> PgKeyPoolStorage<Domain> {
|
pub(crate) async fn setup() -> (PgKeyPoolStorage<Domain>, PgKey<Domain>) {
|
||||||
INIT.call_once(|| {
|
INIT.call_once(|| {
|
||||||
dotenv::dotenv().ok();
|
dotenv::dotenv().ok();
|
||||||
});
|
});
|
||||||
|
@ -533,17 +599,17 @@ pub(crate) mod test {
|
||||||
let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
|
let storage = PgKeyPoolStorage::new(pool.clone(), 1000);
|
||||||
storage.initialise().await.unwrap();
|
storage.initialise().await.unwrap();
|
||||||
|
|
||||||
storage
|
let key = storage
|
||||||
.store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All])
|
.store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All])
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
storage
|
(storage, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_initialise() {
|
async fn test_initialise() {
|
||||||
let storage = setup().await;
|
let (storage, _) = setup().await;
|
||||||
|
|
||||||
if let Err(e) = storage.initialise().await {
|
if let Err(e) = storage.initialise().await {
|
||||||
panic!("Initialising key storage failed: {:?}", e);
|
panic!("Initialising key storage failed: {:?}", e);
|
||||||
|
@ -551,25 +617,43 @@ pub(crate) mod test {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_store_duplicate() {
|
async fn test_store_duplicate_key() {
|
||||||
let storage = setup().await;
|
let (storage, key) = setup().await;
|
||||||
match storage
|
let key = storage
|
||||||
.store_key(1, std::env::var("APIKEY").unwrap(), vec![])
|
.store_key(1, key.key, vec![Domain::User { id: 1 }])
|
||||||
.await
|
.await
|
||||||
.unwrap_err()
|
.unwrap();
|
||||||
{
|
|
||||||
PgStorageError::DuplicateKey(key) => {
|
assert_eq!(key.domains.0.len(), 2);
|
||||||
assert_eq!(key, std::env::var("APIKEY").unwrap())
|
}
|
||||||
}
|
|
||||||
why => panic!("Expected duplicate key error but found '{why}'"),
|
#[test]
|
||||||
};
|
async fn test_store_duplicate_key_duplicate_domain() {
|
||||||
|
let (storage, key) = setup().await;
|
||||||
|
let key = storage
|
||||||
|
.store_key(1, key.key, vec![Domain::All])
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(key.domains.0.len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_add_domain() {
|
async fn test_add_domain() {
|
||||||
let storage = setup().await;
|
let (storage, key) = setup().await;
|
||||||
let key = storage
|
let key = storage
|
||||||
.add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::User { id: 12345 })
|
.add_domain_to_key(KeySelector::Key(key.key), Domain::User { id: 12345 })
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(key.domains.0.contains(&Domain::User { id: 12345 }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_add_domain_id() {
|
||||||
|
let (storage, key) = setup().await;
|
||||||
|
let key = storage
|
||||||
|
.add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 12345 })
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -578,22 +662,56 @@ pub(crate) mod test {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_add_duplicate_domain() {
|
async fn test_add_duplicate_domain() {
|
||||||
let storage = setup().await;
|
let (storage, key) = setup().await;
|
||||||
match storage
|
let key = storage
|
||||||
.add_domain_to_key(std::env::var("APIKEY").unwrap(), Domain::All)
|
.add_domain_to_key(KeySelector::Key(key.key), Domain::All)
|
||||||
.await
|
.await
|
||||||
.unwrap_err()
|
.unwrap();
|
||||||
{
|
assert_eq!(
|
||||||
PgStorageError::DuplicateDomain(d) => assert_eq!(d, Domain::All),
|
key.domains
|
||||||
why => panic!("Expected duplicate domain error but found '{why}'"),
|
.0
|
||||||
};
|
.into_iter()
|
||||||
|
.filter(|d| *d == Domain::All)
|
||||||
|
.count(),
|
||||||
|
1
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_remove_domain() {
|
async fn test_remove_domain() {
|
||||||
let storage = setup().await;
|
let (storage, key) = setup().await;
|
||||||
|
storage
|
||||||
|
.add_domain_to_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
let key = storage
|
let key = storage
|
||||||
.remove_domain_from_key(std::env::var("APIKEY").unwrap(), Domain::All)
|
.remove_domain_from_key(KeySelector::Key(key.key.clone()), Domain::User { id: 1 })
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(key.domains.0, vec![Domain::All]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_remove_domain_id() {
|
||||||
|
let (storage, key) = setup().await;
|
||||||
|
storage
|
||||||
|
.add_domain_to_key(KeySelector::Id(key.id), Domain::User { id: 1 })
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let key = storage
|
||||||
|
.remove_domain_from_key(KeySelector::Id(key.id), Domain::User { id: 1 })
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(key.domains.0, vec![Domain::All]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_remove_last_domain() {
|
||||||
|
let (storage, key) = setup().await;
|
||||||
|
let key = storage
|
||||||
|
.remove_domain_from_key(KeySelector::Key(key.key), Domain::All)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -602,7 +720,7 @@ pub(crate) mod test {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_store_key() {
|
async fn test_store_key() {
|
||||||
let storage = setup().await;
|
let (storage, _) = setup().await;
|
||||||
let key = storage
|
let key = storage
|
||||||
.store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
|
.store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
|
||||||
.await
|
.await
|
||||||
|
@ -612,7 +730,7 @@ pub(crate) mod test {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_read_user_keys() {
|
async fn test_read_user_keys() {
|
||||||
let storage = setup().await;
|
let (storage, _) = setup().await;
|
||||||
|
|
||||||
let keys = storage.read_user_keys(1).await.unwrap();
|
let keys = storage.read_user_keys(1).await.unwrap();
|
||||||
assert_eq!(keys.len(), 1);
|
assert_eq!(keys.len(), 1);
|
||||||
|
@ -620,7 +738,7 @@ pub(crate) mod test {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn acquire_one() {
|
async fn acquire_one() {
|
||||||
let storage = setup().await;
|
let (storage, _) = setup().await;
|
||||||
|
|
||||||
if let Err(e) = storage.acquire_key(Domain::All).await {
|
if let Err(e) = storage.acquire_key(Domain::All).await {
|
||||||
panic!("Acquiring key failed: {:?}", e);
|
panic!("Acquiring key failed: {:?}", e);
|
||||||
|
@ -629,11 +747,7 @@ pub(crate) mod test {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_flag_key_one() {
|
async fn test_flag_key_one() {
|
||||||
let storage = setup().await;
|
let (storage, key) = setup().await;
|
||||||
let key = storage
|
|
||||||
.read_key(std::env::var("APIKEY").unwrap())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(storage.flag_key(key, 2).await.unwrap());
|
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||||
|
|
||||||
|
@ -645,11 +759,7 @@ pub(crate) mod test {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_flag_key_many() {
|
async fn test_flag_key_many() {
|
||||||
let storage = setup().await;
|
let (storage, key) = setup().await;
|
||||||
let key = storage
|
|
||||||
.read_key(std::env::var("APIKEY").unwrap())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(storage.flag_key(key, 2).await.unwrap());
|
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||||
|
|
||||||
|
@ -661,7 +771,7 @@ pub(crate) mod test {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn acquire_many() {
|
async fn acquire_many() {
|
||||||
let storage = setup().await;
|
let (storage, _) = setup().await;
|
||||||
|
|
||||||
match storage.acquire_many_keys(Domain::All, 30).await {
|
match storage.acquire_many_keys(Domain::All, 30).await {
|
||||||
Err(e) => panic!("Acquiring key failed: {:?}", e),
|
Err(e) => panic!("Acquiring key failed: {:?}", e),
|
||||||
|
@ -669,9 +779,10 @@ pub(crate) mod test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HACK: this test is time sensitive and will fail if runs at the top of the minute
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_concurrent() {
|
async fn test_concurrent() {
|
||||||
let storage = Arc::new(setup().await);
|
let storage = Arc::new(setup().await.0);
|
||||||
|
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
let mut set = tokio::task::JoinSet::new();
|
let mut set = tokio::task::JoinSet::new();
|
||||||
|
@ -702,9 +813,10 @@ pub(crate) mod test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HACK: this test is time sensitive and will fail if runs at the top of the minute
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_concurrent_many() {
|
async fn test_concurrent_many() {
|
||||||
let storage = Arc::new(setup().await);
|
let storage = Arc::new(setup().await.0);
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
let mut set = tokio::task::JoinSet::new();
|
let mut set = tokio::task::JoinSet::new();
|
||||||
|
|
||||||
|
|
|
@ -179,7 +179,7 @@ mod test {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_pool_request() {
|
async fn test_pool_request() {
|
||||||
let storage = setup().await;
|
let (storage, _) = setup().await;
|
||||||
let pool = KeyPool::new(
|
let pool = KeyPool::new(
|
||||||
reqwest::Client::default(),
|
reqwest::Client::default(),
|
||||||
storage,
|
storage,
|
||||||
|
@ -192,7 +192,7 @@ mod test {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
async fn test_with_storage_request() {
|
async fn test_with_storage_request() {
|
||||||
let storage = setup().await;
|
let (storage, _) = setup().await;
|
||||||
|
|
||||||
let response = reqwest::Client::new()
|
let response = reqwest::Client::new()
|
||||||
.with_storage(&storage, Domain::All)
|
.with_storage(&storage, Domain::All)
|
||||||
|
|
Loading…
Reference in a new issue