From ddfbc0f7e85d3ffad7210a41dcbe455956035179 Mon Sep 17 00:00:00 2001 From: TotallyNot <44345987+TotallyNot@users.noreply.github.com> Date: Wed, 22 Feb 2023 18:54:55 +0100 Subject: [PATCH] added methods to query keys --- torn-key-pool/Cargo.toml | 2 +- torn-key-pool/src/lib.rs | 7 +++- torn-key-pool/src/postgres.rs | 77 ++++++++++++++++++++++++++++++++--- 3 files changed, 79 insertions(+), 7 deletions(-) diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index b1d4cf3..4f61abf 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "0.5.4" +version = "0.5.5" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index 72ac07c..c24be24 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -75,12 +75,17 @@ pub trait KeyPoolStorage { domains: Vec, ) -> Result; - async fn read_key(&self, key: KeySelector) -> Result; + async fn read_key(&self, key: KeySelector) + -> Result, Self::Error>; async fn read_user_keys(&self, user_id: i32) -> Result, Self::Error>; async fn remove_key(&self, key: KeySelector) -> Result; + async fn query_key(&self, domain: Self::Domain) -> Result, Self::Error>; + + async fn query_all(&self, domain: Self::Domain) -> Result, Self::Error>; + async fn add_domain_to_key( &self, key: KeySelector, diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 61229d5..967aa96 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -432,21 +432,40 @@ where .map_err(Into::into) } - async fn read_key(&self, selector: KeySelector) -> Result { + async fn read_key( + &self, + selector: KeySelector, + ) -> Result, Self::Error> { match &selector { KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1") .bind(key) .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + .await + .map_err(Into::into), KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1") .bind(id) .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + .await + .map_err(Into::into), } } + async fn query_key(&self, domain: D) -> Result, Self::Error> { + sqlx::query_as("select * from api_keys where domains @> $1 limit 1") + .bind(sqlx::types::Json(vec![domain])) + .fetch_optional(&self.pool) + .await + .map_err(Into::into) + } + + async fn query_all(&self, domain: D) -> Result, Self::Error> { + sqlx::query_as("select * from api_keys where domains @> $1") + .bind(sqlx::types::Json(vec![domain])) + .fetch_all(&self.pool) + .await + .map_err(Into::into) + } + async fn read_user_keys(&self, user_id: i32) -> Result, Self::Error> { sqlx::query_as("select * from api_keys where user_id=$1") .bind(user_id) @@ -845,4 +864,52 @@ pub(crate) mod test { .unwrap(); } } + + #[test] + async fn read_key() { + let (storage, key) = setup().await; + + let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap(); + assert!(key.is_some()); + } + + #[test] + async fn read_key_id() { + let (storage, key) = setup().await; + + let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap(); + assert!(key.is_some()); + } + + #[test] + async fn read_nonexistent_key() { + let (storage, _) = setup().await; + + let key = storage.read_key(KeySelector::Id(-1)).await.unwrap(); + assert!(key.is_none()); + } + + #[test] + async fn query_key() { + let (storage, _) = setup().await; + + let key = storage.query_key(Domain::All).await.unwrap(); + assert!(key.is_some()); + } + + #[test] + async fn query_nonexistent_key() { + let (storage, _) = setup().await; + + let key = storage.query_key(Domain::Guild { id: 0 }).await.unwrap(); + assert!(key.is_none()); + } + + #[test] + async fn query_all() { + let (storage, _) = setup().await; + + let keys = storage.query_all(Domain::All).await.unwrap(); + assert!(keys.len() == 1); + } }