Merge remote-tracking branch 'origin/master'

This commit is contained in:
TotallyNot 2023-07-29 17:40:10 +02:00
commit 4f3d62da95
14 changed files with 627 additions and 325 deletions

1
.gitignore vendored
View file

@ -1,3 +1,4 @@
/target /target
/Cargo.lock /Cargo.lock
.env .env
.DS_Store

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-api" name = "torn-api"
version = "0.5.13" version = "0.5.19"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license = "MIT" license = "MIT"

View file

@ -56,3 +56,108 @@ pub struct Territory {
#[serde(deserialize_with = "de_util::string_or_decimal")] #[serde(deserialize_with = "de_util::string_or_decimal")]
pub coordinate_y: rust_decimal::Decimal, pub coordinate_y: rust_decimal::Decimal,
} }
#[derive(Debug, Clone, Copy, Deserialize)]
pub enum AttackResult {
Attacked,
Mugged,
Hospitalized,
Lost,
Arrested,
Escape,
Interrupted,
Assist,
Timeout,
Stalemate,
Special,
Looted,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Attack<'a> {
pub code: &'a str,
#[serde(with = "ts_seconds")]
pub timestamp_started: DateTime<Utc>,
#[serde(with = "ts_seconds")]
pub timestamp_ended: DateTime<Utc>,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub attacker_id: Option<i32>,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub attacker_faction: Option<i32>,
pub defender_id: i32,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub defender_faction: Option<i32>,
pub result: AttackResult,
#[serde(deserialize_with = "de_util::int_is_bool")]
pub stealthed: bool,
#[cfg(feature = "decimal")]
pub respect: rust_decimal::Decimal,
#[cfg(not(feature = "decimal"))]
pub respect: f32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RespectModifiers {
pub fair_fight: f32,
pub war: f32,
pub retaliation: f32,
pub group_attack: f32,
pub overseas: f32,
pub chain_bonus: f32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AttackFull<'a> {
pub code: &'a str,
#[serde(with = "ts_seconds")]
pub timestamp_started: DateTime<Utc>,
#[serde(with = "ts_seconds")]
pub timestamp_ended: DateTime<Utc>,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub attacker_id: Option<i32>,
#[serde(deserialize_with = "de_util::empty_string_is_none")]
pub attacker_name: Option<&'a str>,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub attacker_faction: Option<i32>,
#[serde(
deserialize_with = "de_util::empty_string_is_none",
rename = "attacker_factionname"
)]
pub attacker_faction_name: Option<&'a str>,
pub defender_id: i32,
pub defender_name: &'a str,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub defender_faction: Option<i32>,
#[serde(
deserialize_with = "de_util::empty_string_is_none",
rename = "defender_factionname"
)]
pub defender_faction_name: Option<&'a str>,
pub result: AttackResult,
#[serde(deserialize_with = "de_util::int_is_bool")]
pub stealthed: bool,
#[serde(deserialize_with = "de_util::int_is_bool")]
pub raid: bool,
#[serde(deserialize_with = "de_util::int_is_bool")]
pub ranked_war: bool,
#[cfg(feature = "decimal")]
pub respect: rust_decimal::Decimal,
#[cfg(feature = "decimal")]
pub respect_loss: rust_decimal::Decimal,
#[cfg(not(feature = "decimal"))]
pub respect: f32,
#[cfg(not(feature = "decimal"))]
pub respect_loss: f32,
pub modifiers: RespectModifiers,
}

View file

@ -1,8 +1,8 @@
#![allow(unused)] #![allow(unused)]
use std::collections::BTreeMap; use std::collections::{BTreeMap, HashMap};
use chrono::{DateTime, NaiveDateTime, Utc}; use chrono::{serde::ts_nanoseconds::deserialize, DateTime, NaiveDateTime, Utc};
use serde::de::{Deserialize, Deserializer, Error, Unexpected, Visitor}; use serde::de::{Deserialize, Deserializer, Error, Unexpected, Visitor};
pub(crate) fn empty_string_is_none<'de, D>(deserializer: D) -> Result<Option<&'de str>, D::Error> pub(crate) fn empty_string_is_none<'de, D>(deserializer: D) -> Result<Option<&'de str>, D::Error>
@ -135,6 +135,62 @@ where
deserializer.deserialize_map(MapVisitor) deserializer.deserialize_map(MapVisitor)
} }
pub(crate) fn empty_dict_is_empty_array<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
where
D: Deserializer<'de>,
T: Deserialize<'de>,
{
struct ArrayVisitor<T>(std::marker::PhantomData<T>);
impl<'de, T> Visitor<'de> for ArrayVisitor<T>
where
T: Deserialize<'de>,
{
type Value = Vec<T>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "vec or empty object")
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
match map.size_hint() {
Some(0) | None => Ok(Vec::default()),
Some(len) => Err(A::Error::invalid_length(len, &"empty dict")),
}
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut result = match seq.size_hint() {
Some(len) => Vec::with_capacity(len),
None => Vec::default(),
};
while let Some(element) = seq.next_element()? {
result.push(element);
}
Ok(result)
}
}
deserializer.deserialize_any(ArrayVisitor(std::marker::PhantomData))
}
pub(crate) fn null_is_empty_dict<'de, D, K, V>(deserializer: D) -> Result<HashMap<K, V>, D::Error>
where
D: Deserializer<'de>,
K: std::hash::Hash + std::cmp::Eq + Deserialize<'de>,
V: Deserialize<'de>,
{
Ok(Option::deserialize(deserializer)?.unwrap_or_default())
}
#[cfg(feature = "decimal")] #[cfg(feature = "decimal")]
pub(crate) fn string_or_decimal<'de, D>(deserializer: D) -> Result<rust_decimal::Decimal, D::Error> pub(crate) fn string_or_decimal<'de, D>(deserializer: D) -> Result<rust_decimal::Decimal, D::Error>
where where

View file

@ -1,13 +1,13 @@
use std::collections::{BTreeMap, HashMap}; use std::collections::{BTreeMap, HashMap};
use chrono::{serde::ts_seconds, DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::Deserialize; use serde::Deserialize;
use torn_api_macros::ApiCategory; use torn_api_macros::ApiCategory;
use crate::de_util; use crate::de_util::{self, null_is_empty_dict};
pub use crate::common::{LastAction, Status, Territory}; pub use crate::common::{Attack, AttackFull, LastAction, Status, Territory};
#[derive(Debug, Clone, Copy, ApiCategory)] #[derive(Debug, Clone, Copy, ApiCategory)]
#[api(category = "faction")] #[api(category = "faction")]
@ -21,7 +21,11 @@ pub enum Selection {
#[api(type = "BTreeMap<i32, AttackFull>", field = "attacks")] #[api(type = "BTreeMap<i32, AttackFull>", field = "attacks")]
Attacks, Attacks,
#[api(type = "HashMap<String, Territory>", field = "territory")] #[api(
type = "HashMap<String, Territory>",
field = "territory",
with = "null_is_empty_dict"
)]
Territory, Territory,
} }
@ -68,115 +72,10 @@ pub struct Basic<'a> {
#[serde(deserialize_with = "de_util::datetime_map")] #[serde(deserialize_with = "de_util::datetime_map")]
pub peace: BTreeMap<i32, DateTime<Utc>>, pub peace: BTreeMap<i32, DateTime<Utc>>,
#[serde(borrow)] #[serde(borrow, deserialize_with = "de_util::empty_dict_is_empty_array")]
pub territory_wars: Vec<FactionTerritoryWar<'a>>, pub territory_wars: Vec<FactionTerritoryWar<'a>>,
} }
#[derive(Debug, Clone, Copy, Deserialize)]
pub enum AttackResult {
Attacked,
Mugged,
Hospitalized,
Lost,
Arrested,
Escape,
Interrupted,
Assist,
Timeout,
Stalemate,
Special,
Looted,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Attack<'a> {
pub code: &'a str,
#[serde(with = "ts_seconds")]
pub timestamp_started: DateTime<Utc>,
#[serde(with = "ts_seconds")]
pub timestamp_ended: DateTime<Utc>,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub attacker_id: Option<i32>,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub attacker_faction: Option<i32>,
pub defender_id: i32,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub defender_faction: Option<i32>,
pub result: AttackResult,
#[serde(deserialize_with = "de_util::int_is_bool")]
pub stealthed: bool,
#[cfg(feature = "decimal")]
pub respect: rust_decimal::Decimal,
#[cfg(not(feature = "decimal"))]
pub respect: f32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RespectModifiers {
pub fair_fight: f32,
pub war: f32,
pub retaliation: f32,
pub group_attack: f32,
pub overseas: f32,
pub chain_bonus: f32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AttackFull<'a> {
pub code: &'a str,
#[serde(with = "ts_seconds")]
pub timestamp_started: DateTime<Utc>,
#[serde(with = "ts_seconds")]
pub timestamp_ended: DateTime<Utc>,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub attacker_id: Option<i32>,
#[serde(deserialize_with = "de_util::empty_string_is_none")]
pub attacker_name: Option<&'a str>,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub attacker_faction: Option<i32>,
#[serde(
deserialize_with = "de_util::empty_string_is_none",
rename = "attacker_factionname"
)]
pub attacker_faction_name: Option<&'a str>,
pub defender_id: i32,
pub defender_name: &'a str,
#[serde(deserialize_with = "de_util::empty_string_int_option")]
pub defender_faction: Option<i32>,
#[serde(
deserialize_with = "de_util::empty_string_is_none",
rename = "defender_factionname"
)]
pub defender_faction_name: Option<&'a str>,
pub result: AttackResult,
#[serde(deserialize_with = "de_util::int_is_bool")]
pub stealthed: bool,
#[serde(deserialize_with = "de_util::int_is_bool")]
pub raid: bool,
#[serde(deserialize_with = "de_util::int_is_bool")]
pub ranked_war: bool,
#[cfg(feature = "decimal")]
pub respect: rust_decimal::Decimal,
#[cfg(feature = "decimal")]
pub respect_loss: rust_decimal::Decimal,
#[cfg(not(feature = "decimal"))]
pub respect: f32,
#[cfg(not(feature = "decimal"))]
pub respect_loss: f32,
pub modifiers: RespectModifiers,
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -199,4 +98,21 @@ mod tests {
response.attacks_full().unwrap(); response.attacks_full().unwrap();
response.territory().unwrap(); response.territory().unwrap();
} }
#[async_test]
async fn destroyed_faction() {
let key = setup();
let response = Client::default()
.torn_api(key)
.faction(|b| {
b.id(8981)
.selections(&[Selection::Basic, Selection::Territory])
})
.await
.unwrap();
response.basic().unwrap();
response.territory().unwrap();
}
} }

