Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
4f3d62da95
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,3 +1,4 @@
|
||||||
/target
|
/target
|
||||||
/Cargo.lock
|
/Cargo.lock
|
||||||
.env
|
.env
|
||||||
|
.DS_Store
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "torn-api"
|
name = "torn-api"
|
||||||
version = "0.5.13"
|
version = "0.5.19"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Pyrit [2111649]"]
|
authors = ["Pyrit [2111649]"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|
|
@ -56,3 +56,108 @@ pub struct Territory {
|
||||||
#[serde(deserialize_with = "de_util::string_or_decimal")]
|
#[serde(deserialize_with = "de_util::string_or_decimal")]
|
||||||
pub coordinate_y: rust_decimal::Decimal,
|
pub coordinate_y: rust_decimal::Decimal,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Deserialize)]
|
||||||
|
pub enum AttackResult {
|
||||||
|
Attacked,
|
||||||
|
Mugged,
|
||||||
|
Hospitalized,
|
||||||
|
Lost,
|
||||||
|
Arrested,
|
||||||
|
Escape,
|
||||||
|
Interrupted,
|
||||||
|
Assist,
|
||||||
|
Timeout,
|
||||||
|
Stalemate,
|
||||||
|
Special,
|
||||||
|
Looted,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Attack<'a> {
|
||||||
|
pub code: &'a str,
|
||||||
|
#[serde(with = "ts_seconds")]
|
||||||
|
pub timestamp_started: DateTime<Utc>,
|
||||||
|
#[serde(with = "ts_seconds")]
|
||||||
|
pub timestamp_ended: DateTime<Utc>,
|
||||||
|
|
||||||
|
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
||||||
|
pub attacker_id: Option<i32>,
|
||||||
|
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
||||||
|
pub attacker_faction: Option<i32>,
|
||||||
|
pub defender_id: i32,
|
||||||
|
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
||||||
|
pub defender_faction: Option<i32>,
|
||||||
|
pub result: AttackResult,
|
||||||
|
|
||||||
|
#[serde(deserialize_with = "de_util::int_is_bool")]
|
||||||
|
pub stealthed: bool,
|
||||||
|
|
||||||
|
#[cfg(feature = "decimal")]
|
||||||
|
pub respect: rust_decimal::Decimal,
|
||||||
|
|
||||||
|
#[cfg(not(feature = "decimal"))]
|
||||||
|
pub respect: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct RespectModifiers {
|
||||||
|
pub fair_fight: f32,
|
||||||
|
pub war: f32,
|
||||||
|
pub retaliation: f32,
|
||||||
|
pub group_attack: f32,
|
||||||
|
pub overseas: f32,
|
||||||
|
pub chain_bonus: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct AttackFull<'a> {
|
||||||
|
pub code: &'a str,
|
||||||
|
#[serde(with = "ts_seconds")]
|
||||||
|
pub timestamp_started: DateTime<Utc>,
|
||||||
|
#[serde(with = "ts_seconds")]
|
||||||
|
pub timestamp_ended: DateTime<Utc>,
|
||||||
|
|
||||||
|
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
||||||
|
pub attacker_id: Option<i32>,
|
||||||
|
#[serde(deserialize_with = "de_util::empty_string_is_none")]
|
||||||
|
pub attacker_name: Option<&'a str>,
|
||||||
|
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
||||||
|
pub attacker_faction: Option<i32>,
|
||||||
|
#[serde(
|
||||||
|
deserialize_with = "de_util::empty_string_is_none",
|
||||||
|
rename = "attacker_factionname"
|
||||||
|
)]
|
||||||
|
pub attacker_faction_name: Option<&'a str>,
|
||||||
|
|
||||||
|
pub defender_id: i32,
|
||||||
|
pub defender_name: &'a str,
|
||||||
|
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
||||||
|
pub defender_faction: Option<i32>,
|
||||||
|
#[serde(
|
||||||
|
deserialize_with = "de_util::empty_string_is_none",
|
||||||
|
rename = "defender_factionname"
|
||||||
|
)]
|
||||||
|
pub defender_faction_name: Option<&'a str>,
|
||||||
|
|
||||||
|
pub result: AttackResult,
|
||||||
|
|
||||||
|
#[serde(deserialize_with = "de_util::int_is_bool")]
|
||||||
|
pub stealthed: bool,
|
||||||
|
#[serde(deserialize_with = "de_util::int_is_bool")]
|
||||||
|
pub raid: bool,
|
||||||
|
#[serde(deserialize_with = "de_util::int_is_bool")]
|
||||||
|
pub ranked_war: bool,
|
||||||
|
|
||||||
|
#[cfg(feature = "decimal")]
|
||||||
|
pub respect: rust_decimal::Decimal,
|
||||||
|
#[cfg(feature = "decimal")]
|
||||||
|
pub respect_loss: rust_decimal::Decimal,
|
||||||
|
|
||||||
|
#[cfg(not(feature = "decimal"))]
|
||||||
|
pub respect: f32,
|
||||||
|
#[cfg(not(feature = "decimal"))]
|
||||||
|
pub respect_loss: f32,
|
||||||
|
|
||||||
|
pub modifiers: RespectModifiers,
|
||||||
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
|
|
||||||
use std::collections::BTreeMap;
|
use std::collections::{BTreeMap, HashMap};
|
||||||
|
|
||||||
use chrono::{DateTime, NaiveDateTime, Utc};
|
use chrono::{serde::ts_nanoseconds::deserialize, DateTime, NaiveDateTime, Utc};
|
||||||
use serde::de::{Deserialize, Deserializer, Error, Unexpected, Visitor};
|
use serde::de::{Deserialize, Deserializer, Error, Unexpected, Visitor};
|
||||||
|
|
||||||
pub(crate) fn empty_string_is_none<'de, D>(deserializer: D) -> Result<Option<&'de str>, D::Error>
|
pub(crate) fn empty_string_is_none<'de, D>(deserializer: D) -> Result<Option<&'de str>, D::Error>
|
||||||
|
@ -135,6 +135,62 @@ where
|
||||||
deserializer.deserialize_map(MapVisitor)
|
deserializer.deserialize_map(MapVisitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn empty_dict_is_empty_array<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
T: Deserialize<'de>,
|
||||||
|
{
|
||||||
|
struct ArrayVisitor<T>(std::marker::PhantomData<T>);
|
||||||
|
|
||||||
|
impl<'de, T> Visitor<'de> for ArrayVisitor<T>
|
||||||
|
where
|
||||||
|
T: Deserialize<'de>,
|
||||||
|
{
|
||||||
|
type Value = Vec<T>;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
write!(formatter, "vec or empty object")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: serde::de::MapAccess<'de>,
|
||||||
|
{
|
||||||
|
match map.size_hint() {
|
||||||
|
Some(0) | None => Ok(Vec::default()),
|
||||||
|
Some(len) => Err(A::Error::invalid_length(len, &"empty dict")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: serde::de::SeqAccess<'de>,
|
||||||
|
{
|
||||||
|
let mut result = match seq.size_hint() {
|
||||||
|
Some(len) => Vec::with_capacity(len),
|
||||||
|
None => Vec::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
while let Some(element) = seq.next_element()? {
|
||||||
|
result.push(element);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deserializer.deserialize_any(ArrayVisitor(std::marker::PhantomData))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn null_is_empty_dict<'de, D, K, V>(deserializer: D) -> Result<HashMap<K, V>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
K: std::hash::Hash + std::cmp::Eq + Deserialize<'de>,
|
||||||
|
V: Deserialize<'de>,
|
||||||
|
{
|
||||||
|
Ok(Option::deserialize(deserializer)?.unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "decimal")]
|
#[cfg(feature = "decimal")]
|
||||||
pub(crate) fn string_or_decimal<'de, D>(deserializer: D) -> Result<rust_decimal::Decimal, D::Error>
|
pub(crate) fn string_or_decimal<'de, D>(deserializer: D) -> Result<rust_decimal::Decimal, D::Error>
|
||||||
where
|
where
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
use std::collections::{BTreeMap, HashMap};
|
use std::collections::{BTreeMap, HashMap};
|
||||||
|
|
||||||
use chrono::{serde::ts_seconds, DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use torn_api_macros::ApiCategory;
|
use torn_api_macros::ApiCategory;
|
||||||
|
|
||||||
use crate::de_util;
|
use crate::de_util::{self, null_is_empty_dict};
|
||||||
|
|
||||||
pub use crate::common::{LastAction, Status, Territory};
|
pub use crate::common::{Attack, AttackFull, LastAction, Status, Territory};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, ApiCategory)]
|
#[derive(Debug, Clone, Copy, ApiCategory)]
|
||||||
#[api(category = "faction")]
|
#[api(category = "faction")]
|
||||||
|
@ -21,7 +21,11 @@ pub enum Selection {
|
||||||
#[api(type = "BTreeMap<i32, AttackFull>", field = "attacks")]
|
#[api(type = "BTreeMap<i32, AttackFull>", field = "attacks")]
|
||||||
Attacks,
|
Attacks,
|
||||||
|
|
||||||
#[api(type = "HashMap<String, Territory>", field = "territory")]
|
#[api(
|
||||||
|
type = "HashMap<String, Territory>",
|
||||||
|
field = "territory",
|
||||||
|
with = "null_is_empty_dict"
|
||||||
|
)]
|
||||||
Territory,
|
Territory,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,115 +72,10 @@ pub struct Basic<'a> {
|
||||||
#[serde(deserialize_with = "de_util::datetime_map")]
|
#[serde(deserialize_with = "de_util::datetime_map")]
|
||||||
pub peace: BTreeMap<i32, DateTime<Utc>>,
|
pub peace: BTreeMap<i32, DateTime<Utc>>,
|
||||||
|
|
||||||
#[serde(borrow)]
|
#[serde(borrow, deserialize_with = "de_util::empty_dict_is_empty_array")]
|
||||||
pub territory_wars: Vec<FactionTerritoryWar<'a>>,
|
pub territory_wars: Vec<FactionTerritoryWar<'a>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Deserialize)]
|
|
||||||
pub enum AttackResult {
|
|
||||||
Attacked,
|
|
||||||
Mugged,
|
|
||||||
Hospitalized,
|
|
||||||
Lost,
|
|
||||||
Arrested,
|
|
||||||
Escape,
|
|
||||||
Interrupted,
|
|
||||||
Assist,
|
|
||||||
Timeout,
|
|
||||||
Stalemate,
|
|
||||||
Special,
|
|
||||||
Looted,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct Attack<'a> {
|
|
||||||
pub code: &'a str,
|
|
||||||
#[serde(with = "ts_seconds")]
|
|
||||||
pub timestamp_started: DateTime<Utc>,
|
|
||||||
#[serde(with = "ts_seconds")]
|
|
||||||
pub timestamp_ended: DateTime<Utc>,
|
|
||||||
|
|
||||||
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
|
||||||
pub attacker_id: Option<i32>,
|
|
||||||
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
|
||||||
pub attacker_faction: Option<i32>,
|
|
||||||
pub defender_id: i32,
|
|
||||||
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
|
||||||
pub defender_faction: Option<i32>,
|
|
||||||
pub result: AttackResult,
|
|
||||||
|
|
||||||
#[serde(deserialize_with = "de_util::int_is_bool")]
|
|
||||||
pub stealthed: bool,
|
|
||||||
|
|
||||||
#[cfg(feature = "decimal")]
|
|
||||||
pub respect: rust_decimal::Decimal,
|
|
||||||
|
|
||||||
#[cfg(not(feature = "decimal"))]
|
|
||||||
pub respect: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct RespectModifiers {
|
|
||||||
pub fair_fight: f32,
|
|
||||||
pub war: f32,
|
|
||||||
pub retaliation: f32,
|
|
||||||
pub group_attack: f32,
|
|
||||||
pub overseas: f32,
|
|
||||||
pub chain_bonus: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct AttackFull<'a> {
|
|
||||||
pub code: &'a str,
|
|
||||||
#[serde(with = "ts_seconds")]
|
|
||||||
pub timestamp_started: DateTime<Utc>,
|
|
||||||
#[serde(with = "ts_seconds")]
|
|
||||||
pub timestamp_ended: DateTime<Utc>,
|
|
||||||
|
|
||||||
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
|
||||||
pub attacker_id: Option<i32>,
|
|
||||||
#[serde(deserialize_with = "de_util::empty_string_is_none")]
|
|
||||||
pub attacker_name: Option<&'a str>,
|
|
||||||
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
|
||||||
pub attacker_faction: Option<i32>,
|
|
||||||
#[serde(
|
|
||||||
deserialize_with = "de_util::empty_string_is_none",
|
|
||||||
rename = "attacker_factionname"
|
|
||||||
)]
|
|
||||||
pub attacker_faction_name: Option<&'a str>,
|
|
||||||
|
|
||||||
pub defender_id: i32,
|
|
||||||
pub defender_name: &'a str,
|
|
||||||
#[serde(deserialize_with = "de_util::empty_string_int_option")]
|
|
||||||
pub defender_faction: Option<i32>,
|
|
||||||
#[serde(
|
|
||||||
deserialize_with = "de_util::empty_string_is_none",
|
|
||||||
rename = "defender_factionname"
|
|
||||||
)]
|
|
||||||
pub defender_faction_name: Option<&'a str>,
|
|
||||||
|
|
||||||
pub result: AttackResult,
|
|
||||||
|
|
||||||
#[serde(deserialize_with = "de_util::int_is_bool")]
|
|
||||||
pub stealthed: bool,
|
|
||||||
#[serde(deserialize_with = "de_util::int_is_bool")]
|
|
||||||
pub raid: bool,
|
|
||||||
#[serde(deserialize_with = "de_util::int_is_bool")]
|
|
||||||
pub ranked_war: bool,
|
|
||||||
|
|
||||||
#[cfg(feature = "decimal")]
|
|
||||||
pub respect: rust_decimal::Decimal,
|
|
||||||
#[cfg(feature = "decimal")]
|
|
||||||
pub respect_loss: rust_decimal::Decimal,
|
|
||||||
|
|
||||||
#[cfg(not(feature = "decimal"))]
|
|
||||||
pub respect: f32,
|
|
||||||
#[cfg(not(feature = "decimal"))]
|
|
||||||
pub respect_loss: f32,
|
|
||||||
|
|
||||||
pub modifiers: RespectModifiers,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -199,4 +98,21 @@ mod tests {
|
||||||
response.attacks_full().unwrap();
|
response.attacks_full().unwrap();
|
||||||
response.territory().unwrap();
|
response.territory().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_test]
|
||||||
|
async fn destroyed_faction() {
|
||||||
|
let key = setup();
|
||||||
|
|
||||||
|
let response = Client::default()
|
||||||
|
.torn_api(key)
|
||||||
|
.faction(|b| {
|
||||||
|
b.id(8981)
|
||||||
|
.selections(&[Selection::Basic, Selection::Territory])
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
response.basic().unwrap();
|
||||||
|
response.territory().unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,6 +122,8 @@ pub enum FactionSelection {
|
||||||
Upgrades,
|
Upgrades,
|
||||||
Weapons,
|
Weapons,
|
||||||
Lookup,
|
Lookup,
|
||||||
|
Caches,
|
||||||
|
CrimeExp,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
|
@ -171,6 +173,8 @@ pub enum TornSelection {
|
||||||
Timestamp,
|
Timestamp,
|
||||||
Lookup,
|
Lookup,
|
||||||
CityShops,
|
CityShops,
|
||||||
|
ItemDetails,
|
||||||
|
TerritoryNames,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
|
|
|
@ -22,7 +22,7 @@ pub mod awc;
|
||||||
pub mod reqwest;
|
pub mod reqwest;
|
||||||
|
|
||||||
#[cfg(feature = "__common")]
|
#[cfg(feature = "__common")]
|
||||||
mod common;
|
pub mod common;
|
||||||
|
|
||||||
mod de_util;
|
mod de_util;
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,12 @@ pub enum Selection {
|
||||||
|
|
||||||
#[api(type = "HashMap<String, TerritoryWar>", field = "territorywars")]
|
#[api(type = "HashMap<String, TerritoryWar>", field = "territorywars")]
|
||||||
TerritoryWars,
|
TerritoryWars,
|
||||||
|
|
||||||
|
#[api(type = "HashMap<String, Racket>", field = "rackets")]
|
||||||
|
Rackets,
|
||||||
|
|
||||||
|
#[api(type = "HashMap<String, Territory>", field = "territory")]
|
||||||
|
Territory,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
@ -111,6 +117,31 @@ pub struct TerritoryWar {
|
||||||
pub ends: DateTime<Utc>,
|
pub ends: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Racket {
|
||||||
|
pub name: String,
|
||||||
|
pub level: i16,
|
||||||
|
pub reward: String,
|
||||||
|
|
||||||
|
#[serde(with = "chrono::serde::ts_seconds")]
|
||||||
|
pub created: DateTime<Utc>,
|
||||||
|
#[serde(with = "chrono::serde::ts_seconds")]
|
||||||
|
pub changed: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Territory {
|
||||||
|
pub sector: i16,
|
||||||
|
pub size: i16,
|
||||||
|
pub slots: i16,
|
||||||
|
pub daily_respect: i16,
|
||||||
|
pub faction: i32,
|
||||||
|
|
||||||
|
pub neighbors: Vec<String>,
|
||||||
|
pub war: Option<TerritoryWar>,
|
||||||
|
pub racket: Option<Racket>,
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -122,11 +153,32 @@ mod tests {
|
||||||
|
|
||||||
let response = Client::default()
|
let response = Client::default()
|
||||||
.torn_api(key)
|
.torn_api(key)
|
||||||
.torn(|b| b.selections(&[Selection::Competition, Selection::TerritoryWars]))
|
.torn(|b| {
|
||||||
|
b.selections(&[
|
||||||
|
Selection::Competition,
|
||||||
|
Selection::TerritoryWars,
|
||||||
|
Selection::Rackets,
|
||||||
|
])
|
||||||
|
})
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
response.competition().unwrap();
|
response.competition().unwrap();
|
||||||
response.territory_wars().unwrap();
|
response.territory_wars().unwrap();
|
||||||
|
response.rackets().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_test]
|
||||||
|
async fn territory() {
|
||||||
|
let key = setup();
|
||||||
|
|
||||||
|
let response = Client::default()
|
||||||
|
.torn_api(key)
|
||||||
|
.torn(|b| b.selections(&[Selection::Territory]).id("NSC"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let territory = response.territory().unwrap();
|
||||||
|
assert!(territory.contains_key("NSC"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,12 +2,13 @@ use serde::{
|
||||||
de::{self, MapAccess, Visitor},
|
de::{self, MapAccess, Visitor},
|
||||||
Deserialize, Deserializer,
|
Deserialize, Deserializer,
|
||||||
};
|
};
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use torn_api_macros::ApiCategory;
|
use torn_api_macros::ApiCategory;
|
||||||
|
|
||||||
use crate::de_util;
|
use crate::de_util;
|
||||||
|
|
||||||
pub use crate::common::{LastAction, Status};
|
pub use crate::common::{Attack, AttackFull, LastAction, Status};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, ApiCategory)]
|
#[derive(Debug, Clone, Copy, ApiCategory)]
|
||||||
#[api(category = "user")]
|
#[api(category = "user")]
|
||||||
|
@ -22,6 +23,10 @@ pub enum Selection {
|
||||||
PersonalStats,
|
PersonalStats,
|
||||||
#[api(type = "CriminalRecord", field = "criminalrecord")]
|
#[api(type = "CriminalRecord", field = "criminalrecord")]
|
||||||
Crimes,
|
Crimes,
|
||||||
|
#[api(type = "BTreeMap<i32, Attack>", field = "attacks")]
|
||||||
|
AttacksFull,
|
||||||
|
#[api(type = "BTreeMap<i32, AttackFull>", field = "attacks")]
|
||||||
|
Attacks,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||||
|
@ -168,10 +173,14 @@ pub enum EliminationTeam {
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Competition {
|
pub enum Competition {
|
||||||
Elimination {
|
Elimination {
|
||||||
score: i16,
|
score: i32,
|
||||||
attacks: i16,
|
attacks: i16,
|
||||||
team: EliminationTeam,
|
team: EliminationTeam,
|
||||||
},
|
},
|
||||||
|
DogTags {
|
||||||
|
score: i32,
|
||||||
|
position: Option<i32>,
|
||||||
|
},
|
||||||
Unknown,
|
Unknown,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,6 +196,7 @@ where
|
||||||
Team,
|
Team,
|
||||||
Attacks,
|
Attacks,
|
||||||
TeamName,
|
TeamName,
|
||||||
|
Position,
|
||||||
#[serde(other)]
|
#[serde(other)]
|
||||||
Ignore,
|
Ignore,
|
||||||
}
|
}
|
||||||
|
@ -194,6 +204,8 @@ where
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
enum CompetitionName {
|
enum CompetitionName {
|
||||||
Elimination,
|
Elimination,
|
||||||
|
#[serde(rename = "Dog Tags")]
|
||||||
|
DogTags,
|
||||||
#[serde(other)]
|
#[serde(other)]
|
||||||
Unknown,
|
Unknown,
|
||||||
}
|
}
|
||||||
|
@ -229,6 +241,7 @@ where
|
||||||
let mut score = None;
|
let mut score = None;
|
||||||
let mut attacks = None;
|
let mut attacks = None;
|
||||||
let mut name = None;
|
let mut name = None;
|
||||||
|
let mut position = None;
|
||||||
|
|
||||||
while let Some(key) = map.next_key()? {
|
while let Some(key) = map.next_key()? {
|
||||||
match key {
|
match key {
|
||||||
|
@ -241,6 +254,9 @@ where
|
||||||
Field::Attacks => {
|
Field::Attacks => {
|
||||||
attacks = Some(map.next_value()?);
|
attacks = Some(map.next_value()?);
|
||||||
}
|
}
|
||||||
|
Field::Position => {
|
||||||
|
position = Some(map.next_value()?);
|
||||||
|
}
|
||||||
Field::Team => {
|
Field::Team => {
|
||||||
let team_raw: &str = map.next_value()?;
|
let team_raw: &str = map.next_value()?;
|
||||||
team = if team_raw.is_empty() {
|
team = if team_raw.is_empty() {
|
||||||
|
@ -299,6 +315,12 @@ where
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
CompetitionName::DogTags => {
|
||||||
|
let score = score.ok_or_else(|| de::Error::missing_field("score"))?;
|
||||||
|
let position = position.ok_or_else(|| de::Error::missing_field("position"))?;
|
||||||
|
|
||||||
|
Ok(Some(Competition::DogTags { score, position }))
|
||||||
|
}
|
||||||
CompetitionName::Unknown => Ok(Some(Competition::Unknown)),
|
CompetitionName::Unknown => Ok(Some(Competition::Unknown)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -416,6 +438,7 @@ mod tests {
|
||||||
Selection::Profile,
|
Selection::Profile,
|
||||||
Selection::PersonalStats,
|
Selection::PersonalStats,
|
||||||
Selection::Crimes,
|
Selection::Crimes,
|
||||||
|
Selection::Attacks,
|
||||||
])
|
])
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
|
@ -426,6 +449,8 @@ mod tests {
|
||||||
response.profile().unwrap();
|
response.profile().unwrap();
|
||||||
response.personal_stats().unwrap();
|
response.personal_stats().unwrap();
|
||||||
response.crimes().unwrap();
|
response.crimes().unwrap();
|
||||||
|
response.attacks().unwrap();
|
||||||
|
response.attacks_full().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_test]
|
#[async_test]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "torn-key-pool"
|
name = "torn-key-pool"
|
||||||
version = "0.5.7"
|
version = "0.6.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Pyrit [2111649]"]
|
authors = ["Pyrit [2111649]"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|
|
@ -29,8 +29,8 @@ where
|
||||||
Response(ResponseError),
|
Response(ResponseError),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait ApiKey: Sync + Send {
|
pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone {
|
||||||
type IdType: PartialEq + Eq + std::hash::Hash;
|
type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone;
|
||||||
|
|
||||||
fn value(&self) -> &str;
|
fn value(&self) -> &str;
|
||||||
|
|
||||||
|
@ -44,12 +44,65 @@ pub trait KeyDomain: Clone + std::fmt::Debug + Send + Sync {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum KeySelector<K>
|
pub enum KeySelector<K, D>
|
||||||
where
|
where
|
||||||
K: ApiKey,
|
K: ApiKey,
|
||||||
|
D: KeyDomain,
|
||||||
{
|
{
|
||||||
Key(String),
|
Key(String),
|
||||||
Id(K::IdType),
|
Id(K::IdType),
|
||||||
|
UserId(i32),
|
||||||
|
Has(D),
|
||||||
|
OneOf(Vec<D>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<K, D> KeySelector<K, D>
|
||||||
|
where
|
||||||
|
K: ApiKey,
|
||||||
|
D: KeyDomain,
|
||||||
|
{
|
||||||
|
pub(crate) fn fallback(&self) -> Option<Self> {
|
||||||
|
match self {
|
||||||
|
Self::Key(_) | Self::UserId(_) | Self::Id(_) => None,
|
||||||
|
Self::Has(domain) => domain.fallback().map(Self::Has),
|
||||||
|
Self::OneOf(domains) => {
|
||||||
|
let fallbacks: Vec<_> = domains.iter().filter_map(|d| d.fallback()).collect();
|
||||||
|
if fallbacks.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(Self::OneOf(fallbacks))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait IntoSelector<K, D>: Send + Sync
|
||||||
|
where
|
||||||
|
K: ApiKey,
|
||||||
|
D: KeyDomain,
|
||||||
|
{
|
||||||
|
fn into_selector(self) -> KeySelector<K, D>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<K, D> IntoSelector<K, D> for D
|
||||||
|
where
|
||||||
|
K: ApiKey,
|
||||||
|
D: KeyDomain,
|
||||||
|
{
|
||||||
|
fn into_selector(self) -> KeySelector<K, D> {
|
||||||
|
KeySelector::Has(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<K, D> IntoSelector<K, D> for KeySelector<K, D>
|
||||||
|
where
|
||||||
|
K: ApiKey,
|
||||||
|
D: KeyDomain,
|
||||||
|
{
|
||||||
|
fn into_selector(self) -> KeySelector<K, D> {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
@ -58,13 +111,17 @@ pub trait KeyPoolStorage {
|
||||||
type Domain: KeyDomain;
|
type Domain: KeyDomain;
|
||||||
type Error: std::error::Error + Sync + Send;
|
type Error: std::error::Error + Sync + Send;
|
||||||
|
|
||||||
async fn acquire_key(&self, domain: Self::Domain) -> Result<Self::Key, Self::Error>;
|
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||||
|
where
|
||||||
|
S: IntoSelector<Self::Key, Self::Domain>;
|
||||||
|
|
||||||
async fn acquire_many_keys(
|
async fn acquire_many_keys<S>(
|
||||||
&self,
|
&self,
|
||||||
domain: Self::Domain,
|
selector: S,
|
||||||
number: i64,
|
number: i64,
|
||||||
) -> Result<Vec<Self::Key>, Self::Error>;
|
) -> Result<Vec<Self::Key>, Self::Error>
|
||||||
|
where
|
||||||
|
S: IntoSelector<Self::Key, Self::Domain>;
|
||||||
|
|
||||||
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
|
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>;
|
||||||
|
|
||||||
|
@ -75,34 +132,41 @@ pub trait KeyPoolStorage {
|
||||||
domains: Vec<Self::Domain>,
|
domains: Vec<Self::Domain>,
|
||||||
) -> Result<Self::Key, Self::Error>;
|
) -> Result<Self::Key, Self::Error>;
|
||||||
|
|
||||||
async fn read_key(&self, key: KeySelector<Self::Key>)
|
async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
|
||||||
-> Result<Option<Self::Key>, Self::Error>;
|
where
|
||||||
|
S: IntoSelector<Self::Key, Self::Domain>;
|
||||||
|
|
||||||
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error>;
|
async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
|
||||||
|
where
|
||||||
|
S: IntoSelector<Self::Key, Self::Domain>;
|
||||||
|
|
||||||
async fn remove_key(&self, key: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error>;
|
async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||||
|
where
|
||||||
|
S: IntoSelector<Self::Key, Self::Domain>;
|
||||||
|
|
||||||
async fn query_key(&self, domain: Self::Domain) -> Result<Option<Self::Key>, Self::Error>;
|
async fn add_domain_to_key<S>(
|
||||||
|
|
||||||
async fn query_all(&self, domain: Self::Domain) -> Result<Vec<Self::Key>, Self::Error>;
|
|
||||||
|
|
||||||
async fn add_domain_to_key(
|
|
||||||
&self,
|
&self,
|
||||||
key: KeySelector<Self::Key>,
|
selector: S,
|
||||||
domain: Self::Domain,
|
domain: Self::Domain,
|
||||||
) -> Result<Self::Key, Self::Error>;
|
) -> Result<Self::Key, Self::Error>
|
||||||
|
where
|
||||||
|
S: IntoSelector<Self::Key, Self::Domain>;
|
||||||
|
|
||||||
async fn remove_domain_from_key(
|
async fn remove_domain_from_key<S>(
|
||||||
&self,
|
&self,
|
||||||
key: KeySelector<Self::Key>,
|
selector: S,
|
||||||
domain: Self::Domain,
|
domain: Self::Domain,
|
||||||
) -> Result<Self::Key, Self::Error>;
|
) -> Result<Self::Key, Self::Error>
|
||||||
|
where
|
||||||
|
S: IntoSelector<Self::Key, Self::Domain>;
|
||||||
|
|
||||||
async fn set_domains_for_key(
|
async fn set_domains_for_key<S>(
|
||||||
&self,
|
&self,
|
||||||
key: KeySelector<Self::Key>,
|
selector: S,
|
||||||
domains: Vec<Self::Domain>,
|
domains: Vec<Self::Domain>,
|
||||||
) -> Result<Self::Key, Self::Error>;
|
) -> Result<Self::Key, Self::Error>
|
||||||
|
where
|
||||||
|
S: IntoSelector<Self::Key, Self::Domain>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -112,7 +176,7 @@ where
|
||||||
{
|
{
|
||||||
storage: &'a S,
|
storage: &'a S,
|
||||||
comment: Option<&'a str>,
|
comment: Option<&'a str>,
|
||||||
domain: S::Domain,
|
selector: KeySelector<S::Key, S::Domain>,
|
||||||
_marker: std::marker::PhantomData<C>,
|
_marker: std::marker::PhantomData<C>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,10 +184,14 @@ impl<'a, C, S> KeyPoolExecutor<'a, C, S>
|
||||||
where
|
where
|
||||||
S: KeyPoolStorage,
|
S: KeyPoolStorage,
|
||||||
{
|
{
|
||||||
pub fn new(storage: &'a S, domain: S::Domain, comment: Option<&'a str>) -> Self {
|
pub fn new(
|
||||||
|
storage: &'a S,
|
||||||
|
selector: KeySelector<S::Key, S::Domain>,
|
||||||
|
comment: Option<&'a str>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
storage,
|
storage,
|
||||||
domain,
|
selector,
|
||||||
comment,
|
comment,
|
||||||
_marker: std::marker::PhantomData,
|
_marker: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ use torn_api::{
|
||||||
ApiRequest, ApiResponse, ApiSelection, ResponseError,
|
ApiRequest, ApiResponse, ApiSelection, ResponseError,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
|
use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, IntoSelector};
|
||||||
|
|
||||||
#[async_trait(?Send)]
|
#[async_trait(?Send)]
|
||||||
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
||||||
|
@ -30,7 +30,7 @@ where
|
||||||
loop {
|
loop {
|
||||||
let key = self
|
let key = self
|
||||||
.storage
|
.storage
|
||||||
.acquire_key(self.domain.clone())
|
.acquire_key(self.selector.clone())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
|
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
|
||||||
let url = request.url(key.value(), id.as_deref());
|
let url = request.url(key.value(), id.as_deref());
|
||||||
|
@ -66,7 +66,7 @@ where
|
||||||
{
|
{
|
||||||
let keys = match self
|
let keys = match self
|
||||||
.storage
|
.storage
|
||||||
.acquire_many_keys(self.domain.clone(), ids.len() as i64)
|
.acquire_many_keys(self.selector.clone(), ids.len() as i64)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(keys) => keys,
|
Ok(keys) => keys,
|
||||||
|
@ -114,7 +114,7 @@ where
|
||||||
Ok(res) => return (id, Ok(res)),
|
Ok(res) => return (id, Ok(res)),
|
||||||
};
|
};
|
||||||
|
|
||||||
key = match self.storage.acquire_key(self.domain.clone()).await {
|
key = match self.storage.acquire_key(self.selector.clone()).await {
|
||||||
Ok(k) => k,
|
Ok(k) => k,
|
||||||
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
|
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
|
||||||
};
|
};
|
||||||
|
@ -150,25 +150,26 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn torn_api(&self, domain: S::Domain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
|
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>> where I: IntoSelector<S::Key, S::Domain> {
|
||||||
ApiProvider::new(
|
ApiProvider::new(
|
||||||
&self.client,
|
&self.client,
|
||||||
KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()),
|
KeyPoolExecutor::new(&self.storage, selector.into_selector(), self.comment.as_deref()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait WithStorage {
|
pub trait WithStorage {
|
||||||
fn with_storage<'a, S>(
|
fn with_storage<'a, S, I>(
|
||||||
&'a self,
|
&'a self,
|
||||||
storage: &'a S,
|
storage: &'a S,
|
||||||
domain: S::Domain,
|
selector: I
|
||||||
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
|
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
|
||||||
where
|
where
|
||||||
Self: ApiClient + Sized,
|
Self: ApiClient + Sized,
|
||||||
S: KeyPoolStorage + 'static,
|
S: KeyPoolStorage + 'static,
|
||||||
|
I: IntoSelector<S::Key, S::Domain>
|
||||||
{
|
{
|
||||||
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None))
|
ApiProvider::new(self, KeyPoolExecutor::new(storage, selector.into_selector(), None))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use sqlx::{FromRow, PgPool};
|
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::{ApiKey, KeyDomain, KeyPoolStorage, KeySelector};
|
use crate::{ApiKey, IntoSelector, KeyDomain, KeyPoolStorage, KeySelector};
|
||||||
|
|
||||||
pub trait PgKeyDomain:
|
pub trait PgKeyDomain:
|
||||||
KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
|
KeyDomain + serde::Serialize + serde::de::DeserializeOwned + Eq + Unpin
|
||||||
|
@ -24,10 +24,10 @@ where
|
||||||
Pg(#[from] sqlx::Error),
|
Pg(#[from] sqlx::Error),
|
||||||
|
|
||||||
#[error("No key avalaible for domain {0:?}")]
|
#[error("No key avalaible for domain {0:?}")]
|
||||||
Unavailable(D),
|
Unavailable(KeySelector<PgKey<D>, D>),
|
||||||
|
|
||||||
#[error("Key not found: '{0:?}'")]
|
#[error("Key not found: '{0:?}'")]
|
||||||
KeyNotFound(KeySelector<PgKey<D>>),
|
KeyNotFound(KeySelector<PgKey<D>, D>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, FromRow)]
|
#[derive(Debug, Clone, FromRow)]
|
||||||
|
@ -42,6 +42,41 @@ where
|
||||||
pub domains: sqlx::types::Json<Vec<D>>,
|
pub domains: sqlx::types::Json<Vec<D>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn build_predicate<'b, D>(
|
||||||
|
builder: &mut QueryBuilder<'b, Postgres>,
|
||||||
|
selector: &'b KeySelector<PgKey<D>, D>,
|
||||||
|
) where
|
||||||
|
D: PgKeyDomain,
|
||||||
|
{
|
||||||
|
match selector {
|
||||||
|
KeySelector::Id(id) => builder.push("id=").push_bind(id),
|
||||||
|
KeySelector::UserId(user_id) => builder.push("user_id=").push_bind(user_id),
|
||||||
|
KeySelector::Key(key) => builder.push("key=").push_bind(key),
|
||||||
|
KeySelector::Has(domain) => builder
|
||||||
|
.push("domains @> ")
|
||||||
|
.push_bind(sqlx::types::Json(vec![domain])),
|
||||||
|
KeySelector::OneOf(domains) => {
|
||||||
|
if domains.is_empty() {
|
||||||
|
builder.push("false");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (idx, domain) in domains.iter().enumerate() {
|
||||||
|
if idx == 0 {
|
||||||
|
builder.push("(");
|
||||||
|
} else {
|
||||||
|
builder.push(" or ");
|
||||||
|
}
|
||||||
|
builder
|
||||||
|
.push("domains @> ")
|
||||||
|
.push_bind(sqlx::types::Json(vec![domain]));
|
||||||
|
}
|
||||||
|
builder.push(")")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, FromRow)]
|
#[derive(Debug, Clone, FromRow)]
|
||||||
pub struct PgKeyPoolStorage<D>
|
pub struct PgKeyPoolStorage<D>
|
||||||
where
|
where
|
||||||
|
@ -160,7 +195,11 @@ where
|
||||||
|
|
||||||
type Error = PgStorageError<D>;
|
type Error = PgStorageError<D>;
|
||||||
|
|
||||||
async fn acquire_key(&self, domain: D) -> Result<Self::Key, Self::Error> {
|
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||||
|
where
|
||||||
|
S: IntoSelector<Self::Key, Self::Domain>,
|
||||||
|
{
|
||||||
|
let selector = selector.into_selector();
|
||||||
loop {
|
loop {
|
||||||
let attempt = async {
|
let attempt = async {
|
||||||
let mut tx = self.pool.begin().await?;
|
let mut tx = self.pool.begin().await?;
|
||||||
|
@ -169,22 +208,33 @@ where
|
||||||
.execute(&mut tx)
|
.execute(&mut tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// TODO: improve query
|
let mut qb = QueryBuilder::new(indoc::indoc! {
|
||||||
let key = sqlx::query_as(&indoc::formatdoc!(
|
|
||||||
r#"
|
r#"
|
||||||
with key as (
|
with key as (
|
||||||
select
|
select
|
||||||
id,
|
id,
|
||||||
0::int2 as uses
|
0::int2 as uses
|
||||||
from api_keys where last_used < date_trunc('minute', now())
|
from api_keys where last_used < date_trunc('minute', now())
|
||||||
and (cooldown is null or now() >= cooldown)
|
and (cooldown is null or now() >= cooldown)
|
||||||
and domains @> $1
|
and "#
|
||||||
union (
|
});
|
||||||
|
|
||||||
|
build_predicate(&mut qb, &selector);
|
||||||
|
|
||||||
|
qb.push(indoc::indoc! {
|
||||||
|
"
|
||||||
|
\n union (
|
||||||
select id, uses from api_keys
|
select id, uses from api_keys
|
||||||
where last_used >= date_trunc('minute', now())
|
where last_used >= date_trunc('minute', now())
|
||||||
and (cooldown is null or now() >= cooldown)
|
and (cooldown is null or now() >= cooldown)
|
||||||
and domains @> $1
|
and "
|
||||||
order by uses asc limit 1
|
});
|
||||||
|
|
||||||
|
build_predicate(&mut qb, &selector);
|
||||||
|
|
||||||
|
qb.push(indoc::indoc! {
|
||||||
|
"
|
||||||
|
\n order by uses asc limit 1
|
||||||
)
|
)
|
||||||
order by uses asc limit 1
|
order by uses asc limit 1
|
||||||
)
|
)
|
||||||
|
@ -194,19 +244,21 @@ where
|
||||||
flag = null,
|
flag = null,
|
||||||
last_used = now()
|
last_used = now()
|
||||||
from key where
|
from key where
|
||||||
api_keys.id=key.id and key.uses < $2
|
api_keys.id=key.id and key.uses < "
|
||||||
returning
|
});
|
||||||
|
|
||||||
|
qb.push_bind(self.limit);
|
||||||
|
|
||||||
|
qb.push(indoc::indoc! { "
|
||||||
|
\nreturning
|
||||||
api_keys.id,
|
api_keys.id,
|
||||||
api_keys.user_id,
|
api_keys.user_id,
|
||||||
api_keys.key,
|
api_keys.key,
|
||||||
api_keys.uses,
|
api_keys.uses,
|
||||||
api_keys.domains
|
api_keys.domains"
|
||||||
"#,
|
});
|
||||||
))
|
|
||||||
.bind(sqlx::types::Json(vec![&domain]))
|
let key = qb.build_query_as().fetch_optional(&mut tx).await?;
|
||||||
.bind(self.limit)
|
|
||||||
.fetch_optional(&mut tx)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
tx.commit().await?;
|
tx.commit().await?;
|
||||||
|
|
||||||
|
@ -219,9 +271,9 @@ where
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
return self
|
return self
|
||||||
.acquire_key(
|
.acquire_key(
|
||||||
domain
|
selector
|
||||||
.fallback()
|
.fallback()
|
||||||
.ok_or_else(|| PgStorageError::Unavailable(domain))?,
|
.ok_or_else(|| PgStorageError::Unavailable(selector))?,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
@ -241,11 +293,15 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn acquire_many_keys(
|
async fn acquire_many_keys<S>(
|
||||||
&self,
|
&self,
|
||||||
domain: D,
|
selector: S,
|
||||||
number: i64,
|
number: i64,
|
||||||
) -> Result<Vec<Self::Key>, Self::Error> {
|
) -> Result<Vec<Self::Key>, Self::Error>
|
||||||
|
where
|
||||||
|
S: IntoSelector<Self::Key, Self::Domain>,
|
||||||
|
{
|
||||||
|
let selector = selector.into_selector();
|
||||||
loop {
|
loop {
|
||||||
let attempt = async {
|
let attempt = async {
|
||||||
let mut tx = self.pool.begin().await?;
|
let mut tx = self.pool.begin().await?;
|
||||||
|
@ -254,33 +310,36 @@ where
|
||||||
.execute(&mut tx)
|
.execute(&mut tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut keys: Vec<Self::Key> = sqlx::query_as(&indoc::formatdoc!(
|
let mut qb = QueryBuilder::new(indoc::indoc! {
|
||||||
r#"select
|
r#"select
|
||||||
id,
|
id,
|
||||||
user_id,
|
user_id,
|
||||||
key,
|
key,
|
||||||
0::int2 as uses,
|
0::int2 as uses,
|
||||||
domains
|
domains
|
||||||
from api_keys where last_used < date_trunc('minute', now())
|
from api_keys where last_used < date_trunc('minute', now())
|
||||||
and (cooldown is null or now() >= cooldown)
|
and (cooldown is null or now() >= cooldown)
|
||||||
and domains @> $1
|
and "#
|
||||||
union
|
});
|
||||||
|
build_predicate(&mut qb, &selector);
|
||||||
|
qb.push(indoc::indoc! {
|
||||||
|
"
|
||||||
|
\nunion
|
||||||
select
|
select
|
||||||
id,
|
id,
|
||||||
user_id,
|
user_id,
|
||||||
key,
|
key,
|
||||||
uses,
|
uses,
|
||||||
domains
|
domains
|
||||||
from api_keys where last_used >= date_trunc('minute', now())
|
from api_keys where last_used >= date_trunc('minute', now())
|
||||||
and (cooldown is null or now() >= cooldown)
|
and (cooldown is null or now() >= cooldown)
|
||||||
and domains @> $1
|
and "
|
||||||
order by uses limit $2
|
});
|
||||||
"#,
|
build_predicate(&mut qb, &selector);
|
||||||
))
|
qb.push("\norder by uses limit ");
|
||||||
.bind(sqlx::types::Json(vec![&domain]))
|
qb.push_bind(self.limit);
|
||||||
.bind(number)
|
|
||||||
.fetch_all(&mut tx)
|
let mut keys: Vec<Self::Key> = qb.build_query_as().fetch_all(&mut tx).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
if keys.is_empty() {
|
if keys.is_empty() {
|
||||||
tx.commit().await?;
|
tx.commit().await?;
|
||||||
|
@ -338,9 +397,9 @@ where
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
return self
|
return self
|
||||||
.acquire_many_keys(
|
.acquire_many_keys(
|
||||||
domain
|
selector
|
||||||
.fallback()
|
.fallback()
|
||||||
.ok_or_else(|| Self::Error::Unavailable(domain))?,
|
.ok_or_else(|| Self::Error::Unavailable(selector))?,
|
||||||
number,
|
number,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
@ -433,143 +492,116 @@ where
|
||||||
.map_err(Into::into)
|
.map_err(Into::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn read_key(
|
async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error>
|
||||||
&self,
|
where
|
||||||
selector: KeySelector<Self::Key>,
|
S: IntoSelector<Self::Key, Self::Domain>,
|
||||||
) -> Result<Option<Self::Key>, Self::Error> {
|
{
|
||||||
match &selector {
|
let selector = selector.into_selector();
|
||||||
KeySelector::Key(key) => sqlx::query_as("select * from api_keys where key=$1")
|
|
||||||
.bind(key)
|
|
||||||
.fetch_optional(&self.pool)
|
|
||||||
.await
|
|
||||||
.map_err(Into::into),
|
|
||||||
KeySelector::Id(id) => sqlx::query_as("select * from api_keys where id=$1")
|
|
||||||
.bind(id)
|
|
||||||
.fetch_optional(&self.pool)
|
|
||||||
.await
|
|
||||||
.map_err(Into::into),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn query_key(&self, domain: D) -> Result<Option<Self::Key>, Self::Error> {
|
let mut qb = QueryBuilder::new("select * from api_keys where ");
|
||||||
sqlx::query_as("select * from api_keys where domains @> $1 limit 1")
|
build_predicate(&mut qb, &selector);
|
||||||
.bind(sqlx::types::Json(vec![domain]))
|
|
||||||
|
qb.build_query_as()
|
||||||
.fetch_optional(&self.pool)
|
.fetch_optional(&self.pool)
|
||||||
.await
|
.await
|
||||||
.map_err(Into::into)
|
.map_err(Into::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn query_all(&self, domain: D) -> Result<Vec<Self::Key>, Self::Error> {
|
async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error>
|
||||||
sqlx::query_as("select * from api_keys where domains @> $1")
|
where
|
||||||
.bind(sqlx::types::Json(vec![domain]))
|
S: IntoSelector<Self::Key, Self::Domain>,
|
||||||
|
{
|
||||||
|
let selector = selector.into_selector();
|
||||||
|
|
||||||
|
let mut qb = QueryBuilder::new("select * from api_keys where ");
|
||||||
|
build_predicate(&mut qb, &selector);
|
||||||
|
|
||||||
|
qb.build_query_as()
|
||||||
.fetch_all(&self.pool)
|
.fetch_all(&self.pool)
|
||||||
.await
|
.await
|
||||||
.map_err(Into::into)
|
.map_err(Into::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn read_user_keys(&self, user_id: i32) -> Result<Vec<Self::Key>, Self::Error> {
|
async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
|
||||||
sqlx::query_as("select * from api_keys where user_id=$1")
|
where
|
||||||
.bind(user_id)
|
S: IntoSelector<Self::Key, Self::Domain>,
|
||||||
.fetch_all(&self.pool)
|
{
|
||||||
.await
|
let selector = selector.into_selector();
|
||||||
.map_err(Into::into)
|
|
||||||
|
let mut qb = QueryBuilder::new("delete from api_keys where ");
|
||||||
|
build_predicate(&mut qb, &selector);
|
||||||
|
qb.push(" returning *");
|
||||||
|
|
||||||
|
qb.build_query_as()
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn remove_key(&self, selector: KeySelector<Self::Key>) -> Result<Self::Key, Self::Error> {
|
async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
|
||||||
match &selector {
|
where
|
||||||
KeySelector::Key(key) => {
|
S: IntoSelector<Self::Key, Self::Domain>,
|
||||||
sqlx::query_as("delete from api_keys where key=$1 returning *")
|
{
|
||||||
.bind(key)
|
let selector = selector.into_selector();
|
||||||
.fetch_optional(&self.pool)
|
|
||||||
.await?
|
let mut qb = QueryBuilder::new(
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
"update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(",
|
||||||
}
|
);
|
||||||
KeySelector::Id(id) => sqlx::query_as("delete from api_keys where id=$1 returning *")
|
qb.push_bind(sqlx::types::Json(domain));
|
||||||
.bind(id)
|
qb.push(")) where ");
|
||||||
.fetch_optional(&self.pool)
|
build_predicate(&mut qb, &selector);
|
||||||
.await?
|
qb.push(" returning *");
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
|
||||||
}
|
qb.build_query_as()
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add_domain_to_key(
|
async fn remove_domain_from_key<S>(
|
||||||
&self,
|
&self,
|
||||||
selector: KeySelector<Self::Key>,
|
selector: S,
|
||||||
domain: D,
|
domain: D,
|
||||||
) -> Result<Self::Key, Self::Error> {
|
) -> Result<Self::Key, Self::Error>
|
||||||
match &selector {
|
where
|
||||||
KeySelector::Key(key) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
S: IntoSelector<Self::Key, Self::Domain>,
|
||||||
"update api_keys set domains = __unique_jsonb_array(domains || \
|
{
|
||||||
jsonb_build_array($1)) where key=$2 returning *",
|
let selector = selector.into_selector();
|
||||||
)
|
|
||||||
.bind(sqlx::types::Json(domain))
|
let mut qb = QueryBuilder::new(
|
||||||
.bind(key)
|
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, ",
|
||||||
|
);
|
||||||
|
qb.push_bind(sqlx::types::Json(domain));
|
||||||
|
qb.push("), '[]'::jsonb) where ");
|
||||||
|
build_predicate(&mut qb, &selector);
|
||||||
|
qb.push(" returning *");
|
||||||
|
|
||||||
|
qb.build_query_as()
|
||||||
.fetch_optional(&self.pool)
|
.fetch_optional(&self.pool)
|
||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||||
KeySelector::Id(id) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
|
||||||
"update api_keys set domains = __unique_jsonb_array(domains || \
|
|
||||||
jsonb_build_array($1)) where id=$2 returning *",
|
|
||||||
)
|
|
||||||
.bind(sqlx::types::Json(domain))
|
|
||||||
.bind(id)
|
|
||||||
.fetch_optional(&self.pool)
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn remove_domain_from_key(
|
async fn set_domains_for_key<S>(
|
||||||
&self,
|
&self,
|
||||||
selector: KeySelector<Self::Key>,
|
selector: S,
|
||||||
domain: D,
|
|
||||||
) -> Result<Self::Key, Self::Error> {
|
|
||||||
match &selector {
|
|
||||||
KeySelector::Key(key) => sqlx::query_as(
|
|
||||||
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \
|
|
||||||
'[]'::jsonb) where key=$2 returning *",
|
|
||||||
)
|
|
||||||
.bind(sqlx::types::Json(domain))
|
|
||||||
.bind(key)
|
|
||||||
.fetch_optional(&self.pool)
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
|
||||||
KeySelector::Id(id) => sqlx::query_as(
|
|
||||||
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, $1), \
|
|
||||||
'[]'::jsonb) where id=$2 returning *",
|
|
||||||
)
|
|
||||||
.bind(sqlx::types::Json(domain))
|
|
||||||
.bind(id)
|
|
||||||
.fetch_optional(&self.pool)
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn set_domains_for_key(
|
|
||||||
&self,
|
|
||||||
selector: KeySelector<Self::Key>,
|
|
||||||
domains: Vec<D>,
|
domains: Vec<D>,
|
||||||
) -> Result<Self::Key, Self::Error> {
|
) -> Result<Self::Key, Self::Error>
|
||||||
match &selector {
|
where
|
||||||
KeySelector::Key(key) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
S: IntoSelector<Self::Key, Self::Domain>,
|
||||||
"update api_keys set domains = $1 where key=$2 returning *",
|
{
|
||||||
)
|
let selector = selector.into_selector();
|
||||||
.bind(sqlx::types::Json(domains))
|
|
||||||
.bind(key)
|
|
||||||
.fetch_optional(&self.pool)
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
|
||||||
|
|
||||||
KeySelector::Id(id) => sqlx::query_as::<sqlx::Postgres, PgKey<D>>(
|
let mut qb = QueryBuilder::new("update api_keys set domains = ");
|
||||||
"update api_keys set domains = $1 where id=$2 returning *",
|
qb.push_bind(sqlx::types::Json(domains));
|
||||||
)
|
qb.push(" where ");
|
||||||
.bind(sqlx::types::Json(domains))
|
build_predicate(&mut qb, &selector);
|
||||||
.bind(id)
|
qb.push(" returning *");
|
||||||
|
|
||||||
|
qb.build_query_as()
|
||||||
.fetch_optional(&self.pool)
|
.fetch_optional(&self.pool)
|
||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| PgStorageError::KeyNotFound(selector)),
|
.ok_or_else(|| PgStorageError::KeyNotFound(selector))
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -752,7 +784,7 @@ pub(crate) mod test {
|
||||||
async fn test_read_user_keys() {
|
async fn test_read_user_keys() {
|
||||||
let (storage, _) = setup().await;
|
let (storage, _) = setup().await;
|
||||||
|
|
||||||
let keys = storage.read_user_keys(1).await.unwrap();
|
let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
|
||||||
assert_eq!(keys.len(), 1);
|
assert_eq!(keys.len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -777,7 +809,7 @@ pub(crate) mod test {
|
||||||
_ = storage.acquire_key(Domain::All).await.unwrap();
|
_ = storage.acquire_key(Domain::All).await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let keys = storage.read_user_keys(1).await.unwrap();
|
let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
|
||||||
assert_eq!(keys.len(), 2);
|
assert_eq!(keys.len(), 2);
|
||||||
for key in keys {
|
for key in keys {
|
||||||
assert_eq!(key.uses, 5);
|
assert_eq!(key.uses, 5);
|
||||||
|
@ -791,7 +823,7 @@ pub(crate) mod test {
|
||||||
assert!(storage.flag_key(key, 2).await.unwrap());
|
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||||
|
|
||||||
match storage.acquire_key(Domain::All).await.unwrap_err() {
|
match storage.acquire_key(Domain::All).await.unwrap_err() {
|
||||||
PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All),
|
PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))),
|
||||||
why => panic!("Expected domain unavailable error but found '{why}'"),
|
why => panic!("Expected domain unavailable error but found '{why}'"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -803,7 +835,7 @@ pub(crate) mod test {
|
||||||
assert!(storage.flag_key(key, 2).await.unwrap());
|
assert!(storage.flag_key(key, 2).await.unwrap());
|
||||||
|
|
||||||
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
|
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
|
||||||
PgStorageError::Unavailable(d) => assert_eq!(d, Domain::All),
|
PgStorageError::Unavailable(d) => assert!(matches!(d, KeySelector::Has(Domain::All))),
|
||||||
why => panic!("Expected domain unavailable error but found '{why}'"),
|
why => panic!("Expected domain unavailable error but found '{why}'"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -877,7 +909,7 @@ pub(crate) mod test {
|
||||||
set.join_next().await.unwrap().unwrap();
|
set.join_next().await.unwrap().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let keys = storage.read_user_keys(1).await.unwrap();
|
let keys = storage.read_keys(KeySelector::UserId(1)).await.unwrap();
|
||||||
|
|
||||||
assert_eq!(keys.len(), 25);
|
assert_eq!(keys.len(), 25);
|
||||||
|
|
||||||
|
@ -952,7 +984,7 @@ pub(crate) mod test {
|
||||||
async fn query_key() {
|
async fn query_key() {
|
||||||
let (storage, _) = setup().await;
|
let (storage, _) = setup().await;
|
||||||
|
|
||||||
let key = storage.query_key(Domain::All).await.unwrap();
|
let key = storage.read_key(Domain::All).await.unwrap();
|
||||||
assert!(key.is_some());
|
assert!(key.is_some());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -960,7 +992,7 @@ pub(crate) mod test {
|
||||||
async fn query_nonexistent_key() {
|
async fn query_nonexistent_key() {
|
||||||
let (storage, _) = setup().await;
|
let (storage, _) = setup().await;
|
||||||
|
|
||||||
let key = storage.query_key(Domain::Guild { id: 0 }).await.unwrap();
|
let key = storage.read_key(Domain::Guild { id: 0 }).await.unwrap();
|
||||||
assert!(key.is_none());
|
assert!(key.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -968,7 +1000,38 @@ pub(crate) mod test {
|
||||||
async fn query_all() {
|
async fn query_all() {
|
||||||
let (storage, _) = setup().await;
|
let (storage, _) = setup().await;
|
||||||
|
|
||||||
let keys = storage.query_all(Domain::All).await.unwrap();
|
let keys = storage.read_keys(Domain::All).await.unwrap();
|
||||||
assert!(keys.len() == 1);
|
assert!(keys.len() == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn query_by_id() {
|
||||||
|
let (storage, _) = setup().await;
|
||||||
|
let key = storage.read_key(KeySelector::Id(1)).await.unwrap();
|
||||||
|
|
||||||
|
assert!(key.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn query_by_key() {
|
||||||
|
let (storage, key) = setup().await;
|
||||||
|
let key = storage.read_key(KeySelector::Key(key.key)).await.unwrap();
|
||||||
|
|
||||||
|
assert!(key.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn query_by_set() {
|
||||||
|
let (storage, _key) = setup().await;
|
||||||
|
let key = storage
|
||||||
|
.read_key(KeySelector::OneOf(vec![
|
||||||
|
Domain::All,
|
||||||
|
Domain::Guild { id: 0 },
|
||||||
|
Domain::Faction { id: 0 },
|
||||||
|
]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(key.is_some());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ use torn_api::{
|
||||||
ApiRequest, ApiResponse, ApiSelection, ResponseError,
|
ApiRequest, ApiResponse, ApiSelection, ResponseError,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
|
use crate::{ApiKey, IntoSelector, KeyPoolError, KeyPoolExecutor, KeyPoolStorage};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
|
||||||
|
@ -30,7 +30,7 @@ where
|
||||||
loop {
|
loop {
|
||||||
let key = self
|
let key = self
|
||||||
.storage
|
.storage
|
||||||
.acquire_key(self.domain.clone())
|
.acquire_key(self.selector.clone())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
|
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
|
||||||
let url = request.url(key.value(), id.as_deref());
|
let url = request.url(key.value(), id.as_deref());
|
||||||
|
@ -66,7 +66,7 @@ where
|
||||||
{
|
{
|
||||||
let keys = match self
|
let keys = match self
|
||||||
.storage
|
.storage
|
||||||
.acquire_many_keys(self.domain.clone(), ids.len() as i64)
|
.acquire_many_keys(self.selector.clone(), ids.len() as i64)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(keys) => keys,
|
Ok(keys) => keys,
|
||||||
|
@ -114,7 +114,7 @@ where
|
||||||
Ok(res) => return (id, Ok(res)),
|
Ok(res) => return (id, Ok(res)),
|
||||||
};
|
};
|
||||||
|
|
||||||
key = match self.storage.acquire_key(self.domain.clone()).await {
|
key = match self.storage.acquire_key(self.selector.clone()).await {
|
||||||
Ok(k) => k,
|
Ok(k) => k,
|
||||||
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
|
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
|
||||||
};
|
};
|
||||||
|
@ -150,25 +150,36 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn torn_api(&self, domain: S::Domain) -> ApiProvider<C, KeyPoolExecutor<C, S>> {
|
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
|
||||||
|
where
|
||||||
|
I: IntoSelector<S::Key, S::Domain>,
|
||||||
|
{
|
||||||
ApiProvider::new(
|
ApiProvider::new(
|
||||||
&self.client,
|
&self.client,
|
||||||
KeyPoolExecutor::new(&self.storage, domain, self.comment.as_deref()),
|
KeyPoolExecutor::new(
|
||||||
|
&self.storage,
|
||||||
|
selector.into_selector(),
|
||||||
|
self.comment.as_deref(),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait WithStorage {
|
pub trait WithStorage {
|
||||||
fn with_storage<'a, S>(
|
fn with_storage<'a, S, I>(
|
||||||
&'a self,
|
&'a self,
|
||||||
storage: &'a S,
|
storage: &'a S,
|
||||||
domain: S::Domain,
|
selector: I,
|
||||||
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
|
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
|
||||||
where
|
where
|
||||||
Self: ApiClient + Sized,
|
Self: ApiClient + Sized,
|
||||||
S: KeyPoolStorage + Send + Sync + 'static,
|
S: KeyPoolStorage + Send + Sync + 'static,
|
||||||
|
I: IntoSelector<S::Key, S::Domain>,
|
||||||
{
|
{
|
||||||
ApiProvider::new(self, KeyPoolExecutor::new(storage, domain, None))
|
ApiProvider::new(
|
||||||
|
self,
|
||||||
|
KeyPoolExecutor::new(storage, selector.into_selector(), None),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue