added methods to query keys

This commit is contained in:
TotallyNot 2023-02-22 18:54:55 +01:00
parent 71eef676d3
commit ddfbc0f7e8
3 changed files with 79 additions and 7 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.5.4" version = "0.5.5"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license = "MIT" license = "MIT"

View file

@ -75,12 +75,17 @@ pub trait KeyPoolStorage {
domains: Vec<Self::Domain>, domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>; ) -> Result<Self::Key, Self::Error>;
async fn read_key(&self, key: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error>; async fn read_key(&self, key: KeySelector<Self::Key>)
-> Result<Option<Self::Key>, Self::Error>;
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error>; async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error>;
async fn remove_key(&self, key: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error>; async fn remove_key(&self, key: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error>;
async fn query_key(&self, domain: Self::Domain) -> Result<Option<Self::Key>, Self::Error>;
async fn query_all(&self, domain: Self::Domain) -> Result<Vec<Self::Key>, Self::Error>;
async fn add_domain_to_key( async fn add_domain_to_key(
&self, &self,
key: KeySelector<Self::Key>, key: KeySelector<Self::Key>,

View file

@ -432,21 +432,40 @@ where
.map_err(Into::into) .map_err(Into::into)
} }
async fn read_key(&self, selector: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error> { async fn read_key(
&self,
selector: KeySelector<Self::Key>,
) -> Result<Option<Self::Key>, Self::Error> {
match &selector { match &selector {
KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1") KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1")
.bind(key) .bind(key)
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await? .await
.ok_or_else(|| PgStorageError::KeyNotFound(selector)), .map_err(Into::into),
KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1") KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1")
.bind(id) .bind(id)
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await? .await
.ok_or_else(|| PgStorageError::KeyNotFound(selector)), .map_err(Into::into),
} }
} }
async fn query_key(&self, domain: D) -> Result<Option<Self::Key>, Self::Error> {
sqlx::query_as("select * from api_keys where domains @> $1 limit 1")
.bind(sqlx::types::Json(vec![domain]))
.fetch_optional(&self.pool)
.await
.map_err(Into::into)
}
async fn query_all(&self, domain: D) -> Result<Vec<Self::Key>, Self::Error> {
sqlx::query_as("select * from api_keys where domains @> $1")
.bind(sqlx::types::Json(vec![domain]))
.fetch_all(&self.pool)
.await
.map_err(Into::into)
}
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error> { async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error> {
sqlx::query_as("select * from api_keys where user_id=$1") sqlx::query_as("select * from api_keys where user_id=$1")
.bind(user_id) .bind(user_id)
@ -845,4 +864,52 @@ pub(crate) mod test {
.unwrap(); .unwrap();
} }
} }
#[test]
async fn read_key() {
let (storage, key) = setup().await;
let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
assert!(key.is_some());
}
#[test]
async fn read_key_id() {
let (storage, key) = setup().await;
let key = storage.read_key(KeySelector::Id(key.id)).await.unwrap();
assert!(key.is_some());
}
#[test]
async fn read_nonexistent_key() {
let (storage, _) = setup().await;
let key = storage.read_key(KeySelector::Id(-1)).await.unwrap();
assert!(key.is_none());
}
#[test]
async fn query_key() {
let (storage, _) = setup().await;
let key = storage.query_key(Domain::All).await.unwrap();
assert!(key.is_some());
}
#[test]
async fn query_nonexistent_key() {
let (storage, _) = setup().await;
let key = storage.query_key(Domain::Guild { id: 0 }).await.unwrap();
assert!(key.is_none());
}
#[test]
async fn query_all() {
let (storage, _) = setup().await;
let keys = storage.query_all(Domain::All).await.unwrap();
assert!(keys.len() == 1);
}
} }