View file

@ -122,6 +122,8 @@ pub enum FactionSelection {
Upgrades, Upgrades,
Weapons, Weapons,
Lookup, Lookup,
Caches,
CrimeExp,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
@ -171,6 +173,8 @@ pub enum TornSelection {
Timestamp, Timestamp,
Lookup, Lookup,
CityShops, CityShops,
ItemDetails,
TerritoryNames,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]

View file

@ -22,7 +22,7 @@ pub mod awc;
pub mod reqwest; pub mod reqwest;
#[cfg(feature = "__common")] #[cfg(feature = "__common")]
mod common; pub mod common;
mod de_util; mod de_util;

View file

@ -22,6 +22,12 @@ pub enum Selection {
#[api(type = "HashMap<String, TerritoryWar>", field = "territorywars")] #[api(type = "HashMap<String, TerritoryWar>", field = "territorywars")]
TerritoryWars, TerritoryWars,
#[api(type = "HashMap<String, Racket>", field = "rackets")]
Rackets,
#[api(type = "HashMap<String, Territory>", field = "territory")]
Territory,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -111,6 +117,31 @@ pub struct TerritoryWar {
pub ends: DateTime<Utc>, pub ends: DateTime<Utc>,
} }
#[derive(Debug, Clone, Deserialize)]
pub struct Racket {
pub name: String,
pub level: i16,
pub reward: String,
#[serde(with = "chrono::serde::ts_seconds")]
pub created: DateTime<Utc>,
#[serde(with = "chrono::serde::ts_seconds")]
pub changed: DateTime<Utc>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Territory {
pub sector: i16,
pub size: i16,
pub slots: i16,
pub daily_respect: i16,
pub faction: i32,
pub neighbors: Vec<String>,
pub war: Option<TerritoryWar>,
pub racket: Option<Racket>,
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -122,11 +153,32 @@ mod tests {
let response = Client::default() let response = Client::default()
.torn_api(key) .torn_api(key)
.torn(|b| b.selections(&[Selection::Competition, Selection::TerritoryWars])) .torn(|b| {
b.selections(&[
Selection::Competition,
Selection::TerritoryWars,
Selection::Rackets,
])
})
.await .await
.unwrap(); .unwrap();
response.competition().unwrap(); response.competition().unwrap();
response.territory_wars().unwrap(); response.territory_wars().unwrap();
response.rackets().unwrap();
}
#[async_test]
async fn territory() {
let key = setup();
let response = Client::default()
.torn_api(key)
.torn(|b| b.selections(&[Selection::Territory]).id("NSC"))
.await
.unwrap();
let territory = response.territory().unwrap();
assert!(territory.contains_key("NSC"));
} }
} }

View file

@ -2,12 +2,13 @@ use serde::{
de::{self, MapAccess, Visitor}, de::{self, MapAccess, Visitor},
Deserialize, Deserializer, Deserialize, Deserializer,
}; };
use std::collections::BTreeMap;
use torn_api_macros::ApiCategory; use torn_api_macros::ApiCategory;
use crate::de_util; use crate::de_util;
pub use crate::common::{LastAction, Status}; pub use crate::common::{Attack, AttackFull, LastAction, Status};
#[derive(Debug, Clone, Copy, ApiCategory)] #[derive(Debug, Clone, Copy, ApiCategory)]
#[api(category = "user")] #[api(category = "user")]
@ -22,6 +23,10 @@ pub enum Selection {
PersonalStats, PersonalStats,
#[api(type = "CriminalRecord", field = "criminalrecord")] #[api(type = "CriminalRecord", field = "criminalrecord")]
Crimes, Crimes,
#[api(type = "BTreeMap<i32, Attack>", field = "attacks")]
AttacksFull,
#[api(type = "BTreeMap<i32, AttackFull>", field = "attacks")]
Attacks,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
@ -168,10 +173,14 @@ pub enum EliminationTeam {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Competition { pub enum Competition {
Elimination { Elimination {
score: i16, score: i32,
attacks: i16, attacks: i16,
team: EliminationTeam, team: EliminationTeam,
}, },
DogTags {
score: i32,
position: Option<i32>,
},
Unknown, Unknown,
} }
@ -187,6 +196,7 @@ where
Team, Team,
Attacks, Attacks,
TeamName, TeamName,
Position,
#[serde(other)] #[serde(other)]
Ignore, Ignore,
} }
@ -194,6 +204,8 @@ where
#[derive(Deserialize)] #[derive(Deserialize)]
enum CompetitionName { enum CompetitionName {
Elimination, Elimination,
#[serde(rename = "Dog Tags")]
DogTags,
#[serde(other)] #[serde(other)]
Unknown, Unknown,
} }
@ -229,6 +241,7 @@ where
let mut score = None; let mut score = None;
let mut attacks = None; let mut attacks = None;
let mut name = None; let mut name = None;
let mut position = None;
while let Some(key) = map.next_key()? { while let Some(key) = map.next_key()? {
match key { match key {
@ -241,6 +254,9 @@ where
Field::Attacks => { Field::Attacks => {
attacks = Some(map.next_value()?); attacks = Some(map.next_value()?);
} }
Field::Position => {
position = Some(map.next_value()?);
}
Field::Team => { Field::Team => {
let team_raw: &str = map.next_value()?; let team_raw: &str = map.next_value()?;
team = if team_raw.is_empty() { team = if team_raw.is_empty() {
@ -299,6 +315,12 @@ where
Ok(None) Ok(None)
} }
} }
CompetitionName::DogTags => {
let score = score.ok_or_else(|| de::Error::missing_field("score"))?;
let position = position.ok_or_else(|| de::Error::missing_field("position"))?;
Ok(Some(Competition::DogTags { score, position }))
}
CompetitionName::Unknown => Ok(Some(Competition::Unknown)), CompetitionName::Unknown => Ok(Some(Competition::Unknown)),
} }
} }
@ -416,6 +438,7 @@ mod tests {
Selection::Profile, Selection::Profile,
Selection::PersonalStats, Selection::PersonalStats,
Selection::Crimes, Selection::Crimes,
Selection::Attacks,
]) ])
}) })
.await .await
@ -426,6 +449,8 @@ mod tests {
response.profile().unwrap(); response.profile().unwrap();
response.personal_stats().unwrap(); response.personal_stats().unwrap();
response.crimes().unwrap(); response.crimes().unwrap();
response.attacks().unwrap();
response.attacks_full().unwrap();
} }
#[async_test] #[async_test]

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.5.7" version = "0.6.1"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license = "MIT" license = "MIT"

