conditional Sync + Send if client supports it

This commit is contained in:
TotallyNot 2022-09-04 20:32:40 +02:00
parent da9c1b1563
commit 54345fef19
6 changed files with 37 additions and 27 deletions

View file

@ -1,2 +1,3 @@
[workspace] [workspace]
resolver = "2"
members = [ "macros", "torn-api", "torn-key-pool" ] members = [ "macros", "torn-api", "torn-key-pool" ]

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-api" name = "torn-api"
version = "0.2.0" version = "0.2.1"
edition = "2021" edition = "2021"
[features] [features]

View file

@ -11,7 +11,7 @@ use serde::de::{DeserializeOwned, Error as DeError};
use thiserror::Error; use thiserror::Error;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum ClientError {
#[error("api returned error '{reason}', code = '{code}'")] #[error("api returned error '{reason}', code = '{code}'")]
Api { code: u8, reason: String }, Api { code: u8, reason: String },
@ -36,7 +36,7 @@ pub struct ApiResponse {
} }
impl ApiResponse { impl ApiResponse {
fn from_value(mut value: serde_json::Value) -> Result<Self, Error> { fn from_value(mut value: serde_json::Value) -> Result<Self, ClientError> {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct ApiErrorDto { struct ApiErrorDto {
code: u8, code: u8,
@ -46,7 +46,7 @@ impl ApiResponse {
match value.get_mut("error") { match value.get_mut("error") {
Some(error) => { Some(error) => {
let dto: ApiErrorDto = serde_json::from_value(error.take())?; let dto: ApiErrorDto = serde_json::from_value(error.take())?;
Err(Error::Api { Err(ClientError::Api {
code: dto.code, code: dto.code,
reason: dto.reason, reason: dto.reason,
}) })
@ -82,15 +82,22 @@ pub trait ApiSelection {
fn category() -> &'static str; fn category() -> &'static str;
} }
pub trait ApiCategoryResponse { pub trait ApiCategoryResponse: Send + Sync {
type Selection: ApiSelection; type Selection: ApiSelection;
fn from_response(response: ApiResponse) -> Self; fn from_response(response: ApiResponse) -> Self;
} }
#[cfg(feature = "awc")]
#[async_trait(?Send)] #[async_trait(?Send)]
pub trait ApiClient { pub trait ApiClient {
async fn request(&self, url: String) -> Result<ApiResponse, Error>; async fn request(&self, url: String) -> Result<ApiResponse, ClientError>;
}
#[cfg(not(feature = "awc"))]
#[async_trait]
pub trait ApiClient: Send + Sync {
async fn request(&self, url: String) -> Result<ApiResponse, ClientError>;
} }
pub trait DirectApiClient: ApiClient { pub trait DirectApiClient: ApiClient {
@ -105,32 +112,32 @@ pub trait DirectApiClient: ApiClient {
pub trait BackedApiClient: ApiClient {} pub trait BackedApiClient: ApiClient {}
#[cfg(feature = "reqwest")] #[cfg(feature = "reqwest")]
#[async_trait(?Send)] #[cfg_attr(feature = "awc", async_trait(?Send))]
#[cfg_attr(not(feature = "awc"), async_trait)]
impl crate::ApiClient for reqwest::Client { impl crate::ApiClient for reqwest::Client {
async fn request(&self, url: String) -> Result<ApiResponse, crate::Error> { async fn request(&self, url: String) -> Result<ApiResponse, crate::ClientError> {
let value: serde_json::Value = self.get(url).send().await?.json().await?; let value: serde_json::Value = self.get(url).send().await?.json().await?;
Ok(ApiResponse::from_value(value)?) Ok(ApiResponse::from_value(value)?)
} }
} }
#[cfg(feature = "reqwest")] #[cfg(feature = "reqwest")]
#[async_trait(?Send)]
impl crate::DirectApiClient for reqwest::Client {} impl crate::DirectApiClient for reqwest::Client {}
#[cfg(feature = "awc")] #[cfg(feature = "awc")]
#[async_trait(?Send)] #[async_trait(?Send)]
impl crate::ApiClient for awc::Client { impl crate::ApiClient for awc::Client {
async fn request(&self, url: String) -> Result<ApiResponse, crate::Error> { async fn request(&self, url: String) -> Result<ApiResponse, crate::ClientError> {
let value: serde_json::Value = self.get(url).send().await?.json().await?; let value: serde_json::Value = self.get(url).send().await?.json().await?;
Ok(ApiResponse::from_value(value)?) Ok(ApiResponse::from_value(value)?)
} }
} }
#[cfg(feature = "awc")] #[cfg(feature = "awc")]
#[async_trait(?Send)]
impl crate::DirectApiClient for awc::Client {} impl crate::DirectApiClient for awc::Client {}
#[async_trait(?Send)] #[cfg_attr(feature = "awc", async_trait(?Send))]
#[cfg_attr(not(feature = "awc"), async_trait)]
pub trait ApiRequestExecutor<'client> { pub trait ApiRequestExecutor<'client> {
type Err: std::error::Error; type Err: std::error::Error;
@ -171,12 +178,13 @@ where
} }
} }
#[async_trait(?Send)] #[cfg_attr(feature = "awc", async_trait(?Send))]
#[cfg_attr(not(feature = "awc"), async_trait)]
impl<'client, C> ApiRequestExecutor<'client> for DirectExecutor<'client, C> impl<'client, C> ApiRequestExecutor<'client> for DirectExecutor<'client, C>
where where
C: ApiClient, C: ApiClient,
{ {
type Err = Error; type Err = ClientError;
async fn excute<A>(&self, request: ApiRequest<A>) -> Result<A, Self::Err> async fn excute<A>(&self, request: ApiRequest<A>) -> Result<A, Self::Err>
where where
@ -315,7 +323,7 @@ where
/// # Examples /// # Examples
/// ///
/// ```no_run /// ```no_run
/// use torn_api::{prelude::*, Error}; /// use torn_api::{prelude::*, ClientError};
/// use reqwest::Client; /// use reqwest::Client;
/// # async { /// # async {
/// ///
@ -327,7 +335,7 @@ where
/// .await; /// .await;
/// ///
/// // invalid key /// // invalid key
/// assert!(matches!(response, Err(Error::Api { code: 2, .. }))); /// assert!(matches!(response, Err(ClientError::Api { code: 2, .. })));
/// # }; /// # };
/// ``` /// ```
/// ///
@ -394,7 +402,7 @@ pub(crate) mod tests {
awc::Client::default() awc::Client::default()
.torn_api(key) .torn_api(key)
.user(None) .user()
.send() .send()
.await .await
.unwrap(); .unwrap();

View file

@ -1,6 +1,6 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.1.2" version = "0.1.3"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View file

@ -11,13 +11,13 @@ use torn_api::prelude::*;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum KeyPoolError<S> pub enum KeyPoolError<S>
where where
S: std::error::Error + std::fmt::Debug, S: Sync + Send + std::error::Error,
{ {
#[error("Key pool storage driver error: {0:?}")] #[error("Key pool storage driver error: {0:?}")]
Storage(#[source] S), Storage(#[source] S),
#[error(transparent)] #[error(transparent)]
Client(#[from] torn_api::Error), Client(#[from] torn_api::ClientError),
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@ -27,14 +27,14 @@ pub enum KeyDomain {
Faction(i32), Faction(i32),
} }
pub trait ApiKey { pub trait ApiKey: Sync + Send {
fn value(&self) -> &str; fn value(&self) -> &str;
} }
#[async_trait(?Send)] #[async_trait]
pub trait KeyPoolStorage { pub trait KeyPoolStorage {
type Key: ApiKey; type Key: ApiKey;
type Err: std::error::Error; type Err: Sync + Send + std::error::Error;
async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Err>; async fn acquire_key(&self, domain: KeyDomain) -> Result<Self::Key, Self::Err>;
@ -66,11 +66,12 @@ where
} }
} }
#[async_trait(?Send)] #[cfg_attr(feature = "awc", async_trait(?Send))]
#[cfg_attr(not(feature = "awc"), async_trait)]
impl<'client, C, S> ApiRequestExecutor<'client> for KeyPoolExecutor<'client, C, S> impl<'client, C, S> ApiRequestExecutor<'client> for KeyPoolExecutor<'client, C, S>
where where
C: ApiClient, C: ApiClient,
S: KeyPoolStorage + 'static, S: KeyPoolStorage + Send + Sync + 'static,
{ {
type Err = KeyPoolError<S::Err>; type Err = KeyPoolError<S::Err>;
@ -88,7 +89,7 @@ where
let res = self.client.request(url).await; let res = self.client.request(url).await;
match res { match res {
Err(torn_api::Error::Api { code, .. }) => { Err(torn_api::ClientError::Api { code, .. }) => {
if !self if !self
.storage .storage
.flag_key(key, code) .flag_key(key, code)

View file

@ -63,7 +63,7 @@ impl PgKeyPoolStorage {
} }
} }
#[async_trait(?Send)] #[async_trait]
impl KeyPoolStorage for PgKeyPoolStorage { impl KeyPoolStorage for PgKeyPoolStorage {
type Key = PgKey; type Key = PgKey;