diff --git a/torn-api-macros/Cargo.toml b/torn-api-macros/Cargo.toml index 6278c55..239845d 100644 --- a/torn-api-macros/Cargo.toml +++ b/torn-api-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-api-macros" -version = "0.1.0" +version = "0.1.1" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" @@ -12,6 +12,7 @@ description = "Macros implementation of #[derive(ApiCategory)]" proc-macro = true [dependencies] -syn = { version = "1.0", features = [ "extra-traits" ] } -quote = "1.0" +syn = { version = "1", features = [ "extra-traits" ] } +proc-macro2 = "1" +quote = "1" convert_case = "0.5" diff --git a/torn-api-macros/src/lib.rs b/torn-api-macros/src/lib.rs index 964b646..c31ace1 100644 --- a/torn-api-macros/src/lib.rs +++ b/torn-api-macros/src/lib.rs @@ -17,11 +17,12 @@ enum ApiField { #[derive(Debug)] struct ApiAttribute { - type_: syn::Ident, field: ApiField, name: syn::Ident, raw_value: String, variant: syn::Ident, + type_name: proc_macro2::TokenStream, + with: Option, } fn get_lit_string(lit: syn::Lit) -> String { @@ -71,6 +72,7 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream { Ok(syn::Meta::List(l)) => { let mut type_: Option = None; let mut field: Option = None; + let mut with: Option = None; for nested in l.nested.into_iter() { match nested { syn::NestedMeta::Meta(syn::Meta::NameValue(m)) @@ -82,6 +84,15 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream { panic!("type can only be specified once"); } } + syn::NestedMeta::Meta(syn::Meta::NameValue(m)) + if m.path.is_ident("with") => + { + if with.is_none() { + with = Some(get_lit_string(m.lit)); + } else { + panic!("with can only be specified once"); + } + } syn::NestedMeta::Meta(syn::Meta::NameValue(m)) if m.path.is_ident("field") => { @@ -109,12 +120,17 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream { let name = format_ident!("{}", variant.ident.to_string().to_case(Case::Snake)); let raw_value = variant.ident.to_string().to_lowercase(); + return Some(ApiAttribute { - type_: quote::format_ident!("{}", type_.expect("type")), field: field.expect("one of field/flatten"), name, raw_value, variant: variant.ident.clone(), + type_name: type_ + .expect("Need to specify type name") + .parse() + .expect("failed to parse type name"), + with: with.map(|w| format_ident!("{}", w)), }); } _ => panic!("Couldn't parse api attribute"), @@ -127,21 +143,34 @@ fn impl_api_category(ast: &syn::DeriveInput) -> TokenStream { let accessors = fields.iter().map( |ApiAttribute { - type_, field, name, .. - }| match field { - ApiField::Property(prop) => { + field, + name, + type_name, + with, + .. + }| match (field, with) { + (ApiField::Property(prop), None) => { let prop_str = prop.to_string(); quote! { - pub fn #name(&self) -> serde_json::Result<#type_> { + pub fn #name(&self) -> serde_json::Result<#type_name> { self.0.decode_field(#prop_str) } } } - ApiField::Flattened => quote! { - pub fn #name(&self) -> serde_json::Result<#type_> { + (ApiField::Property(prop), Some(f)) => { + let prop_str = prop.to_string(); + quote! { + pub fn #name(&self) -> serde_json::Result<#type_name> { + self.0.decode_field_with(#prop_str, #f) + } + } + } + (ApiField::Flattened, None) => quote! { + pub fn #name(&self) -> serde_json::Result<#type_name> { self.0.decode() } }, + (ApiField::Flattened, Some(_)) => todo!(), }, ); diff --git a/torn-api/Cargo.toml b/torn-api/Cargo.toml index 725ba32..b0cbdc6 100644 --- a/torn-api/Cargo.toml +++ b/torn-api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-api" -version = "0.4.2" +version = "0.5.0" edition = "2021" authors = ["Pyrit [2111649]"] license = "MIT" @@ -24,11 +24,12 @@ chrono = { version = "0.4", features = [ "serde" ], default-features = false } async-trait = "0.1" thiserror = "1" num-traits = "0.2" +futures = "0.3" reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true } awc = { version = "3", default-features = false, optional = true } -torn-api-macros = { path = "../torn-api-macros", version = "0.1" } +torn-api-macros = { path = "../torn-api-macros", version = "0.1.1" } [dev-dependencies] actix-rt = { version = "2.7.0" } diff --git a/torn-api/src/awc.rs b/torn-api/src/awc.rs index 40f435b..7e16904 100644 --- a/torn-api/src/awc.rs +++ b/torn-api/src/awc.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use thiserror::Error; -use crate::ApiClient; +use crate::local::ApiClient; #[derive(Error, Debug)] pub enum AwcApiClientError { diff --git a/torn-api/src/faction.rs b/torn-api/src/faction.rs index 028cab5..d6e46cd 100644 --- a/torn-api/src/faction.rs +++ b/torn-api/src/faction.rs @@ -47,7 +47,7 @@ mod tests { let response = Client::default() .torn_api(key) - .faction(|b| b.selections(&[Selection::Basic])) + .faction(None, |b| b.selections(&[Selection::Basic])) .await .unwrap(); diff --git a/torn-api/src/lib.rs b/torn-api/src/lib.rs index 9521e76..decef82 100644 --- a/torn-api/src/lib.rs +++ b/torn-api/src/lib.rs @@ -1,19 +1,22 @@ #![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)] pub mod faction; +pub mod local; +pub mod send; +pub mod torn; pub mod user; #[cfg(feature = "awc")] -pub mod awc; +mod awc; #[cfg(feature = "reqwest")] -pub mod reqwest; +mod reqwest; mod de_util; -use async_trait::async_trait; +use std::fmt::Write; + use chrono::{DateTime, Utc}; -use num_traits::{AsPrimitive, PrimInt}; use serde::de::{DeserializeOwned, Error as DeError}; use thiserror::Error; @@ -54,20 +57,27 @@ impl ApiResponse { where D: DeserializeOwned, { - serde_json::from_value(self.value.clone()) + D::deserialize(&self.value) } fn decode_field(&self, field: &'static str) -> serde_json::Result where D: DeserializeOwned, { - let value = self - .value + self.value .get(field) - .ok_or_else(|| serde_json::Error::missing_field(field))? - .clone(); + .ok_or_else(|| serde_json::Error::missing_field(field)) + .and_then(D::deserialize) + } - serde_json::from_value(value) + fn decode_field_with<'de, V, F>(&'de self, field: &'static str, fun: F) -> serde_json::Result + where + F: FnOnce(&'de serde_json::Value) -> serde_json::Result, + { + self.value + .get(field) + .ok_or_else(|| serde_json::Error::missing_field(field)) + .and_then(fun) } } @@ -83,138 +93,6 @@ pub trait ApiCategoryResponse: Send + Sync { fn from_response(response: ApiResponse) -> Self; } -#[async_trait] -pub trait ThreadSafeApiClient: Send + Sync { - type Error: std::error::Error + Sync + Send; - - async fn request(&self, url: String) -> Result; - - fn torn_api(&self, key: S) -> ThreadSafeApiProvider> - where - Self: Sized, - S: ToString, - { - ThreadSafeApiProvider::new(self, DirectExecutor::new(key.to_string())) - } -} - -#[async_trait(?Send)] -pub trait ApiClient { - type Error: std::error::Error; - - async fn request(&self, url: String) -> Result; - - fn torn_api(&self, key: S) -> ApiProvider> - where - Self: Sized, - S: ToString, - { - ApiProvider::new(self, DirectExecutor::new(key.to_string())) - } -} - -#[async_trait(?Send)] -pub trait RequestExecutor -where - C: ApiClient, -{ - type Error: std::error::Error; - - async fn execute(&self, client: &C, request: ApiRequest) -> Result - where - A: ApiCategoryResponse; -} - -#[async_trait] -pub trait ThreadSafeRequestExecutor -where - C: ThreadSafeApiClient, -{ - type Error: std::error::Error + Send + Sync; - - async fn execute(&self, client: &C, request: ApiRequest) -> Result - where - A: ApiCategoryResponse; -} - -pub struct ApiProvider<'a, C, E> -where - C: ApiClient, - E: RequestExecutor, -{ - client: &'a C, - executor: E, -} - -impl<'a, C, E> ApiProvider<'a, C, E> -where - C: ApiClient, - E: RequestExecutor, -{ - pub fn new(client: &'a C, executor: E) -> ApiProvider<'a, C, E> { - Self { client, executor } - } - - pub async fn user(&self, build: F) -> Result - where - F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, - { - let mut builder = ApiRequestBuilder::::new(); - builder = build(builder); - - self.executor.execute(self.client, builder.request).await - } - - pub async fn faction(&self, build: F) -> Result - where - F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, - { - let mut builder = ApiRequestBuilder::::new(); - builder = build(builder); - - self.executor.execute(self.client, builder.request).await - } -} - -pub struct ThreadSafeApiProvider<'a, C, E> -where - C: ThreadSafeApiClient, - E: ThreadSafeRequestExecutor, -{ - client: &'a C, - executor: E, -} - -impl<'a, C, E> ThreadSafeApiProvider<'a, C, E> -where - C: ThreadSafeApiClient, - E: ThreadSafeRequestExecutor, -{ - pub fn new(client: &'a C, executor: E) -> ThreadSafeApiProvider<'a, C, E> { - Self { client, executor } - } - - pub async fn user(&self, build: F) -> Result - where - F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, - { - let mut builder = ApiRequestBuilder::::new(); - builder = build(builder); - - self.executor.execute(self.client, builder.request).await - } - - pub async fn faction(&self, build: F) -> Result - where - F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, - { - let mut builder = ApiRequestBuilder::::new(); - builder = build(builder); - - self.executor.execute(self.client, builder.request).await - } -} - pub struct DirectExecutor { key: String, _marker: std::marker::PhantomData, @@ -224,7 +102,7 @@ impl DirectExecutor { fn new(key: String) -> Self { Self { key, - _marker: std::marker::PhantomData, + _marker: Default::default(), } } } @@ -241,51 +119,12 @@ where Response(#[from] ResponseError), } -#[async_trait(?Send)] -impl RequestExecutor for DirectExecutor -where - C: ApiClient, -{ - type Error = ApiClientError; - - async fn execute(&self, client: &C, request: ApiRequest) -> Result - where - A: ApiCategoryResponse, - { - let url = request.url(&self.key); - - let value = client.request(url).await.map_err(ApiClientError::Client)?; - - Ok(A::from_response(ApiResponse::from_value(value)?)) - } -} - -#[async_trait] -impl ThreadSafeRequestExecutor for DirectExecutor -where - C: ThreadSafeApiClient, -{ - type Error = ApiClientError; - - async fn execute(&self, client: &C, request: ApiRequest) -> Result - where - A: ApiCategoryResponse, - { - let url = request.url(&self.key); - - let value = client.request(url).await.map_err(ApiClientError::Client)?; - - Ok(A::from_response(ApiResponse::from_value(value)?)) - } -} - #[derive(Debug)] pub struct ApiRequest where A: ApiCategoryResponse, { selections: Vec<&'static str>, - id: Option, from: Option>, to: Option>, comment: Option, @@ -299,7 +138,6 @@ where fn default() -> Self { Self { selections: Vec::default(), - id: None, from: None, to: None, comment: None, @@ -312,37 +150,28 @@ impl ApiRequest where A: ApiCategoryResponse, { - pub fn url(&self, key: &str) -> String { - let mut query_fragments = vec![ - format!("selections={}", self.selections.join(",")), - format!("key={}", key), - ]; + pub fn url(&self, key: &str, id: Option) -> String { + let mut url = format!("https://api.torn.com/{}/", A::Selection::category()); + + if let Some(id) = id { + write!(url, "{}", id).unwrap(); + } + + write!(url, "?selections={}&key={}", self.selections.join(","), key).unwrap(); if let Some(from) = self.from { - query_fragments.push(format!("from={}", from.timestamp())); + write!(url, "&from={}", from.timestamp()).unwrap(); } if let Some(to) = self.to { - query_fragments.push(format!("to={}", to.timestamp())); + write!(url, "&to={}", to.timestamp()).unwrap(); } if let Some(comment) = &self.comment { - query_fragments.push(format!("comment={}", comment)); + write!(url, "&comment={}", comment).unwrap(); } - let query = query_fragments.join("&"); - - let id_fragment = match self.id { - Some(id) => id.to_string(), - None => "".to_owned(), - }; - - format!( - "https://api.torn.com/{}/{}?{}", - A::Selection::category(), - id_fragment, - query - ) + url } } @@ -363,15 +192,6 @@ where } } - #[must_use] - pub fn id(mut self, id: I) -> Self - where - I: PrimInt + AsPrimitive, - { - self.request.id = Some(id.as_()); - self - } - #[must_use] pub fn selections(mut self, selections: &[A::Selection]) -> Self { self.request @@ -399,8 +219,6 @@ where } } -pub mod prelude {} - #[cfg(test)] pub(crate) mod tests { use std::sync::Once; @@ -411,9 +229,9 @@ pub(crate) mod tests { pub use ::reqwest::Client; #[cfg(all(not(feature = "reqwest"), feature = "awc"))] - pub use crate::ApiClient as ClientTrait; + pub use crate::local::ApiClient as ClientTrait; #[cfg(feature = "reqwest")] - pub use crate::ThreadSafeApiClient as ClientTrait; + pub use crate::send::ApiClient as ClientTrait; #[cfg(all(not(feature = "reqwest"), feature = "awc"))] pub use actix_rt::test as async_test; @@ -441,7 +259,11 @@ pub(crate) mod tests { async fn reqwest() { let key = setup(); - Client::default().torn_api(key).user(|b| b).await.unwrap(); + Client::default() + .torn_api(key) + .user(None, |b| b) + .await + .unwrap(); } #[cfg(feature = "awc")] @@ -449,6 +271,10 @@ pub(crate) mod tests { async fn awc() { let key = setup(); - Client::default().torn_api(key).user(|b| b).await.unwrap(); + Client::default() + .torn_api(key) + .user(None, |b| b) + .await + .unwrap(); } } diff --git a/torn-api/src/local.rs b/torn-api/src/local.rs new file mode 100644 index 0000000..3bcdf52 --- /dev/null +++ b/torn-api/src/local.rs @@ -0,0 +1,235 @@ +use std::collections::HashMap; + +use async_trait::async_trait; + +use crate::{ + faction, torn, user, ApiCategoryResponse, ApiClientError, ApiRequest, ApiRequestBuilder, + ApiResponse, DirectExecutor, +}; + +pub struct ApiProvider<'a, C, E, I = i32> +where + C: ApiClient, + E: RequestExecutor, + I: num_traits::AsPrimitive, +{ + client: &'a C, + executor: E, + _marker: std::marker::PhantomData, +} + +impl<'a, C, E, I> ApiProvider<'a, C, E, I> +where + C: ApiClient, + E: RequestExecutor, + I: num_traits::AsPrimitive + std::hash::Hash + std::cmp::Eq, + i64: num_traits::AsPrimitive, +{ + pub fn new(client: &'a C, executor: E) -> ApiProvider<'a, C, E, I> { + Self { + client, + executor, + _marker: Default::default(), + } + } + + pub async fn user(&self, id: Option, build: F) -> Result + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute(self.client, builder.request, id.map(|i| i.as_())) + .await + } + + pub async fn users( + &self, + ids: L, + build: F, + ) -> HashMap> + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + L: IntoIterator, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute_many( + self.client, + builder.request, + ids.into_iter().map(|i| i.as_()).collect(), + ) + .await + .into_iter() + .map(|(i, r)| (num_traits::AsPrimitive::as_(i), r)) + .collect() + } + + pub async fn faction(&self, id: Option, build: F) -> Result + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute(self.client, builder.request, id.map(|i| i.as_())) + .await + } + + pub async fn factions( + &self, + ids: L, + build: F, + ) -> HashMap> + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + L: IntoIterator, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute_many( + self.client, + builder.request, + ids.into_iter().map(|i| i.as_()).collect(), + ) + .await + .into_iter() + .map(|(i, r)| (num_traits::AsPrimitive::as_(i), r)) + .collect() + } + + pub async fn torn(&self, id: Option, build: F) -> Result + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute(self.client, builder.request, id.map(|i| i.as_())) + .await + } + + pub async fn torns( + &self, + ids: L, + build: F, + ) -> HashMap> + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + L: IntoIterator, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute_many( + self.client, + builder.request, + ids.into_iter().map(|i| i.as_()).collect(), + ) + .await + .into_iter() + .map(|(i, r)| (num_traits::AsPrimitive::as_(i), r)) + .collect() + } +} + +#[async_trait(?Send)] +pub trait RequestExecutor +where + C: ApiClient, +{ + type Error: std::error::Error; + + async fn execute( + &self, + client: &C, + request: ApiRequest, + id: Option, + ) -> Result + where + A: ApiCategoryResponse; + + async fn execute_many( + &self, + client: &C, + request: ApiRequest, + ids: Vec, + ) -> HashMap> + where + A: ApiCategoryResponse; +} + +#[async_trait(?Send)] +impl RequestExecutor for DirectExecutor +where + C: ApiClient, +{ + type Error = ApiClientError; + + async fn execute( + &self, + client: &C, + request: ApiRequest, + id: Option, + ) -> Result + where + A: ApiCategoryResponse, + { + let url = request.url(&self.key, id); + + let value = client.request(url).await.map_err(ApiClientError::Client)?; + + Ok(A::from_response(ApiResponse::from_value(value)?)) + } + + async fn execute_many( + &self, + client: &C, + request: ApiRequest, + ids: Vec, + ) -> HashMap> + where + A: ApiCategoryResponse, + { + let request_ref = &request; + futures::future::join_all(ids.into_iter().map(|i| async move { + let url = request_ref.url(&self.key, Some(i)); + + let value = client.request(url).await.map_err(ApiClientError::Client); + + ( + i, + value + .and_then(|v| ApiResponse::from_value(v).map_err(Into::into)) + .map(A::from_response), + ) + })) + .await + .into_iter() + .collect() + } +} + +#[async_trait(?Send)] +pub trait ApiClient { + type Error: std::error::Error; + + async fn request(&self, url: String) -> Result; + + fn torn_api(&self, key: S) -> ApiProvider> + where + Self: Sized, + S: ToString, + { + ApiProvider::new(self, DirectExecutor::new(key.to_string())) + } +} diff --git a/torn-api/src/reqwest.rs b/torn-api/src/reqwest.rs index ff84edd..99a64c3 100644 --- a/torn-api/src/reqwest.rs +++ b/torn-api/src/reqwest.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; -use crate::ThreadSafeApiClient; +use crate::send::ApiClient; #[async_trait] -impl ThreadSafeApiClient for reqwest::Client { +impl ApiClient for reqwest::Client { type Error = reqwest::Error; async fn request(&self, url: String) -> Result { diff --git a/torn-api/src/send.rs b/torn-api/src/send.rs new file mode 100644 index 0000000..49e57f0 --- /dev/null +++ b/torn-api/src/send.rs @@ -0,0 +1,235 @@ +use std::collections::HashMap; + +use async_trait::async_trait; + +use crate::{ + faction, torn, user, ApiCategoryResponse, ApiClientError, ApiRequest, ApiRequestBuilder, + ApiResponse, DirectExecutor, +}; + +pub struct ApiProvider<'a, C, E, I = i32> +where + C: ApiClient, + E: RequestExecutor, + I: num_traits::AsPrimitive, +{ + client: &'a C, + executor: E, + _marker: std::marker::PhantomData, +} + +impl<'a, C, E, I> ApiProvider<'a, C, E, I> +where + C: ApiClient, + E: RequestExecutor, + I: num_traits::AsPrimitive + std::hash::Hash + std::cmp::Eq, + i64: num_traits::AsPrimitive, +{ + pub fn new(client: &'a C, executor: E) -> ApiProvider<'a, C, E, I> { + Self { + client, + executor, + _marker: Default::default(), + } + } + + pub async fn user(&self, id: Option, build: F) -> Result + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute(self.client, builder.request, id.map(|i| i.as_())) + .await + } + + pub async fn users( + &self, + ids: L, + build: F, + ) -> HashMap> + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + L: IntoIterator, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute_many( + self.client, + builder.request, + ids.into_iter().map(|i| i.as_()).collect(), + ) + .await + .into_iter() + .map(|(i, r)| (num_traits::AsPrimitive::as_(i), r)) + .collect() + } + + pub async fn faction(&self, id: Option, build: F) -> Result + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute(self.client, builder.request, id.map(|i| i.as_())) + .await + } + + pub async fn factions( + &self, + ids: L, + build: F, + ) -> HashMap> + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + L: IntoIterator, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute_many( + self.client, + builder.request, + ids.into_iter().map(|i| i.as_()).collect(), + ) + .await + .into_iter() + .map(|(i, r)| (num_traits::AsPrimitive::as_(i), r)) + .collect() + } + + pub async fn torn(&self, id: Option, build: F) -> Result + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute(self.client, builder.request, id.map(|i| i.as_())) + .await + } + + pub async fn torns( + &self, + ids: L, + build: F, + ) -> HashMap> + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + L: IntoIterator, + { + let mut builder = ApiRequestBuilder::new(); + builder = build(builder); + + self.executor + .execute_many( + self.client, + builder.request, + ids.into_iter().map(|i| i.as_()).collect(), + ) + .await + .into_iter() + .map(|(i, r)| (num_traits::AsPrimitive::as_(i), r)) + .collect() + } +} + +#[async_trait] +pub trait RequestExecutor +where + C: ApiClient, +{ + type Error: std::error::Error + Send + Sync; + + async fn execute( + &self, + client: &C, + request: ApiRequest, + id: Option, + ) -> Result + where + A: ApiCategoryResponse; + + async fn execute_many( + &self, + client: &C, + request: ApiRequest, + ids: Vec, + ) -> HashMap> + where + A: ApiCategoryResponse; +} + +#[async_trait] +impl RequestExecutor for DirectExecutor +where + C: ApiClient, +{ + type Error = ApiClientError; + + async fn execute( + &self, + client: &C, + request: ApiRequest, + id: Option, + ) -> Result + where + A: ApiCategoryResponse, + { + let url = request.url(&self.key, id); + + let value = client.request(url).await.map_err(ApiClientError::Client)?; + + Ok(A::from_response(ApiResponse::from_value(value)?)) + } + + async fn execute_many( + &self, + client: &C, + request: ApiRequest, + ids: Vec, + ) -> HashMap> + where + A: ApiCategoryResponse, + { + let request_ref = &request; + futures::future::join_all(ids.into_iter().map(|i| async move { + let url = request_ref.url(&self.key, Some(i)); + + let value = client.request(url).await.map_err(ApiClientError::Client); + + ( + i, + value + .and_then(|v| ApiResponse::from_value(v).map_err(Into::into)) + .map(A::from_response), + ) + })) + .await + .into_iter() + .collect() + } +} + +#[async_trait] +pub trait ApiClient: Send + Sync { + type Error: std::error::Error + Sync + Send; + + async fn request(&self, url: String) -> Result; + + fn torn_api(&self, key: S) -> ApiProvider> + where + Self: Sized, + S: ToString, + { + ApiProvider::new(self, DirectExecutor::new(key.to_string())) + } +} diff --git a/torn-api/src/torn.rs b/torn-api/src/torn.rs new file mode 100644 index 0000000..7c43202 --- /dev/null +++ b/torn-api/src/torn.rs @@ -0,0 +1,100 @@ +use serde::{ + de::{self, MapAccess, Visitor}, + Deserialize, +}; + +use torn_api_macros::ApiCategory; + +use crate::user; + +#[derive(Debug, Clone, Copy, ApiCategory)] +#[api(category = "torn")] +pub enum Selection { + #[api( + field = "competition", + with = "decode_competition", + type = "Option" + )] + Competition, +} + +#[derive(Deserialize)] +pub struct EliminationLeaderboard { + pub position: i16, + pub team: user::EliminationTeam, + pub score: i16, + pub lives: i16, + pub participants: i16, + pub wins: i32, + pub losses: i32, +} + +pub enum Competition { + Elimination { teams: Vec }, +} + +fn decode_competition<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + struct CompetitionVisitor; + + impl<'de> Visitor<'de> for CompetitionVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct Competition") + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut name = None; + let mut teams = None; + + while let Some(key) = map.next_key()? { + match key { + "name" => { + name = Some(map.next_value()?); + } + "teams" => { + teams = Some(map.next_value()?); + } + _ => (), + }; + } + + let name = name.ok_or_else(|| de::Error::missing_field("name"))?; + + match name { + "Elimination" => Ok(Some(Competition::Elimination { + teams: teams.ok_or_else(|| de::Error::missing_field("teams"))?, + })), + "" => Ok(None), + v => Err(de::Error::unknown_variant(v, &["Elimination", ""])), + } + } + } + + deserializer.deserialize_map(CompetitionVisitor) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{async_test, setup, Client, ClientTrait}; + + #[async_test] + async fn competition() { + let key = setup(); + + let response = Client::default() + .torn_api(key) + .torn(None, |b| b.selections(&[Selection::Competition])) + .await + .unwrap(); + + response.competition().unwrap(); + } +} diff --git a/torn-api/src/user.rs b/torn-api/src/user.rs index 3f2b840..0c52540 100644 --- a/torn-api/src/user.rs +++ b/torn-api/src/user.rs @@ -195,7 +195,7 @@ pub enum EliminationTeam { DirtyCops, LaughingStock, JeanTherapy, - #[serde(rename = "statants-soldiers")] + #[serde(rename = "satants-soldiers")] SatansSoldiers, WolfPack, Sleepyheads, @@ -399,7 +399,7 @@ mod tests { let response = Client::default() .torn_api(key) - .user(|b| { + .user(None, |b| { b.selections(&[ Selection::Basic, Selection::Discord, @@ -424,7 +424,7 @@ mod tests { let response = Client::default() .torn_api(key) - .user(|b| b.id(28).selections(&[Selection::Profile])) + .user(Some(28), |b| b.selections(&[Selection::Profile])) .await .unwrap(); @@ -439,11 +439,24 @@ mod tests { let response = Client::default() .torn_api(key) - .user(|b| b.selections(&[Selection::Profile])) + .user(None, |b| b.selections(&[Selection::Profile])) .await .unwrap(); let profile = response.profile().unwrap(); assert!(profile.competition.is_some()); } + + #[async_test] + async fn bulk() { + let key = setup(); + + let response = Client::default() + .torn_api(key) + .users([1, 2111649], |b| b.selections(&[Selection::Basic])) + .await; + + response.get(&1).as_ref().unwrap().as_ref().unwrap(); + response.get(&2111649).as_ref().unwrap().as_ref().unwrap(); + } } diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index 462f966..6129dff 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "0.3.1" +version = "0.4.0" edition = "2021" license = "MIT" repository = "https://github.com/TotallyNot/torn-api.rs.git" @@ -18,7 +18,7 @@ tokio-runtime = [ "dep:tokio", "dep:rand" ] actix-runtime = [ "dep:actix-rt", "dep:rand" ] [dependencies] -torn-api = { path = "../torn-api", default-features = false, version = "0.4" } +torn-api = { path = "../torn-api", default-features = false, version = "0.5" } async-trait = "0.1" thiserror = "1" @@ -28,6 +28,7 @@ indoc = { version = "1", optional = true } tokio = { version = "1", optional = true, default-features = false, features = ["time"] } actix-rt = { version = "2", optional = true, default-features = false } rand = { version = "0.8", optional = true } +futures = "0.3" reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true } awc = { version = "3", default-features = false, optional = true } @@ -40,4 +41,3 @@ tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] } tokio-test = "0.4.2" reqwest = { version = "0.11", default-features = true } awc = { version = "3", features = [ "rustls" ] } -futures = "0.3.24" diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index 8eab607..fff2ec2 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -3,13 +3,15 @@ #[cfg(feature = "postgres")] pub mod postgres; +pub mod local; +pub mod send; + +use std::sync::Arc; + use async_trait::async_trait; use thiserror::Error; -use torn_api::{ - ApiCategoryResponse, ApiClient, ApiProvider, ApiRequest, ApiResponse, RequestExecutor, - ResponseError, ThreadSafeApiClient, ThreadSafeApiProvider, ThreadSafeRequestExecutor, -}; +use torn_api::ResponseError; #[derive(Debug, Error)] pub enum KeyPoolError @@ -18,7 +20,7 @@ where C: std::error::Error, { #[error("Key pool storage driver error: {0:?}")] - Storage(#[source] S), + Storage(#[source] Arc), #[error(transparent)] Client(#[from] C), @@ -45,6 +47,12 @@ pub trait KeyPoolStorage { async fn acquire_key(&self, domain: KeyDomain) -> Result; + async fn acquire_many_keys( + &self, + domain: KeyDomain, + number: i64, + ) -> Result, Self::Error>; + async fn flag_key(&self, key: Self::Key, code: u8) -> Result; } @@ -70,161 +78,3 @@ where } } } - -#[async_trait(?Send)] -impl<'client, C, S> RequestExecutor for KeyPoolExecutor<'client, C, S> -where - C: ApiClient, - S: KeyPoolStorage + 'static, -{ - type Error = KeyPoolError; - - async fn execute(&self, client: &C, request: ApiRequest) -> Result - where - A: ApiCategoryResponse, - { - loop { - let key = self - .storage - .acquire_key(self.domain) - .await - .map_err(KeyPoolError::Storage)?; - let url = request.url(key.value()); - let value = client.request(url).await?; - - match ApiResponse::from_value(value) { - Err(ResponseError::Api { code, reason }) => { - if !self - .storage - .flag_key(key, code) - .await - .map_err(KeyPoolError::Storage)? - { - return Err(KeyPoolError::Response(ResponseError::Api { code, reason })); - } - } - Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)), - Ok(res) => return Ok(A::from_response(res)), - }; - } - } -} - -#[async_trait] -impl<'client, C, S> ThreadSafeRequestExecutor for KeyPoolExecutor<'client, C, S> -where - C: ThreadSafeApiClient, - S: KeyPoolStorage + Send + Sync + 'static, -{ - type Error = KeyPoolError; - - async fn execute(&self, client: &C, request: ApiRequest) -> Result - where - A: ApiCategoryResponse, - { - loop { - let key = self - .storage - .acquire_key(self.domain) - .await - .map_err(KeyPoolError::Storage)?; - let url = request.url(key.value()); - let value = client.request(url).await?; - - match ApiResponse::from_value(value) { - Err(ResponseError::Api { code, reason }) => { - if !self - .storage - .flag_key(key, code) - .await - .map_err(KeyPoolError::Storage)? - { - return Err(KeyPoolError::Response(ResponseError::Api { code, reason })); - } - } - Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)), - Ok(res) => return Ok(A::from_response(res)), - }; - } - } -} - -#[derive(Clone, Debug)] -pub struct KeyPool -where - C: ApiClient, - S: KeyPoolStorage, -{ - client: C, - storage: S, -} - -impl KeyPool -where - C: ApiClient, - S: KeyPoolStorage + 'static, -{ - pub fn new(client: C, storage: S) -> Self { - Self { client, storage } - } - - pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider> { - ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain)) - } -} - -#[derive(Clone, Debug)] -pub struct ThreadSafeKeyPool -where - C: ThreadSafeApiClient, - S: KeyPoolStorage + Send + Sync + 'static, -{ - client: C, - storage: S, -} - -impl ThreadSafeKeyPool -where - C: ThreadSafeApiClient, - S: KeyPoolStorage + Send + Sync + 'static, -{ - pub fn new(client: C, storage: S) -> Self { - Self { client, storage } - } - - pub fn torn_api(&self, domain: KeyDomain) -> ThreadSafeApiProvider> { - ThreadSafeApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain)) - } -} - -pub trait WithStorage { - fn with_storage<'a, S>( - &'a self, - storage: &'a S, - domain: KeyDomain, - ) -> ApiProvider> - where - Self: ApiClient + Sized, - S: KeyPoolStorage + 'static, - { - ApiProvider::new(self, KeyPoolExecutor::new(storage, domain)) - } - - fn with_storage_sync<'a, S>( - &'a self, - storage: &'a S, - domain: KeyDomain, - ) -> ThreadSafeApiProvider> - where - Self: ThreadSafeApiClient + Sized, - S: KeyPoolStorage + Send + Sync + 'static, - { - ThreadSafeApiProvider::new(self, KeyPoolExecutor::new(storage, domain)) - } -} - -#[cfg(feature = "reqwest")] -impl WithStorage for reqwest::Client {} - -#[cfg(feature = "awc")] -impl WithStorage for awc::Client {} diff --git a/torn-key-pool/src/local.rs b/torn-key-pool/src/local.rs new file mode 100644 index 0000000..a6e0173 --- /dev/null +++ b/torn-key-pool/src/local.rs @@ -0,0 +1,161 @@ +use std::{collections::HashMap, sync::Arc}; + +use async_trait::async_trait; + +use torn_api::{ + local::{ApiClient, ApiProvider, RequestExecutor}, + ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError, +}; + +use crate::{ApiKey, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; + +#[async_trait(?Send)] +impl<'client, C, S> RequestExecutor for KeyPoolExecutor<'client, C, S> +where + C: ApiClient, + S: KeyPoolStorage + 'static, +{ + type Error = KeyPoolError; + + async fn execute( + &self, + client: &C, + request: ApiRequest, + id: Option, + ) -> Result + where + A: ApiCategoryResponse, + { + loop { + let key = self + .storage + .acquire_key(self.domain) + .await + .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; + let url = request.url(key.value(), id); + let value = client.request(url).await?; + + match ApiResponse::from_value(value) { + Err(ResponseError::Api { code, reason }) => { + if !self + .storage + .flag_key(key, code) + .await + .map_err(Arc::new) + .map_err(KeyPoolError::Storage)? + { + return Err(KeyPoolError::Response(ResponseError::Api { code, reason })); + } + } + Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)), + Ok(res) => return Ok(A::from_response(res)), + }; + } + } + + async fn execute_many( + &self, + client: &C, + request: ApiRequest, + ids: Vec, + ) -> HashMap> + where + A: ApiCategoryResponse, + { + let keys = match self + .storage + .acquire_many_keys(self.domain, ids.len() as i64) + .await + { + Ok(keys) => keys, + Err(why) => { + let shared = Arc::new(why); + return ids + .into_iter() + .map(|i| (i, Err(Self::Error::Storage(shared.clone())))) + .collect(); + } + }; + + let request_ref = &request; + + futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move { + loop { + let url = request_ref.url(key.value(), Some(id)); + let value = match client.request(url).await { + Ok(v) => v, + Err(why) => return (id, Err(Self::Error::Client(why))), + }; + + match ApiResponse::from_value(value) { + Err(ResponseError::Api { code, reason }) => { + match self.storage.flag_key(key, code).await { + Ok(false) => { + return ( + id, + Err(KeyPoolError::Response(ResponseError::Api { + code, + reason, + })), + ) + } + Ok(true) => (), + Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))), + } + } + Err(parsing_error) => return (id, Err(KeyPoolError::Response(parsing_error))), + Ok(res) => return (id, Ok(A::from_response(res))), + }; + + key = match self.storage.acquire_key(self.domain).await { + Ok(k) => k, + Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), + }; + } + })) + .await + .into_iter() + .collect() + } +} + +#[derive(Clone, Debug)] +pub struct KeyPool +where + C: ApiClient, + S: KeyPoolStorage, +{ + client: C, + storage: S, +} + +impl KeyPool +where + C: ApiClient, + S: KeyPoolStorage + 'static, +{ + pub fn new(client: C, storage: S) -> Self { + Self { client, storage } + } + + pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider> { + ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain)) + } +} + +pub trait WithStorage { + fn with_storage<'a, S>( + &'a self, + storage: &'a S, + domain: KeyDomain, + ) -> ApiProvider> + where + Self: ApiClient + Sized, + S: KeyPoolStorage + 'static, + { + ApiProvider::new(self, KeyPoolExecutor::new(storage, domain)) + } +} + +#[cfg(feature = "awc")] +impl WithStorage for awc::Client {} diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 36d80d5..e644c5d 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -4,7 +4,7 @@ use indoc::indoc; use sqlx::{FromRow, PgPool}; use thiserror::Error; -use crate::{ApiKey, KeyDomain, KeyPool, KeyPoolStorage}; +use crate::{ApiKey, KeyDomain, KeyPoolStorage}; #[derive(Debug, Error)] pub enum PgStorageError { @@ -102,18 +102,12 @@ impl KeyPoolStorage for PgKeyPoolStorage { with key as ( select id, - user_id, - faction_id, - key, case when extract(minute from last_used)=extract(minute from now()) then uses else 0::smallint - end as uses, - user, - faction, - last_used + end as uses from api_keys {} - order by last_used asc limit 1 FOR UPDATE + order by last_used asc limit 1 ) update api_keys set uses = key.uses + 1, @@ -162,6 +156,70 @@ impl KeyPoolStorage for PgKeyPoolStorage { } } + async fn acquire_many_keys( + &self, + domain: KeyDomain, + number: i64, + ) -> Result, Self::Error> { + let predicate = match domain { + KeyDomain::Public => "".to_owned(), + KeyDomain::User(id) => format!("where and user_id={} and user", id), + KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id), + }; + + let mut tx = self.pool.begin().await?; + + let mut keys: Vec = sqlx::query_as(&indoc::formatdoc!( + r#" + select + id, + user_id, + faction_id, + key, + case + when extract(minute from last_used)=extract(minute from now()) then uses + else 0::smallint + end as uses, + "user", + faction, + last_used + from api_keys {} order by last_used limit $1 for update + "#, + predicate + )) + .bind(number) + .fetch_all(&mut tx) + .await?; + + let mut result = Vec::with_capacity(number as usize); + 'outer: for _ in 0..(((number as usize) / keys.len()) + 1) { + for key in &mut keys { + if key.uses == self.limit || result.len() == (number as usize) { + break 'outer; + } else { + key.uses += 1; + result.push(key.clone()); + } + } + } + + sqlx::query(indoc! {r#" + update api_keys set + uses = tmp.uses, + last_used = now() + from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp + where api_keys.id = tmp.id + "#}) + .bind(keys.iter().map(|k| k.id).collect::>()) + .bind(keys.iter().map(|k| k.uses).collect::>()) + .execute(&mut tx) + .await?; + + tx.commit().await?; + + Ok(result) + } + async fn flag_key(&self, key: Self::Key, code: u8) -> Result { // TODO: put keys in cooldown when appropriate match code { @@ -177,27 +235,6 @@ impl KeyPoolStorage for PgKeyPoolStorage { } } -pub type PgKeyPool = KeyPool; - -impl PgKeyPool -where - A: torn_api::ApiClient, -{ - pub async fn connect( - client: A, - database_url: &str, - limit: i16, - ) -> Result { - let db_pool = PgPool::connect(database_url).await?; - let storage = PgKeyPoolStorage::new(db_pool, limit); - storage.initialise().await?; - - let key_pool = Self::new(client, storage); - - Ok(key_pool) - } -} - #[cfg(test)] mod test { use std::sync::{Arc, Once}; @@ -253,13 +290,12 @@ mod test { .unwrap() .get("uses"); - let futures = (0..30).into_iter().map(|_| { - let storage = storage.clone(); - async move { - storage.acquire_key(KeyDomain::Public).await.unwrap(); - } - }); - futures::future::join_all(futures).await; + let keys = storage + .acquire_many_keys(KeyDomain::Public, 30) + .await + .unwrap(); + + assert_eq!(keys.len(), 30); let after: i16 = sqlx::query("select uses from api_keys") .fetch_one(&storage.pool) diff --git a/torn-key-pool/src/send.rs b/torn-key-pool/src/send.rs new file mode 100644 index 0000000..9581f6b --- /dev/null +++ b/torn-key-pool/src/send.rs @@ -0,0 +1,161 @@ +use std::{collections::HashMap, sync::Arc}; + +use async_trait::async_trait; + +use torn_api::{ + send::{ApiClient, ApiProvider, RequestExecutor}, + ApiCategoryResponse, ApiRequest, ApiResponse, ResponseError, +}; + +use crate::{ApiKey, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage}; + +#[async_trait] +impl<'client, C, S> RequestExecutor for KeyPoolExecutor<'client, C, S> +where + C: ApiClient, + S: KeyPoolStorage + Send + Sync + 'static, +{ + type Error = KeyPoolError; + + async fn execute( + &self, + client: &C, + request: ApiRequest, + id: Option, + ) -> Result + where + A: ApiCategoryResponse, + { + loop { + let key = self + .storage + .acquire_key(self.domain) + .await + .map_err(|e| KeyPoolError::Storage(Arc::new(e)))?; + let url = request.url(key.value(), id); + let value = client.request(url).await?; + + match ApiResponse::from_value(value) { + Err(ResponseError::Api { code, reason }) => { + if !self + .storage + .flag_key(key, code) + .await + .map_err(Arc::new) + .map_err(KeyPoolError::Storage)? + { + return Err(KeyPoolError::Response(ResponseError::Api { code, reason })); + } + } + Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)), + Ok(res) => return Ok(A::from_response(res)), + }; + } + } + + async fn execute_many( + &self, + client: &C, + request: ApiRequest, + ids: Vec, + ) -> HashMap> + where + A: ApiCategoryResponse, + { + let keys = match self + .storage + .acquire_many_keys(self.domain, ids.len() as i64) + .await + { + Ok(keys) => keys, + Err(why) => { + let shared = Arc::new(why); + return ids + .into_iter() + .map(|i| (i, Err(Self::Error::Storage(shared.clone())))) + .collect(); + } + }; + + let request_ref = &request; + + futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move { + loop { + let url = request_ref.url(key.value(), Some(id)); + let value = match client.request(url).await { + Ok(v) => v, + Err(why) => return (id, Err(Self::Error::Client(why))), + }; + + match ApiResponse::from_value(value) { + Err(ResponseError::Api { code, reason }) => { + match self.storage.flag_key(key, code).await { + Ok(false) => { + return ( + id, + Err(KeyPoolError::Response(ResponseError::Api { + code, + reason, + })), + ) + } + Ok(true) => (), + Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))), + } + } + Err(parsing_error) => return (id, Err(KeyPoolError::Response(parsing_error))), + Ok(res) => return (id, Ok(A::from_response(res))), + }; + + key = match self.storage.acquire_key(self.domain).await { + Ok(k) => k, + Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))), + }; + } + })) + .await + .into_iter() + .collect() + } +} + +#[derive(Clone, Debug)] +pub struct KeyPool +where + C: ApiClient, + S: KeyPoolStorage, +{ + client: C, + storage: S, +} + +impl KeyPool +where + C: ApiClient, + S: KeyPoolStorage + Send + Sync + 'static, +{ + pub fn new(client: C, storage: S) -> Self { + Self { client, storage } + } + + pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider> { + ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain)) + } +} + +pub trait WithStorage { + fn with_storage<'a, S>( + &'a self, + storage: &'a S, + domain: KeyDomain, + ) -> ApiProvider> + where + Self: ApiClient + Sized, + S: KeyPoolStorage + Send + Sync + 'static, + { + ApiProvider::new(self, KeyPoolExecutor::new(storage, domain)) + } +} + +#[cfg(feature = "reqwest")] +impl WithStorage for reqwest::Client {}