View file

@ -29,8 +29,8 @@ where
Response(ResponseError), Response(ResponseError),
} }
pub trait ApiKey: Sync + Send { pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone {
type IdType: PartialEq + Eq + std::hash::Hash; type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone;
fn value(&self) -> &str; fn value(&self) -> &str;
@ -44,12 +44,65 @@ pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum KeySelector<K> pub enum KeySelector<K, D>
where where
K: ApiKey, K: ApiKey,
D: KeyDomain,
{ {
Key(String), Key(String),
Id(K::IdType), Id(K::IdType),
UserId(i32),
Has(D),
OneOf(Vec<D>),
}
impl<K, D> KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
pub(crate) fn fallback(&self) -> Option<Self> {
match self {
Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
Self::Has(domain) => domain.fallback().map(Self::Has),
Self::OneOf(domains) => {
let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
if fallbacks.is_empty() {
None
} else {
Some(Self::OneOf(fallbacks))
}
}
}
}
}
pub trait IntoSelector<K, D>: Send + Sync
where
K: ApiKey,
D: KeyDomain,
{
fn into_selector(self) -> KeySelector<K, D>;
}
impl<K, D> IntoSelector<K, D> for D
where
K: ApiKey,
D: KeyDomain,
{
fn into_selector(self) -> KeySelector<K, D> {
KeySelector::Has(self)
}
}
impl<K, D> IntoSelector<K, D> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn into_selector(self) -> KeySelector<K, D> {
self
}
} }
#[async_trait] #[async_trait]
@ -58,13 +111,17 @@ pub trait KeyPoolStorage {
type Domain: KeyDomain; type Domain: KeyDomain;
type Error: std::error::Error + Sync + Send; type Error: std::error::Error + Sync + Send;
async fn acquire_key(&self, domain: Self::Domain) -> Result<Self::Key, Self::Error>; async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn acquire_many_keys( async fn acquire_many_keys<S>(
&self, &self,
domain: Self::Domain, selector: S,
number: i64, number: i64,
) -> Result<Vec<Self::Key>, Self::Error>; ) -> Result<Vec<Self::Key>, Self::Error>
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>; async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
@ -75,34 +132,41 @@ pub trait KeyPoolStorage {
domains: Vec<Self::Domain>, domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>; ) -> Result<Self::Key, Self::Error>;
async fn read_key(&self, key: KeySelector<Self::Key>) async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
-> Result<Option<Self::Key>, Self::Error>; where
S: IntoSelector<Self::Key, Self::Domain>;
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error>; async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn remove_key(&self, key: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error>; async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn query_key(&self, domain: Self::Domain) -> Result<Option<Self::Key>, Self::Error>; async fn add_domain_to_key<S>(
async fn query_all(&self, domain: Self::Domain) -> Result<Vec<Self::Key>, Self::Error>;
async fn add_domain_to_key(
&self, &self,
key: KeySelector<Self::Key>, selector: S,
domain: Self::Domain, domain: Self::Domain,
) -> Result<Self::Key, Self::Error>; ) -> Result<Self::Key, Self::Error>
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn remove_domain_from_key( async fn remove_domain_from_key<S>(
&self, &self,
key: KeySelector<Self::Key>, selector: S,
domain: Self::Domain, domain: Self::Domain,
) -> Result<Self::Key, Self::Error>; ) -> Result<Self::Key, Self::Error>
where
S: IntoSelector<Self::Key, Self::Domain>;
async fn set_domains_for_key( async fn set_domains_for_key<S>(
&self, &self,
key: KeySelector<Self::Key>, selector: S,
domains: Vec<Self::Domain>, domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>; ) -> Result<Self::Key, Self::Error>
where
S: IntoSelector<Self::Key, Self::Domain>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -112,7 +176,7 @@ where
{ {
storage: &'a S, storage: &'a S,
comment: Option<&'a str>, comment: Option<&'a str>,
domain: S::Domain, selector: KeySelector<S::Key, S::Domain>,
_marker: std::marker::PhantomData<C>, _marker: std::marker::PhantomData<C>,
} }
@ -120,10 +184,14 @@ impl<'a, C, S> KeyPoolExecutor<'a, C, S>
where where
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
pub fn new(storage: &'a S, domain: S::Domain, comment: Option<&'a str>) -> Self { pub fn new(
storage: &'a S,
selector: KeySelector<S::Key, S::Domain>,
comment: Option<&'a str>,
) -> Self {
Self { Self {
storage, storage,
domain, selector,
comment, comment,
_marker: std::marker::PhantomData, _marker: std::marker::PhantomData,
} }

View file

@ -7,7 +7,7 @@ use torn_api::{
ApiRequest, ApiResponse, ApiSelection, ResponseError, ApiRequest, ApiResponse, ApiSelection, ResponseError,
}; };
use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, IntoSelector};
#[async_trait(?Send)] #[async_trait(?Send)]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S> impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
@ -30,7 +30,7 @@ where
loop { loop {
let key = self let key = self
.storage .storage
.acquire_key(self.domain.clone()) .acquire_key(self.selector.clone())
.await .await
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
let url = request.url(key.value(), id.as_deref()); let url = request.url(key.value(), id.as_deref());
@ -66,7 +66,7 @@ where
{ {
let keys = match self let keys = match self
.storage .storage
.acquire_many_keys(self.domain.clone(), ids.len() as i64) .acquire_many_keys(self.selector.clone(), ids.len() as i64)
.await .await
{ {
Ok(keys) => keys, Ok(keys) => keys,
@ -114,7 +114,7 @@ where
Ok(res) => return (id, Ok(res)), Ok(res) => return (id, Ok(res)),
}; };
key = match self.storage.acquire_key(self.domain.clone()).await { key = match self.storage.acquire_key(self.selector.clone()).await {
Ok(k) => k, Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
}; };
@ -150,25 +150,26 @@ where
} }
} }
pub fn torn_api(&self, domain: S::Domain) -> ApiProvider<C, KeyPoolExecutor<C, S>> { pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>> where I: IntoSelector<S::Key, S::Domain> {
ApiProvider::new( ApiProvider::new(
&self.client, &self.client,
KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()), KeyPoolExecutor::new(&self.storage, selector.into_selector(), self.comment.as_deref()),
) )
} }
} }
pub trait WithStorage { pub trait WithStorage {
fn with_storage<'a, S>( fn with_storage<'a, S, I>(
&'a self, &'a self,
storage: &'a S, storage: &'a S,
domain: S::Domain, selector: I
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>> ) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where where
Self: ApiClient + Sized, Self: ApiClient + Sized,
S: KeyPoolStorage + 'static, S: KeyPoolStorage + 'static,
I: IntoSelector<S::Key, S::Domain>
{ {
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None)) ApiProvider::new(self, KeyPoolExecutor::new(storage, selector.into_selector(), None))
} }
} }

