diff --git a/Cargo.toml b/Cargo.toml index 9d3371c..17effd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,33 +1,2 @@ -[package] -name = "torn-api" -version = "0.1.0" -edition = "2021" - [workspace] -members = [ "macros" ] - -[features] -default = [ "reqwest" ] -reqwest = [ "dep:reqwest" ] -awc = [ "dep:awc" ] - -[dependencies] -serde = { version = "1", features = [ "derive" ] } -serde_json = "1" -chrono = { version = "0.4", features = [ "serde" ], default-features = false } -async-trait = "0.1" -thiserror = "1" -num-traits = "0.2" - -reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true } -awc = { version = "3", default-features = false, optional = true } - -macros = { path = "macros" } - -[dev-dependencies] -actix-rt = { version = "2.7.0" } -dotenv = "0.15.0" -tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] } -tokio-test = "0.4.2" -reqwest = { version = "*", default-features = true } -awc = { version = "*", features = [ "rustls" ] } +members = [ "macros", "torn-api", "torn-key-pool" ] diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..3a26366 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +edition = "2021" diff --git a/torn-api/Cargo.toml b/torn-api/Cargo.toml new file mode 100644 index 0000000..beab9c6 --- /dev/null +++ b/torn-api/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "torn-api" +version = "0.2.0" +edition = "2021" + +[features] +default = [ "reqwest" ] +reqwest = [ "dep:reqwest" ] +awc = [ "dep:awc" ] + +[dependencies] +serde = { version = "1", features = [ "derive" ] } +serde_json = "1" +chrono = { version = "0.4", features = [ "serde" ], default-features = false } +async-trait = "0.1" +thiserror = "1" +num-traits = "0.2" + +reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true } +awc = { version = "3", default-features = false, optional = true } + +macros = { path = "../macros" } + +[dev-dependencies] +actix-rt = { version = "2.7.0" } +dotenv = "0.15.0" +tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] } +tokio-test = "0.4.2" +reqwest = { version = "*", default-features = true } +awc = { version = "*", features = [ "rustls" ] } diff --git a/src/de_util.rs b/torn-api/src/de_util.rs similarity index 100% rename from src/de_util.rs rename to torn-api/src/de_util.rs diff --git a/src/faction.rs b/torn-api/src/faction.rs similarity index 87% rename from src/faction.rs rename to torn-api/src/faction.rs index a675844..c36a329 100644 --- a/src/faction.rs +++ b/torn-api/src/faction.rs @@ -1,12 +1,9 @@ use std::collections::BTreeMap; -use chrono::{serde::ts_seconds, DateTime, Utc}; use serde::Deserialize; use macros::ApiCategory; -use super::de_util; - #[derive(Debug, Clone, Copy, ApiCategory)] #[api(category = "faction")] pub enum Selection { @@ -42,7 +39,10 @@ pub struct Basic { #[cfg(test)] mod tests { use super::*; - use crate::{tests::{setup, Client, async_test}, ApiClient}; + use crate::{ + prelude::*, + tests::{async_test, setup, Client}, + }; #[async_test] async fn faction() { @@ -50,7 +50,7 @@ mod tests { let response = Client::default() .torn_api(key) - .faction(None) + .faction() .selections(&[Selection::Basic]) .send() .await diff --git a/src/lib.rs b/torn-api/src/lib.rs similarity index 63% rename from src/lib.rs rename to torn-api/src/lib.rs index 319c8a1..dca381c 100644 --- a/src/lib.rs +++ b/torn-api/src/lib.rs @@ -1,7 +1,7 @@ -#![warn(clippy::all, clippy::perf, clippy::pedantic, clippy::suspicious)] +#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)] -pub mod user; pub mod faction; +pub mod user; mod de_util; @@ -10,7 +10,6 @@ use chrono::{DateTime, Utc}; use serde::de::{DeserializeOwned, Error as DeError}; use thiserror::Error; - #[derive(Error, Debug)] pub enum Error { #[error("api returned error '{reason}', code = '{code}'")] @@ -91,46 +90,70 @@ pub trait ApiCategoryResponse { #[async_trait(?Send)] pub trait ApiClient { - async fn request(&self, url: String) -> Result; + async fn request(&self, url: String) -> Result; +} - fn torn_api(&self, key: String) -> TornApi +pub trait DirectApiClient: ApiClient { + fn torn_api(&self, key: String) -> DirectExecutor where - Self: Sized; + Self: Sized, + { + DirectExecutor::from_client(self, key) + } +} + +pub trait BackedApiClient: ApiClient {} + +#[cfg(feature = "reqwest")] +#[async_trait(?Send)] +impl crate::ApiClient for reqwest::Client { + async fn request(&self, url: String) -> Result { + let value: serde_json::Value = self.get(url).send().await?.json().await?; + Ok(ApiResponse::from_value(value)?) + } } #[cfg(feature = "reqwest")] #[async_trait(?Send)] -impl crate::ApiClient for ::reqwest::Client { - async fn request(&self, url: String) -> Result { - let value = self.get(url).send().await?.json().await?; - Ok(value) - } +impl crate::DirectApiClient for reqwest::Client {} - fn torn_api(&self, key: String) -> crate::TornApi - where - Self: Sized, - { - crate::TornApi::from_client(self, key) +#[cfg(feature = "awc")] +#[async_trait(?Send)] +impl crate::ApiClient for awc::Client { + async fn request(&self, url: String) -> Result { + let value: serde_json::Value = self.get(url).send().await?.json().await?; + Ok(ApiResponse::from_value(value)?) } } #[cfg(feature = "awc")] #[async_trait(?Send)] -impl crate::ApiClient for awc::Client { - async fn request(&self, url: String) -> Result { - let value = self.get(url).send().await?.json().await?; - Ok(value) +impl crate::DirectApiClient for awc::Client {} + +#[async_trait(?Send)] +pub trait ApiRequestExecutor<'client> { + type Err: std::error::Error; + + async fn excute(&self, request: ApiRequest) -> Result + where + A: ApiCategoryResponse; + + #[must_use] + fn user<'executor>( + &'executor self, + ) -> ApiRequestBuilder<'client, 'executor, Self, user::Response> { + ApiRequestBuilder::new(self) } - fn torn_api(&self, key: String) -> crate::TornApi - where - Self: Sized, - { - crate::TornApi::from_client(self, key) + #[must_use] + fn faction<'executor>( + &'executor self, + ) -> ApiRequestBuilder<'client, 'executor, Self, faction::Response> { + ApiRequestBuilder::new(self) } } -pub struct TornApi<'client, C> +pub struct DirectExecutor<'client, C> where C: ApiClient, { @@ -138,7 +161,7 @@ where key: String, } -impl<'client, C> TornApi<'client, C> +impl<'client, C> DirectExecutor<'client, C> where C: ApiClient, { @@ -146,73 +169,144 @@ where pub(crate) fn from_client(client: &'client C, key: String) -> Self { Self { client, key } } +} - #[must_use] - pub fn user(self, id: Option) -> ApiRequestBuilder<'client, C, user::Response> { - ApiRequestBuilder::new(self.client, self.key, id) - } +#[async_trait(?Send)] +impl<'client, C> ApiRequestExecutor<'client> for DirectExecutor<'client, C> +where + C: ApiClient, +{ + type Err = Error; - #[must_use] - pub fn faction(self, id: Option) -> ApiRequestBuilder<'client, C, faction::Response> { - ApiRequestBuilder::new(self.client, self.key, id) + async fn excute(&self, request: ApiRequest) -> Result + where + A: ApiCategoryResponse, + { + let url = request.url(&self.key); + + self.client.request(url).await.map(A::from_response) } } -pub struct ApiRequestBuilder<'client, C, A> +#[derive(Debug)] +pub struct ApiRequest where - C: ApiClient, A: ApiCategoryResponse, { - client: &'client C, - key: String, - phantom: std::marker::PhantomData, selections: Vec<&'static str>, id: Option, from: Option>, to: Option>, comment: Option, + phantom: std::marker::PhantomData, } -impl<'client, C, A> ApiRequestBuilder<'client, C, A> +impl std::default::Default for ApiRequest where - C: ApiClient, A: ApiCategoryResponse, { - pub(crate) fn new(client: &'client C, key: String, id: Option) -> Self { + fn default() -> Self { Self { - client, - key, - phantom: std::marker::PhantomData, - selections: Vec::new(), - id, + selections: Vec::default(), + id: None, from: None, to: None, comment: None, + phantom: std::marker::PhantomData::default(), + } + } +} + +impl ApiRequest +where + A: ApiCategoryResponse, +{ + pub fn url(&self, key: &str) -> String { + let mut query_fragments = vec![ + format!("selections={}", self.selections.join(",")), + format!("key={}", key), + ]; + + if let Some(from) = self.from { + query_fragments.push(format!("from={}", from.timestamp())); + } + + if let Some(to) = self.to { + query_fragments.push(format!("to={}", to.timestamp())); + } + + if let Some(comment) = &self.comment { + query_fragments.push(format!("comment={}", comment)); + } + + let query = query_fragments.join("&"); + + let id_fragment = match self.id { + Some(id) => id.to_string(), + None => "".to_owned(), + }; + + format!( + "https://api.torn.com/{}/{}?{}", + A::Selection::category(), + id_fragment, + query + ) + } +} + +pub struct ApiRequestBuilder<'client, 'executor, E, A> +where + E: ApiRequestExecutor<'client> + ?Sized, + A: ApiCategoryResponse, +{ + executor: &'executor E, + request: ApiRequest, + _phantom: std::marker::PhantomData<&'client E>, +} + +impl<'client, 'executor, E, A> ApiRequestBuilder<'client, 'executor, E, A> +where + E: ApiRequestExecutor<'client> + ?Sized, + A: ApiCategoryResponse, +{ + pub(crate) fn new(executor: &'executor E) -> Self { + Self { + executor, + request: ApiRequest::default(), + _phantom: std::marker::PhantomData::default(), } } + #[must_use] + pub fn id(mut self, id: u64) -> Self { + self.request.id = Some(id); + self + } + #[must_use] pub fn selections(mut self, selections: &[A::Selection]) -> Self { - self.selections + self.request + .selections .append(&mut selections.iter().map(ApiSelection::raw_value).collect()); self } #[must_use] pub fn from(mut self, from: DateTime) -> Self { - self.from = Some(from); + self.request.from = Some(from); self } #[must_use] pub fn to(mut self, to: DateTime) -> Self { - self.to = Some(to); + self.request.to = Some(to); self } #[must_use] pub fn comment(mut self, comment: String) -> Self { - self.comment = Some(comment); + self.request.comment = Some(comment); self } @@ -221,14 +315,14 @@ where /// # Examples /// /// ```no_run - /// use torn_api::{ApiClient, Error}; + /// use torn_api::{prelude::*, Error}; /// use reqwest::Client; /// # async { /// /// let key = "XXXXXXXXX".to_owned(); /// let response = Client::new() /// .torn_api(key) - /// .user(None) + /// .user() /// .send() /// .await; /// @@ -241,57 +335,28 @@ where /// /// Will return an `Err` if the API returns an API error, the request fails due to a network /// error, or if the response body doesn't contain valid json. - pub async fn send(self) -> Result { - let mut query_fragments = vec![ - format!("selections={}", self.selections.join(",")), - format!("key={}", self.key), - ]; - - if let Some(from) = self.from { - query_fragments.push(format!("from={}", from.timestamp())); - } - - if let Some(to) = self.to { - query_fragments.push(format!("to={}", to.timestamp())); - } - - if let Some(comment) = self.comment { - query_fragments.push(format!("comment={}", comment)); - } - - let query = query_fragments.join("&"); - - let id_fragment = match self.id { - Some(id) => id.to_string(), - None => "".to_owned(), - }; - - let url = format!( - "https://api.torn.com/{}/{}?{}", - A::Selection::category(), - id_fragment, - query - ); - - let value = self.client.request(url).await?; - - ApiResponse::from_value(value).map(A::from_response) + pub async fn send(self) -> Result>::Err> { + self.executor.excute(self.request).await } } +pub mod prelude { + pub use super::{ApiClient, ApiRequestExecutor, DirectApiClient}; +} + #[cfg(test)] pub(crate) mod tests { use std::sync::Once; - #[cfg(feature = "reqwest")] - pub use reqwest::Client; #[cfg(all(not(feature = "reqwest"), feature = "awc"))] pub use awc::Client; - #[cfg(feature = "reqwest")] - pub use tokio::test as async_test; + pub use reqwest::Client; + #[cfg(all(not(feature = "reqwest"), feature = "awc"))] pub use actix_rt::test as async_test; + #[cfg(feature = "reqwest")] + pub use tokio::test as async_test; use super::*; @@ -316,7 +381,7 @@ pub(crate) mod tests { reqwest::Client::default() .torn_api(key) - .user(None) + .user() .send() .await .unwrap(); diff --git a/src/user.rs b/torn-api/src/user.rs similarity index 91% rename from src/user.rs rename to torn-api/src/user.rs index 56dac68..53744b3 100644 --- a/src/user.rs +++ b/torn-api/src/user.rs @@ -89,7 +89,7 @@ pub struct Discord { #[serde(rename = "userID")] pub user_id: i32, #[serde(rename = "discordID", deserialize_with = "de_util::string_is_long")] - pub discord_id: i64, + pub discord_id: i64, } #[derive(Debug, Clone, Deserialize)] @@ -150,7 +150,10 @@ pub struct PersonalStats { #[cfg(test)] mod tests { use super::*; - use crate::{tests::{setup, Client, async_test}, ApiClient}; + use crate::{ + prelude::*, + tests::{async_test, setup, Client}, + }; #[async_test] async fn user() { @@ -158,8 +161,13 @@ mod tests { let response = Client::default() .torn_api(key) - .user(None) - .selections(&[Selection::Basic, Selection::Discord, Selection::Profile, Selection::PersonalStats]) + .user() + .selections(&[ + Selection::Basic, + Selection::Discord, + Selection::Profile, + Selection::PersonalStats, + ]) .send() .await .unwrap(); @@ -176,8 +184,9 @@ mod tests { let response = Client::default() .torn_api(key) - .user(Some(28)) - .selections(&[ Selection::Profile]) + .user() + .id(28) + .selections(&[Selection::Profile]) .send() .await .unwrap(); diff --git a/torn-key-pool/Cargo.toml b/torn-key-pool/Cargo.toml new file mode 100644 index 0000000..2a7de04 --- /dev/null +++ b/torn-key-pool/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "torn-key-pool" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +default = [ "postgres" ] +postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc" ] + +[dependencies] +torn-api = { path = "../torn-api", default-features = false } +sqlx = { version = "0.6", features = [ "postgres", "chrono" ], optional = true } +chrono = { version = "0.4", optional = true } +indoc = { version = "1", optional = true } +async-trait = "0.1" +thiserror = "1" + +[dev-dependencies] +torn-api = { path = "../torn-api", features = [ "reqwest" ] } +sqlx = { version = "*", features = [ "runtime-tokio-rustls" ] } +dotenv = "0.15.0" +tokio = { version = "1.20.1", features = ["test-util", "rt", "macros"] } +tokio-test = "0.4.2" +reqwest = { version = "0.11", features = [ "json" ] } diff --git a/torn-key-pool/src/lib.rs b/torn-key-pool/src/lib.rs new file mode 100644 index 0000000..d682735 --- /dev/null +++ b/torn-key-pool/src/lib.rs @@ -0,0 +1,129 @@ +#![warn(clippy::all, clippy::perf, clippy::style, clippy::suspicious)] + +#[cfg(feature = "postgres")] +pub mod postgres; + +use async_trait::async_trait; +use thiserror::Error; + +use torn_api::prelude::*; + +#[derive(Debug, Error)] +pub enum KeyPoolError +where + S: std::error::Error + std::fmt::Debug, +{ + #[error("Key pool storage driver error: {0:?}")] + Storage(#[source] S), + + #[error(transparent)] + Client(#[from] torn_api::Error), +} + +#[derive(Debug, Clone, Copy)] +pub enum KeyDomain { + Public, + User(i32), + Faction(i32), +} + +pub trait ApiKey { + fn value(&self) -> &str; +} + +#[async_trait(?Send)] +pub trait KeyPoolStorage { + type Key: ApiKey; + type Err: std::error::Error; + + async fn acquire_key(&self, domain: KeyDomain) -> Result; + + async fn flag_key(&self, key: Self::Key, code: u8) -> Result; +} + +#[derive(Debug, Clone)] +pub struct KeyPoolExecutor<'client, C, S> +where + C: ApiClient, + S: KeyPoolStorage, +{ + client: &'client C, + storage: &'client S, + domain: KeyDomain, +} + +impl<'client, C, S> KeyPoolExecutor<'client, C, S> +where + C: ApiClient, + S: KeyPoolStorage, +{ + pub fn new(client: &'client C, storage: &'client S, domain: KeyDomain) -> Self { + Self { + client, + storage, + domain, + } + } +} + +#[async_trait(?Send)] +impl<'client, C, S> ApiRequestExecutor<'client> for KeyPoolExecutor<'client, C, S> +where + C: ApiClient, + S: KeyPoolStorage + 'static, +{ + type Err = KeyPoolError; + + async fn excute(&self, request: torn_api::ApiRequest) -> Result + where + A: torn_api::ApiCategoryResponse, + { + loop { + let key = self + .storage + .acquire_key(self.domain) + .await + .map_err(KeyPoolError::Storage)?; + let url = request.url(key.value()); + let res = self.client.request(url).await; + + match res { + Err(torn_api::Error::Api { code, .. }) => { + if !self + .storage + .flag_key(key, code) + .await + .map_err(KeyPoolError::Storage)? + { + panic!(); + } + } + _ => return res.map(A::from_response).map_err(KeyPoolError::Client), + }; + } + } +} + +#[derive(Clone, Debug)] +pub struct KeyPool +where + C: ApiClient, + S: KeyPoolStorage, +{ + client: C, + storage: S, +} + +impl KeyPool +where + C: ApiClient, + S: KeyPoolStorage, +{ + pub fn new(client: C, storage: S) -> Self { + Self { client, storage } + } + + pub fn torn_api(&self, domain: KeyDomain) -> KeyPoolExecutor { + KeyPoolExecutor::new(&self.client, &self.storage, domain) + } +} diff --git a/torn-key-pool/src/postgres.rs b/torn-key-pool/src/postgres.rs new file mode 100644 index 0000000..8ab0c53 --- /dev/null +++ b/torn-key-pool/src/postgres.rs @@ -0,0 +1,196 @@ +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use indoc::indoc; +use sqlx::{FromRow, PgPool}; +use thiserror::Error; + +use crate::{ApiKey, KeyDomain, KeyPool, KeyPoolStorage}; + +#[derive(Debug, Error)] +pub enum PgStorageError { + #[error(transparent)] + Pg(#[from] sqlx::Error), + + #[error("No key avalaible for domain {0:?}")] + Unavailable(KeyDomain), +} + +#[derive(Debug, Clone, FromRow)] +pub struct PgKey { + pub id: i32, + pub user_id: i32, + pub faction_id: Option, + pub key: String, + pub uses: i16, + pub user: bool, + pub faction: bool, + pub last_used: DateTime, +} + +impl ApiKey for PgKey { + fn value(&self) -> &str { + &self.key + } +} + +#[derive(Debug, Clone, FromRow)] +pub struct PgKeyPoolStorage { + pool: PgPool, + limit: i16, +} + +impl PgKeyPoolStorage { + pub fn new(pool: PgPool, limit: i16) -> Self { + Self { pool, limit } + } + + pub async fn initialise(&self) -> Result<(), PgStorageError> { + sqlx::query(indoc! {r#" + CREATE TABLE IF NOT EXISTS api_keys ( + id serial primary key, + user_id int4 not null, + faction_id int4, + key char(16) not null, + uses int2 not null default 0, + "user" bool not null, + faction bool not null, + last_used timestamptz not null default now() + )"#}) + .execute(&self.pool) + .await?; + + Ok(()) + } +} + +#[async_trait(?Send)] +impl KeyPoolStorage for PgKeyPoolStorage { + type Key = PgKey; + + type Err = PgStorageError; + + async fn acquire_key(&self, domain: KeyDomain) -> Result { + let predicate = match domain { + KeyDomain::Public => "".to_owned(), + KeyDomain::User(id) => format!("where and user_id={} and user", id), + KeyDomain::Faction(id) => format!("where and faction_id={} and faction", id), + }; + let key: Option = sqlx::query_as(&indoc::formatdoc!( + r#" + with key as ( + select + id, + user_id, + faction_id, + key, + case + when extract(minute from last_used)=extract(minute from now()) then uses + else 0::smallint + end as uses, + user, + faction, + last_used + from api_keys {} + order by last_used asc limit 1 FOR UPDATE + ) + update api_keys set + uses = key.uses + 1, + last_used = now() + from key where + api_keys.id=key.id and key.uses < $1 + returning + api_keys.id, + api_keys.user_id, + api_keys.faction_id, + api_keys.key, + api_keys.uses, + api_keys.user, + api_keys.faction, + api_keys.last_used + "#, + predicate + )) + .bind(self.limit) + .fetch_optional(&self.pool) + .await?; + + key.ok_or(PgStorageError::Unavailable(domain)) + } + + async fn flag_key(&self, key: Self::Key, code: u8) -> Result { + // TODO: put keys in cooldown when appropriate + match code { + 2 | 10 | 13 => { + sqlx::query("delete from api_keys where id=$1") + .bind(key.id) + .execute(&self.pool) + .await?; + Ok(true) + } + 9 => Ok(false), + _ => Ok(true), + } + } +} + +pub type PgKeyPool = KeyPool; + +impl PgKeyPool +where + A: torn_api::ApiClient, +{ + pub async fn connect( + client: A, + database_url: &str, + limit: i16, + ) -> Result { + let db_pool = PgPool::connect(database_url).await?; + let storage = PgKeyPoolStorage::new(db_pool, limit); + storage.initialise().await?; + + let key_pool = Self::new(client, storage); + + Ok(key_pool) + } +} + +#[cfg(test)] +mod test { + use std::sync::Once; + + use tokio::test; + + use super::*; + + static INIT: Once = Once::new(); + + pub(crate) async fn setup() -> PgKeyPoolStorage { + INIT.call_once(|| { + dotenv::dotenv().ok(); + }); + + let pool = PgPool::connect(&std::env::var("DATABASE_URL").unwrap()) + .await + .unwrap(); + + PgKeyPoolStorage::new(pool, 3) + } + + #[test] + async fn test_initialise() { + let storage = setup().await; + + if let Err(e) = storage.initialise().await { + panic!("Initialising key storage failed: {:?}", e); + } + } + + #[test] + async fn acquire_one() { + let storage = setup().await; + + if let Err(e) = storage.acquire_key(KeyDomain::Public).await { + panic!("Acquiring key failed: {:?}", e); + } + } +}