diff --git a/torn-api/Cargo.toml b/torn-api/Cargo.toml index 358b02d..2065c1d 100644 --- a/torn-api/Cargo.toml +++ b/torn-api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-api" -version = "0.2.1" +version = "0.3.0" edition = "2021" [features] diff --git a/torn-api/src/awc.rs b/torn-api/src/awc.rs new file mode 100644 index 0000000..40f435b --- /dev/null +++ b/torn-api/src/awc.rs @@ -0,0 +1,22 @@ +use async_trait::async_trait; +use thiserror::Error; + +use crate::ApiClient; + +#[derive(Error, Debug)] +pub enum AwcApiClientError { + #[error(transparent)] + Client(#[from] awc::error::SendRequestError), + + #[error(transparent)] + Payload(#[from] awc::error::JsonPayloadError), +} + +#[async_trait(?Send)] +impl ApiClient for awc::Client { + type Error = AwcApiClientError; + + async fn request(&self, url: String) -> Result { + self.get(url).send().await?.json().await.map_err(Into::into) + } +} diff --git a/torn-api/src/faction.rs b/torn-api/src/faction.rs index c36a329..ff5753f 100644 --- a/torn-api/src/faction.rs +++ b/torn-api/src/faction.rs @@ -39,10 +39,7 @@ pub struct Basic { #[cfg(test)] mod tests { use super::*; - use crate::{ - prelude::*, - tests::{async_test, setup, Client}, - }; + use crate::tests::{async_test, setup, Client, ClientTrait}; #[async_test] async fn faction() { @@ -50,9 +47,7 @@ mod tests { let response = Client::default() .torn_api(key) - .faction() - .selections(&[Selection::Basic]) - .send() + .faction(|b| b.selections(&[Selection::Basic])) .await .unwrap(); diff --git a/torn-api/src/lib.rs b/torn-api/src/lib.rs index cb77ddd..8d129bd 100644 --- a/torn-api/src/lib.rs +++ b/torn-api/src/lib.rs @@ -3,6 +3,12 @@ pub mod faction; pub mod user; +#[cfg(feature = "awc")] +pub mod awc; + +#[cfg(feature = "reqwest")] +pub mod reqwest; + mod de_util; use async_trait::async_trait; @@ -10,33 +16,21 @@ use chrono::{DateTime, Utc}; use serde::de::{DeserializeOwned, Error as DeError}; use thiserror::Error; -#[derive(Error, Debug)] -pub enum ClientError { - #[error("api returned error '{reason}', code = '{code}'")] - Api { code: u8, reason: String }, - - #[cfg(feature = "reqwest")] - #[error("api request failed with network error")] - Reqwest(#[from] reqwest::Error), - - #[cfg(feature = "awc")] - #[error("api request failed with network error")] - AwcSend(#[from] awc::error::SendRequestError), - - #[cfg(feature = "awc")] - #[error("api request failed to read payload")] - AwcPayload(#[from] awc::error::JsonPayloadError), - - #[error("api response couldn't be deserialized")] - Deserialize(#[from] serde_json::Error), -} - pub struct ApiResponse { value: serde_json::Value, } +#[derive(Error, Debug)] +pub enum ResponseError { + #[error("API: {reason}")] + Api { code: u8, reason: String }, + + #[error(transparent)] + Parsing(#[from] serde_json::Error), +} + impl ApiResponse { - fn from_value(mut value: serde_json::Value) -> Result { + pub fn from_value(mut value: serde_json::Value) -> Result { #[derive(serde::Deserialize)] struct ApiErrorDto { code: u8, @@ -46,7 +40,7 @@ impl ApiResponse { match value.get_mut("error") { Some(error) => { let dto: ApiErrorDto = serde_json::from_value(error.take())?; - Err(ClientError::Api { + Err(ResponseError::Api { code: dto.code, reason: dto.reason, }) @@ -88,111 +82,199 @@ pub trait ApiCategoryResponse: Send + Sync { fn from_response(response: ApiResponse) -> Self; } -#[cfg(feature = "awc")] -#[async_trait(?Send)] -pub trait ApiClient { - async fn request(&self, url: String) -> Result; -} - -#[cfg(not(feature = "awc"))] #[async_trait] -pub trait ApiClient: Send + Sync { - async fn request(&self, url: String) -> Result; -} +pub trait ThreadSafeApiClient: Send + Sync { + type Error: std::error::Error + Sync + Send; -pub trait DirectApiClient: ApiClient { - fn torn_api(&self, key: String) -> DirectExecutor + async fn request(&self, url: String) -> Result; + + fn torn_api(&self, key: S) -> ThreadSafeApiProvider> where Self: Sized, + S: ToString, { - DirectExecutor::from_client(self, key) + ThreadSafeApiProvider::new(self, DirectExecutor::new(key.to_string())) } } -pub trait BackedApiClient: ApiClient {} - -#[cfg(feature = "reqwest")] -#[cfg_attr(feature = "awc", async_trait(?Send))] -#[cfg_attr(not(feature = "awc"), async_trait)] -impl crate::ApiClient for reqwest::Client { - async fn request(&self, url: String) -> Result { - let value: serde_json::Value = self.get(url).send().await?.json().await?; - Ok(ApiResponse::from_value(value)?) - } -} - -#[cfg(feature = "reqwest")] -impl crate::DirectApiClient for reqwest::Client {} - -#[cfg(feature = "awc")] #[async_trait(?Send)] -impl crate::ApiClient for awc::Client { - async fn request(&self, url: String) -> Result { - let value: serde_json::Value = self.get(url).send().await?.json().await?; - Ok(ApiResponse::from_value(value)?) +pub trait ApiClient { + type Error: std::error::Error; + + async fn request(&self, url: String) -> Result; + + fn torn_api(&self, key: S) -> ApiProvider> + where + Self: Sized, + S: ToString, + { + ApiProvider::new(self, DirectExecutor::new(key.to_string())) } } -#[cfg(feature = "awc")] -impl crate::DirectApiClient for awc::Client {} +#[async_trait(?Send)] +pub trait RequestExecutor +where + C: ApiClient, +{ + type Error: std::error::Error; -#[cfg_attr(feature = "awc", async_trait(?Send))] -#[cfg_attr(not(feature = "awc"), async_trait)] -pub trait ApiRequestExecutor<'client> { - type Err: std::error::Error; - - async fn excute(&self, request: ApiRequest) -> Result + async fn execute(&self, client: &C, request: ApiRequest) -> Result where A: ApiCategoryResponse; +} - #[must_use] - fn user<'executor>( - &'executor self, - ) -> ApiRequestBuilder<'client, 'executor, Self, user::Response> { - ApiRequestBuilder::new(self) +#[async_trait] +pub trait ThreadSafeRequestExecutor +where + C: ThreadSafeApiClient, +{ + type Error: std::error::Error + Send + Sync; + + async fn execute(&self, client: &C, request: ApiRequest) -> Result + where + A: ApiCategoryResponse; +} + +pub struct ApiProvider<'a, C, E> +where + C: ApiClient, + E: RequestExecutor, +{ + client: &'a C, + executor: E, +} + +impl<'a, C, E> ApiProvider<'a, C, E> +where + C: ApiClient, + E: RequestExecutor, +{ + pub fn new(client: &'a C, executor: E) -> ApiProvider<'a, C, E> { + Self { client, executor } } - #[must_use] - fn faction<'executor>( - &'executor self, - ) -> ApiRequestBuilder<'client, 'executor, Self, faction::Response> { - ApiRequestBuilder::new(self) + pub async fn user(&self, build: F) -> Result + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + { + let mut builder = ApiRequestBuilder::::new(); + builder = build(builder); + + self.executor.execute(self.client, builder.request).await + } + + pub async fn faction(&self, build: F) -> Result + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + { + let mut builder = ApiRequestBuilder::::new(); + builder = build(builder); + + self.executor.execute(self.client, builder.request).await } } -pub struct DirectExecutor<'client, C> +pub struct ThreadSafeApiProvider<'a, C, E> where - C: ApiClient, + C: ThreadSafeApiClient, + E: ThreadSafeRequestExecutor, { - client: &'client C, + client: &'a C, + executor: E, +} + +impl<'a, C, E> ThreadSafeApiProvider<'a, C, E> +where + C: ThreadSafeApiClient, + E: ThreadSafeRequestExecutor, +{ + pub fn new(client: &'a C, executor: E) -> ThreadSafeApiProvider<'a, C, E> { + Self { client, executor } + } + + pub async fn user(&self, build: F) -> Result + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + { + let mut builder = ApiRequestBuilder::::new(); + builder = build(builder); + + self.executor.execute(self.client, builder.request).await + } + + pub async fn faction(&self, build: F) -> Result + where + F: FnOnce(ApiRequestBuilder) -> ApiRequestBuilder, + { + let mut builder = ApiRequestBuilder::::new(); + builder = build(builder); + + self.executor.execute(self.client, builder.request).await + } +} + +pub struct DirectExecutor { key: String, + _marker: std::marker::PhantomData, } -impl<'client, C> DirectExecutor<'client, C> -where - C: ApiClient, -{ - #[allow(dead_code)] - pub(crate) fn from_client(client: &'client C, key: String) -> Self { - Self { client, key } +impl DirectExecutor { + fn new(key: String) -> Self { + Self { + key, + _marker: std::marker::PhantomData, + } } } -#[cfg_attr(feature = "awc", async_trait(?Send))] -#[cfg_attr(not(feature = "awc"), async_trait)] -impl<'client, C> ApiRequestExecutor<'client> for DirectExecutor<'client, C> +#[derive(Error, Debug)] +pub enum ApiClientError +where + C: std::error::Error, +{ + #[error(transparent)] + Client(C), + + #[error(transparent)] + Response(#[from] ResponseError), +} + +#[async_trait(?Send)] +impl RequestExecutor for DirectExecutor where C: ApiClient, { - type Err = ClientError; + type Error = ApiClientError; - async fn excute(&self, request: ApiRequest) -> Result + async fn execute(&self, client: &C, request: ApiRequest) -> Result where A: ApiCategoryResponse, { let url = request.url(&self.key); - self.client.request(url).await.map(A::from_response) + let value = client.request(url).await.map_err(ApiClientError::Client)?; + + Ok(A::from_response(ApiResponse::from_value(value)?)) + } +} + +#[async_trait] +impl ThreadSafeRequestExecutor for DirectExecutor +where + C: ThreadSafeApiClient, +{ + type Error = ApiClientError; + + async fn execute(&self, client: &C, request: ApiRequest) -> Result + where + A: ApiCategoryResponse, + { + let url = request.url(&self.key); + + let value = client.request(url).await.map_err(ApiClientError::Client)?; + + Ok(A::from_response(ApiResponse::from_value(value)?)) } } @@ -263,26 +345,20 @@ where } } -pub struct ApiRequestBuilder<'client, 'executor, E, A> +pub struct ApiRequestBuilder where - E: ApiRequestExecutor<'client> + ?Sized, A: ApiCategoryResponse, { - executor: &'executor E, request: ApiRequest, - _phantom: std::marker::PhantomData<&'client E>, } -impl<'client, 'executor, E, A> ApiRequestBuilder<'client, 'executor, E, A> +impl ApiRequestBuilder where - E: ApiRequestExecutor<'client> + ?Sized, A: ApiCategoryResponse, { - pub(crate) fn new(executor: &'executor E) -> Self { + pub(crate) fn new() -> Self { Self { - executor, request: ApiRequest::default(), - _phantom: std::marker::PhantomData::default(), } } @@ -317,49 +393,23 @@ where self.request.comment = Some(comment); self } - - /// Executes the api request. - /// - /// # Examples - /// - /// ```no_run - /// use torn_api::{prelude::*, ClientError}; - /// use reqwest::Client; - /// # async { - /// - /// let key = "XXXXXXXXX".to_owned(); - /// let response = Client::new() - /// .torn_api(key) - /// .user() - /// .send() - /// .await; - /// - /// // invalid key - /// assert!(matches!(response, Err(ClientError::Api { code: 2, .. }))); - /// # }; - /// ``` - /// - /// # Errors - /// - /// Will return an `Err` if the API returns an API error, the request fails due to a network - /// error, or if the response body doesn't contain valid json. - pub async fn send(self) -> Result>::Err> { - self.executor.excute(self.request).await - } } -pub mod prelude { - pub use super::{ApiClient, ApiRequestExecutor, DirectApiClient}; -} +pub mod prelude {} #[cfg(test)] pub(crate) mod tests { use std::sync::Once; #[cfg(all(not(feature = "reqwest"), feature = "awc"))] - pub use awc::Client; + pub use ::awc::Client; #[cfg(feature = "reqwest")] - pub use reqwest::Client; + pub use ::reqwest::Client; + + #[cfg(all(not(feature = "reqwest"), feature = "awc"))] + pub use crate::ApiClient as ClientTrait; + #[cfg(feature = "reqwest")] + pub use crate::ThreadSafeApiClient as ClientTrait; #[cfg(all(not(feature = "reqwest"), feature = "awc"))] pub use actix_rt::test as async_test; @@ -387,12 +437,7 @@ pub(crate) mod tests { async fn reqwest() { let key = setup(); - reqwest::Client::default() - .torn_api(key) - .user() - .send() - .await - .unwrap(); + Client::default().torn_api(key).user(|b| b).await.unwrap(); } #[cfg(feature = "awc")] @@ -400,11 +445,6 @@ pub(crate) mod tests { async fn awc() { let key = setup(); - awc::Client::default() - .torn_api(key) - .user() - .send() - .await - .unwrap(); + Client::default().torn_api(key).user(|b| b).await.unwrap(); } } diff --git a/torn-api/src/reqwest.rs b/torn-api/src/reqwest.rs new file mode 100644 index 0000000..ff84edd --- /dev/null +++ b/torn-api/src/reqwest.rs @@ -0,0 +1,12 @@ +use async_trait::async_trait; + +use crate::ThreadSafeApiClient; + +#[async_trait] +impl ThreadSafeApiClient for reqwest::Client { + type Error = reqwest::Error; + + async fn request(&self, url: String) -> Result { + self.get(url).send().await?.json().await + } +} diff --git a/torn-api/src/user.rs b/torn-api/src/user.rs index 53744b3..903fdab 100644 --- a/torn-api/src/user.rs +++ b/torn-api/src/user.rs @@ -150,10 +150,7 @@ pub struct PersonalStats { #[cfg(test)] mod tests { use super::*; - use crate::{ - prelude::*, - tests::{async_test, setup, Client}, - }; + use crate::tests::{async_test, setup, Client, ClientTrait}; #[async_test] async fn user() { @@ -161,14 +158,14 @@ mod tests { let response = Client::default() .torn_api(key) - .user() - .selections(&[ - Selection::Basic, - Selection::Discord, - Selection::Profile, - Selection::PersonalStats, - ]) - .send() + .user(|b| { + b.selections(&[ + Selection::Basic, + Selection::Discord, + Selection::Profile, + Selection::PersonalStats, + ]) + }) .await .unwrap(); @@ -184,10 +181,7 @@ mod tests { let response = Client::default() .torn_api(key) - .user() - .id(28) - .selections(&[Selection::Profile]) - .send() + .user(|b| b.id(28).selections(&[Selection::Profile])) .await .unwrap(); diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml index a12488e..7a30c47 100644 --- a/torn-key-pool/Cargo.toml +++ b/torn-key-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torn-key-pool" -version = "0.1.3" +version = "0.2.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs index 5ab2f83..8eab607 100644 --- a/torn-key-pool/src/lib.rs +++ b/torn-key-pool/src/lib.rs @@ -6,18 +6,25 @@ pub mod postgres; use async_trait::async_trait; use thiserror::Error; -use torn_api::prelude::*; +use torn_api::{ + ApiCategoryResponse, ApiClient, ApiProvider, ApiRequest, ApiResponse, RequestExecutor, + ResponseError, ThreadSafeApiClient, ThreadSafeApiProvider, ThreadSafeRequestExecutor, +}; #[derive(Debug, Error)] -pub enum KeyPoolError +pub enum KeyPoolError where - S: Sync + Send + std::error::Error, + S: std::error::Error, + C: std::error::Error, { #[error("Key pool storage driver error: {0:?}")] Storage(#[source] S), #[error(transparent)] - Client(#[from] torn_api::ClientError), + Client(#[from] C), + + #[error(transparent)] + Response(ResponseError), } #[derive(Debug, Clone, Copy)] @@ -34,50 +41,47 @@ pub trait ApiKey: Sync + Send { #[async_trait] pub trait KeyPoolStorage { type Key: ApiKey; - type Err: Sync + Send + std::error::Error; + type Error: std::error::Error + Sync + Send; - async fn acquire_key(&self, domain: KeyDomain) -> Result; + async fn acquire_key(&self, domain: KeyDomain) -> Result; - async fn flag_key(&self, key: Self::Key, code: u8) -> Result; + async fn flag_key(&self, key: Self::Key, code: u8) -> Result; } #[derive(Debug, Clone)] -pub struct KeyPoolExecutor<'client, C, S> +pub struct KeyPoolExecutor<'a, C, S> where - C: ApiClient, S: KeyPoolStorage, { - client: &'client C, - storage: &'client S, + storage: &'a S, domain: KeyDomain, + _marker: std::marker::PhantomData, } -impl<'client, C, S> KeyPoolExecutor<'client, C, S> +impl<'a, C, S> KeyPoolExecutor<'a, C, S> where - C: ApiClient, S: KeyPoolStorage, { - pub fn new(client: &'client C, storage: &'client S, domain: KeyDomain) -> Self { + pub fn new(storage: &'a S, domain: KeyDomain) -> Self { Self { - client, storage, domain, + _marker: std::marker::PhantomData, } } } -#[cfg_attr(feature = "awc", async_trait(?Send))] -#[cfg_attr(not(feature = "awc"), async_trait)] -impl<'client, C, S> ApiRequestExecutor<'client> for KeyPoolExecutor<'client, C, S> +#[async_trait(?Send)] +impl<'client, C, S> RequestExecutor for KeyPoolExecutor<'client, C, S> where C: ApiClient, - S: KeyPoolStorage + Send + Sync + 'static, + S: KeyPoolStorage + 'static, { - type Err = KeyPoolError; + type Error = KeyPoolError; - async fn excute(&self, request: torn_api::ApiRequest) -> Result + async fn execute(&self, client: &C, request: ApiRequest) -> Result where - A: torn_api::ApiCategoryResponse, + A: ApiCategoryResponse, { loop { let key = self @@ -86,20 +90,60 @@ where .await .map_err(KeyPoolError::Storage)?; let url = request.url(key.value()); - let res = self.client.request(url).await; + let value = client.request(url).await?; - match res { - Err(torn_api::ClientError::Api { code, .. }) => { + match ApiResponse::from_value(value) { + Err(ResponseError::Api { code, reason }) => { if !self .storage .flag_key(key, code) .await .map_err(KeyPoolError::Storage)? { - panic!(); + return Err(KeyPoolError::Response(ResponseError::Api { code, reason })); } } - _ => return res.map(A::from_response).map_err(KeyPoolError::Client), + Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)), + Ok(res) => return Ok(A::from_response(res)), + }; + } + } +} + +#[async_trait] +impl<'client, C, S> ThreadSafeRequestExecutor for KeyPoolExecutor<'client, C, S> +where + C: ThreadSafeApiClient, + S: KeyPoolStorage + Send + Sync + 'static, +{ + type Error = KeyPoolError; + + async fn execute(&self, client: &C, request: ApiRequest) -> Result + where + A: ApiCategoryResponse, + { + loop { + let key = self + .storage + .acquire_key(self.domain) + .await + .map_err(KeyPoolError::Storage)?; + let url = request.url(key.value()); + let value = client.request(url).await?; + + match ApiResponse::from_value(value) { + Err(ResponseError::Api { code, reason }) => { + if !self + .storage + .flag_key(key, code) + .await + .map_err(KeyPoolError::Storage)? + { + return Err(KeyPoolError::Response(ResponseError::Api { code, reason })); + } + } + Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)), + Ok(res) => return Ok(A::from_response(res)), }; } } @@ -118,29 +162,69 @@ where impl KeyPool where C: ApiClient, - S: KeyPoolStorage, + S: KeyPoolStorage + 'static, { pub fn new(client: C, storage: S) -> Self { Self { client, storage } } - pub fn torn_api(&self, domain: KeyDomain) -> KeyPoolExecutor { - KeyPoolExecutor::new(&self.client, &self.storage, domain) + pub fn torn_api(&self, domain: KeyDomain) -> ApiProvider> { + ApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain)) } } -pub trait KeyPoolClient: ApiClient { - fn with_pool<'a, S>(&'a self, domain: KeyDomain, storage: &'a S) -> KeyPoolExecutor +#[derive(Clone, Debug)] +pub struct ThreadSafeKeyPool +where + C: ThreadSafeApiClient, + S: KeyPoolStorage + Send + Sync + 'static, +{ + client: C, + storage: S, +} + +impl ThreadSafeKeyPool +where + C: ThreadSafeApiClient, + S: KeyPoolStorage + Send + Sync + 'static, +{ + pub fn new(client: C, storage: S) -> Self { + Self { client, storage } + } + + pub fn torn_api(&self, domain: KeyDomain) -> ThreadSafeApiProvider> { + ThreadSafeApiProvider::new(&self.client, KeyPoolExecutor::new(&self.storage, domain)) + } +} + +pub trait WithStorage { + fn with_storage<'a, S>( + &'a self, + storage: &'a S, + domain: KeyDomain, + ) -> ApiProvider> where - Self: Sized, - S: KeyPoolStorage, + Self: ApiClient + Sized, + S: KeyPoolStorage + 'static, { - KeyPoolExecutor::new(self, storage, domain) + ApiProvider::new(self, KeyPoolExecutor::new(storage, domain)) + } + + fn with_storage_sync<'a, S>( + &'a self, + storage: &'a S, + domain: KeyDomain, + ) -> ThreadSafeApiProvider> + where + Self: ThreadSafeApiClient + Sized, + S: KeyPoolStorage + Send + Sync + 'static, + { + ThreadSafeApiProvider::new(self, KeyPoolExecutor::new(storage, domain)) } } #[cfg(feature = "reqwest")] -impl KeyPoolClient for reqwest::Client {} +impl WithStorage for reqwest::Client {} #[cfg(feature = "awc")] -impl KeyPoolClient for awc::Client {} +impl WithStorage for awc::Client {} diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs index 24caae0..90be3c3 100644 --- a/torn-key-pool/src/postgres.rs +++ b/torn-key-pool/src/postgres.rs @@ -67,9 +67,9 @@ impl PgKeyPoolStorage { impl KeyPoolStorage for PgKeyPoolStorage { type Key = PgKey; - type Err = PgStorageError; + type Error = PgStorageError; - async fn acquire_key(&self, domain: KeyDomain) -> Result { + async fn acquire_key(&self, domain: KeyDomain) -> Result { let predicate = match domain { KeyDomain::Public => "".to_owned(), KeyDomain::User(id) => format!("where and user_id={} and user", id), @@ -117,7 +117,7 @@ impl KeyPoolStorage for PgKeyPoolStorage { key.ok_or(PgStorageError::Unavailable(domain)) } - async fn flag_key(&self, key: Self::Key, code: u8) -> Result { + async fn flag_key(&self, key: Self::Key, code: u8) -> Result { // TODO: put keys in cooldown when appropriate match code { 2 | 10 | 13 => {