View file

@ -1,9 +1,9 @@
use async_trait::async_trait; use async_trait::async_trait;
use indoc::indoc; use indoc::indoc;
use sqlx::{FromRow, PgPool}; use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
use thiserror::Error; use thiserror::Error;
use crate::{ApiKey, KeyDomain, KeyPoolStorage, KeySelector}; use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector};
pub trait PgKeyDomain: pub trait PgKeyDomain:
KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
@ -24,10 +24,10 @@ where
Pg(#[from] sqlx::Error), Pg(#[from] sqlx::Error),
#[error("No key avalaible for domain {0:?}")] #[error("No key avalaible for domain {0:?}")]
Unavailable(D), Unavailable(KeySelector<PgKey<D>, D>),
#[error("Key not found: '{0:?}'")] #[error("Key not found: '{0:?}'")]
KeyNotFound(KeySelector<PgKey<D>>), KeyNotFound(KeySelector<PgKey<D>, D>),
} }
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow)]
@ -42,6 +42,41 @@ where
pub domains: sqlx::types::Json<Vec<D>>, pub domains: sqlx::types::Json<Vec<D>>,
} }
#[inline(always)]
fn build_predicate<'b, D>(
builder: &mut QueryBuilder<'b, Postgres>,
selector: &'b KeySelector<PgKey<D>, D>,
) where
D: PgKeyDomain,
{
match selector {
KeySelector::Id(id) => builder.push("id=").push_bind(id),
KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
KeySelector::Key(key) => builder.push("key=").push_bind(key),
KeySelector::Has(domain) => builder
.push("domains @> ")
.push_bind(sqlx::types::Json(vec![domain])),
KeySelector::OneOf(domains) => {
if domains.is_empty() {
builder.push("false");
return;
}
for (idx, domain) in domains.iter().enumerate() {
if idx == 0 {
builder.push("(");
} else {
builder.push(" or ");
}
builder
.push("domains @> ")
.push_bind(sqlx::types::Json(vec![domain]));
}
builder.push(")")
}
};
}
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow)]
pub struct PgKeyPoolStorage<D> pub struct PgKeyPoolStorage<D>
where where
@ -160,7 +195,11 @@ where
type Error = PgStorageError<D>; type Error = PgStorageError<D>;
async fn acquire_key(&self, domain: D) -> Result<Self::Key, Self::Error> { async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
where
S: IntoSelector<Self::Key, Self::Domain>,
{
let selector = selector.into_selector();
loop { loop {
let attempt = async { let attempt = async {
let mut tx = self.pool.begin().await?; let mut tx = self.pool.begin().await?;
@ -169,22 +208,33 @@ where
.execute(&mut tx) .execute(&mut tx)
.await?; .await?;
// TODO: improve query let mut qb = QueryBuilder::new(indoc::indoc! {
let key = sqlx::query_as(&indoc::formatdoc!(
r#" r#"
with key as ( with key as (
select select
id, id,
0::int2 as uses 0::int2 as uses
from api_keys where last_used < date_trunc('minute', now()) from api_keys where last_used < date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and domains @> $1 and "#
union ( });
build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! {
"
\n union (
select id, uses from api_keys select id, uses from api_keys
where last_used >= date_trunc('minute', now()) where last_used >= date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and domains @> $1 and "
order by uses asc limit 1 });
build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! {
"
\n order by uses asc limit 1
) )
order by uses asc limit 1 order by uses asc limit 1
) )
@ -194,19 +244,21 @@ where
flag = null, flag = null,
last_used = now() last_used = now()
from key where from key where
api_keys.id=key.id and key.uses < $2 api_keys.id=key.id and key.uses < "
returning });
qb.push_bind(self.limit);
qb.push(indoc::indoc! { "
\nreturning
api_keys.id, api_keys.id,
api_keys.user_id, api_keys.user_id,
api_keys.key, api_keys.key,
api_keys.uses, api_keys.uses,
api_keys.domains api_keys.domains"
"#, });
))
.bind(sqlx::types::Json(vec![&domain])) let key = qb.build_query_as().fetch_optional(&mut tx).await?;
.bind(self.limit)
.fetch_optional(&mut tx)
.await?;
tx.commit().await?; tx.commit().await?;
@ -219,9 +271,9 @@ where
Ok(None) => { Ok(None) => {
return self return self
.acquire_key( .acquire_key(
domain selector
.fallback() .fallback()
.ok_or_else(|| PgStorageError::Unavailable(domain))?, .ok_or_else(|| PgStorageError::Unavailable(selector))?,
) )
.await .await
} }
@ -241,11 +293,15 @@ where
} }
} }
async fn acquire_many_keys( async fn acquire_many_keys<S>(
&self, &self,
domain: D, selector: S,
number: i64, number: i64,
) -> Result<Vec<Self::Key>, Self::Error> { ) -> Result<Vec<Self::Key>, Self::Error>
where
S: IntoSelector<Self::Key, Self::Domain>,
{
let selector = selector.into_selector();
loop { loop {
let attempt = async { let attempt = async {
let mut tx = self.pool.begin().await?; let mut tx = self.pool.begin().await?;
@ -254,33 +310,36 @@ where
.execute(&mut tx) .execute(&mut tx)
.await?; .await?;
let mut keys: Vec<Self::Key> = sqlx::query_as(&indoc::formatdoc!( let mut qb = QueryBuilder::new(indoc::indoc! {
r#"select r#"select
id, id,
user_id, user_id,
key, key,
0::int2 as uses, 0::int2 as uses,
domains domains
from api_keys where last_used < date_trunc('minute', now()) from api_keys where last_used < date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and domains @> $1 and "#
union });
build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! {
"
\nunion
select select
id, id,
user_id, user_id,
key, key,
uses, uses,
domains domains
from api_keys where last_used >= date_trunc('minute', now()) from api_keys where last_used >= date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and domains @> $1 and "
order by uses limit $2 });
"#, build_predicate(&mut qb, &selector);
)) qb.push("\norder by uses limit ");
.bind(sqlx::types::Json(vec![&domain])) qb.push_bind(self.limit);
.bind(number)
.fetch_all(&mut tx) let mut keys: Vec<Self::Key> = qb.build_query_as().fetch_all(&mut tx).await?;
.await?;
if keys.is_empty() { if keys.is_empty() {
tx.commit().await?; tx.commit().await?;
@ -338,9 +397,9 @@ where
Ok(None) => { Ok(None) => {
return self return self
.acquire_many_keys( .acquire_many_keys(
domain selector
.fallback() .fallback()
.ok_or_else(|| Self::Error::Unavailable(domain))?, .ok_or_else(|| Self::Error::Unavailable(selector))?,
number, number,
) )
.await .await
@ -433,143 +492,116 @@ where
.map_err(Into::into) .map_err(Into::into)
} }
async fn read_key( async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
&self, where
selector: KeySelector<Self::Key>, S: IntoSelector<Self::Key, Self::Domain>,
) -> Result<Option<Self::Key>, Self::Error> { {
match &selector { let selector = selector.into_selector();
KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1")
.bind(key)
.fetch_optional(&self.pool)
.await
.map_err(Into::into),
KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1")
.bind(id)
.fetch_optional(&self.pool)
.await
.map_err(Into::into),
}
}
async fn query_key(&self, domain: D) -> Result<Option<Self::Key>, Self::Error> { let mut qb = QueryBuilder::new("select * from api_keys where ");
sqlx::query_as("select * from api_keys where domains @> $1 limit 1") build_predicate(&mut qb, &selector);
.bind(sqlx::types::Json(vec![domain]))
qb.build_query_as()
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await .await
.map_err(Into::into) .map_err(Into::into)
} }
async fn query_all(&self, domain: D) -> Result<Vec<Self::Key>, Self::Error> { async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
sqlx::query_as("select * from api_keys where domains @> $1") where
.bind(sqlx::types::Json(vec![domain])) S: IntoSelector<Self::Key, Self::Domain>,
{
let selector = selector.into_selector();
let mut qb = QueryBuilder::new("select * from api_keys where ");
build_predicate(&mut qb, &selector);
qb.build_query_as()
.fetch_all(&self.pool) .fetch_all(&self.pool)
.await .await
.map_err(Into::into) .map_err(Into::into)
} }
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error> { async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
sqlx::query_as("select * from api_keys where user_id=$1") where
.bind(user_id) S: IntoSelector<Self::Key, Self::Domain>,
.fetch_all(&self.pool) {
.await let selector = selector.into_selector();
.map_err(Into::into)
let mut qb = QueryBuilder::new("delete from api_keys where ");
build_predicate(&mut qb, &selector);
qb.push(" returning *");
qb.build_query_as()
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
} }
async fn remove_key(&self, selector: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error> { async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
match &selector { where
KeySelector::Key(key) => { S: IntoSelector<Self::Key, Self::Domain>,
sqlx::query_as("delete from api_keys where key=$1 returning *") {
.bind(key) let selector = selector.into_selector();
.fetch_optional(&self.pool)
.await? let mut qb = QueryBuilder::new(
.ok_or_else(|| PgStorageError::KeyNotFound(selector)) "update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(",
} );
KeySelector::Id(id) => sqlx::query_as("delete from api_keys where id=$1 returning *") qb.push_bind(sqlx::types::Json(domain));
.bind(id) qb.push(")) where ");
.fetch_optional(&self.pool) build_predicate(&mut qb, &selector);
.await? qb.push(" returning *");
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
} qb.build_query_as()
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
} }
async fn add_domain_to_key( async fn remove_domain_from_key<S>(
&self, &self,
selector: KeySelector<Self::Key>, selector: S,
domain: D, domain: D,
) -> Result<Self::Key, Self::Error> { ) -> Result<Self::Key, Self::Error>
match &selector { where
KeySelector::Key(key) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>( S: IntoSelector<Self::Key, Self::Domain>,
"update api_keys set domains = __unique_jsonb_array(domains || \ {
jsonb_build_array($1)) where key=$2 returning *", let selector = selector.into_selector();
)
.bind(sqlx::types::Json(domain)) let mut qb = QueryBuilder::new(
.bind(key) "update api_keys set domains = coalesce(__filter_jsonb_array(domains, ",
);
qb.push_bind(sqlx::types::Json(domain));
qb.push("), '[]'::jsonb) where ");
build_predicate(&mut qb, &selector);
qb.push(" returning *");
qb.build_query_as()
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await? .await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)), .ok_or_else(|| PgStorageError::KeyNotFound(selector))
KeySelector::Id(id) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
"update api_keys set domains = __unique_jsonb_array(domains || \
jsonb_build_array($1)) where id=$2 returning *",
)
.bind(sqlx::types::Json(domain))
.bind(id)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
}
} }
async fn remove_domain_from_key( async fn set_domains_for_key<S>(
&self, &self,
selector: KeySelector<Self::Key>, selector: S,
domain: D,
) -> Result<Self::Key, Self::Error> {
match &selector {
KeySelector::Key(key) => sqlx::query_as(
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \
'[]'::jsonb) where key=$2 returning *",
)
.bind(sqlx::types::Json(domain))
.bind(key)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
KeySelector::Id(id) => sqlx::query_as(
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \
'[]'::jsonb) where id=$2 returning *",
)
.bind(sqlx::types::Json(domain))
.bind(id)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
}
}
async fn set_domains_for_key(
&self,
selector: KeySelector<Self::Key>,
domains: Vec<D>, domains: Vec<D>,
) -> Result<Self::Key, Self::Error> { ) -> Result<Self::Key, Self::Error>
match &selector { where
KeySelector::Key(key) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>( S: IntoSelector<Self::Key, Self::Domain>,
"update api_keys set domains = $1 where key=$2 returning *", {
) let selector = selector.into_selector();
.bind(sqlx::types::Json(domains))
.bind(key)
.fetch_optional(&self.pool)
.await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
KeySelector::Id(id) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>( let mut qb = QueryBuilder::new("update api_keys set domains = ");
"update api_keys set domains = $1 where id=$2 returning *", qb.push_bind(sqlx::types::Json(domains));
) qb.push(" where ");
.bind(sqlx::types::Json(domains)) build_predicate(&mut qb, &selector);
.bind(id) qb.push(" returning *");
qb.build_query_as()
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await? .await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)), .ok_or_else(|| PgStorageError::KeyNotFound(selector))
}
} }
} }
@ -752,7 +784,7 @@ pub(crate) mod test {
async fn test_read_user_keys() { async fn test_read_user_keys() {
let (storage, _) = setup().await; let (storage, _) = setup().await;
let keys = storage.read_user_keys(1).await.unwrap(); let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
assert_eq!(keys.len(), 1); assert_eq!(keys.len(), 1);
} }
@ -777,7 +809,7 @@ pub(crate) mod test {
_ = storage.acquire_key(Domain::All).await.unwrap(); _ = storage.acquire_key(Domain::All).await.unwrap();
} }
let keys = storage.read_user_keys(1).await.unwrap(); let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
assert_eq!(keys.len(), 2); assert_eq!(keys.len(), 2);
for key in keys { for key in keys {
assert_eq!(key.uses, 5); assert_eq!(key.uses, 5);
@ -791,7 +823,7 @@ pub(crate) mod test {
assert!(storage.flag_key(key, 2).await.unwrap()); assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_key(Domain::All).await.unwrap_err() { match storage.acquire_key(Domain::All).await.unwrap_err() {
PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All), PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))),
why => panic!("Expected domain unavailable error but found '{why}'"), why => panic!("Expected domain unavailable error but found '{why}'"),
} }
} }
@ -803,7 +835,7 @@ pub(crate) mod test {
assert!(storage.flag_key(key, 2).await.unwrap()); assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() { match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All), PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))),
why => panic!("Expected domain unavailable error but found '{why}'"), why => panic!("Expected domain unavailable error but found '{why}'"),
} }
} }
@ -877,7 +909,7 @@ pub(crate) mod test {
set.join_next().await.unwrap().unwrap(); set.join_next().await.unwrap().unwrap();
} }
let keys = storage.read_user_keys(1).await.unwrap(); let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
assert_eq!(keys.len(), 25); assert_eq!(keys.len(), 25);
@ -952,7 +984,7 @@ pub(crate) mod test {
async fn query_key() { async fn query_key() {
let (storage, _) = setup().await; let (storage, _) = setup().await;
let key = storage.query_key(Domain::All).await.unwrap(); let key = storage.read_key(Domain::All).await.unwrap();
assert!(key.is_some()); assert!(key.is_some());
} }
@ -960,7 +992,7 @@ pub(crate) mod test {
async fn query_nonexistent_key() { async fn query_nonexistent_key() {
let (storage, _) = setup().await; let (storage, _) = setup().await;
let key = storage.query_key(Domain::Guild { id: 0 }).await.unwrap(); let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
assert!(key.is_none()); assert!(key.is_none());
} }
@ -968,7 +1000,38 @@ pub(crate) mod test {
async fn query_all() { async fn query_all() {
let (storage, _) = setup().await; let (storage, _) = setup().await;
let keys = storage.query_all(Domain::All).await.unwrap(); let keys = storage.read_keys(Domain::All).await.unwrap();
assert!(keys.len() == 1); assert!(keys.len() == 1);
} }
#[test]
async fn query_by_id() {
let (storage, _) = setup().await;
let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_some());
}
#[test]
async fn query_by_key() {
let (storage, key) = setup().await;
let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
assert!(key.is_some());
}
#[test]
async fn query_by_set() {
let (storage, _key) = setup().await;
let key = storage
.read_key(KeySelector::OneOf(vec![
Domain::All,
Domain::Guild { id: 0 },
Domain::Faction { id: 0 },
]))
.await
.unwrap();
assert!(key.is_some());
}
} }

