expanded key storage api

This commit is contained in:
TotallyNot 2023-01-26 19:32:47 +01:00
parent 0993f56489
commit 1f43b186a8
5 changed files with 98 additions and 35 deletions

View file

@ -1,14 +1,13 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.5.0" version = "0.5.1"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"]
license = "MIT" license = "MIT"
repository = "https://github.com/TotallyNot/torn-api.rs.git" repository = "https://github.com/TotallyNot/torn-api.rs.git"
homepage = "https://github.com/TotallyNot/torn-api.rs.git" homepage = "https://github.com/TotallyNot/torn-api.rs.git"
description = "A generalised API key pool for torn-api" description = "A generalised API key pool for torn-api"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features] [features]
default = [ "postgres", "tokio-runtime" ] default = [ "postgres", "tokio-runtime" ]
postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc", "dep:serde" ] postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc", "dep:serde" ]

View file

@ -33,7 +33,11 @@ pub trait ApiKey: Sync + Send {
fn value(&self) -> &str; fn value(&self) -> &str;
} }
pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {} pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
fn fallback(&self) -> Option<Self> {
None
}
}
impl<T> KeyDomain for T where T: Clone + std::fmt::Debug + Send + Sync {} impl<T> KeyDomain for T where T: Clone + std::fmt::Debug + Send + Sync {}
@ -55,12 +59,15 @@ pub trait KeyPoolStorage {
async fn store_key( async fn store_key(
&self, &self,
user_id: i32,
key: String, key: String,
domains: Vec<Self::Domain>, domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>; ) -> Result<Self::Key, Self::Error>;
async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error>; async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error>;
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error>;
async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error>; async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error>;
async fn add_domain_to_key( async fn add_domain_to_key(

View file

@ -128,7 +128,7 @@ where
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
client: C, client: C,
storage: S, pub storage: S,
comment: Option<String>, comment: Option<String>,
} }

View file

@ -42,6 +42,7 @@ where
D: PgKeyDomain, D: PgKeyDomain,
{ {
pub id: i32, pub id: i32,
pub user_id: i32,
pub key: String, pub key: String,
pub uses: i16, pub uses: i16,
pub domains: sqlx::types::Json<Vec<D>>, pub domains: sqlx::types::Json<Vec<D>>,
@ -82,13 +83,14 @@ where
sqlx::query(indoc! {r#" sqlx::query(indoc! {r#"
CREATE TABLE IF NOT EXISTS api_keys ( CREATE TABLE IF NOT EXISTS api_keys (
id serial primary key, id serial primary key,
user_id int4 not null,
key char(16) not null, key char(16) not null,
uses int2 not null default 0, uses int2 not null default 0,
domains jsonb not null default '{}'::jsonb, domains jsonb not null default '{}'::jsonb,
last_used timestamptz not null default now(), last_used timestamptz not null default now(),
flag int2, flag int2,
cooldown timestamptz, cooldown timestamptz,
constraint "uq:api_keys.key" UNIQUE(key) constraint "uq:api_keys.key+user_id" UNIQUE(user_id, key)
)"# )"#
}) })
.execute(&self.pool) .execute(&self.pool)
@ -100,6 +102,12 @@ where
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
sqlx::query(indoc! {r#"
CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id)
"#})
.execute(&self.pool)
.await?;
Ok(()) Ok(())
} }
} }
@ -143,10 +151,14 @@ where
select select
id, id,
0::int2 as uses 0::int2 as uses
from api_keys where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 from api_keys where last_used < date_trunc('minute', now())
and (cooldown is null or now() >= cooldown)
and domains @> $1
union ( union (
select id, uses from api_keys select id, uses from api_keys
where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 where last_used >= date_trunc('minute', now())
and (cooldown is null or now() >= cooldown)
and domains @> $1
order by uses asc order by uses asc
) )
limit 1 limit 1
@ -160,6 +172,7 @@ where
api_keys.id=key.id and key.uses < $2 api_keys.id=key.id and key.uses < $2
returning returning
api_keys.id, api_keys.id,
api_keys.user_id,
api_keys.key, api_keys.key,
api_keys.uses, api_keys.uses,
api_keys.domains api_keys.domains
@ -170,17 +183,23 @@ where
.fetch_optional(&mut tx) .fetch_optional(&mut tx)
.await?; .await?;
tx.commit().await?; tx.commit().await?;
Result::<Option<Self::Key>, sqlx::Error>::Ok( Result::<Option<Self::Key>, sqlx::Error>::Ok(key)
key
)
} }
.await; .await;
match attempt { match attempt {
Ok(Some(result)) => return Ok(result), Ok(Some(result)) => return Ok(result),
Ok(None) => return Err(PgStorageError::Unavailable(domain)), Ok(None) => {
return self
.acquire_key(
domain
.fallback()
.ok_or_else(|| PgStorageError::Unavailable(domain))?,
)
.await
}
Err(error) => { Err(error) => {
if let Some(db_error) = error.as_database_error() { if let Some(db_error) = error.as_database_error() {
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref(); let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
@ -213,17 +232,23 @@ where
let mut keys: Vec<Self::Key> = sqlx::query_as(&indoc::formatdoc!( let mut keys: Vec<Self::Key> = sqlx::query_as(&indoc::formatdoc!(
r#"select r#"select
id, id,
user_id,
key, key,
0::int2 as uses, 0::int2 as uses,
domains domains
from api_keys where last_used < date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 from api_keys where last_used < date_trunc('minute', now())
and (cooldown is null or now() >= cooldown)
and domains @> $1
union union
select select
id, id,
user_id,
key, key,
uses, uses,
domains domains
from api_keys where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) and domains @> $1 from api_keys where last_used >= date_trunc('minute', now())
and (cooldown is null or now() >= cooldown)
and domains @> $1
order by uses limit $2 order by uses limit $2
"#, "#,
)) ))
@ -285,7 +310,16 @@ where
match attempt { match attempt {
Ok(Some(result)) => return Ok(result), Ok(Some(result)) => return Ok(result),
Ok(None) => return Err(Self::Error::Unavailable(domain)), Ok(None) => {
return self
.acquire_many_keys(
domain
.fallback()
.ok_or_else(|| Self::Error::Unavailable(domain))?,
number,
)
.await
}
Err(error) => { Err(error) => {
if let Some(db_error) = error.as_database_error() { if let Some(db_error) = error.as_database_error() {
let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref(); let pg_error: &sqlx::postgres::PgDatabaseError = db_error.downcast_ref();
@ -303,7 +337,6 @@ where
} }
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> { async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> {
// TODO: put keys in cooldown when appropriate
match code { match code {
2 | 10 | 13 => { 2 | 10 | 13 => {
// invalid key, owner fedded or owner inactive // invalid key, owner fedded or owner inactive
@ -350,21 +383,29 @@ where
} }
} }
async fn store_key(&self, key: String, domains: Vec<D>) -> Result<Self::Key, Self::Error> { async fn store_key(
sqlx::query_as("insert into api_keys(key, domains) values ($1, $2) returning *") &self,
.bind(&key) user_id: i32,
.bind(sqlx::types::Json(domains)) key: String,
.fetch_one(&self.pool) domains: Vec<D>,
.await ) -> Result<Self::Key, Self::Error> {
.map_err(|why| { sqlx::query_as(
if let Some(error) = why.as_database_error() { "insert into api_keys(user_id, key, domains) values ($1, $2, $3) returning *",
let pg_error: &sqlx::postgres::PgDatabaseError = error.downcast_ref(); )
if pg_error.code() == "23505" { .bind(user_id)
return PgStorageError::DuplicateKey(key); .bind(&key)
} .bind(sqlx::types::Json(domains))
.fetch_one(&self.pool)
.await
.map_err(|why| {
if let Some(error) = why.as_database_error() {
let pg_error: &sqlx::postgres::PgDatabaseError = error.downcast_ref();
if pg_error.code() == "23505" {
return PgStorageError::DuplicateKey(key);
} }
PgStorageError::Pg(why) }
}) PgStorageError::Pg(why)
})
} }
async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error> { async fn read_key(&self, key: String) -> Result<Self::Key, Self::Error> {
@ -375,6 +416,14 @@ where
.ok_or_else(|| PgStorageError::KeyNotFound(key)) .ok_or_else(|| PgStorageError::KeyNotFound(key))
} }
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error> {
sqlx::query_as("select * from api_keys where user_id=$1")
.bind(user_id)
.fetch_all(&self.pool)
.await
.map_err(Into::into)
}
async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error> { async fn remove_key(&self, key: String) -> Result<Self::Key, Self::Error> {
sqlx::query_as("delete from api_keys where key=$1 returning *") sqlx::query_as("delete from api_keys where key=$1 returning *")
.bind(&key) .bind(&key)
@ -475,7 +524,7 @@ pub(crate) mod test {
storage.initialise().await.unwrap(); storage.initialise().await.unwrap();
storage storage
.store_key(std::env::var("APIKEY").unwrap(), vec![Domain::All]) .store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All])
.await .await
.unwrap(); .unwrap();
@ -495,7 +544,7 @@ pub(crate) mod test {
async fn test_store_duplicate() { async fn test_store_duplicate() {
let storage = setup().await; let storage = setup().await;
match storage match storage
.store_key(std::env::var("APIKEY").unwrap(), vec![]) .store_key(1, std::env::var("APIKEY").unwrap(), vec![])
.await .await
.unwrap_err() .unwrap_err()
{ {
@ -545,12 +594,20 @@ pub(crate) mod test {
async fn test_store_key() { async fn test_store_key() {
let storage = setup().await; let storage = setup().await;
let key = storage let key = storage
.store_key("ABCDABCDABCDABCD".to_owned(), vec![]) .store_key(1, "ABCDABCDABCDABCD".to_owned(), vec![])
.await .await
.unwrap(); .unwrap();
assert_eq!(key.value(), "ABCDABCDABCDABCD"); assert_eq!(key.value(), "ABCDABCDABCDABCD");
} }
#[test]
async fn test_read_user_keys() {
let storage = setup().await;
let keys = storage.read_user_keys(1).await.unwrap();
assert_eq!(keys.len(), 1);
}
#[test] #[test]
async fn acquire_one() { async fn acquire_one() {
let storage = setup().await; let storage = setup().await;

View file

@ -128,7 +128,7 @@ where
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
client: C, client: C,
storage: S, pub storage: S,
comment: Option<String>, comment: Option<String>,
} }