diff --git a/.gitignore b/.gitignore index ffc3118..4bcb2ce 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target /Cargo.lock .env +.DS_Store diff --git a/torn-api/Cargo.toml b/torn-api/Cargo.toml index 539328e..69cf58b 100644 --- a/torn-api/Cargo.toml +++ b/torn-api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-api" -version = "0.5.13" +version = "0.5.19" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" diff --git a/torn-api/src/common.rs b/torn-api/src/common.rs index 3cc948a..91789ed 100644 --- a/torn-api/src/common.rs +++ b/torn-api/src/common.rs @@ -56,3 +56,108 @@ pub struct Territory { #[serde(deserialize_with = "de_util::string_or_decimal")] pub coordinate_y: rust_decimal::Decimal, } + +#[derive(Debug, Clone, Copy, Deserialize)] +pub enum AttackResult { + Attacked, + Mugged, + Hospitalized, + Lost, + Arrested, + Escape, + Interrupted, + Assist, + Timeout, + Stalemate, + Special, + Looted, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct Attack<'a> { + pub code: &'a str, + #[serde(with = "ts_seconds")] + pub timestamp_started: DateTime, + #[serde(with = "ts_seconds")] + pub timestamp_ended: DateTime, + + #[serde(deserialize_with = "de_util::empty_string_int_option")] + pub attacker_id: Option, + #[serde(deserialize_with = "de_util::empty_string_int_option")] + pub attacker_faction: Option, + pub defender_id: i32, + #[serde(deserialize_with = "de_util::empty_string_int_option")] + pub defender_faction: Option, + pub result: AttackResult, + + #[serde(deserialize_with = "de_util::int_is_bool")] + pub stealthed: bool, + + #[cfg(feature = "decimal")] + pub respect: rust_decimal::Decimal, + + #[cfg(not(feature = "decimal"))] + pub respect: f32, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct RespectModifiers { + pub fair_fight: f32, + pub war: f32, + pub retaliation: f32, + pub group_attack: f32, + pub overseas: f32, + pub chain_bonus: f32, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AttackFull<'a> { + pub code: &'a str, + #[serde(with = "ts_seconds")] + pub timestamp_started: DateTime, + #[serde(with = "ts_seconds")] + pub timestamp_ended: DateTime, + + #[serde(deserialize_with = "de_util::empty_string_int_option")] + pub attacker_id: Option, + #[serde(deserialize_with = "de_util::empty_string_is_none")] + pub attacker_name: Option<&'a str>, + #[serde(deserialize_with = "de_util::empty_string_int_option")] + pub attacker_faction: Option, + #[serde( + deserialize_with = "de_util::empty_string_is_none", + rename = "attacker_factionname" + )] + pub attacker_faction_name: Option<&'a str>, + + pub defender_id: i32, + pub defender_name: &'a str, + #[serde(deserialize_with = "de_util::empty_string_int_option")] + pub defender_faction: Option, + #[serde( + deserialize_with = "de_util::empty_string_is_none", + rename = "defender_factionname" + )] + pub defender_faction_name: Option<&'a str>, + + pub result: AttackResult, + + #[serde(deserialize_with = "de_util::int_is_bool")] + pub stealthed: bool, + #[serde(deserialize_with = "de_util::int_is_bool")] + pub raid: bool, + #[serde(deserialize_with = "de_util::int_is_bool")] + pub ranked_war: bool, + + #[cfg(feature = "decimal")] + pub respect: rust_decimal::Decimal, + #[cfg(feature = "decimal")] + pub respect_loss: rust_decimal::Decimal, + + #[cfg(not(feature = "decimal"))] + pub respect: f32, + #[cfg(not(feature = "decimal"))] + pub respect_loss: f32, + + pub modifiers: RespectModifiers, +} diff --git a/torn-api/src/de_util.rs b/torn-api/src/de_util.rs index 3f3983b..0bf658c 100644 --- a/torn-api/src/de_util.rs +++ b/torn-api/src/de_util.rs @@ -1,8 +1,8 @@ #![allow(unused)] -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; -use chrono::{DateTime, NaiveDateTime, Utc}; +use chrono::{serde::ts_nanoseconds::deserialize, DateTime, NaiveDateTime, Utc}; use serde::de::{Deserialize, Deserializer, Error, Unexpected, Visitor}; pub(crate) fn empty_string_is_none<'de, D>(deserializer: D) -> Result, D::Error> @@ -135,6 +135,62 @@ where deserializer.deserialize_map(MapVisitor) } +pub(crate) fn empty_dict_is_empty_array<'de, D, T>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, + T: Deserialize<'de>, +{ + struct ArrayVisitor(std::marker::PhantomData); + + impl<'de, T> Visitor<'de> for ArrayVisitor + where + T: Deserialize<'de>, + { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "vec or empty object") + } + + fn visit_map(self, map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + match map.size_hint() { + Some(0) | None => Ok(Vec::default()), + Some(len) => Err(A::Error::invalid_length(len, &"empty dict")), + } + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut result = match seq.size_hint() { + Some(len) => Vec::with_capacity(len), + None => Vec::default(), + }; + + while let Some(element) = seq.next_element()? { + result.push(element); + } + + Ok(result) + } + } + + deserializer.deserialize_any(ArrayVisitor(std::marker::PhantomData)) +} + +pub(crate) fn null_is_empty_dict<'de, D, K, V>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, + K: std::hash::Hash + std::cmp::Eq + Deserialize<'de>, + V: Deserialize<'de>, +{ + Ok(Option::deserialize(deserializer)?.unwrap_or_default()) +} + #[cfg(feature = "decimal")] pub(crate) fn string_or_decimal<'de, D>(deserializer: D) -> Result where diff --git a/torn-api/src/faction.rs b/torn-api/src/faction.rs index 9ae8b47..298b00e 100644 --- a/torn-api/src/faction.rs +++ b/torn-api/src/faction.rs @@ -1,13 +1,13 @@ use std::collections::{BTreeMap, HashMap}; -use chrono::{serde::ts_seconds, DateTime, Utc}; +use chrono::{DateTime, Utc}; use serde::Deserialize; use torn_api_macros::ApiCategory; -use crate::de_util; +use crate::de_util::{self, null_is_empty_dict}; -pub use crate::common::{LastAction, Status, Territory}; +pub use crate::common::{Attack, AttackFull, LastAction, Status, Territory}; #[derive(Debug, Clone, Copy, ApiCategory)] #[api(category = "faction")] @@ -21,7 +21,11 @@ pub enum Selection { #[api(type = "BTreeMap", field = "attacks")] Attacks, - #[api(type = "HashMap", field = "territory")] + #[api( + type = "HashMap", + field = "territory", + with = "null_is_empty_dict" + )] Territory, } @@ -68,115 +72,10 @@ pub struct Basic<'a> { #[serde(deserialize_with = "de_util::datetime_map")] pub peace: BTreeMap>, - #[serde(borrow)] + #[serde(borrow, deserialize_with = "de_util::empty_dict_is_empty_array")] pub territory_wars: Vec>, } -#[derive(Debug, Clone, Copy, Deserialize)] -pub enum AttackResult { - Attacked, - Mugged, - Hospitalized, - Lost, - Arrested, - Escape, - Interrupted, - Assist, - Timeout, - Stalemate, - Special, - Looted, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct Attack<'a> { - pub code: &'a str, - #[serde(with = "ts_seconds")] - pub timestamp_started: DateTime, - #[serde(with = "ts_seconds")] - pub timestamp_ended: DateTime, - - #[serde(deserialize_with = "de_util::empty_string_int_option")] - pub attacker_id: Option, - #[serde(deserialize_with = "de_util::empty_string_int_option")] - pub attacker_faction: Option, - pub defender_id: i32, - #[serde(deserialize_with = "de_util::empty_string_int_option")] - pub defender_faction: Option, - pub result: AttackResult, - - #[serde(deserialize_with = "de_util::int_is_bool")] - pub stealthed: bool, - - #[cfg(feature = "decimal")] - pub respect: rust_decimal::Decimal, - - #[cfg(not(feature = "decimal"))] - pub respect: f32, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct RespectModifiers { - pub fair_fight: f32, - pub war: f32, - pub retaliation: f32, - pub group_attack: f32, - pub overseas: f32, - pub chain_bonus: f32, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct AttackFull<'a> { - pub code: &'a str, - #[serde(with = "ts_seconds")] - pub timestamp_started: DateTime, - #[serde(with = "ts_seconds")] - pub timestamp_ended: DateTime, - - #[serde(deserialize_with = "de_util::empty_string_int_option")] - pub attacker_id: Option, - #[serde(deserialize_with = "de_util::empty_string_is_none")] - pub attacker_name: Option<&'a str>, - #[serde(deserialize_with = "de_util::empty_string_int_option")] - pub attacker_faction: Option, - #[serde( - deserialize_with = "de_util::empty_string_is_none", - rename = "attacker_factionname" - )] - pub attacker_faction_name: Option<&'a str>, - - pub defender_id: i32, - pub defender_name: &'a str, - #[serde(deserialize_with = "de_util::empty_string_int_option")] - pub defender_faction: Option, - #[serde( - deserialize_with = "de_util::empty_string_is_none", - rename = "defender_factionname" - )] - pub defender_faction_name: Option<&'a str>, - - pub result: AttackResult, - - #[serde(deserialize_with = "de_util::int_is_bool")] - pub stealthed: bool, - #[serde(deserialize_with = "de_util::int_is_bool")] - pub raid: bool, - #[serde(deserialize_with = "de_util::int_is_bool")] - pub ranked_war: bool, - - #[cfg(feature = "decimal")] - pub respect: rust_decimal::Decimal, - #[cfg(feature = "decimal")] - pub respect_loss: rust_decimal::Decimal, - - #[cfg(not(feature = "decimal"))] - pub respect: f32, - #[cfg(not(feature = "decimal"))] - pub respect_loss: f32, - - pub modifiers: RespectModifiers, -} - #[cfg(test)] mod tests { use super::*; @@ -199,4 +98,21 @@ mod tests { response.attacks_full().unwrap(); response.territory().unwrap(); } + + #[async_test] + async fn destroyed_faction() { + let key = setup(); + + let response = Client::default() + .torn_api(key) + .faction(|b| { + b.id(8981) + .selections(&[Selection::Basic, Selection::Territory]) + }) + .await + .unwrap(); + + response.basic().unwrap(); + response.territory().unwrap(); + } } diff --git a/torn-api/src/key.rs b/torn-api/src/key.rs index 0e970d2..55d22d1 100644 --- a/torn-api/src/key.rs +++ b/torn-api/src/key.rs @@ -122,6 +122,8 @@ pub enum FactionSelection { Upgrades, Weapons, Lookup, + Caches, + CrimeExp, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -171,6 +173,8 @@ pub enum TornSelection { Timestamp, Lookup, CityShops, + ItemDetails, + TerritoryNames, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] diff --git a/torn-api/src/lib.rs b/torn-api/src/lib.rs index b167d38..25e7b2f 100644 --- a/torn-api/src/lib.rs +++ b/torn-api/src/lib.rs @@ -22,7 +22,7 @@ pub mod awc; pub mod reqwest; #[cfg(feature = "__common")] -mod common; +pub mod common; mod de_util; diff --git a/torn-api/src/torn.rs b/torn-api/src/torn.rs index f118b63..20364ce 100644 --- a/torn-api/src/torn.rs +++ b/torn-api/src/torn.rs @@ -22,6 +22,12 @@ pub enum Selection { #[api(type = "HashMap", field = "territorywars")] TerritoryWars, + + #[api(type = "HashMap", field = "rackets")] + Rackets, + + #[api(type = "HashMap", field = "territory")] + Territory, } #[derive(Deserialize)] @@ -111,6 +117,31 @@ pub struct TerritoryWar { pub ends: DateTime, } +#[derive(Debug, Clone, Deserialize)] +pub struct Racket { + pub name: String, + pub level: i16, + pub reward: String, + + #[serde(with = "chrono::serde::ts_seconds")] + pub created: DateTime, + #[serde(with = "chrono::serde::ts_seconds")] + pub changed: DateTime, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct Territory { + pub sector: i16, + pub size: i16, + pub slots: i16, + pub daily_respect: i16, + pub faction: i32, + + pub neighbors: Vec, + pub war: Option, + pub racket: Option, +} + #[cfg(test)] mod tests { use super::*; @@ -122,11 +153,32 @@ mod tests { let response = Client::default() .torn_api(key) - .torn(|b| b.selections(&[Selection::Competition, Selection::TerritoryWars])) + .torn(|b| { + b.selections(&[ + Selection::Competition, + Selection::TerritoryWars, + Selection::Rackets, + ]) + }) .await .unwrap(); response.competition().unwrap(); response.territory_wars().unwrap(); + response.rackets().unwrap(); + } + + #[async_test] + async fn territory() { + let key = setup(); + + let response = Client::default() + .torn_api(key) + .torn(|b| b.selections(&[Selection::Territory]).id("NSC")) + .await + .unwrap(); + + let territory = response.territory().unwrap(); + assert!(territory.contains_key("NSC")); } } diff --git a/torn-api/src/user.rs b/torn-api/src/user.rs index 7d7094a..eb46acc 100644 --- a/torn-api/src/user.rs +++ b/torn-api/src/user.rs @@ -2,12 +2,13 @@ use serde::{ de::{self, MapAccess, Visitor}, Deserialize, Deserializer, }; +use std::collections::BTreeMap; use torn_api_macros::ApiCategory; use crate::de_util; -pub use crate::common::{LastAction, Status}; +pub use crate::common::{Attack, AttackFull, LastAction, Status}; #[derive(Debug, Clone, Copy, ApiCategory)] #[api(category = "user")] @@ -22,6 +23,10 @@ pub enum Selection { PersonalStats, #[api(type = "CriminalRecord", field = "criminalrecord")] Crimes, + #[api(type = "BTreeMap", field = "attacks")] + AttacksFull, + #[api(type = "BTreeMap", field = "attacks")] + Attacks, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] @@ -168,10 +173,14 @@ pub enum EliminationTeam { #[derive(Debug, Clone)] pub enum Competition { Elimination { - score: i16, + score: i32, attacks: i16, team: EliminationTeam, }, + DogTags { + score: i32, + position: Option, + }, Unknown, } @@ -187,6 +196,7 @@ where Team, Attacks, TeamName, + Position, #[serde(other)] Ignore, } @@ -194,6 +204,8 @@ where #[derive(Deserialize)] enum CompetitionName { Elimination, + #[serde(rename = "Dog Tags")] + DogTags, #[serde(other)] Unknown, } @@ -229,6 +241,7 @@ where let mut score = None; let mut attacks = None; let mut name = None; + let mut position = None; while let Some(key) = map.next_key()? { match key { @@ -241,6 +254,9 @@ where Field::Attacks => { attacks = Some(map.next_value()?); } + Field::Position => { + position = Some(map.next_value()?); + } Field::Team => { let team_raw: &str = map.next_value()?; team = if team_raw.is_empty() { @@ -299,6 +315,12 @@ where Ok(None) } } + CompetitionName::DogTags => { + let score = score.ok_or_else(|| de::Error::missing_field("score"))?; + let position = position.ok_or_else(|| de::Error::missing_field("position"))?; + + Ok(Some(Competition::DogTags { score, position })) + } CompetitionName::Unknown => Ok(Some(Competition::Unknown)), } } @@ -416,6 +438,7 @@ mod tests { Selection::Profile, Selection::PersonalStats, Selection::Crimes, + Selection::Attacks, ]) }) .await @@ -426,6 +449,8 @@ mod tests { response.profile().unwrap(); response.personal_stats().unwrap(); response.crimes().unwrap(); + response.attacks().unwrap(); + response.attacks_full().unwrap(); } #[async_test] diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index 2bd1e79..d2df050 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "0.5.7" +version = "0.6.1" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index c24be24..e942a3c 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -29,8 +29,8 @@ where Response(ResponseError), } -pub trait ApiKey: Sync + Send { - type IdType: PartialEq + Eq + std::hash::Hash; +pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone { + type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone; fn value(&self) -> &str; @@ -44,12 +44,65 @@ pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync { } #[derive(Debug, Clone)] -pub enum KeySelector +pub enum KeySelector where K: ApiKey, + D: KeyDomain, { Key(String), Id(K::IdType), + UserId(i32), + Has(D), + OneOf(Vec), +} + +impl KeySelector +where + K: ApiKey, + D: KeyDomain, +{ + pub(crate) fn fallback(&self) -> Option { + match self { + Self::Key(_) | Self::UserId(_) | Self::Id(_) => None, + Self::Has(domain) => domain.fallback().map(Self::Has), + Self::OneOf(domains) => { + let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect(); + if fallbacks.is_empty() { + None + } else { + Some(Self::OneOf(fallbacks)) + } + } + } + } +} + +pub trait IntoSelector: Send + Sync +where + K: ApiKey, + D: KeyDomain, +{ + fn into_selector(self) -> KeySelector; +} + +impl IntoSelector for D +where + K: ApiKey, + D: KeyDomain, +{ + fn into_selector(self) -> KeySelector { + KeySelector::Has(self) + } +} + +impl IntoSelector for KeySelector +where + K: ApiKey, + D: KeyDomain, +{ + fn into_selector(self) -> KeySelector { + self + } } #[async_trait] @@ -58,13 +111,17 @@ pub trait KeyPoolStorage { type Domain: KeyDomain; type Error: std::error::Error + Sync + Send; - async fn acquire_key(&self, domain: Self::Domain) -> Result; + async fn acquire_key(&self, selector: S) -> Result + where + S: IntoSelector; - async fn acquire_many_keys( + async fn acquire_many_keys( &self, - domain: Self::Domain, + selector: S, number: i64, - ) -> Result, Self::Error>; + ) -> Result, Self::Error> + where + S: IntoSelector; async fn flag_key(&self, key: Self::Key, code: u8) -> Result; @@ -75,34 +132,41 @@ pub trait KeyPoolStorage { domains: Vec, ) -> Result; - async fn read_key(&self, key: KeySelector) - -> Result, Self::Error>; + async fn read_key(&self, selector: S) -> Result, Self::Error> + where + S: IntoSelector; - async fn read_user_keys(&self, user_id: i32) -> Result, Self::Error>; + async fn read_keys(&self, selector: S) -> Result, Self::Error> + where + S: IntoSelector; - async fn remove_key(&self, key: KeySelector) -> Result; + async fn remove_key(&self, selector: S) -> Result + where + S: IntoSelector; - async fn query_key(&self, domain: Self::Domain) -> Result, Self::Error>; - - async fn query_all(&self, domain: Self::Domain) -> Result, Self::Error>; - - async fn add_domain_to_key( + async fn add_domain_to_key( &self, - key: KeySelector, + selector: S, domain: Self::Domain, - ) -> Result; + ) -> Result + where + S: IntoSelector; - async fn remove_domain_from_key( + async fn remove_domain_from_key( &self, - key: KeySelector, + selector: S, domain: Self::Domain, - ) -> Result; + ) -> Result + where + S: IntoSelector; - async fn set_domains_for_key( + async fn set_domains_for_key( &self, - key: KeySelector, + selector: S, domains: Vec, - ) -> Result; + ) -> Result + where + S: IntoSelector; } #[derive(Debug, Clone)] @@ -112,7 +176,7 @@ where { storage: &'a S, comment: Option<&'a str>, - domain: S::Domain, + selector: KeySelector, _marker: std::marker::PhantomData, } @@ -120,10 +184,14 @@ impl<'a, C, S> KeyPoolExecutor<'a, C, S> where S: KeyPoolStorage, { - pub fn new(storage: &'a S, domain: S::Domain, comment: Option<&'a str>) -> Self { + pub fn new( + storage: &'a S, + selector: KeySelector, + comment: Option<&'a str>, + ) -> Self { Self { storage, - domain, + selector, comment, _marker: std::marker::PhantomData, } diff --git a/torn-key-pool/src/local.rs b/torn-key-pool/src/local.rs index 8eaabf3..1a23544 100644 --- a/torn-key-pool/src/local.rs +++ b/torn-key-pool/src/local.rs @@ -7,7 +7,7 @@ use torn_api::{ ApiRequest, ApiResponse, ApiSelection, ResponseError, }; -use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; +use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, IntoSelector}; #[async_trait(?Send)] impl<'client, C, S> RequestExecutor for KeyPoolExecutor<'client, C, S> @@ -30,7 +30,7 @@ where loop { let key = self .storage - .acquire_key(self.domain.clone()) + .acquire_key(self.selector.clone()) .await .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; let url = request.url(key.value(), id.as_deref()); @@ -66,7 +66,7 @@ where { let keys = match self .storage - .acquire_many_keys(self.domain.clone(), ids.len() as i64) + .acquire_many_keys(self.selector.clone(), ids.len() as i64) .await { Ok(keys) => keys, @@ -114,7 +114,7 @@ where Ok(res) => return (id, Ok(res)), }; - key = match self.storage.acquire_key(self.domain.clone()).await { + key = match self.storage.acquire_key(self.selector.clone()).await { Ok(k) => k, Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), }; @@ -150,25 +150,26 @@ where } } - pub fn torn_api(&self, domain: S::Domain) -> ApiProvider> { + pub fn torn_api(&self, selector: I) -> ApiProvider> where I: IntoSelector { ApiProvider::new( &self.client, - KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()), + KeyPoolExecutor::new(&self.storage, selector.into_selector(), self.comment.as_deref()), ) } } pub trait WithStorage { - fn with_storage<'a, S>( + fn with_storage<'a, S, I>( &'a self, storage: &'a S, - domain: S::Domain, + selector: I ) -> ApiProvider> where Self: ApiClient + Sized, S: KeyPoolStorage + 'static, + I: IntoSelector { - ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None)) + ApiProvider::new(self, KeyPoolExecutor::new(storage, selector.into_selector(), None)) } } diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 561ad3f..c6c461f 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; use indoc::indoc; -use sqlx::{FromRow, PgPool}; +use sqlx::{FromRow, PgPool, Postgres, QueryBuilder}; use thiserror::Error; -use crate::{ApiKey, KeyDomain, KeyPoolStorage, KeySelector}; +use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector}; pub trait PgKeyDomain: KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin @@ -24,10 +24,10 @@ where Pg(#[from] sqlx::Error), #[error("No key avalaible for domain {0:?}")] - Unavailable(D), + Unavailable(KeySelector, D>), #[error("Key not found: '{0:?}'")] - KeyNotFound(KeySelector>), + KeyNotFound(KeySelector, D>), } #[derive(Debug, Clone, FromRow)] @@ -42,6 +42,41 @@ where pub domains: sqlx::types::Json>, } +#[inline(always)] +fn build_predicate<'b, D>( + builder: &mut QueryBuilder<'b, Postgres>, + selector: &'b KeySelector, D>, +) where + D: PgKeyDomain, +{ + match selector { + KeySelector::Id(id) => builder.push("id=").push_bind(id), + KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id), + KeySelector::Key(key) => builder.push("key=").push_bind(key), + KeySelector::Has(domain) => builder + .push("domains @> ") + .push_bind(sqlx::types::Json(vec![domain])), + KeySelector::OneOf(domains) => { + if domains.is_empty() { + builder.push("false"); + return; + } + + for (idx, domain) in domains.iter().enumerate() { + if idx == 0 { + builder.push("("); + } else { + builder.push(" or "); + } + builder + .push("domains @> ") + .push_bind(sqlx::types::Json(vec![domain])); + } + builder.push(")") + } + }; +} + #[derive(Debug, Clone, FromRow)] pub struct PgKeyPoolStorage where @@ -160,7 +195,11 @@ where type Error = PgStorageError; - async fn acquire_key(&self, domain: D) -> Result { + async fn acquire_key(&self, selector: S) -> Result + where + S: IntoSelector, + { + let selector = selector.into_selector(); loop { let attempt = async { let mut tx = self.pool.begin().await?; @@ -169,22 +208,33 @@ where .execute(&mut tx) .await?; - // TODO: improve query - let key = sqlx::query_as(&indoc::formatdoc!( + let mut qb = QueryBuilder::new(indoc::indoc! { r#" with key as ( select id, 0::int2 as uses from api_keys where last_used < date_trunc('minute', now()) - and (cooldown is null or now() >= cooldown) - and domains @> $1 - union ( + and (cooldown is null or now() >= cooldown) + and "# + }); + + build_predicate(&mut qb, &selector); + + qb.push(indoc::indoc! { + " + \n union ( select id, uses from api_keys where last_used >= date_trunc('minute', now()) and (cooldown is null or now() >= cooldown) - and domains @> $1 - order by uses asc limit 1 + and " + }); + + build_predicate(&mut qb, &selector); + + qb.push(indoc::indoc! { + " + \n order by uses asc limit 1 ) order by uses asc limit 1 ) @@ -194,19 +244,21 @@ where flag = null, last_used = now() from key where - api_keys.id=key.id and key.uses < $2 - returning + api_keys.id=key.id and key.uses < " + }); + + qb.push_bind(self.limit); + + qb.push(indoc::indoc! { " + \nreturning api_keys.id, api_keys.user_id, api_keys.key, api_keys.uses, - api_keys.domains - "#, - )) - .bind(sqlx::types::Json(vec![&domain])) - .bind(self.limit) - .fetch_optional(&mut tx) - .await?; + api_keys.domains" + }); + + let key = qb.build_query_as().fetch_optional(&mut tx).await?; tx.commit().await?; @@ -219,9 +271,9 @@ where Ok(None) => { return self .acquire_key( - domain + selector .fallback() - .ok_or_else(|| PgStorageError::Unavailable(domain))?, + .ok_or_else(|| PgStorageError::Unavailable(selector))?, ) .await } @@ -241,11 +293,15 @@ where } } - async fn acquire_many_keys( + async fn acquire_many_keys( &self, - domain: D, + selector: S, number: i64, - ) -> Result, Self::Error> { + ) -> Result, Self::Error> + where + S: IntoSelector, + { + let selector = selector.into_selector(); loop { let attempt = async { let mut tx = self.pool.begin().await?; @@ -254,33 +310,36 @@ where .execute(&mut tx) .await?; - let mut keys: Vec = sqlx::query_as(&indoc::formatdoc!( + let mut qb = QueryBuilder::new(indoc::indoc! { r#"select id, user_id, key, 0::int2 as uses, domains - from api_keys where last_used < date_trunc('minute', now()) - and (cooldown is null or now() >= cooldown) - and domains @> $1 - union + from api_keys where last_used < date_trunc('minute', now()) + and (cooldown is null or now() >= cooldown) + and "# + }); + build_predicate(&mut qb, &selector); + qb.push(indoc::indoc! { + " + \nunion select id, user_id, key, uses, domains - from api_keys where last_used >= date_trunc('minute', now()) - and (cooldown is null or now() >= cooldown) - and domains @> $1 - order by uses limit $2 - "#, - )) - .bind(sqlx::types::Json(vec![&domain])) - .bind(number) - .fetch_all(&mut tx) - .await?; + from api_keys where last_used >= date_trunc('minute', now()) + and (cooldown is null or now() >= cooldown) + and " + }); + build_predicate(&mut qb, &selector); + qb.push("\norder by uses limit "); + qb.push_bind(self.limit); + + let mut keys: Vec = qb.build_query_as().fetch_all(&mut tx).await?; if keys.is_empty() { tx.commit().await?; @@ -338,9 +397,9 @@ where Ok(None) => { return self .acquire_many_keys( - domain + selector .fallback() - .ok_or_else(|| Self::Error::Unavailable(domain))?, + .ok_or_else(|| Self::Error::Unavailable(selector))?, number, ) .await @@ -433,143 +492,116 @@ where .map_err(Into::into) } - async fn read_key( - &self, - selector: KeySelector, - ) -> Result, Self::Error> { - match &selector { - KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1") - .bind(key) - .fetch_optional(&self.pool) - .await - .map_err(Into::into), - KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1") - .bind(id) - .fetch_optional(&self.pool) - .await - .map_err(Into::into), - } - } + async fn read_key(&self, selector: S) -> Result, Self::Error> + where + S: IntoSelector, + { + let selector = selector.into_selector(); - async fn query_key(&self, domain: D) -> Result, Self::Error> { - sqlx::query_as("select * from api_keys where domains @> $1 limit 1") - .bind(sqlx::types::Json(vec![domain])) + let mut qb = QueryBuilder::new("select * from api_keys where "); + build_predicate(&mut qb, &selector); + + qb.build_query_as() .fetch_optional(&self.pool) .await .map_err(Into::into) } - async fn query_all(&self, domain: D) -> Result, Self::Error> { - sqlx::query_as("select * from api_keys where domains @> $1") - .bind(sqlx::types::Json(vec![domain])) + async fn read_keys(&self, selector: S) -> Result, Self::Error> + where + S: IntoSelector, + { + let selector = selector.into_selector(); + + let mut qb = QueryBuilder::new("select * from api_keys where "); + build_predicate(&mut qb, &selector); + + qb.build_query_as() .fetch_all(&self.pool) .await .map_err(Into::into) } - async fn read_user_keys(&self, user_id: i32) -> Result, Self::Error> { - sqlx::query_as("select * from api_keys where user_id=$1") - .bind(user_id) - .fetch_all(&self.pool) - .await - .map_err(Into::into) + async fn remove_key(&self, selector: S) -> Result + where + S: IntoSelector, + { + let selector = selector.into_selector(); + + let mut qb = QueryBuilder::new("delete from api_keys where "); + build_predicate(&mut qb, &selector); + qb.push(" returning *"); + + qb.build_query_as() + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)) } - async fn remove_key(&self, selector: KeySelector) -> Result { - match &selector { - KeySelector::Key(key) => { - sqlx::query_as("delete from api_keys where key=$1 returning *") - .bind(key) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)) - } - KeySelector::Id(id) => sqlx::query_as("delete from api_keys where id=$1 returning *") - .bind(id) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - } + async fn add_domain_to_key(&self, selector: S, domain: D) -> Result + where + S: IntoSelector, + { + let selector = selector.into_selector(); + + let mut qb = QueryBuilder::new( + "update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(", + ); + qb.push_bind(sqlx::types::Json(domain)); + qb.push(")) where "); + build_predicate(&mut qb, &selector); + qb.push(" returning *"); + + qb.build_query_as() + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| PgStorageError::KeyNotFound(selector)) } - async fn add_domain_to_key( + async fn remove_domain_from_key( &self, - selector: KeySelector, + selector: S, domain: D, - ) -> Result { - match &selector { - KeySelector::Key(key) => sqlx::query_as::>( - "update api_keys set domains = __unique_jsonb_array(domains || \ - jsonb_build_array($1)) where key=$2 returning *", - ) - .bind(sqlx::types::Json(domain)) - .bind(key) + ) -> Result + where + S: IntoSelector, + { + let selector = selector.into_selector(); + + let mut qb = QueryBuilder::new( + "update api_keys set domains = coalesce(__filter_jsonb_array(domains, ", + ); + qb.push_bind(sqlx::types::Json(domain)); + qb.push("), '[]'::jsonb) where "); + build_predicate(&mut qb, &selector); + qb.push(" returning *"); + + qb.build_query_as() .fetch_optional(&self.pool) .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - KeySelector::Id(id) => sqlx::query_as::>( - "update api_keys set domains = __unique_jsonb_array(domains || \ - jsonb_build_array($1)) where id=$2 returning *", - ) - .bind(sqlx::types::Json(domain)) - .bind(id) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - } + .ok_or_else(|| PgStorageError::KeyNotFound(selector)) } - async fn remove_domain_from_key( + async fn set_domains_for_key( &self, - selector: KeySelector, - domain: D, - ) -> Result { - match &selector { - KeySelector::Key(key) => sqlx::query_as( - "update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \ - '[]'::jsonb) where key=$2 returning *", - ) - .bind(sqlx::types::Json(domain)) - .bind(key) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - KeySelector::Id(id) => sqlx::query_as( - "update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \ - '[]'::jsonb) where id=$2 returning *", - ) - .bind(sqlx::types::Json(domain)) - .bind(id) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - } - } - - async fn set_domains_for_key( - &self, - selector: KeySelector, + selector: S, domains: Vec, - ) -> Result { - match &selector { - KeySelector::Key(key) => sqlx::query_as::>( - "update api_keys set domains = $1 where key=$2 returning *", - ) - .bind(sqlx::types::Json(domains)) - .bind(key) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), + ) -> Result + where + S: IntoSelector, + { + let selector = selector.into_selector(); - KeySelector::Id(id) => sqlx::query_as::>( - "update api_keys set domains = $1 where id=$2 returning *", - ) - .bind(sqlx::types::Json(domains)) - .bind(id) + let mut qb = QueryBuilder::new("update api_keys set domains = "); + qb.push_bind(sqlx::types::Json(domains)); + qb.push(" where "); + build_predicate(&mut qb, &selector); + qb.push(" returning *"); + + qb.build_query_as() .fetch_optional(&self.pool) .await? - .ok_or_else(|| PgStorageError::KeyNotFound(selector)), - } + .ok_or_else(|| PgStorageError::KeyNotFound(selector)) } } @@ -752,7 +784,7 @@ pub(crate) mod test { async fn test_read_user_keys() { let (storage, _) = setup().await; - let keys = storage.read_user_keys(1).await.unwrap(); + let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap(); assert_eq!(keys.len(), 1); } @@ -777,7 +809,7 @@ pub(crate) mod test { _ = storage.acquire_key(Domain::All).await.unwrap(); } - let keys = storage.read_user_keys(1).await.unwrap(); + let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap(); assert_eq!(keys.len(), 2); for key in keys { assert_eq!(key.uses, 5); @@ -791,7 +823,7 @@ pub(crate) mod test { assert!(storage.flag_key(key, 2).await.unwrap()); match storage.acquire_key(Domain::All).await.unwrap_err() { - PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All), + PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))), why => panic!("Expected domain unavailable error but found '{why}'"), } } @@ -803,7 +835,7 @@ pub(crate) mod test { assert!(storage.flag_key(key, 2).await.unwrap()); match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() { - PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All), + PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))), why => panic!("Expected domain unavailable error but found '{why}'"), } } @@ -877,7 +909,7 @@ pub(crate) mod test { set.join_next().await.unwrap().unwrap(); } - let keys = storage.read_user_keys(1).await.unwrap(); + let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap(); assert_eq!(keys.len(), 25); @@ -952,7 +984,7 @@ pub(crate) mod test { async fn query_key() { let (storage, _) = setup().await; - let key = storage.query_key(Domain::All).await.unwrap(); + let key = storage.read_key(Domain::All).await.unwrap(); assert!(key.is_some()); } @@ -960,7 +992,7 @@ pub(crate) mod test { async fn query_nonexistent_key() { let (storage, _) = setup().await; - let key = storage.query_key(Domain::Guild { id: 0 }).await.unwrap(); + let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap(); assert!(key.is_none()); } @@ -968,7 +1000,38 @@ pub(crate) mod test { async fn query_all() { let (storage, _) = setup().await; - let keys = storage.query_all(Domain::All).await.unwrap(); + let keys = storage.read_keys(Domain::All).await.unwrap(); assert!(keys.len() == 1); } + + #[test] + async fn query_by_id() { + let (storage, _) = setup().await; + let key = storage.read_key(KeySelector::Id(1)).await.unwrap(); + + assert!(key.is_some()); + } + + #[test] + async fn query_by_key() { + let (storage, key) = setup().await; + let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap(); + + assert!(key.is_some()); + } + + #[test] + async fn query_by_set() { + let (storage, _key) = setup().await; + let key = storage + .read_key(KeySelector::OneOf(vec![ + Domain::All, + Domain::Guild { id: 0 }, + Domain::Faction { id: 0 }, + ])) + .await + .unwrap(); + + assert!(key.is_some()); + } } diff --git a/torn-key-pool/src/send.rs b/torn-key-pool/src/send.rs index d4d9bba..d1f95b5 100644 --- a/torn-key-pool/src/send.rs +++ b/torn-key-pool/src/send.rs @@ -7,7 +7,7 @@ use torn_api::{ ApiRequest, ApiResponse, ApiSelection, ResponseError, }; -use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; +use crate::{ApiKey, IntoSelector, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; #[async_trait] impl<'client, C, S> RequestExecutor for KeyPoolExecutor<'client, C, S> @@ -30,7 +30,7 @@ where loop { let key = self .storage - .acquire_key(self.domain.clone()) + .acquire_key(self.selector.clone()) .await .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; let url = request.url(key.value(), id.as_deref()); @@ -66,7 +66,7 @@ where { let keys = match self .storage - .acquire_many_keys(self.domain.clone(), ids.len() as i64) + .acquire_many_keys(self.selector.clone(), ids.len() as i64) .await { Ok(keys) => keys, @@ -114,7 +114,7 @@ where Ok(res) => return (id, Ok(res)), }; - key = match self.storage.acquire_key(self.domain.clone()).await { + key = match self.storage.acquire_key(self.selector.clone()).await { Ok(k) => k, Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), }; @@ -150,25 +150,36 @@ where } } - pub fn torn_api(&self, domain: S::Domain) -> ApiProvider> { + pub fn torn_api(&self, selector: I) -> ApiProvider> + where + I: IntoSelector, + { ApiProvider::new( &self.client, - KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()), + KeyPoolExecutor::new( + &self.storage, + selector.into_selector(), + self.comment.as_deref(), + ), ) } } pub trait WithStorage { - fn with_storage<'a, S>( + fn with_storage<'a, S, I>( &'a self, storage: &'a S, - domain: S::Domain, + selector: I, ) -> ApiProvider> where Self: ApiClient + Sized, S: KeyPoolStorage + Send + Sync + 'static, + I: IntoSelector, { - ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None)) + ApiProvider::new( + self, + KeyPoolExecutor::new(storage, selector.into_selector(), None), + ) } }