View file

@ -7,7 +7,7 @@ use torn_api::{
ApiRequest, ApiResponse, ApiSelection, ResponseError, ApiRequest, ApiResponse, ApiSelection, ResponseError,
}; };
use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; use crate::{ApiKey, IntoSelector, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
#[async_trait] #[async_trait]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S> impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
@ -30,7 +30,7 @@ where
loop { loop {
let key = self let key = self
.storage .storage
.acquire_key(self.domain.clone()) .acquire_key(self.selector.clone())
.await .await
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
let url = request.url(key.value(), id.as_deref()); let url = request.url(key.value(), id.as_deref());
@ -66,7 +66,7 @@ where
{ {
let keys = match self let keys = match self
.storage .storage
.acquire_many_keys(self.domain.clone(), ids.len() as i64) .acquire_many_keys(self.selector.clone(), ids.len() as i64)
.await .await
{ {
Ok(keys) => keys, Ok(keys) => keys,
@ -114,7 +114,7 @@ where
Ok(res) => return (id, Ok(res)), Ok(res) => return (id, Ok(res)),
}; };
key = match self.storage.acquire_key(self.domain.clone()).await { key = match self.storage.acquire_key(self.selector.clone()).await {
Ok(k) => k, Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
}; };
@ -150,25 +150,36 @@ where
} }
} }
pub fn torn_api(&self, domain: S::Domain) -> ApiProvider<C, KeyPoolExecutor<C, S>> { pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
where
I: IntoSelector<S::Key, S::Domain>,
{
ApiProvider::new( ApiProvider::new(
&self.client, &self.client,
KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()), KeyPoolExecutor::new(
&self.storage,
selector.into_selector(),
self.comment.as_deref(),
),
) )
} }
} }
pub trait WithStorage { pub trait WithStorage {
fn with_storage<'a, S>( fn with_storage<'a, S, I>(
&'a self, &'a self,
storage: &'a S, storage: &'a S,
domain: S::Domain, selector: I,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>> ) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where where
Self: ApiClient + Sized, Self: ApiClient + Sized,
S: KeyPoolStorage + Send + Sync + 'static, S: KeyPoolStorage + Send + Sync + 'static,
I: IntoSelector<S::Key, S::Domain>,
{ {
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None)) ApiProvider::new(
self,
KeyPoolExecutor::new(storage, selector.into_selector(), None),
)
} }
} }