diff --git a/torn-api/Cargo.toml b/torn-api/Cargo.toml index a979bcc..c71e2cf 100644 --- a/torn-api/Cargo.toml +++ b/torn-api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-api" -version = "0.3.2" +version = "0.4.0" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" @@ -8,6 +8,10 @@ repository = "https://github.com/TotallyNot/torn-api.rs.git" homepage = "https://github.com/TotallyNot/torn-api.rs.git" description = "Torn API bindings for rust" +[[bench]] +name = "deserialisation_benchmark" +harness = false + [features] default = [ "reqwest" ] reqwest = [ "dep:reqwest" ] @@ -33,3 +37,4 @@ tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] } tokio-test = "0.4.2" reqwest = { version = "0.11", default-features = true } awc = { version = "3", features = [ "rustls" ] } +criterion = "0.3" diff --git a/torn-api/benches/deserialisation_benchmark.rs b/torn-api/benches/deserialisation_benchmark.rs new file mode 100644 index 0000000..359f087 --- /dev/null +++ b/torn-api/benches/deserialisation_benchmark.rs @@ -0,0 +1,65 @@ +use criterion::{criterion_group, criterion_main, Criterion}; +use torn_api::{faction, user, ThreadSafeApiClient}; + +pub fn user_benchmark(c: &mut Criterion) { + dotenv::dotenv().unwrap(); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + let response = rt.block_on(async { + let key = std::env::var("APIKEY").expect("api key"); + let client = reqwest::Client::default(); + + client + .torn_api(key) + .user(|b| { + b.selections(&[ + user::Selection::Basic, + user::Selection::Discord, + user::Selection::Profile, + user::Selection::PersonalStats, + ]) + }) + .await + .unwrap() + }); + + c.bench_function("user deserialize", |b| { + b.iter(|| { + response.basic().unwrap(); + response.discord().unwrap(); + response.profile().unwrap(); + response.personal_stats().unwrap(); + }) + }); +} + +pub fn faction_benchmark(c: &mut Criterion) { + dotenv::dotenv().unwrap(); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + let response = rt.block_on(async { + let key = std::env::var("APIKEY").expect("api key"); + let client = reqwest::Client::default(); + + client + .torn_api(key) + .faction(|b| b.selections(&[faction::Selection::Basic])) + .await + .unwrap() + }); + + c.bench_function("user deserialize", |b| { + b.iter(|| { + response.basic().unwrap(); + }) + }); +} + +criterion_group!(benches, user_benchmark, faction_benchmark); +criterion_main!(benches); diff --git a/torn-api/src/de_util.rs b/torn-api/src/de_util.rs index 6a6056d..01802ec 100644 --- a/torn-api/src/de_util.rs +++ b/torn-api/src/de_util.rs @@ -1,5 +1,4 @@ use chrono::{DateTime, NaiveDateTime, Utc}; -use num_traits::{PrimInt, Zero}; use serde::de::{Deserialize, Deserializer, Error, Unexpected}; pub fn empty_string_is_none<'de, D>(deserializer: D) -> Result, D::Error> @@ -14,13 +13,18 @@ where } } -pub fn string_is_long<'de, D>(deserializer: D) -> Result +pub fn string_is_long<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { let s = String::deserialize(deserializer)?; - s.parse() - .map_err(|_e| Error::invalid_type(Unexpected::Str(&s), &"i64")) + if s.is_empty() { + Ok(None) + } else { + s.parse() + .map(Some) + .map_err(|_e| Error::invalid_type(Unexpected::Str(&s), &"i64")) + } } pub fn zero_date_is_none<'de, D>(deserializer: D) -> Result>, D::Error> @@ -35,28 +39,3 @@ where Ok(Some(DateTime::from_utc(naive, Utc))) } } - -pub fn zero_is_none<'de, D, I>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, - I: PrimInt + Zero + Deserialize<'de>, -{ - let i = I::deserialize(deserializer)?; - if i == I::zero() { - Ok(None) - } else { - Ok(Some(i)) - } -} - -pub fn none_is_none<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - if s == "None" { - Ok(None) - } else { - Ok(Some(s)) - } -} diff --git a/torn-api/src/lib.rs b/torn-api/src/lib.rs index 8d129bd..9521e76 100644 --- a/torn-api/src/lib.rs +++ b/torn-api/src/lib.rs @@ -13,6 +13,7 @@ mod de_util; use async_trait::async_trait; use chrono::{DateTime, Utc}; +use num_traits::{AsPrimitive, PrimInt}; use serde::de::{DeserializeOwned, Error as DeError}; use thiserror::Error; @@ -363,8 +364,11 @@ where } #[must_use] - pub fn id(mut self, id: u64) -> Self { - self.request.id = Some(id); + pub fn id(mut self, id: I) -> Self + where + I: PrimInt + AsPrimitive, + { + self.request.id = Some(id.as_()); self } diff --git a/torn-api/src/user.rs b/torn-api/src/user.rs index d719184..88b7ec5 100644 --- a/torn-api/src/user.rs +++ b/torn-api/src/user.rs @@ -34,19 +34,98 @@ pub struct LastAction { pub timestamp: DateTime, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone)] pub struct Faction { - #[serde(deserialize_with = "de_util::zero_is_none")] - pub faction_id: Option, - #[serde(deserialize_with = "de_util::none_is_none")] - pub faction_name: Option, - #[serde(deserialize_with = "de_util::zero_is_none")] - pub days_in_faction: Option, - #[serde(deserialize_with = "de_util::none_is_none")] - pub position: Option, + pub faction_id: i32, + pub faction_name: String, + pub days_in_faction: i16, + pub position: String, pub faction_tag: Option, } +fn deserialize_faction<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + #[derive(Deserialize)] + #[serde(rename_all = "snake_case")] + enum Field { + FactionId, + FactionName, + DaysInFaction, + Position, + FactionTag, + } + + struct FactionVisitor; + + impl<'de> Visitor<'de> for FactionVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct Faction") + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut faction_id = None; + let mut faction_name = None; + let mut days_in_faction = None; + let mut position = None; + let mut faction_tag = None; + + while let Some(key) = map.next_key()? { + match key { + Field::FactionId => { + faction_id = Some(map.next_value()?); + } + Field::FactionName => { + faction_name = Some(map.next_value()?); + } + Field::DaysInFaction => { + days_in_faction = Some(map.next_value()?); + } + Field::Position => { + position = Some(map.next_value()?); + } + Field::FactionTag => { + faction_tag = map.next_value()?; + } + } + } + let faction_id = faction_id.ok_or_else(|| de::Error::missing_field("faction_id"))?; + let faction_name = + faction_name.ok_or_else(|| de::Error::missing_field("faction_name"))?; + let days_in_faction = + days_in_faction.ok_or_else(|| de::Error::missing_field("days_in_faction"))?; + let position = position.ok_or_else(|| de::Error::missing_field("position"))?; + + if faction_id == 0 { + Ok(None) + } else { + Ok(Some(Faction { + faction_id, + faction_name, + days_in_faction, + position, + faction_tag, + })) + } + } + } + + const FIELDS: &[&str] = &[ + "faction_id", + "faction_name", + "days_in_faction", + "position", + "faction_tag", + ]; + deserializer.deserialize_struct("Faction", FIELDS, FactionVisitor) +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] pub enum State { Okay, @@ -92,7 +171,7 @@ pub struct Discord { #[serde(rename = "userID")] pub user_id: i32, #[serde(rename = "discordID", deserialize_with = "de_util::string_is_long")] - pub discord_id: i64, + pub discord_id: Option, } #[derive(Debug, Clone, Deserialize)] @@ -160,10 +239,10 @@ where where V: MapAccess<'de>, { - let mut team: Option = None; + let mut team = None; let mut score = None; let mut attacks = None; - let mut name: Option = None; + let mut name = None; while let Some(key) = map.next_key()? { match key { @@ -217,21 +296,27 @@ where } } - match (name, team, score, attacks) { - (Some(CompetitionName::Elimination), Some(team), Some(score), Some(attacks)) => { - Ok(Some(Competition::Elimination { - team, - score, - attacks, - })) + let name = name.ok_or_else(|| de::Error::missing_field("name"))?; + + match name { + CompetitionName::Elimination => { + if let Some(team) = team { + let score = score.ok_or_else(|| de::Error::missing_field("score"))?; + let attacks = attacks.ok_or_else(|| de::Error::missing_field("attacks"))?; + Ok(Some(Competition::Elimination { + team, + score, + attacks, + })) + } else { + Ok(None) + } } - _ => Ok(None), } } } - const FIELDS: &[&str] = &["name", "score", "team", "attacks"]; - deserializer.deserialize_struct("Competition", FIELDS, CompetitionVisitor) + deserializer.deserialize_map(CompetitionVisitor) } #[derive(Debug, Clone, Deserialize)] @@ -245,7 +330,8 @@ pub struct Profile { pub life: LifeBar, pub last_action: LastAction, - pub faction: Faction, + #[serde(deserialize_with = "deserialize_faction")] + pub faction: Option, pub status: Status, #[serde(deserialize_with = "deserialize_comp")] @@ -325,11 +411,7 @@ mod tests { let faction = response.profile().unwrap().faction; - assert!(faction.faction_id.is_none()); - assert!(faction.faction_name.is_none()); - assert!(faction.faction_tag.is_none()); - assert!(faction.days_in_faction.is_none()); - assert!(faction.position.is_none()); + assert!(faction.is_none()); } #[async_test] diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index 66b45d2..4048028 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "0.2.1" +version = "0.3.0" edition = "2021" license = "MIT" repository = "https://github.com/TotallyNot/torn-api.rs.git" @@ -16,7 +16,7 @@ reqwest = [ "dep:reqwest", "torn-api/reqwest" ] awc = [ "dep:awc", "torn-api/awc" ] [dependencies] -torn-api = { path = "../torn-api", default-features = false, version = "0.3" } +torn-api = { path = "../torn-api", default-features = false, version = "0.4" } async-trait = "0.1" thiserror = "1"