From c320036cd0d5b177eef1174e59d7029581f544a3 Mon Sep 17 00:00:00 2001 From: TotallyNot <44345987+TotallyNot@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:34:38 +0100 Subject: [PATCH] added faction->chain, changed `selections()` signature --- torn-api-macros/Cargo.toml | 2 +- torn-api-macros/src/lib.rs | 46 +++++++---- torn-api/Cargo.toml | 4 +- torn-api/src/faction.rs | 146 ++++++++++++++++++++++++++++++++-- torn-api/src/key.rs | 2 +- torn-api/src/lib.rs | 36 +++++++-- torn-api/src/torn.rs | 14 ++-- torn-api/src/user.rs | 116 +++++++++++++++++++++++++-- torn-key-pool/src/postgres.rs | 2 +- 9 files changed, 323 insertions(+), 45 deletions(-) diff --git a/torn-api-macros/Cargo.toml b/torn-api-macros/Cargo.toml index 6a7b6c4..e860b4b 100644 --- a/torn-api-macros/Cargo.toml +++ b/torn-api-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-api-macros" -version = "0.1.2" +version = "0.2.0" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" diff --git a/torn-api-macros/src/lib.rs b/torn-api-macros/src/lib.rs index 1a0086c..320d941 100644 --- a/torn-api-macros/src/lib.rs +++ b/torn-api-macros/src/lib.rs @@ -44,7 +44,8 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream { } else { Err(meta.error("unknown attribute")) } - }).unwrap(); + }) + .unwrap(); } } @@ -78,17 +79,18 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream { } else { Err(meta.error("unsupported attribute")) } - }).unwrap(); + }) + .unwrap(); let name = format_ident!("{}", variant.ident.to_string().to_case(Case::Snake)); let raw_value = variant.ident.to_string().to_lowercase(); - return Some(ApiAttribute { + return Some(ApiAttribute { field: field.expect("field or flatten attribute must be specified"), raw_value, variant: variant.ident.clone(), type_name: r#type.expect("type must be specified").parse().unwrap(), name, - with - }) + with, + }); } } None @@ -154,7 +156,7 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream { } impl crate::ApiSelection for #name { - fn raw_value(&self) -> &'static str { + fn raw_value(self) -> &'static str { match self { #(#raw_values,)* } @@ -180,19 +182,25 @@ fn to_static_lt(ty: &mut syn::Type) -> bool { let mut res = false; match ty { syn::Type::Path(path) => { - if let Some(syn::PathArguments::AngleBracketed(ab)) = path.path.segments.last_mut().map(|s| &mut s.arguments).as_mut() { + if let Some(syn::PathArguments::AngleBracketed(ab)) = path + .path + .segments + .last_mut() + .map(|s| &mut s.arguments) + .as_mut() + { for mut arg in &mut ab.args { match &mut arg { syn::GenericArgument::Type(ty) => { if to_static_lt(ty) { res = true; } - }, + } syn::GenericArgument::Lifetime(lt) => { lt.ident = syn::Ident::new("static", proc_macro2::Span::call_site()); res = true; } - _ => () + _ => (), } } } @@ -204,7 +212,7 @@ fn to_static_lt(ty: &mut syn::Type) -> bool { } to_static_lt(&mut r.elem); } - _ => () + _ => (), }; res } @@ -223,7 +231,8 @@ fn impl_into_owned(ast: &syn::DeriveInput) -> TokenStream { } else { Err(meta.error("unknown attribute")) } - }).unwrap(); + }) + .unwrap(); } } @@ -235,7 +244,8 @@ fn impl_into_owned(ast: &syn::DeriveInput) -> TokenStream { self } } - }.into() + } + .into(); } let syn::Data::Struct(r#struct) = &ast.data else { @@ -263,15 +273,21 @@ fn impl_into_owned(ast: &syn::DeriveInput) -> TokenStream { let vis = &field.vis; if to_static_lt(&mut ty) { - owned_fields.push(quote! { #vis #field_name: <#ty as crate::into_owned::IntoOwned>::Owned }); - fields.push(quote! { #field_name: crate::into_owned::IntoOwned::into_owned(self.#field_name) }); + owned_fields + .push(quote! { #vis #field_name: <#ty as crate::into_owned::IntoOwned>::Owned }); + fields.push( + quote! { #field_name: crate::into_owned::IntoOwned::into_owned(self.#field_name) }, + ); } else { owned_fields.push(quote! { #vis #field_name: #ty }); fields.push(quote! { #field_name: self.#field_name }); }; } - let owned_name = syn::Ident::new(&format!("{}Owned", ast.ident), proc_macro2::Span::call_site()); + let owned_name = syn::Ident::new( + &format!("{}Owned", ast.ident), + proc_macro2::Span::call_site(), + ); let gen = quote! { #[derive(Debug, Clone)] diff --git a/torn-api/Cargo.toml b/torn-api/Cargo.toml index ba77a93..20308b5 100644 --- a/torn-api/Cargo.toml +++ b/torn-api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-api" -version = "0.5.28" +version = "0.6.0" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" @@ -37,7 +37,7 @@ reqwest = { version = "0.11", default-features = false, features = [ "json" ], o awc = { version = "3", default-features = false, optional = true } rust_decimal = { version = "1", default-features = false, optional = true, features = [ "serde" ] } -torn-api-macros = { path = "../torn-api-macros", version = "0.1.2" } +torn-api-macros = { path = "../torn-api-macros", version = "0.2" } [dev-dependencies] actix-rt = { version = "2.7.0" } diff --git a/torn-api/src/faction.rs b/torn-api/src/faction.rs index fdebbd3..2756b5f 100644 --- a/torn-api/src/faction.rs +++ b/torn-api/src/faction.rs @@ -1,7 +1,10 @@ use std::collections::{BTreeMap, HashMap}; -use chrono::{DateTime, Utc}; -use serde::Deserialize; +use chrono::{DateTime, TimeZone, Utc}; +use serde::{ + de::{Error, Unexpected, Visitor}, + Deserialize, Deserializer, +}; use torn_api_macros::{ApiCategory, IntoOwned}; @@ -28,6 +31,9 @@ pub enum FactionSelection { with = "null_is_empty_dict" )] Territory, + + #[api(type = "Option", field = "chain", with = "deserialize_chain")] + Chain, } pub type Selection = FactionSelection; @@ -80,6 +86,128 @@ pub struct Basic<'a> { pub territory_wars: Vec>, } +#[derive(Debug)] +pub struct Chain { + pub current: i32, + pub max: i32, + #[cfg(feature = "decimal")] + pub modifier: rust_decimal::Decimal, + pub timeout: Option, + pub cooldown: Option, + pub start: DateTime, + pub end: DateTime, +} + +fn deserialize_chain<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + struct ChainVisitor; + + impl<'de> Visitor<'de> for ChainVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct Chain") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + #[derive(Deserialize)] + #[serde(rename_all = "snake_case")] + enum Fields { + Current, + Max, + Modifier, + Timeout, + Cooldown, + Start, + End, + #[serde(other)] + Ignore, + } + + let mut current = None; + let mut max = None; + #[cfg(feature = "decimal")] + let mut modifier = None; + let mut timeout = None; + let mut cooldown = None; + let mut start = None; + let mut end = None; + + while let Some(key) = map.next_key()? { + match key { + Fields::Current => { + let value = map.next_value()?; + if value != 0 { + current = Some(value); + } + } + Fields::Max => { + max = Some(map.next_value()?); + } + Fields::Modifier => { + #[cfg(feature = "decimal")] + { + modifier = Some(map.next_value()?); + } + } + Fields::Timeout => { + match map.next_value()? { + 0 => timeout = Some(None), + val => timeout = Some(Some(val)), + }; + } + Fields::Cooldown => { + match map.next_value()? { + 0 => cooldown = Some(None), + val => cooldown = Some(Some(val)), + }; + } + Fields::Start => { + let ts: i64 = map.next_value()?; + start = Some(Utc.timestamp_opt(ts, 0).single().ok_or_else(|| { + A::Error::invalid_value(Unexpected::Signed(ts), &"Epoch timestamp") + })?); + } + Fields::End => { + let ts: i64 = map.next_value()?; + end = Some(Utc.timestamp_opt(ts, 0).single().ok_or_else(|| { + A::Error::invalid_value(Unexpected::Signed(ts), &"Epoch timestamp") + })?); + } + Fields::Ignore => (), + } + } + + let Some(current) = current else { + return Ok(None); + }; + let max = max.ok_or_else(|| A::Error::missing_field("max"))?; + let timeout = timeout.ok_or_else(|| A::Error::missing_field("timeout"))?; + let cooldown = cooldown.ok_or_else(|| A::Error::missing_field("cooldown"))?; + let start = start.ok_or_else(|| A::Error::missing_field("start"))?; + let end = end.ok_or_else(|| A::Error::missing_field("end"))?; + + Ok(Some(Chain { + current, + max, + #[cfg(feature = "decimal")] + modifier: modifier.ok_or_else(|| A::Error::missing_field("modifier"))?, + timeout, + cooldown, + start, + end, + })) + } + } + + deserializer.deserialize_map(ChainVisitor) +} + #[cfg(test)] mod tests { use super::*; @@ -92,7 +220,12 @@ mod tests { let response = Client::default() .torn_api(key) .faction(|b| { - b.selections(&[Selection::Basic, Selection::Attacks, Selection::Territory]) + b.selections([ + Selection::Basic, + Selection::Attacks, + Selection::Territory, + Selection::Chain, + ]) }) .await .unwrap(); @@ -101,6 +234,7 @@ mod tests { response.attacks().unwrap(); response.attacks_full().unwrap(); response.territory().unwrap(); + response.chain().unwrap(); } #[async_test] @@ -111,13 +245,14 @@ mod tests { .torn_api(key) .faction(|b| { b.id(7049) - .selections(&[Selection::Basic, Selection::Territory]) + .selections([Selection::Basic, Selection::Territory, Selection::Chain]) }) .await .unwrap(); response.basic().unwrap(); response.territory().unwrap(); + response.chain().unwrap(); } #[async_test] @@ -128,12 +263,13 @@ mod tests { .torn_api(key) .faction(|b| { b.id(8981) - .selections(&[Selection::Basic, Selection::Territory]) + .selections([Selection::Basic, Selection::Territory, Selection::Chain]) }) .await .unwrap(); response.basic().unwrap(); response.territory().unwrap(); + assert!(response.chain().unwrap().is_none()); } } diff --git a/torn-api/src/key.rs b/torn-api/src/key.rs index a39a943..08776ac 100644 --- a/torn-api/src/key.rs +++ b/torn-api/src/key.rs @@ -247,7 +247,7 @@ mod tests { let response = Client::default() .torn_api(key) - .key(|b| b.selections(&[Selection::Info])) + .key(|b| b.selections([Selection::Info])) .await .unwrap(); diff --git a/torn-api/src/lib.rs b/torn-api/src/lib.rs index 835a93f..9f16c63 100644 --- a/torn-api/src/lib.rs +++ b/torn-api/src/lib.rs @@ -45,7 +45,16 @@ pub enum ResponseError { Api { code: u8, reason: String }, #[error(transparent)] - Parsing(#[from] serde_json::Error), + MalformedResponse(#[from] serde_json::Error), +} + +impl ResponseError { + pub fn api_code(&self) -> Option { + match self { + Self::Api { code, .. } => Some(*code), + _ => None, + } + } } impl ApiResponse { @@ -100,7 +109,7 @@ impl ApiResponse { } pub trait ApiSelection: Send + Sync { - fn raw_value(&self) -> &'static str; + fn raw_value(self) -> &'static str; fn category() -> &'static str; } @@ -137,6 +146,18 @@ where Response(#[from] ResponseError), } +impl ApiClientError +where + C: std::error::Error, +{ + pub fn api_code(&self) -> Option { + match self { + Self::Response(err) => err.api_code(), + _ => None, + } + } +} + #[derive(Debug)] pub struct ApiRequest where @@ -218,10 +239,13 @@ where A: ApiSelection, { #[must_use] - pub fn selections(mut self, selections: &[A]) -> Self { - self.request - .selections - .append(&mut selections.iter().map(ApiSelection::raw_value).collect()); + pub fn selections(mut self, selections: impl IntoIterator) -> Self { + self.request.selections.append( + &mut selections + .into_iter() + .map(ApiSelection::raw_value) + .collect(), + ); self } diff --git a/torn-api/src/torn.rs b/torn-api/src/torn.rs index 8b5b3b9..952d825 100644 --- a/torn-api/src/torn.rs +++ b/torn-api/src/torn.rs @@ -222,7 +222,7 @@ mod tests { let response = Client::default() .torn_api(key) .torn(|b| { - b.selections(&[ + b.selections([ TornSelection::Competition, TornSelection::TerritoryWars, TornSelection::Rackets, @@ -242,7 +242,7 @@ mod tests { let response = Client::default() .torn_api(key) - .torn(|b| b.selections(&[Selection::Territory]).id("NSC")) + .torn(|b| b.selections([Selection::Territory]).id("NSC")) .await .unwrap(); @@ -256,7 +256,7 @@ mod tests { let response = Client::default() .torn_api(key) - .torn(|b| b.selections(&[Selection::Territory]).id("AAA")) + .torn(|b| b.selections([Selection::Territory]).id("AAA")) .await .unwrap(); @@ -269,7 +269,7 @@ mod tests { let response = Client::default() .torn_api(&key) - .torn(|b| b.selections(&[Selection::TerritoryWarReport]).id(37403)) + .torn(|b| b.selections([Selection::TerritoryWarReport]).id(37403)) .await .unwrap(); @@ -280,7 +280,7 @@ mod tests { let response = Client::default() .torn_api(&key) - .torn(|b| b.selections(&[Selection::TerritoryWarReport]).id(37502)) + .torn(|b| b.selections([Selection::TerritoryWarReport]).id(37502)) .await .unwrap(); @@ -291,7 +291,7 @@ mod tests { let response = Client::default() .torn_api(&key) - .torn(|b| b.selections(&[Selection::TerritoryWarReport]).id(37860)) + .torn(|b| b.selections([Selection::TerritoryWarReport]).id(37860)) .await .unwrap(); @@ -302,7 +302,7 @@ mod tests { let response = Client::default() .torn_api(&key) - .torn(|b| b.selections(&[Selection::TerritoryWarReport]).id(23757)) + .torn(|b| b.selections([Selection::TerritoryWarReport]).id(23757)) .await .unwrap(); diff --git a/torn-api/src/user.rs b/torn-api/src/user.rs index bfb04bf..bcc2185 100644 --- a/torn-api/src/user.rs +++ b/torn-api/src/user.rs @@ -145,8 +145,11 @@ pub struct Basic<'a> { #[derive(Debug, Clone, IntoOwned, PartialEq, Eq, Deserialize)] #[into_owned(identity)] pub struct Discord { - #[serde(rename = "userID")] - pub user_id: i32, + #[serde( + rename = "userID", + deserialize_with = "de_util::empty_string_int_option" + )] + pub user_id: Option, #[serde(rename = "discordID", deserialize_with = "de_util::string_is_long")] pub discord_id: Option, } @@ -342,6 +345,7 @@ pub struct Profile<'a> { pub last_action: LastAction, #[serde(deserialize_with = "deserialize_faction")] pub faction: Option>, + pub job: EmploymentStatus, pub status: Status<'a>, #[serde(deserialize_with = "deserialize_comp")] @@ -484,6 +488,104 @@ impl<'de> Deserialize<'de> for Icon { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[non_exhaustive] +pub enum Job { + Director, + Employee, + Education, + Army, + Law, + Casino, + Medical, + Grocer, + #[serde(other)] + Other, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Company { + PlayerRun { + name: String, + id: i32, + company_type: u8, + }, + CityJob, +} + +impl<'de> Deserialize<'de> for Company { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct CompanyVisitor; + + impl<'de> Visitor<'de> for CompanyVisitor { + type Value = Company; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("enum Company") + } + + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + #[allow(clippy::enum_variant_names)] + #[derive(Deserialize)] + #[serde(rename_all = "snake_case")] + enum Field { + CompanyId, + CompanyName, + CompanyType, + #[serde(other)] + Other, + } + + let mut id = None; + let mut name = None; + let mut company_type = None; + + while let Some(key) = map.next_key()? { + match key { + Field::CompanyId => { + id = Some(map.next_value()?); + if id == Some(0) { + return Ok(Company::CityJob); + } + } + Field::CompanyType => company_type = Some(map.next_value()?), + Field::CompanyName => { + name = Some(map.next_value()?); + } + Field::Other => (), + } + } + + let id = id.ok_or_else(|| de::Error::missing_field("company_id"))?; + let name = name.ok_or_else(|| de::Error::missing_field("company_name"))?; + let company_type = + company_type.ok_or_else(|| de::Error::missing_field("company_type"))?; + + Ok(Company::PlayerRun { + name, + id, + company_type, + }) + } + } + + deserializer.deserialize_map(CompanyVisitor) + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct EmploymentStatus { + pub job: Job, + #[serde(flatten)] + pub company: Company, +} + #[cfg(test)] mod tests { use super::*; @@ -496,7 +598,7 @@ mod tests { let response = Client::default() .torn_api(key) .user(|b| { - b.selections(&[ + b.selections([ Selection::Basic, Selection::Discord, Selection::Profile, @@ -523,7 +625,7 @@ mod tests { let response = Client::default() .torn_api(key) - .user(|b| b.id(28).selections(&[Selection::Profile])) + .user(|b| b.id(28).selections([Selection::Profile])) .await .unwrap(); @@ -539,7 +641,7 @@ mod tests { let response = Client::default() .torn_api(key) .users([1, 2111649, 374272176892674048i64], |b| { - b.selections(&[Selection::Basic]) + b.selections([Selection::Basic]) }) .await; @@ -553,7 +655,7 @@ mod tests { let response = Client::default() .torn_api(key) - .user(|b| b.id(374272176892674048i64).selections(&[Selection::Basic])) + .user(|b| b.id(374272176892674048i64).selections([Selection::Basic])) .await .unwrap(); @@ -566,7 +668,7 @@ mod tests { let response = Client::default() .torn_api(key) - .user(|b| b.id(1900654).selections(&[Selection::Icons])) + .user(|b| b.id(1900654).selections([Selection::Icons])) .await .unwrap(); diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 8899259..c05997a 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -1,4 +1,4 @@ -use async_trait::async_trait; + use indoc::indoc; use sqlx::{FromRow, PgPool, Postgres, QueryBuilder}; use thiserror::Error;