Compare commits

...

35 commits

Author SHA1 Message Date
pyrite 485c2ea176
chore(torn-api): updated spec 2025-06-27 17:01:50 +02:00
pyrite cf98d24090
feat(codegen): various improvements to robustness 2025-06-27 16:59:38 +02:00
pyrite a90bcb00c4
chore(torn-api): update schema 2025-05-28 19:51:32 +02:00
pyrite 1c9b4123c3
fix(api): removed missing properties from personalstats 2025-05-28 11:13:27 +02:00
pyrite e5a766b893
fix(torn-api): fix popular personalstats 2025-05-28 11:01:45 +02:00
pyrite 45899430bb
chore(torn-api): changed personalstats order 2025-05-28 10:43:44 +02:00
pyrite bd27916aa5
chore: update schema and release versions 2025-05-28 10:20:55 +02:00
pyrite 98073a37bd
fix(torn-api): fix request parameter encoding 2025-05-28 10:19:43 +02:00
pyrite 40913bc89b
fix(codegen): hacky fix for colliding enum names 2025-05-28 10:19:05 +02:00
pyrite 14e6e81278
chore: versions 2025-05-27 19:57:17 +02:00
pyrite 3ad92fb8c8
feat(codegen): derive Eq and Hash for most enum types 2025-05-27 19:56:03 +02:00
pyrite 1af37bea89
chore: release versions 2025-05-27 19:31:50 +02:00
pyrite 39731f2f5d
feat(codegen): implemented oneOf unions for primitive types 2025-05-27 19:27:59 +02:00
pyrite 83dfdb27ac
chore(codegen): release version 2025-05-27 19:27:59 +02:00
pyrite 6aaa06f501
chore(codegen): implemented Eq for OpenApiSchema 2025-05-27 19:27:59 +02:00
pyrite c8bdcc81c4
chore(key-pool): expose inner storage and client 2025-05-27 19:27:58 +02:00
pyrite 8bfa1b8ac3
feat(torn-api): added optional strum feature 2025-05-27 19:27:58 +02:00
pyrite 56e64470de
chore: release versions 2025-05-27 19:27:58 +02:00
pyrite 6d57f275a2
chore: moved schema file to torn-api crate 2025-05-27 19:27:58 +02:00
pyrite 11c5d71bf6
chore(torn-api): release version 1.3.0 2025-05-27 19:27:57 +02:00
pyrite eb6e98f41b
chore(codegen): release version 2025-05-27 19:27:57 +02:00
pyrite 7a4f6462f5
feat(torn-api): add chrono for datetime support 2025-05-27 19:27:57 +02:00
pyrite 266122ea0e
chore(codegen): publish version 2025-05-27 19:27:57 +02:00
pyrite 47461b61b2
fix(codegen): fixed codegen for array path parameters 2025-05-27 19:27:56 +02:00
pyrite b245e3e712
chore(key-pool): release version 2025-05-27 19:27:56 +02:00
TotallyNot 73358b70cc
chore: updated schemas 2025-05-19 20:09:38 +02:00
TotallyNot b4ce0c764e
feat: allow arbitrary deserialisation from unions 2025-05-01 16:10:24 +02:00
TotallyNot 7bc61de1c2
feat: simplified lifetime bounds on bulk executor 2025-04-29 22:46:43 +02:00
TotallyNot c17f93f600
feat: implemented bulk requests 2025-04-29 18:27:42 +02:00
TotallyNot 4dd4fd37d4
feat(core): allow optionally disabling expensive codegen 2025-04-27 15:23:52 +02:00
TotallyNot 26043ac318
fix: changed rust edition to 2021 2025-04-27 13:47:22 +02:00
TotallyNot b069c7e493
fix(code-gen): hack to fix resolution of malformed number property 2025-04-27 11:11:30 +02:00
TotallyNot 40b784cf57
feat(key-pool): allow setting schema for key table 2025-04-27 11:02:33 +02:00
TotallyNot c5651efbb0
chore(codegen): updated spec 2025-04-27 10:59:54 +02:00
TotallyNot 4b52c37774
feat(key-pool): updated key pool to use v2 api 2025-04-25 17:23:43 +02:00
26 changed files with 15137 additions and 7839 deletions

1366
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,11 +1,11 @@
[workspace] [workspace]
resolver = "2" resolver = "2"
members = ["torn-api", "torn-api-codegen"] members = ["torn-api", "torn-api-codegen", "torn-key-pool"]
[workspace.package] [workspace.package]
license-file = "./LICENSE" license-file = "./LICENSE"
repository = "https://github.com/TotallyNot/torn-api.rs.git" repository = "https://git.elimination.me/pyrite/torn-api.rs.git"
homepage = "https://github.com/TotallyNot/torn-api.rs.git" homepage = "https://git.elimination.me/pyrite/torn-api.rs"
[workspace.dependencies] [workspace.dependencies]
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }

View file

@ -1,8 +1,8 @@
[package] [package]
name = "torn-api-codegen" name = "torn-api-codegen"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
version = "0.1.1" version = "0.7.0"
edition = "2024" edition = "2021"
description = "Contains the v2 torn API model descriptions and codegen for the bindings" description = "Contains the v2 torn API model descriptions and codegen for the bindings"
license-file = { workspace = true } license-file = { workspace = true }
repository = { workspace = true } repository = { workspace = true }

View file

@ -1,12 +1,15 @@
use heck::ToUpperCamelCase; use heck::{ToSnakeCase, ToUpperCamelCase};
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::Ident;
use crate::openapi::{ use crate::openapi::{
parameter::OpenApiParameterSchema, parameter::OpenApiParameterSchema,
r#type::{OpenApiType, OpenApiVariants}, r#type::{OpenApiType, OpenApiVariants},
}; };
use super::{object::PrimitiveType, Model, ResolvedSchema};
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnumRepr { pub enum EnumRepr {
U8, U8,
@ -17,10 +20,12 @@ pub enum EnumRepr {
pub enum EnumVariantTupleValue { pub enum EnumVariantTupleValue {
Ref { ty_name: String }, Ref { ty_name: String },
ArrayOfRefs { ty_name: String }, ArrayOfRefs { ty_name: String },
Primitive(PrimitiveType),
Enum { name: String, inner: Enum },
} }
impl EnumVariantTupleValue { impl EnumVariantTupleValue {
pub fn from_schema(schema: &OpenApiType) -> Option<Self> { pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
match schema { match schema {
OpenApiType { OpenApiType {
ref_path: Some(path), ref_path: Some(path),
@ -44,14 +49,66 @@ impl EnumVariantTupleValue {
ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(), ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
}) })
} }
OpenApiType {
r#type: Some("string"),
format: None,
r#enum: None,
..
} => Some(Self::Primitive(PrimitiveType::String)),
OpenApiType {
r#type: Some("string"),
format: None,
r#enum: Some(_),
..
} => {
let name = format!("{name}Variant");
Some(Self::Enum {
inner: Enum::from_schema(&name, schema)?,
name,
})
}
OpenApiType {
r#type: Some("integer"),
format: Some("int64"),
..
} => Some(Self::Primitive(PrimitiveType::I64)),
OpenApiType {
r#type: Some("integer"),
format: Some("int32"),
..
} => Some(Self::Primitive(PrimitiveType::I32)),
OpenApiType {
r#type: Some("number"),
format: Some("float") | None,
..
} => Some(Self::Primitive(PrimitiveType::Float)),
_ => None, _ => None,
} }
} }
pub fn type_name(&self) -> &str { pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream {
match self { match self {
Self::Ref { ty_name } => ty_name, Self::Ref { ty_name } => {
Self::ArrayOfRefs { ty_name } => ty_name, let ty = format_ident!("{ty_name}");
quote! { crate::models::#ty }
}
Self::ArrayOfRefs { ty_name } => {
let ty = format_ident!("{ty_name}");
quote! { Vec<crate::models::#ty> }
}
Self::Primitive(PrimitiveType::I64) => quote! { i64 },
Self::Primitive(PrimitiveType::I32) => quote! { i32 },
Self::Primitive(PrimitiveType::Float) => quote! { f32 },
Self::Primitive(PrimitiveType::String) => quote! { String },
Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime<chrono::Utc> },
Self::Primitive(PrimitiveType::Bool) => quote! { bool },
Self::Enum { name, .. } => {
let path = ns.get_ident();
let ty_name = format_ident!("{name}");
quote! {
#path::#ty_name,
}
}
} }
} }
@ -59,6 +116,49 @@ impl EnumVariantTupleValue {
match self { match self {
Self::Ref { ty_name } => ty_name.clone(), Self::Ref { ty_name } => ty_name.clone(),
Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"), Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"),
Self::Primitive(PrimitiveType::I64) => "I64".to_owned(),
Self::Primitive(PrimitiveType::I32) => "I32".to_owned(),
Self::Primitive(PrimitiveType::Float) => "Float".to_owned(),
Self::Primitive(PrimitiveType::String) => "String".to_owned(),
Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(),
Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(),
Self::Enum { .. } => "Variant".to_owned(),
}
}
pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
match self {
Self::Primitive(_) => true,
Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
.models
.get(ty_name)
.map(|f| f.is_display(resolved))
.unwrap_or_default(),
Self::Enum { inner, .. } => inner.is_display(resolved),
}
}
pub fn codegen_display(&self) -> TokenStream {
match self {
Self::ArrayOfRefs { .. } => quote! {
write!(f, "{}", value.iter().map(ToString::to_string).collect::<Vec<_>>().join(","))
},
_ => quote! {
write!(f, "{}", value)
},
}
}
pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
match self {
Self::Primitive(PrimitiveType::Float) => false,
Self::Primitive(_) => true,
Self::Enum { inner, .. } => inner.is_comparable(resolved),
Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
.models
.get(ty_name)
.map(|m| matches!(m, Model::Newtype(_)))
.unwrap_or_default(),
} }
} }
} }
@ -77,12 +177,39 @@ impl Default for EnumVariantValue {
} }
impl EnumVariantValue { impl EnumVariantValue {
pub fn codegen_display(&self, name: &str) -> Option<TokenStream> { pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
match self { match self {
Self::Repr(i) => Some(quote! { write!(f, "{}", #i) }), Self::Repr(_) | Self::String { .. } => true,
Self::Tuple(val) => {
val.len() == 1
&& val
.iter()
.next()
.map(|v| v.is_display(resolved))
.unwrap_or_default()
}
}
}
pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
match self {
Self::Repr(_) | Self::String { .. } => true,
Self::Tuple(values) => values.iter().all(|v| v.is_comparable(resolved)),
}
}
pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
let variant = format_ident!("{name}");
match self {
Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }),
Self::String { rename } => { Self::String { rename } => {
let name = rename.as_deref().unwrap_or(name); let name = rename.as_deref().unwrap_or(name);
Some(quote! { write!(f, #name) }) Some(quote! { Self::#variant => write!(f, #name) })
}
Self::Tuple(values) if values.len() == 1 => {
let rhs = values.first().unwrap().codegen_display();
Some(quote! { Self::#variant(value) => #rhs })
} }
_ => None, _ => None,
} }
@ -96,8 +223,61 @@ pub struct EnumVariant {
pub value: EnumVariantValue, pub value: EnumVariantValue,
} }
pub struct EnumNamespace<'e> {
r#enum: &'e Enum,
ident: Option<Ident>,
elements: Vec<TokenStream>,
top_level_elements: Vec<TokenStream>,
}
impl EnumNamespace<'_> {
pub fn get_ident(&mut self) -> Ident {
self.ident
.get_or_insert_with(|| {
let name = self.r#enum.name.to_snake_case();
format_ident!("{name}")
})
.clone()
}
pub fn push_element(&mut self, el: TokenStream) {
self.elements.push(el);
}
pub fn push_top_level(&mut self, el: TokenStream) {
self.top_level_elements.push(el);
}
pub fn codegen(mut self) -> Option<TokenStream> {
if self.elements.is_empty() && self.top_level_elements.is_empty() {
None
} else {
let top_level = &self.top_level_elements;
let mut output = quote! {
#(#top_level)*
};
if !self.elements.is_empty() {
let ident = self.get_ident();
let elements = self.elements;
output.extend(quote! {
pub mod #ident {
#(#elements)*
}
});
}
Some(output)
}
}
}
impl EnumVariant { impl EnumVariant {
pub fn codegen(&self) -> Option<TokenStream> { pub fn codegen(
&self,
ns: &mut EnumNamespace,
resolved: &ResolvedSchema,
) -> Option<TokenStream> {
let doc = self.description.as_ref().map(|d| { let doc = self.description.as_ref().map(|d| {
quote! { quote! {
#[doc = #d] #[doc = #d]
@ -127,15 +307,29 @@ impl EnumVariant {
EnumVariantValue::Tuple(values) => { EnumVariantValue::Tuple(values) => {
let mut val_tys = Vec::with_capacity(values.len()); let mut val_tys = Vec::with_capacity(values.len());
for value in values { if let [value] = values.as_slice() {
let ty_name = value.type_name(); let enum_name = format_ident!("{}", ns.r#enum.name);
let ty_name = format_ident!("{ty_name}"); let ty_name = value.type_name(ns);
val_tys.push(quote! { ns.push_top_level(quote! {
crate::models::#ty_name impl From<#ty_name> for #enum_name {
fn from(value: #ty_name) -> Self {
Self::#name(value)
}
}
}); });
} }
for value in values {
let ty_name = value.type_name(ns);
if let EnumVariantTupleValue::Enum { inner, .. } = &value {
ns.push_element(inner.codegen(resolved)?);
}
val_tys.push(ty_name);
}
Some(quote! { Some(quote! {
#name(#(#val_tys),*) #name(#(#val_tys),*)
}) })
@ -144,12 +338,7 @@ impl EnumVariant {
} }
pub fn codegen_display(&self) -> Option<TokenStream> { pub fn codegen_display(&self) -> Option<TokenStream> {
let rhs = self.value.codegen_display(&self.name)?; self.value.codegen_display(&self.name)
let name = format_ident!("{}", self.name);
Some(quote! {
Self::#name => #rhs
})
} }
} }
@ -159,7 +348,6 @@ pub struct Enum {
pub description: Option<String>, pub description: Option<String>,
pub repr: Option<EnumRepr>, pub repr: Option<EnumRepr>,
pub copy: bool, pub copy: bool,
pub display: bool,
pub untagged: bool, pub untagged: bool,
pub variants: Vec<EnumVariant>, pub variants: Vec<EnumVariant>,
} }
@ -176,7 +364,6 @@ impl Enum {
match &schema.r#enum { match &schema.r#enum {
Some(OpenApiVariants::Int(int_variants)) => { Some(OpenApiVariants::Int(int_variants)) => {
result.repr = Some(EnumRepr::U32); result.repr = Some(EnumRepr::U32);
result.display = true;
result.variants = int_variants result.variants = int_variants
.iter() .iter()
.copied() .copied()
@ -188,7 +375,6 @@ impl Enum {
.collect(); .collect();
} }
Some(OpenApiVariants::Str(str_variants)) => { Some(OpenApiVariants::Str(str_variants)) => {
result.display = true;
result.variants = str_variants result.variants = str_variants
.iter() .iter()
.copied() .copied()
@ -214,7 +400,6 @@ impl Enum {
let mut result = Self { let mut result = Self {
name: name.to_owned(), name: name.to_owned(),
copy: true, copy: true,
display: true,
..Default::default() ..Default::default()
}; };
@ -240,7 +425,7 @@ impl Enum {
}; };
for schema in schemas { for schema in schemas {
let value = EnumVariantTupleValue::from_schema(schema)?; let value = EnumVariantTupleValue::from_schema(name, schema)?;
let name = value.name(); let name = value.name();
result.variants.push(EnumVariant { result.variants.push(EnumVariant {
@ -250,10 +435,39 @@ impl Enum {
}); });
} }
// HACK: idk
let shared: Vec<_> = result
.variants
.iter_mut()
.filter(|v| v.name == "Variant")
.collect();
if shared.len() >= 2 {
for (idx, variant) in shared.into_iter().enumerate() {
let label = idx + 1;
variant.name = format!("Variant{}", label);
if let EnumVariantValue::Tuple(values) = &mut variant.value {
if let [EnumVariantTupleValue::Enum { name, inner, .. }] = values.as_mut_slice()
{
inner.name.push_str(&label.to_string());
name.push_str(&label.to_string());
}
}
}
}
Some(result) Some(result)
} }
pub fn codegen(&self) -> Option<TokenStream> { pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
self.variants.iter().all(|v| v.value.is_display(resolved))
}
pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
self.variants
.iter()
.all(|v| v.value.is_comparable(resolved))
}
pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
let repr = self.repr.map(|r| match r { let repr = self.repr.map(|r| match r {
EnumRepr::U8 => quote! { #[repr(u8)] }, EnumRepr::U8 => quote! { #[repr(u8)] },
EnumRepr::U32 => quote! { #[repr(u32)] }, EnumRepr::U32 => quote! { #[repr(u32)] },
@ -266,12 +480,21 @@ impl Enum {
} }
}); });
let mut ns = EnumNamespace {
r#enum: self,
ident: None,
elements: Default::default(),
top_level_elements: Default::default(),
};
let is_display = self.is_display(resolved);
let mut display = Vec::with_capacity(self.variants.len()); let mut display = Vec::with_capacity(self.variants.len());
let mut variants = Vec::with_capacity(self.variants.len()); let mut variants = Vec::with_capacity(self.variants.len());
for variant in &self.variants { for variant in &self.variants {
variants.push(variant.codegen()?); variants.push(variant.codegen(&mut ns, resolved)?);
if self.display { if is_display {
display.push(variant.codegen_display()?); display.push(variant.codegen_display()?);
} }
} }
@ -285,7 +508,11 @@ impl Enum {
} }
if self.copy { if self.copy {
derives.push(quote! { Copy, Hash }); derives.push(quote! { Copy });
}
if self.is_comparable(resolved) {
derives.push(quote! { Eq, Hash });
} }
let serde_attr = self.untagged.then(|| { let serde_attr = self.untagged.then(|| {
@ -294,7 +521,7 @@ impl Enum {
} }
}); });
let display = self.display.then(|| { let display = is_display.then(|| {
quote! { quote! {
impl std::fmt::Display for #name { impl std::fmt::Display for #name {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
@ -306,34 +533,35 @@ impl Enum {
} }
}); });
let module = ns.codegen();
Some(quote! { Some(quote! {
#desc #desc
#[derive(Debug, Clone, PartialEq, #(#derives),*)] #[derive(Debug, Clone, PartialEq, #(#derives),*)]
#[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))]
#serde_attr #serde_attr
pub enum #name { pub enum #name {
#(#variants),* #(#variants),*
} }
#display #display
#module
}) })
} }
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::openapi::schema::OpenApiSchema;
use super::*; use super::*;
use crate::openapi::schema::test::get_schema;
#[test] #[test]
fn codegen() { fn is_display() {
let schema = OpenApiSchema::read().unwrap(); let schema = get_schema();
let resolved = ResolvedSchema::from_open_api(&schema);
let revive_setting = schema.components.schemas.get("ReviveSetting").unwrap(); let torn_selection_name = resolved.models.get("TornSelectionName").unwrap();
assert!(torn_selection_name.is_display(&resolved));
let r#enum = Enum::from_schema("ReviveSetting", revive_setting).unwrap();
let code = r#enum.codegen().unwrap();
panic!("{code}");
} }
} }

View file

@ -1,10 +1,15 @@
use r#enum::Enum; use std::{cell::RefCell, rc::Rc};
use indexmap::IndexMap; use indexmap::IndexMap;
use newtype::Newtype; use newtype::Newtype;
use object::Object; use object::Object;
use parameter::Parameter;
use path::{Path, PrettySegments};
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use r#enum::Enum;
use scope::Scope;
use crate::openapi::r#type::OpenApiType; use crate::openapi::{r#type::OpenApiType, schema::OpenApiSchema};
pub mod r#enum; pub mod r#enum;
pub mod newtype; pub mod newtype;
@ -22,7 +27,169 @@ pub enum Model {
Unresolved, Unresolved,
} }
pub fn resolve(r#type: &OpenApiType, name: &str, schemas: &IndexMap<&str, OpenApiType>) -> Model { impl Model {
pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
match self {
Self::Enum(r#enum) => r#enum.is_display(resolved),
Self::Newtype(_) => true,
_ => false,
}
}
}
#[derive(Default)]
pub struct ResolvedSchema {
pub models: IndexMap<String, Model>,
pub paths: IndexMap<String, Path>,
pub parameters: Vec<Parameter>,
pub warnings: WarningReporter,
}
#[derive(Clone)]
pub struct Warning {
pub message: String,
pub path: Vec<String>,
}
impl std::fmt::Display for Warning {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "in {}: {}", self.path.join("."), self.message)
}
}
#[derive(Default)]
struct WarningReporterState {
warnings: Vec<Warning>,
path: Vec<String>,
}
#[derive(Clone, Default)]
pub struct WarningReporter {
state: Rc<RefCell<WarningReporterState>>,
}
impl WarningReporter {
pub fn new() -> Self {
Self::default()
}
fn push(&self, message: impl ToString) {
let mut state = self.state.borrow_mut();
let path = state.path.iter().map(ToString::to_string).collect();
state.warnings.push(Warning {
message: message.to_string(),
path,
});
}
fn child(&self, name: impl ToString) -> WarningReporter {
self.state.borrow_mut().path.push(name.to_string());
Self {
state: self.state.clone(),
}
}
pub fn is_empty(&self) -> bool {
self.state.borrow().warnings.is_empty()
}
pub fn get_warnings(&self) -> Vec<Warning> {
self.state.borrow().warnings.clone()
}
}
impl Drop for WarningReporter {
fn drop(&mut self) {
self.state.borrow_mut().path.pop();
}
}
impl ResolvedSchema {
pub fn from_open_api(schema: &OpenApiSchema) -> Self {
let mut result = Self::default();
for (name, r#type) in &schema.components.schemas {
result.models.insert(
name.to_string(),
resolve(r#type, name, &schema.components.schemas, &result.warnings),
);
}
for (path, body) in &schema.paths {
result.paths.insert(
path.to_string(),
Path::from_schema(
path,
body,
&schema.components.parameters,
result.warnings.child(path),
)
.unwrap(),
);
}
for (name, param) in &schema.components.parameters {
result
.parameters
.push(Parameter::from_schema(name, param).unwrap());
}
result
}
pub fn codegen_models(&self) -> TokenStream {
let mut output = TokenStream::default();
for model in self.models.values() {
output.extend(model.codegen(self));
}
output
}
pub fn codegen_requests(&self) -> TokenStream {
let mut output = TokenStream::default();
for path in self.paths.values() {
output.extend(
path.codegen_request(self, self.warnings.child(PrettySegments(&path.segments))),
);
}
output
}
pub fn codegen_parameters(&self) -> TokenStream {
let mut output = TokenStream::default();
for param in &self.parameters {
output.extend(param.codegen(self));
}
output
}
pub fn codegen_scopes(&self) -> TokenStream {
let mut output = TokenStream::default();
let scopes = Scope::from_paths(self.paths.values().cloned().collect());
for scope in scopes {
output.extend(scope.codegen());
}
output
}
}
pub fn resolve(
r#type: &OpenApiType,
name: &str,
schemas: &IndexMap<&str, OpenApiType>,
warnings: &WarningReporter,
) -> Model {
match r#type { match r#type {
OpenApiType { OpenApiType {
r#enum: Some(_), .. r#enum: Some(_), ..
@ -30,8 +197,12 @@ pub fn resolve(r#type: &OpenApiType, name: &str, schemas: &IndexMap<&str, OpenAp
OpenApiType { OpenApiType {
r#type: Some("object"), r#type: Some("object"),
.. ..
} => Object::from_schema_object(name, r#type, schemas) } => Model::Object(Object::from_schema_object(
.map_or(Model::Unresolved, Model::Object), name,
r#type,
schemas,
warnings.child(name),
)),
OpenApiType { OpenApiType {
r#type: Some(_), .. r#type: Some(_), ..
} => Newtype::from_schema(name, r#type).map_or(Model::Unresolved, Model::Newtype), } => Newtype::from_schema(name, r#type).map_or(Model::Unresolved, Model::Newtype),
@ -42,17 +213,22 @@ pub fn resolve(r#type: &OpenApiType, name: &str, schemas: &IndexMap<&str, OpenAp
OpenApiType { OpenApiType {
all_of: Some(types), all_of: Some(types),
.. ..
} => Object::from_all_of(name, types, schemas).map_or(Model::Unresolved, Model::Object), } => Model::Object(Object::from_all_of(
name,
types,
schemas,
warnings.child(name),
)),
_ => Model::Unresolved, _ => Model::Unresolved,
} }
} }
impl Model { impl Model {
pub fn codegen(&self) -> Option<TokenStream> { pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
match self { match self {
Self::Newtype(newtype) => newtype.codegen(), Self::Newtype(newtype) => newtype.codegen(),
Self::Enum(r#enum) => r#enum.codegen(), Self::Enum(r#enum) => r#enum.codegen(resolved),
Self::Object(object) => object.codegen(), Self::Object(object) => object.codegen(resolved),
Self::Unresolved => None, Self::Unresolved => None,
} }
} }
@ -61,18 +237,22 @@ impl Model {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
use crate::{ use crate::openapi::schema::test::get_schema;
model::r#enum::{EnumRepr, EnumVariant},
openapi::schema::OpenApiSchema,
};
#[test] #[test]
fn resolve_newtypes() { fn resolve_newtypes() {
let schema = OpenApiSchema::read().unwrap(); let schema = get_schema();
let user_id_schema = schema.components.schemas.get("UserId").unwrap(); let user_id_schema = schema.components.schemas.get("UserId").unwrap();
let user_id = resolve(user_id_schema, "UserId", &schema.components.schemas); let reporter = WarningReporter::new();
let user_id = resolve(
user_id_schema,
"UserId",
&schema.components.schemas,
&reporter,
);
assert!(reporter.is_empty());
assert_eq!( assert_eq!(
user_id, user_id,
@ -87,7 +267,13 @@ mod test {
let attack_code_schema = schema.components.schemas.get("AttackCode").unwrap(); let attack_code_schema = schema.components.schemas.get("AttackCode").unwrap();
let attack_code = resolve(attack_code_schema, "AttackCode", &schema.components.schemas); let attack_code = resolve(
attack_code_schema,
"AttackCode",
&schema.components.schemas,
&reporter,
);
assert!(reporter.is_empty());
assert_eq!( assert_eq!(
attack_code, attack_code,
@ -101,74 +287,18 @@ mod test {
); );
} }
#[test]
fn resolve_enums() {
let schema = OpenApiSchema::read().unwrap();
let forum_feed_type_schema = schema.components.schemas.get("ForumFeedTypeEnum").unwrap();
let forum_feed_type = resolve(
forum_feed_type_schema,
"ForumFeedTypeEnum",
&schema.components.schemas,
);
assert_eq!(forum_feed_type, Model::Enum(Enum {
name: "ForumFeedType".to_owned(),
description: Some("This represents the type of the activity. Values range from 1 to 8 where:\n * 1 = 'X posted on a thread',\n * 2 = 'X created a thread',\n * 3 = 'X liked your thread',\n * 4 = 'X disliked your thread',\n * 5 = 'X liked your post',\n * 6 = 'X disliked your post',\n * 7 = 'X quoted your post'.".to_owned()),
repr: Some(EnumRepr::U32),
copy: true,
untagged: false,
display: true,
variants: vec![
EnumVariant {
name: "Variant1".to_owned(),
value: r#enum::EnumVariantValue::Repr(1),
..Default::default()
},
EnumVariant {
name: "Variant2".to_owned(),
value: r#enum::EnumVariantValue::Repr(2),
..Default::default()
},
EnumVariant {
name: "Variant3".to_owned(),
value: r#enum::EnumVariantValue::Repr(3),
..Default::default()
},
EnumVariant {
name: "Variant4".to_owned(),
value: r#enum::EnumVariantValue::Repr(4),
..Default::default()
},
EnumVariant {
name: "Variant5".to_owned(),
value: r#enum::EnumVariantValue::Repr(5),
..Default::default()
},
EnumVariant {
name: "Variant6".to_owned(),
value: r#enum::EnumVariantValue::Repr(6),
..Default::default()
},
EnumVariant {
name: "Variant7".to_owned(),
value: r#enum::EnumVariantValue::Repr(7),
..Default::default()
},
]
}))
}
#[test] #[test]
fn resolve_all() { fn resolve_all() {
let schema = OpenApiSchema::read().unwrap(); let schema = get_schema();
let mut unresolved = vec![]; let mut unresolved = vec![];
let total = schema.components.schemas.len(); let total = schema.components.schemas.len();
for (name, desc) in &schema.components.schemas { for (name, desc) in &schema.components.schemas {
if resolve(desc, name, &schema.components.schemas) == Model::Unresolved { let reporter = WarningReporter::new();
if resolve(desc, name, &schema.components.schemas, &reporter) == Model::Unresolved
|| !reporter.is_empty()
{
unresolved.push(name); unresolved.push(name);
} }
} }

View file

@ -121,24 +121,3 @@ impl Newtype {
Some(body) Some(body)
} }
} }
#[cfg(test)]
mod test {
use super::*;
use crate::openapi::schema::OpenApiSchema;
#[test]
fn codegen() {
let schema = OpenApiSchema::read().unwrap();
let user_id = schema.components.schemas.get("UserId").unwrap();
let mut newtype = Newtype::from_schema("UserId", user_id).unwrap();
newtype.description = Some("Description goes here".to_owned());
let code = newtype.codegen().unwrap().to_string();
panic!("{code}");
}
}

View file

@ -1,12 +1,12 @@
use heck::{ToSnakeCase, ToUpperCamelCase}; use heck::{ToSnakeCase, ToUpperCamelCase};
use indexmap::IndexMap; use indexmap::{map::Entry, IndexMap};
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::{ToTokens, format_ident, quote}; use quote::{format_ident, quote, ToTokens};
use syn::Ident; use syn::Ident;
use crate::openapi::r#type::OpenApiType; use crate::openapi::r#type::OpenApiType;
use super::r#enum::Enum; use super::{r#enum::Enum, ResolvedSchema, WarningReporter};
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrimitiveType { pub enum PrimitiveType {
@ -15,6 +15,7 @@ pub enum PrimitiveType {
I64, I64,
String, String,
Float, Float,
DateTime,
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@ -27,7 +28,11 @@ pub enum PropertyType {
} }
impl PropertyType { impl PropertyType {
pub fn codegen(&self, namespace: &mut ObjectNamespace) -> Option<TokenStream> { pub fn codegen(
&self,
namespace: &mut ObjectNamespace,
resolved: &ResolvedSchema,
) -> Option<TokenStream> {
match self { match self {
Self::Primitive(PrimitiveType::Bool) => Some(format_ident!("bool").into_token_stream()), Self::Primitive(PrimitiveType::Bool) => Some(format_ident!("bool").into_token_stream()),
Self::Primitive(PrimitiveType::I32) => Some(format_ident!("i32").into_token_stream()), Self::Primitive(PrimitiveType::I32) => Some(format_ident!("i32").into_token_stream()),
@ -35,6 +40,9 @@ impl PropertyType {
Self::Primitive(PrimitiveType::String) => { Self::Primitive(PrimitiveType::String) => {
Some(format_ident!("String").into_token_stream()) Some(format_ident!("String").into_token_stream())
} }
Self::Primitive(PrimitiveType::DateTime) => {
Some(quote! { chrono::DateTime<chrono::Utc> })
}
Self::Primitive(PrimitiveType::Float) => Some(format_ident!("f64").into_token_stream()), Self::Primitive(PrimitiveType::Float) => Some(format_ident!("f64").into_token_stream()),
Self::Ref(path) => { Self::Ref(path) => {
let name = path.strip_prefix("#/components/schemas/")?; let name = path.strip_prefix("#/components/schemas/")?;
@ -43,7 +51,7 @@ impl PropertyType {
Some(quote! { crate::models::#name }) Some(quote! { crate::models::#name })
} }
Self::Enum(r#enum) => { Self::Enum(r#enum) => {
let code = r#enum.codegen()?; let code = r#enum.codegen(resolved)?;
namespace.push_element(code); namespace.push_element(code);
let ns = namespace.get_ident(); let ns = namespace.get_ident();
@ -54,14 +62,14 @@ impl PropertyType {
}) })
} }
Self::Array(array) => { Self::Array(array) => {
let inner_ty = array.codegen(namespace)?; let inner_ty = array.codegen(namespace, resolved)?;
Some(quote! { Some(quote! {
Vec<#inner_ty> Vec<#inner_ty>
}) })
} }
Self::Nested(nested) => { Self::Nested(nested) => {
let code = nested.codegen()?; let code = nested.codegen(resolved)?;
namespace.push_element(code); namespace.push_element(code);
let ns = namespace.get_ident(); let ns = namespace.get_ident();
@ -77,11 +85,13 @@ impl PropertyType {
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct Property { pub struct Property {
pub field_name: String,
pub name: String, pub name: String,
pub description: Option<String>, pub description: Option<String>,
pub required: bool, pub required: bool,
pub nullable: bool, pub nullable: bool,
pub r#type: PropertyType, pub r#type: PropertyType,
pub deprecated: bool,
} }
impl Property { impl Property {
@ -90,60 +100,68 @@ impl Property {
required: bool, required: bool,
schema: &OpenApiType, schema: &OpenApiType,
schemas: &IndexMap<&str, OpenApiType>, schemas: &IndexMap<&str, OpenApiType>,
warnings: WarningReporter,
) -> Option<Self> { ) -> Option<Self> {
let name = name.to_owned(); let name = name.to_owned();
let field_name = name.to_snake_case();
let description = schema.description.as_deref().map(ToOwned::to_owned); let description = schema.description.as_deref().map(ToOwned::to_owned);
match schema { match schema {
OpenApiType { OpenApiType {
r#enum: Some(_), .. r#enum: Some(_), ..
} => Some(Self { } => {
r#type: PropertyType::Enum(Enum::from_schema( let Some(r#enum) = Enum::from_schema(&name.clone().to_upper_camel_case(), schema)
&name.clone().to_upper_camel_case(), else {
schema, warnings.push("Failed to create enum");
)?), return None;
};
Some(Self {
r#type: PropertyType::Enum(r#enum),
name, name,
field_name,
description, description,
required, required,
deprecated: schema.deprecated,
nullable: false, nullable: false,
}), })
}
OpenApiType { OpenApiType {
one_of: Some(types), one_of: Some(types),
.. ..
} => match types.as_slice() { } => match types.as_slice() {
[ [left, OpenApiType {
left,
OpenApiType {
r#type: Some("null"), r#type: Some("null"),
.. ..
}, }] => {
] => { let mut inner = Self::from_schema(&name, required, left, schemas, warnings)?;
let mut inner = Self::from_schema(&name, required, left, schemas)?;
inner.nullable = true; inner.nullable = true;
Some(inner) Some(inner)
} }
[ [left @ .., OpenApiType {
left @ ..,
OpenApiType {
r#type: Some("null"), r#type: Some("null"),
.. ..
}, }] => {
] => {
let rest = OpenApiType { let rest = OpenApiType {
one_of: Some(left.to_owned()), one_of: Some(left.to_owned()),
..schema.clone() ..schema.clone()
}; };
let mut inner = Self::from_schema(&name, required, &rest, schemas)?; let mut inner = Self::from_schema(&name, required, &rest, schemas, warnings)?;
inner.nullable = true; inner.nullable = true;
Some(inner) Some(inner)
} }
cases => { cases => {
let r#enum = Enum::from_one_of(&name.to_upper_camel_case(), cases)?; let Some(r#enum) = Enum::from_one_of(&name.to_upper_camel_case(), cases) else {
warnings.push("Failed to create oneOf enum");
return None;
};
Some(Self { Some(Self {
name, name,
description: None, field_name,
description,
required, required,
nullable: false, nullable: false,
deprecated: schema.deprecated,
r#type: PropertyType::Enum(r#enum), r#type: PropertyType::Enum(r#enum),
}) })
} }
@ -152,37 +170,49 @@ impl Property {
all_of: Some(types), all_of: Some(types),
.. ..
} => { } => {
let composite = Object::from_all_of(&name.to_upper_camel_case(), types, schemas)?; let obj_name = name.to_upper_camel_case();
let composite =
Object::from_all_of(&obj_name, types, schemas, warnings.child(&obj_name));
Some(Self { Some(Self {
name, name,
description: None, field_name,
description,
required, required,
nullable: false, nullable: false,
deprecated: schema.deprecated,
r#type: PropertyType::Nested(Box::new(composite)), r#type: PropertyType::Nested(Box::new(composite)),
}) })
} }
OpenApiType { OpenApiType {
r#type: Some("object"), r#type: Some("object"),
.. ..
} => Some(Self { } => {
let obj_name = name.to_upper_camel_case();
Some(Self {
r#type: PropertyType::Nested(Box::new(Object::from_schema_object( r#type: PropertyType::Nested(Box::new(Object::from_schema_object(
&name.clone().to_upper_camel_case(), &obj_name,
schema, schema,
schemas, schemas,
)?)), warnings.child(&obj_name),
))),
name, name,
field_name,
description, description,
required, required,
deprecated: schema.deprecated,
nullable: false, nullable: false,
}), })
}
OpenApiType { OpenApiType {
ref_path: Some(path), ref_path: Some(path),
.. ..
} => Some(Self { } => Some(Self {
name, name,
field_name,
description, description,
r#type: PropertyType::Ref((*path).to_owned()), r#type: PropertyType::Ref((*path).to_owned()),
required, required,
deprecated: schema.deprecated,
nullable: false, nullable: false,
}), }),
OpenApiType { OpenApiType {
@ -190,13 +220,15 @@ impl Property {
items: Some(items), items: Some(items),
.. ..
} => { } => {
let inner = Self::from_schema(&name, required, items, schemas)?; let inner = Self::from_schema(&name, required, items, schemas, warnings)?;
Some(Self { Some(Self {
name, name,
field_name,
description, description,
required, required,
nullable: false, nullable: false,
deprecated: schema.deprecated,
r#type: PropertyType::Array(Box::new(inner.r#type)), r#type: PropertyType::Array(Box::new(inner.r#type)),
}) })
} }
@ -206,38 +238,50 @@ impl Property {
let prim = match (schema.r#type, schema.format) { let prim = match (schema.r#type, schema.format) {
(Some("integer"), Some("int32")) => PrimitiveType::I32, (Some("integer"), Some("int32")) => PrimitiveType::I32,
(Some("integer"), Some("int64")) => PrimitiveType::I64, (Some("integer"), Some("int64")) => PrimitiveType::I64,
(Some("number"), Some("float")) => PrimitiveType::Float, (Some("number"), /* Some("float") */ _) | (_, Some("float")) => {
PrimitiveType::Float
}
(Some("string"), None) => PrimitiveType::String, (Some("string"), None) => PrimitiveType::String,
(Some("string"), Some("date")) => PrimitiveType::DateTime,
(Some("boolean"), None) => PrimitiveType::Bool, (Some("boolean"), None) => PrimitiveType::Bool,
_ => return None, _ => return None,
}; };
Some(Self { Some(Self {
name, name,
field_name,
description, description,
required, required,
nullable: false, nullable: false,
deprecated: schema.deprecated,
r#type: PropertyType::Primitive(prim), r#type: PropertyType::Primitive(prim),
}) })
} }
_ => None, _ => {
warnings.push("Could not resolve property type");
None
}
} }
} }
pub fn codegen(&self, namespace: &mut ObjectNamespace) -> Option<TokenStream> { pub fn codegen(
&self,
namespace: &mut ObjectNamespace,
resolved: &ResolvedSchema,
) -> Option<TokenStream> {
let desc = self.description.as_ref().map(|d| quote! { #[doc = #d]}); let desc = self.description.as_ref().map(|d| quote! { #[doc = #d]});
let name = &self.name; let name = &self.name;
let (name, serde_attr) = match name.as_str() { let (name, serde_attr) = match name.as_str() {
"type" => (format_ident!("r#type"), None), "type" => (format_ident!("r#type"), None),
name if name != name.to_snake_case() => ( name if name != self.field_name => (
format_ident!("{}", name.to_snake_case()), format_ident!("{}", self.field_name),
Some(quote! { #[serde(rename = #name)]}), Some(quote! { #[serde(rename = #name)]}),
), ),
_ => (format_ident!("{name}"), None), _ => (format_ident!("{}", self.field_name), None),
}; };
let ty_inner = self.r#type.codegen(namespace)?; let ty_inner = self.r#type.codegen(namespace, resolved)?;
let ty = if !self.required || self.nullable { let ty = if !self.required || self.nullable {
quote! { Option<#ty_inner> } quote! { Option<#ty_inner> }
@ -245,8 +289,17 @@ impl Property {
ty_inner ty_inner
}; };
let deprecated = self.deprecated.then(|| {
let note = self.description.as_ref().map(|d| quote! { note = #d });
quote! {
#[deprecated(#note)]
}
});
Some(quote! { Some(quote! {
#desc #desc
#deprecated
#serde_attr #serde_attr
pub #name: #ty pub #name: #ty
}) })
@ -257,7 +310,7 @@ impl Property {
pub struct Object { pub struct Object {
pub name: String, pub name: String,
pub description: Option<String>, pub description: Option<String>,
pub properties: Vec<Property>, pub properties: IndexMap<String, Property>,
} }
impl Object { impl Object {
@ -265,7 +318,8 @@ impl Object {
name: &str, name: &str,
schema: &OpenApiType, schema: &OpenApiType,
schemas: &IndexMap<&str, OpenApiType>, schemas: &IndexMap<&str, OpenApiType>,
) -> Option<Self> { warnings: WarningReporter,
) -> Self {
let mut result = Object { let mut result = Object {
name: name.to_owned(), name: name.to_owned(),
description: schema.description.as_deref().map(ToOwned::to_owned), description: schema.description.as_deref().map(ToOwned::to_owned),
@ -273,38 +327,54 @@ impl Object {
}; };
let Some(props) = &schema.properties else { let Some(props) = &schema.properties else {
return None; warnings.push("Missing properties");
return result;
}; };
let required = schema.required.clone().unwrap_or_default(); let required = schema.required.clone().unwrap_or_default();
for (prop_name, prop) in props { for (prop_name, prop) in props {
// HACK: This will cause a duplicate key otherwise let Some(prop) = Property::from_schema(
if ["itemDetails", "sci-fi", "non-attackers", "co-leader_id"].contains(prop_name) {
continue;
}
// TODO: implement custom enum for this (depends on overrides being added)
if *prop_name == "value" && name == "TornHof" {
continue;
}
result.properties.push(Property::from_schema(
prop_name, prop_name,
required.contains(prop_name), required.contains(prop_name),
prop, prop,
schemas, schemas,
)?); warnings.child(prop_name),
) else {
continue;
};
let field_name = prop.field_name.clone();
let entry = result.properties.entry(field_name.clone());
if let Entry::Occupied(mut entry) = entry {
let other_name = entry.get().name.clone();
warnings.push(format!(
"Property name collision: {other_name} and {field_name}"
));
// deprioritise kebab and camelcase
if other_name.contains('-')
|| other_name
.chars()
.filter(|c| c.is_alphabetic())
.all(|c| c.is_ascii_lowercase())
{
entry.insert(prop);
}
} else {
entry.insert_entry(prop);
}
} }
Some(result) result
} }
pub fn from_all_of( pub fn from_all_of(
name: &str, name: &str,
types: &[OpenApiType], types: &[OpenApiType],
schemas: &IndexMap<&str, OpenApiType>, schemas: &IndexMap<&str, OpenApiType>,
) -> Option<Self> { warnings: WarningReporter,
) -> Self {
let mut result = Self { let mut result = Self {
name: name.to_owned(), name: name.to_owned(),
..Default::default() ..Default::default()
@ -312,25 +382,32 @@ impl Object {
for r#type in types { for r#type in types {
let r#type = if let OpenApiType { let r#type = if let OpenApiType {
ref_path: Some(path), ref_path: Some(ref_path),
.. ..
} = r#type } = r#type
{ {
let name = path.strip_prefix("#/components/schemas/")?; let Some(name) = ref_path.strip_prefix("#/components/schemas/") else {
schemas.get(name)? warnings.push(format!("Malformed ref {ref_path}"));
continue;
};
let Some(schema) = schemas.get(name) else {
warnings.push(format!("Missing schema for ref {name}"));
continue;
};
schema
} else { } else {
r#type r#type
}; };
let obj = Self::from_schema_object(name, r#type, schemas)?; let obj = Self::from_schema_object(name, r#type, schemas, warnings.child("variant"));
result.description = result.description.or(obj.description); result.description = result.description.or(obj.description);
result.properties.extend(obj.properties); result.properties.extend(obj.properties);
} }
Some(result) result
} }
pub fn codegen(&self) -> Option<TokenStream> { pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
let doc = self.description.as_ref().map(|d| { let doc = self.description.as_ref().map(|d| {
quote! { quote! {
#[doc = #d] #[doc = #d]
@ -344,8 +421,8 @@ impl Object {
}; };
let mut props = Vec::with_capacity(self.properties.len()); let mut props = Vec::with_capacity(self.properties.len());
for prop in &self.properties { for (_, prop) in &self.properties {
props.push(prop.codegen(&mut namespace)?); props.push(prop.codegen(&mut namespace, resolved)?);
} }
let name = format_ident!("{}", self.name); let name = format_ident!("{}", self.name);
@ -402,23 +479,11 @@ impl ObjectNamespace<'_> {
mod test { mod test {
use super::*; use super::*;
use crate::openapi::schema::OpenApiSchema; use crate::openapi::schema::test::get_schema;
#[test]
fn resolve_object() {
let schema = OpenApiSchema::read().unwrap();
let attack = schema.components.schemas.get("FactionUpgrades").unwrap();
let resolved =
Object::from_schema_object("FactionUpgrades", attack, &schema.components.schemas)
.unwrap();
let _code = resolved.codegen().unwrap();
}
#[test] #[test]
fn resolve_objects() { fn resolve_objects() {
let schema = OpenApiSchema::read().unwrap(); let schema = get_schema();
let mut objects = 0; let mut objects = 0;
let mut unresolved = vec![]; let mut unresolved = vec![];
@ -426,7 +491,14 @@ mod test {
for (name, desc) in &schema.components.schemas { for (name, desc) in &schema.components.schemas {
if desc.r#type == Some("object") { if desc.r#type == Some("object") {
objects += 1; objects += 1;
if Object::from_schema_object(name, desc, &schema.components.schemas).is_none() { let reporter = WarningReporter::new();
Object::from_schema_object(
name,
desc,
&schema.components.schemas,
reporter.clone(),
);
if !reporter.is_empty() {
unresolved.push(name); unresolved.push(name);
} }
} }

View file

@ -2,14 +2,14 @@ use std::fmt::Write;
use heck::ToUpperCamelCase; use heck::ToUpperCamelCase;
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::{ToTokens, format_ident, quote}; use quote::{format_ident, quote, ToTokens};
use crate::openapi::parameter::{ use crate::openapi::parameter::{
OpenApiParameter, OpenApiParameterDefault, OpenApiParameterSchema, OpenApiParameter, OpenApiParameterDefault, OpenApiParameterSchema,
ParameterLocation as SchemaLocation, ParameterLocation as SchemaLocation,
}; };
use super::r#enum::Enum; use super::{r#enum::Enum, ResolvedSchema};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ParameterOptions<P> { pub struct ParameterOptions<P> {
@ -42,9 +42,7 @@ impl ParameterType {
match schema { match schema {
OpenApiParameterSchema { OpenApiParameterSchema {
r#type: Some("integer"), r#type: Some("integer"),
// BUG: missing for some types in the spec format: Some("int32"),
// format: Some("int32"),
.. ..
} => { } => {
let default = match schema.default { let default = match schema.default {
@ -90,6 +88,17 @@ impl ParameterType {
r#type: Enum::from_parameter_schema(name, schema)?, r#type: Enum::from_parameter_schema(name, schema)?,
}) })
} }
OpenApiParameterSchema {
one_of: Some(schemas),
..
} => Some(ParameterType::Enum {
options: ParameterOptions {
default: None,
minimum: None,
maximum: None,
},
r#type: Enum::from_one_of(name, schemas)?,
}),
OpenApiParameterSchema { OpenApiParameterSchema {
r#type: Some("string"), r#type: Some("string"),
.. ..
@ -170,7 +179,7 @@ impl Parameter {
}) })
} }
pub fn codegen(&self) -> Option<TokenStream> { pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
match &self.r#type { match &self.r#type {
ParameterType::I32 { options } => { ParameterType::I32 { options } => {
let name = format_ident!("{}", self.name); let name = format_ident!("{}", self.name);
@ -274,7 +283,7 @@ The default value [Self::{}](self::{}#variant.{})"#,
} }
let doc = quote! { #[doc = #desc]}; let doc = quote! { #[doc = #desc]};
let inner = r#type.codegen()?; let inner = r#type.codegen(resolved)?;
Some(quote! { Some(quote! {
#doc #doc
@ -300,13 +309,13 @@ The default value [Self::{}](self::{}#variant.{})"#,
..self.clone() ..self.clone()
}; };
let mut code = inner.codegen().unwrap_or_default(); let mut code = inner.codegen(resolved).unwrap_or_default();
let name = format_ident!("{}", outer_name); let name = format_ident!("{}", outer_name);
let inner_ty = items.codegen_type_name(&inner_name); let inner_ty = items.codegen_type_name(&inner_name);
code.extend(quote! { code.extend(quote! {
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct #name(pub Vec<#inner_ty>); pub struct #name(pub Vec<#inner_ty>);
impl std::fmt::Display for #name { impl std::fmt::Display for #name {
@ -324,9 +333,9 @@ The default value [Self::{}](self::{}#variant.{})"#,
} }
} }
impl<T> From<T> for #name where T: IntoIterator<Item = #inner_ty> { impl<T> From<T> for #name where T: IntoIterator, T::Item: Into<#inner_ty> {
fn from(value: T) -> #name { fn from(value: T) -> #name {
let items = value.into_iter().collect(); let items = value.into_iter().map(Into::into).collect();
Self(items) Self(items)
} }
@ -342,13 +351,13 @@ The default value [Self::{}](self::{}#variant.{})"#,
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::openapi::{path::OpenApiPathParameter, schema::OpenApiSchema}; use crate::openapi::{path::OpenApiPathParameter, schema::test::get_schema};
use super::*; use super::*;
#[test] #[test]
fn resolve_components() { fn resolve_components() {
let schema = OpenApiSchema::read().unwrap(); let schema = get_schema();
let mut parameters = 0; let mut parameters = 0;
let mut unresolved = vec![]; let mut unresolved = vec![];
@ -376,7 +385,7 @@ mod test {
#[test] #[test]
fn resolve_inline() { fn resolve_inline() {
let schema = OpenApiSchema::read().unwrap(); let schema = get_schema();
let mut params = 0; let mut params = 0;
let mut unresolved = Vec::new(); let mut unresolved = Vec::new();
@ -404,7 +413,8 @@ mod test {
#[test] #[test]
fn codegen_inline() { fn codegen_inline() {
let schema = OpenApiSchema::read().unwrap(); let schema = get_schema();
let resolved = ResolvedSchema::from_open_api(&schema);
let mut params = 0; let mut params = 0;
let mut unresolved = Vec::new(); let mut unresolved = Vec::new();
@ -425,7 +435,7 @@ mod test {
continue; continue;
} }
params += 1; params += 1;
if param.codegen().is_none() { if param.codegen(&resolved).is_none() {
unresolved.push(format!("`{}.{}`", path, inline.name)); unresolved.push(format!("`{}.{}`", path, inline.name));
} }
} }

View file

@ -1,4 +1,4 @@
use std::{fmt::Write, ops::Deref}; use std::fmt::Write;
use heck::{ToSnakeCase, ToUpperCamelCase}; use heck::{ToSnakeCase, ToUpperCamelCase};
use indexmap::IndexMap; use indexmap::IndexMap;
@ -14,6 +14,7 @@ use crate::openapi::{
use super::{ use super::{
parameter::{Parameter, ParameterLocation, ParameterType}, parameter::{Parameter, ParameterLocation, ParameterType},
union::Union, union::Union,
ResolvedSchema, WarningReporter,
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -22,6 +23,21 @@ pub enum PathSegment {
Parameter { name: String }, Parameter { name: String },
} }
pub struct PrettySegments<'a>(pub &'a [PathSegment]);
impl std::fmt::Display for PrettySegments<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for segment in self.0 {
match segment {
PathSegment::Constant(c) => write!(f, "/{c}")?,
PathSegment::Parameter { name } => write!(f, "/{{{name}}}")?,
}
}
Ok(())
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum PathParameter { pub enum PathParameter {
Inline(Parameter), Inline(Parameter),
@ -40,7 +56,7 @@ pub struct Path {
pub segments: Vec<PathSegment>, pub segments: Vec<PathSegment>,
pub name: String, pub name: String,
pub summary: Option<String>, pub summary: Option<String>,
pub description: String, pub description: Option<String>,
pub parameters: Vec<PathParameter>, pub parameters: Vec<PathParameter>,
pub response: PathResponse, pub response: PathResponse,
} }
@ -50,6 +66,7 @@ impl Path {
path: &str, path: &str,
schema: &OpenApiPath, schema: &OpenApiPath,
parameters: &IndexMap<&str, OpenApiParameter>, parameters: &IndexMap<&str, OpenApiParameter>,
warnings: WarningReporter,
) -> Option<Self> { ) -> Option<Self> {
let mut segments = Vec::new(); let mut segments = Vec::new();
for segment in path.strip_prefix('/')?.split('/') { for segment in path.strip_prefix('/')?.split('/') {
@ -63,7 +80,7 @@ impl Path {
} }
let summary = schema.get.summary.as_deref().map(ToOwned::to_owned); let summary = schema.get.summary.as_deref().map(ToOwned::to_owned);
let description = schema.get.description.deref().to_owned(); let description = schema.get.description.as_deref().map(ToOwned::to_owned);
let mut params = Vec::with_capacity(schema.get.parameters.len()); let mut params = Vec::with_capacity(schema.get.parameters.len());
for parameter in &schema.get.parameters { for parameter in &schema.get.parameters {
@ -110,9 +127,13 @@ impl Path {
.strip_prefix("#/components/schemas/")? .strip_prefix("#/components/schemas/")?
.to_owned(), .to_owned(),
}, },
OpenApiResponseBody::Union { any_of: _ } => PathResponse::ArbitraryUnion( OpenApiResponseBody::Union { any_of: _ } => {
Union::from_schema("Response", &schema.get.response_content)?, PathResponse::ArbitraryUnion(Union::from_schema(
), "Response",
&schema.get.response_content,
warnings.child("response"),
)?)
}
}; };
Some(Self { Some(Self {
@ -125,7 +146,11 @@ impl Path {
}) })
} }
pub fn codegen_request(&self) -> Option<TokenStream> { pub fn codegen_request(
&self,
resolved: &ResolvedSchema,
warnings: WarningReporter,
) -> Option<TokenStream> {
let name = if self.segments.len() == 1 { let name = if self.segments.len() == 1 {
let Some(PathSegment::Constant(first)) = self.segments.first() else { let Some(PathSegment::Constant(first)) = self.segments.first() else {
return None; return None;
@ -159,21 +184,21 @@ impl Path {
let ty_name = format_ident!("{}", param.name); let ty_name = format_ident!("{}", param.name);
if is_inline { if is_inline {
ns.push_element(param.codegen()?); ns.push_element(param.codegen(resolved)?);
let path = ns.get_ident(); let path = ns.get_ident();
( (
quote! { quote! {
crate::request::models::#path::#ty_name crate::request::models::#path::#ty_name
}, },
Some(quote! { #[builder(into)] }), Some(quote! { #[cfg_attr(feature = "builder", builder(into))] }),
) )
} else { } else {
( (
quote! { quote! {
crate::parameters::#ty_name crate::parameters::#ty_name
}, },
Some(quote! { #[builder(into)]}), Some(quote! { #[cfg_attr(feature = "builder", builder(into))]}),
) )
} }
} }
@ -190,14 +215,14 @@ impl Path {
) )
} }
ParameterType::Array { .. } => { ParameterType::Array { .. } => {
ns.push_element(param.codegen()?); ns.push_element(param.codegen(resolved)?);
let ty_name = param.r#type.codegen_type_name(&param.name); let ty_name = param.r#type.codegen_type_name(&param.name);
let path = ns.get_ident(); let path = ns.get_ident();
( (
quote! { quote! {
crate::request::models::#path::#ty_name crate::request::models::#path::#ty_name
}, },
Some(quote! { #[builder(into)] }), Some(quote! { #[cfg_attr(feature = "builder", builder(into))] }),
) )
} }
}; };
@ -206,17 +231,30 @@ impl Path {
let query_val = &param.value; let query_val = &param.value;
if param.location == ParameterLocation::Path { if param.location == ParameterLocation::Path {
if self.segments.iter().any(|s| {
if let PathSegment::Parameter { name } = s {
name == &param.value
} else {
false
}
}) {
discriminant.push(ty.clone()); discriminant.push(ty.clone());
discriminant_val.push(quote! { self.#name }); discriminant_val.push(quote! { self.#name });
let path_name = format_ident!("{}", param.value); let path_name = format_ident!("{}", param.value);
start_fields.push(quote! { start_fields.push(quote! {
#[builder(start_fn)] #[cfg_attr(feature = "builder", builder(start_fn))]
#builder_param #builder_param
pub #name: #ty pub #name: #ty
}); });
fmt_val.push(quote! { fmt_val.push(quote! {
#path_name=self.#name #path_name=self.#name
}); });
} else {
warnings.push(format!(
"Provided path parameter is not present in the url: {}",
param.value
));
}
} else { } else {
let ty = if param.required { let ty = if param.required {
convert_field.push(quote! { convert_field.push(quote! {
@ -273,8 +311,9 @@ impl Path {
Some(quote! { Some(quote! {
#ns #ns
#[derive(Debug, Clone, bon::Builder)] #[cfg_attr(feature = "builder", derive(bon::Builder))]
#[builder(state_mod(vis = "pub(crate)"), on(String, into))] #[derive(Debug, Clone)]
#[cfg_attr(feature = "builder", builder(state_mod(vis = "pub(crate)"), on(String, into)))]
pub struct #name { pub struct #name {
#(#start_fields),* #(#start_fields),*
} }
@ -283,15 +322,18 @@ impl Path {
#[allow(unused_parens)] #[allow(unused_parens)]
type Discriminant = (#(#discriminant),*); type Discriminant = (#(#discriminant),*);
type Response = #response_ty; type Response = #response_ty;
fn into_request(self) -> crate::request::ApiRequest<Self::Discriminant> { fn into_request(self) -> (Self::Discriminant, crate::request::ApiRequest) {
let path = format!(#path_fmt_str, #(#fmt_val),*);
#[allow(unused_parens)] #[allow(unused_parens)]
(
(#(#discriminant_val),*),
crate::request::ApiRequest { crate::request::ApiRequest {
path: format!(#path_fmt_str, #(#fmt_val),*), path,
parameters: std::iter::empty() parameters: std::iter::empty()
#(#convert_field)* #(#convert_field)*
.collect(), .collect(),
disriminant: (#(#discriminant_val),*),
} }
)
} }
} }
}) })
@ -324,7 +366,15 @@ impl Path {
PathParameter::Component(param) => (param, false), PathParameter::Component(param) => (param, false),
}; };
if param.location == ParameterLocation::Path { if param.location == ParameterLocation::Path
&& self.segments.iter().any(|s| {
if let PathSegment::Parameter { name } = s {
name == &param.value
} else {
false
}
})
{
let ty = match &param.r#type { let ty = match &param.r#type {
ParameterType::I32 { .. } | ParameterType::Enum { .. } => { ParameterType::I32 { .. } | ParameterType::Enum { .. } => {
let ty_name = format_ident!("{}", param.name); let ty_name = format_ident!("{}", param.name);
@ -348,7 +398,13 @@ impl Path {
crate::models::#ty_name crate::models::#ty_name
} }
} }
ParameterType::Array { .. } => param.r#type.codegen_type_name(&param.name), ParameterType::Array { .. } => {
let ty_name = param.r#type.codegen_type_name(&param.name);
quote! {
crate::request::models::#request_mod_name::#ty_name
}
}
}; };
let arg_name = format_ident!("{}", param.value.to_snake_case()); let arg_name = format_ident!("{}", param.value.to_snake_case());
@ -373,9 +429,25 @@ impl Path {
} }
}; };
let doc = match (&self.summary, &self.description) {
(Some(summary), Some(description)) => {
Some(format!("{summary}\n\n# Description\n{description}"))
}
(Some(summary), None) => Some(summary.clone()),
(None, Some(description)) => Some(format!("# Description\n{description}")),
(None, None) => None,
};
let doc = doc.map(|d| {
quote! {
#[doc = #d]
}
});
Some(quote! { Some(quote! {
#doc
pub async fn #fn_name<S>( pub async fn #fn_name<S>(
&self, self,
#(#extra_args)* #(#extra_args)*
builder: impl FnOnce( builder: impl FnOnce(
#builder_path<#builder_mod_path::Empty> #builder_path<#builder_mod_path::Empty>
@ -390,6 +462,148 @@ impl Path {
} }
}) })
} }
pub fn codegen_bulk_scope_call(&self) -> Option<TokenStream> {
let mut disc = Vec::new();
let mut disc_ty = Vec::new();
let snake_name = self.name.to_snake_case();
let request_name = format_ident!("{}Request", self.name);
let builder_name = format_ident!("{}RequestBuilder", self.name);
let builder_mod_name = format_ident!("{}_request_builder", snake_name);
let request_mod_name = format_ident!("{snake_name}");
let request_path = quote! { crate::request::models::#request_name };
let builder_path = quote! { crate::request::models::#builder_name };
let builder_mod_path = quote! { crate::request::models::#builder_mod_name };
let tail = snake_name
.split_once('_')
.map_or_else(|| "for_selections".to_owned(), |(_, tail)| tail.to_owned());
let fn_name = format_ident!("{tail}");
for param in &self.parameters {
let (param, is_inline) = match param {
PathParameter::Inline(param) => (param, true),
PathParameter::Component(param) => (param, false),
};
if param.location == ParameterLocation::Path
&& self.segments.iter().any(|s| {
if let PathSegment::Parameter { name } = s {
name == &param.value
} else {
false
}
})
{
let ty = match &param.r#type {
ParameterType::I32 { .. } | ParameterType::Enum { .. } => {
let ty_name = format_ident!("{}", param.name);
if is_inline {
quote! {
crate::request::models::#request_mod_name::#ty_name
}
} else {
quote! {
crate::parameters::#ty_name
}
}
}
ParameterType::String => quote! { String },
ParameterType::Boolean => quote! { bool },
ParameterType::Schema { type_name } => {
let ty_name = format_ident!("{}", type_name);
quote! {
crate::models::#ty_name
}
}
ParameterType::Array { .. } => {
let name = param.r#type.codegen_type_name(&param.name);
quote! {
crate::request::models::#request_mod_name::#name
}
}
};
let arg_name = format_ident!("{}", param.value.to_snake_case());
disc_ty.push(ty);
disc.push(arg_name);
}
}
if disc.is_empty() {
return None;
}
let response_ty = match &self.response {
PathResponse::Component { name } => {
let name = format_ident!("{name}");
quote! {
crate::models::#name
}
}
PathResponse::ArbitraryUnion(union) => {
let name = format_ident!("{}", union.name);
quote! {
crate::request::models::#request_mod_name::#name
}
}
};
let disc = if disc.len() > 1 {
quote! { (#(#disc),*) }
} else {
quote! { #(#disc),* }
};
let disc_ty = if disc_ty.len() > 1 {
quote! { (#(#disc_ty),*) }
} else {
quote! { #(#disc_ty),* }
};
let doc = match (&self.summary, &self.description) {
(Some(summary), Some(description)) => {
Some(format!("{summary}\n\n# Description\n{description}"))
}
(Some(summary), None) => Some(summary.clone()),
(None, Some(description)) => Some(format!("# Description\n{description}")),
(None, None) => None,
};
let doc = doc.map(|d| {
quote! {
#[doc = #d]
}
});
Some(quote! {
#doc
pub fn #fn_name<S, I, B>(
self,
ids: I,
builder: B
) -> impl futures::Stream<Item = (#disc_ty, Result<#response_ty, E::Error>)>
where
I: IntoIterator<Item = #disc_ty>,
S: #builder_mod_path::IsComplete,
B: Fn(
#builder_path<#builder_mod_path::Empty>
) -> #builder_path<S>,
{
let requests = ids.into_iter()
.map(move |#disc| builder(#request_path::builder(#disc)).build());
let executor = self.executor;
executor.fetch_many(requests)
}
})
}
} }
pub struct PathNamespace<'r> { pub struct PathNamespace<'r> {
@ -431,18 +645,25 @@ impl PathNamespace<'_> {
mod test { mod test {
use super::*; use super::*;
use crate::openapi::schema::OpenApiSchema; use crate::openapi::schema::test::get_schema;
#[test] #[test]
fn resolve_paths() { fn resolve_paths() {
let schema = OpenApiSchema::read().unwrap(); let schema = get_schema();
let mut paths = 0; let mut paths = 0;
let mut unresolved = vec![]; let mut unresolved = vec![];
for (name, desc) in &schema.paths { for (name, desc) in &schema.paths {
paths += 1; paths += 1;
if Path::from_schema(name, desc, &schema.components.parameters).is_none() { if Path::from_schema(
name,
desc,
&schema.components.parameters,
WarningReporter::new(),
)
.is_none()
{
unresolved.push(name); unresolved.push(name);
} }
} }
@ -463,19 +684,25 @@ mod test {
#[test] #[test]
fn codegen_paths() { fn codegen_paths() {
let schema = OpenApiSchema::read().unwrap(); let schema = get_schema();
let resolved = ResolvedSchema::from_open_api(&schema);
let reporter = WarningReporter::new();
let mut paths = 0; let mut paths = 0;
let mut unresolved = vec![]; let mut unresolved = vec![];
for (name, desc) in &schema.paths { for (name, desc) in &schema.paths {
paths += 1; paths += 1;
let Some(path) = Path::from_schema(name, desc, &schema.components.parameters) else { let Some(path) =
Path::from_schema(name, desc, &schema.components.parameters, reporter.clone())
else {
unresolved.push(name); unresolved.push(name);
continue; continue;
}; };
if path.codegen_scope_call().is_none() || path.codegen_request().is_none() { if path.codegen_scope_call().is_none()
|| path.codegen_request(&resolved, reporter.clone()).is_none()
{
unresolved.push(name); unresolved.push(name);
} }
} }

View file

@ -35,30 +35,56 @@ impl Scope {
pub fn codegen(&self) -> Option<TokenStream> { pub fn codegen(&self) -> Option<TokenStream> {
let name = format_ident!("{}", self.name); let name = format_ident!("{}", self.name);
let bulk_name = format_ident!("Bulk{}", self.name);
let mut functions = Vec::with_capacity(self.members.len()); let mut functions = Vec::with_capacity(self.members.len());
let mut bulk_functions = Vec::with_capacity(self.members.len());
for member in &self.members { for member in &self.members {
if let Some(code) = member.codegen_scope_call() { if let Some(code) = member.codegen_scope_call() {
functions.push(code); functions.push(code);
} }
if let Some(code) = member.codegen_bulk_scope_call() {
bulk_functions.push(code);
}
} }
Some(quote! { Some(quote! {
pub struct #name<'e, E>(&'e E) #[allow(dead_code)]
pub struct #name<E>(E)
where where
E: crate::executor::Executor; E: crate::executor::Executor;
impl<'e, E> #name<'e, E> impl<E> #name<E>
where where
E: crate::executor::Executor E: crate::executor::Executor
{ {
pub fn new(executor: &'e E) -> Self { pub fn new(executor: E) -> Self {
Self(executor) Self(executor)
} }
#(#functions)* #(#functions)*
} }
#[allow(dead_code)]
pub struct #bulk_name<E> where
E: crate::executor::BulkExecutor,
{
executor: E,
}
impl<E> #bulk_name<E>
where
E: crate::executor::BulkExecutor
{
pub fn new(executor: E) -> Self {
Self {
executor,
}
}
#(#bulk_functions)*
}
}) })
} }
} }

View file

@ -4,6 +4,8 @@ use quote::{format_ident, quote};
use crate::openapi::path::OpenApiResponseBody; use crate::openapi::path::OpenApiResponseBody;
use super::WarningReporter;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Union { pub struct Union {
pub name: String, pub name: String,
@ -11,10 +13,23 @@ pub struct Union {
} }
impl Union { impl Union {
pub fn from_schema(name: &str, schema: &OpenApiResponseBody) -> Option<Self> { pub fn from_schema(
name: &str,
schema: &OpenApiResponseBody,
warnings: WarningReporter,
) -> Option<Self> {
let members = match schema { let members = match schema {
OpenApiResponseBody::Union { any_of } => { OpenApiResponseBody::Union { any_of } => {
any_of.iter().map(|l| l.ref_path.to_owned()).collect() let mut members = Vec::with_capacity(any_of.len());
for l in any_of {
let path = l.ref_path.to_owned();
if members.contains(&path) {
warnings.push(format!("Duplicate member: {path}"));
} else {
members.push(path);
}
}
members
} }
_ => return None, _ => return None,
}; };
@ -33,7 +48,7 @@ impl Union {
let ty_name = format_ident!("{}", variant_name); let ty_name = format_ident!("{}", variant_name);
variants.push(quote! { variants.push(quote! {
pub fn #accessor_name(&self) -> Result<crate::models::#ty_name, serde_json::Error> { pub fn #accessor_name(&self) -> Result<crate::models::#ty_name, serde_json::Error> {
<crate::models::#ty_name as serde::Deserialize>::deserialize(&self.0) self.deserialize()
} }
}); });
} }
@ -43,6 +58,13 @@ impl Union {
pub struct #name(serde_json::Value); pub struct #name(serde_json::Value);
impl #name { impl #name {
pub fn deserialize<'de, T>(&'de self) -> Result<T, serde_json::Error>
where
T: serde::Deserialize<'de>,
{
T::deserialize(&self.0)
}
#(#variants)* #(#variants)*
} }
}) })

View file

@ -2,6 +2,8 @@ use std::borrow::Cow;
use serde::Deserialize; use serde::Deserialize;
use super::r#type::OpenApiType;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum ParameterLocation { pub enum ParameterLocation {
@ -9,14 +11,15 @@ pub enum ParameterLocation {
Path, Path,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(untagged)] #[serde(untagged)]
pub enum OpenApiParameterDefault<'a> { pub enum OpenApiParameterDefault<'a> {
Int(i32), Int(i32),
Str(&'a str), Str(&'a str),
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct OpenApiParameterSchema<'a> { pub struct OpenApiParameterSchema<'a> {
#[serde(rename = "$ref")] #[serde(rename = "$ref")]
pub ref_path: Option<&'a str>, pub ref_path: Option<&'a str>,
@ -27,9 +30,10 @@ pub struct OpenApiParameterSchema<'a> {
pub maximum: Option<i32>, pub maximum: Option<i32>,
pub minimum: Option<i32>, pub minimum: Option<i32>,
pub items: Option<Box<OpenApiParameterSchema<'a>>>, pub items: Option<Box<OpenApiParameterSchema<'a>>>,
pub one_of: Option<Vec<OpenApiType<'a>>>,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
pub struct OpenApiParameter<'a> { pub struct OpenApiParameter<'a> {
pub name: &'a str, pub name: &'a str,
pub description: Option<Cow<'a, str>>, pub description: Option<Cow<'a, str>>,

View file

@ -1,10 +1,10 @@
use std::borrow::Cow; use std::borrow::Cow;
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer, Serialize};
use super::parameter::OpenApiParameter; use super::parameter::OpenApiParameter;
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(untagged)] #[serde(untagged)]
pub enum OpenApiPathParameter<'a> { pub enum OpenApiPathParameter<'a> {
Link { Link {
@ -14,13 +14,13 @@ pub enum OpenApiPathParameter<'a> {
Inline(OpenApiParameter<'a>), Inline(OpenApiParameter<'a>),
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
pub struct SchemaLink<'a> { pub struct SchemaLink<'a> {
#[serde(rename = "$ref")] #[serde(rename = "$ref")]
pub ref_path: &'a str, pub ref_path: &'a str,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(untagged)] #[serde(untagged)]
pub enum OpenApiResponseBody<'a> { pub enum OpenApiResponseBody<'a> {
Schema(SchemaLink<'a>), Schema(SchemaLink<'a>),
@ -30,6 +30,9 @@ pub enum OpenApiResponseBody<'a> {
}, },
} }
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct OperationId(pub String);
fn deserialize_response_body<'de, D>(deserializer: D) -> Result<OpenApiResponseBody<'de>, D::Error> fn deserialize_response_body<'de, D>(deserializer: D) -> Result<OpenApiResponseBody<'de>, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
@ -60,10 +63,11 @@ where
Ok(responses.ok.content.json.schema) Ok(responses.ok.content.json.schema)
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct OpenApiPathBody<'a> { pub struct OpenApiPathBody<'a> {
pub summary: Option<Cow<'a, str>>, pub summary: Option<Cow<'a, str>>,
pub description: Cow<'a, str>, pub description: Option<Cow<'a, str>>,
#[serde(borrow, default)] #[serde(borrow, default)]
pub parameters: Vec<OpenApiPathParameter<'a>>, pub parameters: Vec<OpenApiPathParameter<'a>>,
#[serde( #[serde(
@ -72,9 +76,10 @@ pub struct OpenApiPathBody<'a> {
deserialize_with = "deserialize_response_body" deserialize_with = "deserialize_response_body"
)] )]
pub response_content: OpenApiResponseBody<'a>, pub response_content: OpenApiResponseBody<'a>,
pub operation_id: Option<OperationId>,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
pub struct OpenApiPath<'a> { pub struct OpenApiPath<'a> {
#[serde(borrow)] #[serde(borrow)]
pub get: OpenApiPathBody<'a>, pub get: OpenApiPathBody<'a>,

View file

@ -3,7 +3,7 @@ use serde::Deserialize;
use super::{parameter::OpenApiParameter, path::OpenApiPath, r#type::OpenApiType}; use super::{parameter::OpenApiParameter, path::OpenApiPath, r#type::OpenApiType};
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
pub struct Components<'a> { pub struct Components<'a> {
#[serde(borrow)] #[serde(borrow)]
pub schemas: IndexMap<&'a str, OpenApiType<'a>>, pub schemas: IndexMap<&'a str, OpenApiType<'a>>,
@ -11,7 +11,7 @@ pub struct Components<'a> {
pub parameters: IndexMap<&'a str, OpenApiParameter<'a>>, pub parameters: IndexMap<&'a str, OpenApiParameter<'a>>,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
pub struct OpenApiSchema<'a> { pub struct OpenApiSchema<'a> {
#[serde(borrow)] #[serde(borrow)]
pub paths: IndexMap<&'a str, OpenApiPath<'a>>, pub paths: IndexMap<&'a str, OpenApiPath<'a>>,
@ -19,20 +19,12 @@ pub struct OpenApiSchema<'a> {
pub components: Components<'a>, pub components: Components<'a>,
} }
impl OpenApiSchema<'_> {
pub fn read() -> Result<Self, serde_json::Error> {
let s = include_str!("../../openapi.json");
serde_json::from_str(s)
}
}
#[cfg(test)] #[cfg(test)]
mod test { pub(crate) mod test {
use super::*; use super::*;
#[test] pub(crate) fn get_schema() -> OpenApiSchema<'static> {
fn read() { let s = include_str!("../../../torn-api/openapi.json");
OpenApiSchema::read().unwrap(); serde_json::from_str(s).unwrap()
} }
} }

View file

@ -1,17 +1,25 @@
[package] [package]
name = "torn-api" name = "torn-api"
version = "1.0.1" version = "1.7.0"
edition = "2024" edition = "2021"
description = "Auto-generated bindings for the v2 torn api" description = "Auto-generated bindings for the v2 torn api"
license-file = { workspace = true } license-file = { workspace = true }
repository = { workspace = true } repository = { workspace = true }
homepage = { workspace = true } homepage = { workspace = true }
[features]
default = ["scopes", "requests", "builder", "models"]
scopes = ["builder"]
builder = ["requests", "dep:bon"]
requests = ["models"]
models = ["dep:serde_repr"]
strum = ["dep:strum"]
[dependencies] [dependencies]
serde = { workspace = true, features = ["derive"] } serde = { workspace = true, features = ["derive"] }
serde_repr = "0.1" serde_repr = { version = "0.1", optional = true }
serde_json = { workspace = true } serde_json = { workspace = true }
bon = "3.6" bon = { version = "3.6", optional = true }
bytes = "1" bytes = "1"
http = "1" http = "1"
reqwest = { version = "0.12", default-features = false, features = [ reqwest = { version = "0.12", default-features = false, features = [
@ -20,12 +28,19 @@ reqwest = { version = "0.12", default-features = false, features = [
"brotli", "brotli",
] } ] }
thiserror = "2" thiserror = "2"
futures = { version = "0.3", default-features = false, features = [
"std",
"async-await",
] }
chrono = { version = "0.4.41", features = ["serde"] }
strum = { version = "0.27.1", features = ["derive"], optional = true }
[dev-dependencies] [dev-dependencies]
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
[build-dependencies] [build-dependencies]
torn-api-codegen = { path = "../torn-api-codegen", version = "0.1.1" } torn-api-codegen = { path = "../torn-api-codegen", version = "0.7.0" }
syn = { workspace = true, features = ["parsing"] } syn = { workspace = true, features = ["parsing"] }
proc-macro2 = { workspace = true } proc-macro2 = { workspace = true }
prettyplease = "0.2" prettyplease = "0.2"
serde_json = { workspace = true }

View file

@ -1,12 +1,6 @@
use std::{env, fs, path::Path}; use std::{env, fs, path::Path};
use proc_macro2::TokenStream; use torn_api_codegen::{model::ResolvedSchema, openapi::schema::OpenApiSchema};
use torn_api_codegen::{
model::{parameter::Parameter, path::Path as ApiPath, resolve, scope::Scope},
openapi::schema::OpenApiSchema,
};
const DENY_LIST: &[&str] = &[];
fn main() { fn main() {
let out_dir = env::var_os("OUT_DIR").unwrap(); let out_dir = env::var_os("OUT_DIR").unwrap();
@ -15,61 +9,27 @@ fn main() {
let requests_dest = Path::new(&out_dir).join("requests.rs"); let requests_dest = Path::new(&out_dir).join("requests.rs");
let scopes_dest = Path::new(&out_dir).join("scopes.rs"); let scopes_dest = Path::new(&out_dir).join("scopes.rs");
let schema = OpenApiSchema::read().unwrap(); let s = include_str!("./openapi.json");
let schema: OpenApiSchema = serde_json::from_str(s).unwrap();
let resolved = ResolvedSchema::from_open_api(&schema);
let mut models_code = TokenStream::new(); let models_file = syn::parse2(resolved.codegen_models()).unwrap();
for (name, model) in &schema.components.schemas {
if DENY_LIST.contains(name) {
continue;
}
let model = resolve(model, name, &schema.components.schemas);
if let Some(new_code) = model.codegen() {
models_code.extend(new_code);
}
}
let models_file = syn::parse2(models_code).unwrap();
let models_pretty = prettyplease::unparse(&models_file); let models_pretty = prettyplease::unparse(&models_file);
fs::write(&model_dest, models_pretty).unwrap(); fs::write(&model_dest, models_pretty).unwrap();
let mut params_code = TokenStream::new(); let params_file = syn::parse2(resolved.codegen_parameters()).unwrap();
for (name, param) in &schema.components.parameters {
if let Some(code) = Parameter::from_schema(name, param).unwrap().codegen() {
params_code.extend(code);
}
}
let params_file = syn::parse2(params_code).unwrap();
let params_pretty = prettyplease::unparse(&params_file); let params_pretty = prettyplease::unparse(&params_file);
fs::write(&params_dest, params_pretty).unwrap(); fs::write(&params_dest, params_pretty).unwrap();
let mut requests_code = TokenStream::new(); let requests_file = syn::parse2(resolved.codegen_requests()).unwrap();
let mut paths = Vec::new();
for (name, path) in &schema.paths {
let Some(path) = ApiPath::from_schema(name, path, &schema.components.parameters) else {
continue;
};
if let Some(code) = path.codegen_request() {
requests_code.extend(code);
}
paths.push(path);
}
let requests_file = syn::parse2(requests_code).unwrap();
let requests_pretty = prettyplease::unparse(&requests_file); let requests_pretty = prettyplease::unparse(&requests_file);
fs::write(&requests_dest, requests_pretty).unwrap(); fs::write(&requests_dest, requests_pretty).unwrap();
let mut scope_code = TokenStream::new(); let scopes_file = syn::parse2(resolved.codegen_scopes()).unwrap();
let scopes = Scope::from_paths(paths);
for scope in scopes {
if let Some(code) = scope.codegen() {
scope_code.extend(code);
}
}
let scopes_file = syn::parse2(scope_code).unwrap();
let scopes_pretty = prettyplease::unparse(&scopes_file); let scopes_pretty = prettyplease::unparse(&scopes_file);
fs::write(&scopes_dest, scopes_pretty).unwrap(); fs::write(&scopes_dest, scopes_pretty).unwrap();
for warning in resolved.warnings.get_warnings() {
println!("cargo:warning={}", warning);
}
} }

File diff suppressed because it is too large Load diff

View file

@ -1,22 +1,30 @@
use http::{HeaderMap, HeaderValue, header::AUTHORIZATION}; use std::future::Future;
use futures::{Stream, StreamExt};
use http::{header::AUTHORIZATION, HeaderMap, HeaderValue};
use serde::Deserialize; use serde::Deserialize;
#[cfg(feature = "scopes")]
use crate::scopes::{
BulkFactionScope, BulkForumScope, BulkMarketScope, BulkRacingScope, BulkTornScope,
BulkUserScope, FactionScope, ForumScope, MarketScope, RacingScope, TornScope, UserScope,
};
use crate::{ use crate::{
request::{ApiResponse, IntoRequest}, request::{ApiRequest, ApiResponse, IntoRequest},
scopes::{FactionScope, ForumScope, MarketScope, RacingScope, TornScope, UserScope}, scopes::{BulkKeyScope, KeyScope},
}; };
pub trait Executor { pub trait Executor: Sized {
type Error: From<serde_json::Error> + From<crate::ApiError> + Send; type Error: From<serde_json::Error> + From<crate::ApiError> + Send;
fn execute<R>( fn execute<R>(
&self, self,
request: R, request: R,
) -> impl Future<Output = Result<ApiResponse<R::Discriminant>, Self::Error>> + Send ) -> impl Future<Output = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Send
where where
R: IntoRequest; R: IntoRequest;
fn fetch<R>(&self, request: R) -> impl Future<Output = Result<R::Response, Self::Error>> + Send fn fetch<R>(self, request: R) -> impl Future<Output = Result<R::Response, Self::Error>> + Send
where where
R: IntoRequest, R: IntoRequest,
{ {
@ -24,7 +32,7 @@ pub trait Executor {
// The future is `Send` but `&self` might not be. // The future is `Send` but `&self` might not be.
let fut = self.execute(request); let fut = self.execute(request);
async { async {
let resp = fut.await?; let resp = fut.await.1?;
let bytes = resp.body.unwrap(); let bytes = resp.body.unwrap();
@ -51,6 +59,164 @@ pub trait Executor {
} }
} }
pub trait BulkExecutor: Sized {
type Error: From<serde_json::Error> + From<crate::ApiError> + Send;
fn execute<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
where
R: IntoRequest;
fn fetch_many<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl Stream<Item = (R::Discriminant, Result<R::Response, Self::Error>)> + Unpin
where
R: IntoRequest,
{
self.execute(requests).map(|(d, r)| {
let r = match r {
Ok(r) => r,
Err(why) => return (d, Err(why)),
};
let bytes = r.body.unwrap();
if bytes.starts_with(br#"{"error":{"#) {
#[derive(Deserialize)]
struct ErrorBody<'a> {
code: u16,
error: &'a str,
}
#[derive(Deserialize)]
struct ErrorContainer<'a> {
#[serde(borrow)]
error: ErrorBody<'a>,
}
let error: ErrorContainer = match serde_json::from_slice(&bytes) {
Ok(error) => error,
Err(why) => return (d, Err(why.into())),
};
return (
d,
Err(crate::ApiError::new(error.error.code, error.error.error).into()),
);
}
let resp = match serde_json::from_slice(&bytes) {
Ok(resp) => resp,
Err(why) => return (d, Err(why.into())),
};
(d, Ok(resp))
})
}
}
#[cfg(feature = "scopes")]
pub trait ExecutorExt: Executor + Sized {
fn user(self) -> UserScope<Self>;
fn faction(self) -> FactionScope<Self>;
fn torn(self) -> TornScope<Self>;
fn market(self) -> MarketScope<Self>;
fn racing(self) -> RacingScope<Self>;
fn forum(self) -> ForumScope<Self>;
fn key(self) -> KeyScope<Self>;
}
#[cfg(feature = "scopes")]
impl<T> ExecutorExt for T
where
T: Executor + Sized,
{
fn user(self) -> UserScope<Self> {
UserScope::new(self)
}
fn faction(self) -> FactionScope<Self> {
FactionScope::new(self)
}
fn torn(self) -> TornScope<Self> {
TornScope::new(self)
}
fn market(self) -> MarketScope<Self> {
MarketScope::new(self)
}
fn racing(self) -> RacingScope<Self> {
RacingScope::new(self)
}
fn forum(self) -> ForumScope<Self> {
ForumScope::new(self)
}
fn key(self) -> KeyScope<Self> {
KeyScope::new(self)
}
}
#[cfg(feature = "scopes")]
pub trait BulkExecutorExt: BulkExecutor + Sized {
fn user_bulk(self) -> BulkUserScope<Self>;
fn faction_bulk(self) -> BulkFactionScope<Self>;
fn torn_bulk(self) -> BulkTornScope<Self>;
fn market_bulk(self) -> BulkMarketScope<Self>;
fn racing_bulk(self) -> BulkRacingScope<Self>;
fn forum_bulk(self) -> BulkForumScope<Self>;
fn key_bulk(self) -> BulkKeyScope<Self>;
}
#[cfg(feature = "scopes")]
impl<T> BulkExecutorExt for T
where
T: BulkExecutor + Sized,
{
fn user_bulk(self) -> BulkUserScope<Self> {
BulkUserScope::new(self)
}
fn faction_bulk(self) -> BulkFactionScope<Self> {
BulkFactionScope::new(self)
}
fn torn_bulk(self) -> BulkTornScope<Self> {
BulkTornScope::new(self)
}
fn market_bulk(self) -> BulkMarketScope<Self> {
BulkMarketScope::new(self)
}
fn racing_bulk(self) -> BulkRacingScope<Self> {
BulkRacingScope::new(self)
}
fn forum_bulk(self) -> BulkForumScope<Self> {
BulkForumScope::new(self)
}
fn key_bulk(self) -> BulkKeyScope<Self> {
BulkKeyScope::new(self)
}
}
pub struct ReqwestClient(reqwest::Client); pub struct ReqwestClient(reqwest::Client);
impl ReqwestClient { impl ReqwestClient {
@ -71,77 +237,53 @@ impl ReqwestClient {
} }
} }
pub trait ExecutorExt: Executor + Sized { impl ReqwestClient {
fn user(&self) -> UserScope<'_, Self>; async fn execute_api_request(&self, request: ApiRequest) -> Result<ApiResponse, crate::Error> {
fn faction(&self) -> FactionScope<'_, Self>;
fn torn(&self) -> TornScope<'_, Self>;
fn market(&self) -> MarketScope<'_, Self>;
fn racing(&self) -> RacingScope<'_, Self>;
fn forum(&self) -> ForumScope<'_, Self>;
}
impl<T> ExecutorExt for T
where
T: Executor + Sized,
{
fn user(&self) -> UserScope<'_, Self> {
UserScope::new(self)
}
fn faction(&self) -> FactionScope<'_, Self> {
FactionScope::new(self)
}
fn torn(&self) -> TornScope<'_, Self> {
TornScope::new(self)
}
fn market(&self) -> MarketScope<'_, Self> {
MarketScope::new(self)
}
fn racing(&self) -> RacingScope<'_, Self> {
RacingScope::new(self)
}
fn forum(&self) -> ForumScope<'_, Self> {
ForumScope::new(self)
}
}
impl Executor for ReqwestClient {
type Error = crate::Error;
async fn execute<R>(&self, request: R) -> Result<ApiResponse<R::Discriminant>, Self::Error>
where
R: IntoRequest,
{
let request = request.into_request();
let url = request.url(); let url = request.url();
let response = self.0.get(url).send().await?; let response = self.0.get(url).send().await?;
let status = response.status(); let status = response.status();
let body = response.bytes().await.ok(); let body = response.bytes().await.ok();
Ok(ApiResponse { Ok(ApiResponse { status, body })
discriminant: request.disriminant, }
status, }
body,
}) impl Executor for &ReqwestClient {
type Error = crate::Error;
async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
where
R: IntoRequest,
{
let (d, request) = request.into_request();
(d, self.execute_api_request(request).await)
}
}
impl BulkExecutor for &ReqwestClient {
type Error = crate::Error;
fn execute<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)>
where
R: IntoRequest,
{
futures::stream::iter(requests)
.map(move |r| <Self as Executor>::execute(self, r))
.buffer_unordered(25)
} }
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::{ApiError, Error, scopes::test::test_client}; use crate::{scopes::test::test_client, ApiError, Error};
use super::*; use super::*;
#[cfg(feature = "scopes")]
#[tokio::test] #[tokio::test]
async fn api_error() { async fn api_error() {
let client = test_client().await; let client = test_client().await;
@ -153,4 +295,22 @@ mod test {
other => panic!("Expected incorrect id entity relation error, got {other:?}"), other => panic!("Expected incorrect id entity relation error, got {other:?}"),
} }
} }
#[cfg(feature = "scopes")]
#[tokio::test]
async fn bulk_request() {
let client = test_client().await;
let stream = client
.faction_bulk()
.basic_for_id(vec![19.into(), 89.into()], |b| b);
let mut responses: Vec<_> = stream.collect().await;
let (_id1, basic1) = responses.pop().unwrap();
basic1.unwrap();
let (_id2, basic2) = responses.pop().unwrap();
basic2.unwrap();
}
} }

View file

@ -1,9 +1,12 @@
use thiserror::Error; use thiserror::Error;
pub mod executor; pub mod executor;
#[cfg(feature = "models")]
pub mod models; pub mod models;
#[cfg(feature = "requests")]
pub mod parameters; pub mod parameters;
pub mod request; pub mod request;
#[cfg(feature = "scopes")]
pub mod scopes; pub mod scopes;
#[derive(Debug, Error, Clone, PartialEq, Eq)] #[derive(Debug, Error, Clone, PartialEq, Eq)]

View file

@ -1,26 +1,26 @@
use bon::Builder;
use bytes::Bytes; use bytes::Bytes;
use http::StatusCode; use http::StatusCode;
use crate::{ #[cfg(feature = "requests")]
executor::Executor,
models::{FactionChainsResponse, FactionId},
};
pub mod models; pub mod models;
#[derive(Default)] #[derive(Default)]
pub struct ApiRequest<D = ()> { pub struct ApiRequest {
pub disriminant: D,
pub path: String, pub path: String,
pub parameters: Vec<(&'static str, String)>, pub parameters: Vec<(&'static str, String)>,
} }
impl<D> ApiRequest<D> { impl ApiRequest {
pub fn url(&self) -> String { pub fn url(&self) -> String {
let mut url = format!("https://api.torn.com/v2{}?", self.path); let mut url = format!("https://api.torn.com/v2{}?", self.path);
let mut first = true;
for (name, value) in &self.parameters { for (name, value) in &self.parameters {
if first {
first = false;
} else {
url.push('&');
}
url.push_str(&format!("{name}={value}")); url.push_str(&format!("{name}={value}"));
} }
@ -28,77 +28,35 @@ impl<D> ApiRequest<D> {
} }
} }
pub struct ApiResponse<D = ()> { pub struct ApiResponse {
pub discriminant: D,
pub body: Option<Bytes>, pub body: Option<Bytes>,
pub status: StatusCode, pub status: StatusCode,
} }
pub trait IntoRequest: Send { pub trait IntoRequest: Send {
type Discriminant: Send; type Discriminant: Send + 'static;
type Response: for<'de> serde::Deserialize<'de> + Send; type Response: for<'de> serde::Deserialize<'de> + Send;
fn into_request(self) -> ApiRequest<Self::Discriminant>; fn into_request(self) -> (Self::Discriminant, ApiRequest);
} }
pub struct FactionScope<'e, E>(&'e E) pub(crate) struct WrappedApiRequest<R>
where where
E: Executor; R: IntoRequest,
impl<E> FactionScope<'_, E>
where
E: Executor,
{ {
pub async fn chains_for_id<S>( discriminant: R::Discriminant,
&self, request: ApiRequest,
id: FactionId, }
builder: impl FnOnce(
FactionChainsRequestBuilder<faction_chains_request_builder::Empty>, impl<R> IntoRequest for WrappedApiRequest<R>
) -> FactionChainsRequestBuilder<S>,
) -> Result<FactionChainsResponse, E::Error>
where where
S: faction_chains_request_builder::IsComplete, R: IntoRequest,
{ {
let r = builder(FactionChainsRequest::with_id(id)).build(); type Discriminant = R::Discriminant;
type Response = R::Response;
self.0.fetch(r).await fn into_request(self) -> (Self::Discriminant, ApiRequest) {
} (self.discriminant, self.request)
}
#[derive(Builder)]
#[builder(start_fn = with_id)]
pub struct FactionChainsRequest {
#[builder(start_fn)]
pub id: FactionId,
pub limit: Option<usize>,
}
impl IntoRequest for FactionChainsRequest {
type Discriminant = FactionId;
type Response = FactionChainsResponse;
fn into_request(self) -> ApiRequest<Self::Discriminant> {
ApiRequest {
disriminant: self.id,
path: format!("/faction/{}/chains", self.id),
parameters: self
.limit
.into_iter()
.map(|l| ("limit", l.to_string()))
.collect(),
}
} }
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {}
use crate::executor::ReqwestClient;
use super::*;
#[tokio::test]
async fn test_request() {
let client = ReqwestClient::new("nAYRXaoqzBAGalWt");
let r = models::TornItemsForIdsRequest::builder("1".to_owned()).build();
client.fetch(r).await.unwrap();
}
}

View file

@ -9,8 +9,9 @@ pub(super) mod test {
use crate::{ use crate::{
executor::{ExecutorExt, ReqwestClient}, executor::{ExecutorExt, ReqwestClient},
models::{ models::{
AttackCode, FactionSelectionName, PersonalStatsCategoryEnum, PersonalStatsStatName, faction_selection_name::FactionSelectionNameVariant,
UserListEnum, user_selection_name::UserSelectionNameVariant, AttackCode, PersonalStatsCategoryEnum,
PersonalStatsStatName, UserListEnum,
}, },
}; };
@ -67,7 +68,10 @@ pub(super) mod test {
let r = client let r = client
.faction() .faction()
.for_selections(|b| { .for_selections(|b| {
b.selections([FactionSelectionName::Basic, FactionSelectionName::Balance]) b.selections([
FactionSelectionNameVariant::Basic,
FactionSelectionNameVariant::Balance,
])
}) })
.await .await
.unwrap(); .unwrap();
@ -366,6 +370,15 @@ pub(super) mod test {
faction_scope.lookup(|b| b).await.unwrap(); faction_scope.lookup(|b| b).await.unwrap();
} }
#[tokio::test]
async fn faction_reports() {
let client = test_client().await;
let faction_scope = FactionScope(&client);
faction_scope.reports(|b| b).await.unwrap();
}
#[tokio::test] #[tokio::test]
async fn forum_categories() { async fn forum_categories() {
let client = test_client().await; let client = test_client().await;
@ -415,7 +428,7 @@ pub(super) mod test {
let forum_scope = ForumScope(&client); let forum_scope = ForumScope(&client);
forum_scope forum_scope
.threads_for_category_ids("2".to_owned(), |b| b) .threads_for_category_ids([2].into(), |b| b)
.await .await
.unwrap(); .unwrap();
} }
@ -486,14 +499,14 @@ pub(super) mod test {
racing_scope.carupgrades(|b| b).await.unwrap(); racing_scope.carupgrades(|b| b).await.unwrap();
} }
#[tokio::test] /* #[tokio::test]
async fn racing_races() { async fn racing_races() {
let client = test_client().await; let client = test_client().await;
let racing_scope = RacingScope(&client); let racing_scope = RacingScope(&client);
racing_scope.races(|b| b).await.unwrap(); racing_scope.races(|b| b).await.unwrap();
} } */
#[tokio::test] #[tokio::test]
async fn racing_race_for_race_id() { async fn racing_race_for_race_id() {
@ -639,10 +652,7 @@ pub(super) mod test {
let torn_scope = TornScope(&client); let torn_scope = TornScope(&client);
torn_scope torn_scope.items_for_ids([1].into(), |b| b).await.unwrap();
.items_for_ids("1".to_owned(), |b| b)
.await
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -909,6 +919,161 @@ pub(super) mod test {
.unwrap(); .unwrap();
} }
#[cfg(feature = "strum")]
#[tokio::test]
async fn user_personalstats_popular() {
let client = test_client().await;
let resp = client
.user()
.for_selections(|b| {
b.selections([UserSelectionNameVariant::Personalstats])
.cat(PersonalStatsCategoryEnum::Popular)
})
.await
.unwrap();
assert!(resp
.user_personal_stats_response()
.unwrap()
.is_user_personal_stats_popular());
}
#[cfg(feature = "strum")]
#[tokio::test]
async fn user_personalstats_all() {
let client = test_client().await;
let resp = client
.user()
.for_selections(|b| {
b.selections([UserSelectionNameVariant::Personalstats])
.cat(PersonalStatsCategoryEnum::All)
})
.await
.unwrap();
assert!(resp
.user_personal_stats_response()
.unwrap()
.is_user_personal_stats_full());
}
#[cfg(feature = "strum")]
#[tokio::test]
async fn user_personalstats_cat_attacking() {
let client = test_client().await;
let resp = client
.user()
.for_selections(|b| {
b.selections([UserSelectionNameVariant::Personalstats])
.cat(PersonalStatsCategoryEnum::Attacking)
})
.await
.unwrap();
assert!(resp
.user_personal_stats_response()
.unwrap()
.try_as_user_personal_stats_category()
.unwrap()
.personalstats
.is_personal_stats_attacking_public());
}
#[cfg(feature = "strum")]
#[tokio::test]
async fn user_personalstats_cat_jobs() {
let client = test_client().await;
let resp = client
.user()
.for_selections(|b| {
b.selections([UserSelectionNameVariant::Personalstats])
.cat(PersonalStatsCategoryEnum::Jobs)
})
.await
.unwrap();
assert!(resp
.user_personal_stats_response()
.unwrap()
.try_as_user_personal_stats_category()
.unwrap()
.personalstats
.is_personal_stats_jobs_public());
}
#[cfg(feature = "strum")]
#[tokio::test]
async fn user_personalstats_cat_trading() {
let client = test_client().await;
let resp = client
.user()
.for_selections(|b| {
b.selections([UserSelectionNameVariant::Personalstats])
.cat(PersonalStatsCategoryEnum::Trading)
})
.await
.unwrap();
assert!(resp
.user_personal_stats_response()
.unwrap()
.try_as_user_personal_stats_category()
.unwrap()
.personalstats
.is_personal_stats_trading());
}
#[cfg(feature = "strum")]
#[tokio::test]
async fn user_personalstats_cat_jail() {
let client = test_client().await;
let resp = client
.user()
.for_selections(|b| {
b.selections([UserSelectionNameVariant::Personalstats])
.cat(PersonalStatsCategoryEnum::Jail)
})
.await
.unwrap();
assert!(resp
.user_personal_stats_response()
.unwrap()
.try_as_user_personal_stats_category()
.unwrap()
.personalstats
.is_personal_stats_jail());
}
#[cfg(feature = "strum")]
#[tokio::test]
async fn user_personalstats_cat_hospital() {
let client = test_client().await;
let resp = client
.user()
.for_selections(|b| {
b.selections([UserSelectionNameVariant::Personalstats])
.cat(PersonalStatsCategoryEnum::Hospital)
})
.await
.unwrap();
assert!(resp
.user_personal_stats_response()
.unwrap()
.try_as_user_personal_stats_category()
.unwrap()
.personalstats
.is_personal_stats_hospital());
}
#[tokio::test] #[tokio::test]
async fn user_personalstats_for_id() { async fn user_personalstats_for_id() {
let client = test_client().await; let client = test_client().await;
@ -954,4 +1119,25 @@ pub(super) mod test {
client.user().attacks(|b| b).await.unwrap(); client.user().attacks(|b| b).await.unwrap();
} }
#[tokio::test]
async fn user_reports() {
let client = test_client().await;
client.user().reports(|b| b).await.unwrap();
}
#[tokio::test]
async fn key_info() {
let client = test_client().await;
client.key().info(|b| b).await.unwrap();
}
#[tokio::test]
async fn key_log() {
let client = test_client().await;
client.key().log(|b| b).await.unwrap();
}
} }

View file

@ -1,43 +1,48 @@
[package] [package]
name = "torn-key-pool" name = "torn-key-pool"
version = "0.9.0" version = "1.1.3"
edition = "2021" edition = "2021"
authors = ["Pyrit [2111649]"] authors = ["Pyrit [2111649]"]
license = "MIT" license-file = { workspace = true }
repository = "https://github.com/TotallyNot/torn-api.rs.git" repository = { workspace = true }
homepage = "https://github.com/TotallyNot/torn-api.rs.git" homepage = { workspace = true }
description = "A generalised API key pool for torn-api" description = "A generalised API key pool for torn-api"
[features] [features]
default = ["postgres", "tokio-runtime"] default = ["postgres", "tokio-runtime"]
postgres = [ "dep:sqlx", "dep:chrono", "dep:indoc", "dep:serde" ] postgres = ["dep:sqlx", "dep:chrono", "dep:indoc"]
reqwest = [ "dep:reqwest", "torn-api/reqwest" ] tokio-runtime = ["dep:tokio", "dep:rand", "dep:tokio-stream"]
awc = [ "dep:awc", "torn-api/awc" ]
tokio-runtime = [ "dep:tokio", "dep:rand" ]
actix-runtime = [ "dep:actix-rt", "dep:rand" ]
[dependencies] [dependencies]
torn-api = { path = "../torn-api", default-features = false, version = "0.7" } torn-api = { path = "../torn-api", default-features = false, version = "1.1.1" }
async-trait = "0.1"
thiserror = "2" thiserror = "2"
sqlx = { version = "0.8", features = [ "postgres", "chrono", "json", "derive" ], optional = true, default-features = false } sqlx = { version = "0.8", features = [
serde = { version = "1.0", optional = true } "postgres",
"chrono",
"json",
"derive",
], optional = true, default-features = false }
serde = { workspace = true }
serde_json = { workspace = true }
chrono = { version = "0.4", optional = true } chrono = { version = "0.4", optional = true }
indoc = { version = "2", optional = true } indoc = { version = "2", optional = true }
tokio = { version = "1", optional = true, default-features = false, features = ["time"] } tokio = { version = "1", optional = true, default-features = false, features = [
actix-rt = { version = "2", optional = true, default-features = false } "time",
rand = { version = "0.8", optional = true } ] }
tokio-stream = { version = "0.1", optional = true, default-features = false, features = [
"time",
] }
rand = { version = "0.9", optional = true }
futures = "0.3" futures = "0.3"
reqwest = { version = "0.12", default-features = false, features = [
reqwest = { version = "0.12", default-features = false, features = [ "json" ], optional = true } "brotli",
awc = { version = "3", default-features = false, optional = true } "http2",
"rustls-tls-webpki-roots",
] }
[dev-dependencies] [dev-dependencies]
torn-api = { path = "../torn-api", features = [ "reqwest" ] } torn-api = { path = "../torn-api" }
sqlx = { version = "0.8", features = ["runtime-tokio-rustls"] } sqlx = { version = "0.8", features = ["runtime-tokio-rustls"] }
dotenvy = "0.15"
tokio = { version = "1.42", features = ["rt"] } tokio = { version = "1.42", features = ["rt"] }
tokio-test = "0.4"
reqwest = { version = "0.12", default-features = true } reqwest = { version = "0.12", default-features = true }
awc = { version = "3", features = [ "rustls" ] }

View file

@ -3,48 +3,24 @@
#[cfg(feature = "postgres")] #[cfg(feature = "postgres")]
pub mod postgres; pub mod postgres;
// pub mod local; use std::{collections::HashMap, future::Future, ops::Deref, sync::Arc, time::Duration};
pub mod send;
use std::sync::Arc; use futures::{future::BoxFuture, FutureExt, Stream, StreamExt};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde::Deserialize;
use tokio_stream::StreamExt as TokioStreamExt;
use torn_api::{
executor::{BulkExecutor, Executor},
request::{ApiRequest, ApiResponse},
ApiError,
};
use async_trait::async_trait; pub trait ApiKeyId: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
use thiserror::Error;
use torn_api::ResponseError; impl<T> ApiKeyId for T where T: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync {}
#[derive(Debug, Error)] pub trait ApiKey: Send + Sync + Clone + 'static {
pub enum KeyPoolError<S, C> type IdType: ApiKeyId;
where
S: std::error::Error + Clone,
C: std::error::Error,
{
#[error("Key pool storage driver error: {0:?}")]
Storage(#[source] S),
#[error(transparent)]
Client(#[from] C),
#[error(transparent)]
Response(ResponseError),
}
impl<S, C> KeyPoolError<S, C>
where
S: std::error::Error + Clone,
C: std::error::Error,
{
#[inline(always)]
pub fn api_code(&self) -> Option<u8> {
match self {
Self::Response(why) => why.api_code(),
_ => None,
}
}
}
pub trait ApiKey: Sync + Send + std::fmt::Debug + Clone + 'static {
type IdType: PartialEq + Eq + std::hash::Hash + Send + Sync + std::fmt::Debug + Clone;
fn value(&self) -> &str; fn value(&self) -> &str;
@ -105,7 +81,47 @@ where
} }
} }
pub trait IntoSelector<K, D>: Send + Sync impl<K, D> From<&str> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: &str) -> Self {
Self::Key(value.to_owned())
}
}
impl<K, D> From<D> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: D) -> Self {
Self::Has(vec![value])
}
}
impl<K, D> From<&[D]> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: &[D]) -> Self {
Self::Has(value.to_vec())
}
}
impl<K, D> From<Vec<D>> for KeySelector<K, D>
where
K: ApiKey,
D: KeyDomain,
{
fn from(value: Vec<D>) -> Self {
Self::Has(value)
}
}
pub trait IntoSelector<K, D>: Send
where where
K: ApiKey, K: ApiKey,
D: KeyDomain, D: KeyDomain,
@ -113,134 +129,612 @@ where
fn into_selector(self) -> KeySelector<K, D>; fn into_selector(self) -> KeySelector<K, D>;
} }
impl<K, D> IntoSelector<K, D> for D impl<K, D, T> IntoSelector<K, D> for T
where where
K: ApiKey, K: ApiKey,
D: KeyDomain, D: KeyDomain,
T: Into<KeySelector<K, D>> + Send,
{ {
fn into_selector(self) -> KeySelector<K, D> { fn into_selector(self) -> KeySelector<K, D> {
KeySelector::Has(vec![self]) self.into()
} }
} }
impl<K, D> IntoSelector<K, D> for KeySelector<K, D> pub trait KeyPoolError:
where From<reqwest::Error> + From<serde_json::Error> + From<torn_api::ApiError> + From<Arc<Self>> + Send
K: ApiKey,
D: KeyDomain,
{ {
fn into_selector(self) -> KeySelector<K, D> {
self
}
} }
pub enum KeyAction<D> impl<T> KeyPoolError for T where
where T: From<reqwest::Error>
D: KeyDomain, + From<serde_json::Error>
+ From<torn_api::ApiError>
+ From<Arc<Self>>
+ Send
{ {
Delete,
RemoveDomain(D),
Timeout(chrono::Duration),
} }
#[async_trait] pub trait KeyPoolStorage: Send + Sync {
pub trait KeyPoolStorage {
type Key: ApiKey; type Key: ApiKey;
type Domain: KeyDomain; type Domain: KeyDomain;
type Error: std::error::Error + Sync + Send + Clone; type Error: KeyPoolError;
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error> fn acquire_key<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where where
S: IntoSelector<Self::Key, Self::Domain>; S: IntoSelector<Self::Key, Self::Domain>;
async fn acquire_many_keys<S>( fn acquire_many_keys<S>(
&self, &self,
selector: S, selector: S,
number: i64, number: i64,
) -> Result<Vec<Self::Key>, Self::Error> ) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
where where
S: IntoSelector<Self::Key, Self::Domain>; S: IntoSelector<Self::Key, Self::Domain>;
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error>; fn store_key(
async fn store_key(
&self, &self,
user_id: i32, user_id: i32,
key: String, key: String,
domains: Vec<Self::Domain>, domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error>; ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send;
async fn read_key<S>(&self, selector: S) -> Result<Option<Self::Key>, Self::Error> fn read_key<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Option<Self::Key>, Self::Error>> + Send
where where
S: IntoSelector<Self::Key, Self::Domain>; S: IntoSelector<Self::Key, Self::Domain>;
async fn read_keys<S>(&self, selector: S) -> Result<Vec<Self::Key>, Self::Error> fn read_keys<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Vec<Self::Key>, Self::Error>> + Send
where where
S: IntoSelector<Self::Key, Self::Domain>; S: IntoSelector<Self::Key, Self::Domain>;
async fn remove_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error> fn remove_key<S>(
&self,
selector: S,
) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where where
S: IntoSelector<Self::Key, Self::Domain>; S: IntoSelector<Self::Key, Self::Domain>;
async fn add_domain_to_key<S>( fn add_domain_to_key<S>(
&self, &self,
selector: S, selector: S,
domain: Self::Domain, domain: Self::Domain,
) -> Result<Self::Key, Self::Error> ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where where
S: IntoSelector<Self::Key, Self::Domain>; S: IntoSelector<Self::Key, Self::Domain>;
async fn remove_domain_from_key<S>( fn remove_domain_from_key<S>(
&self, &self,
selector: S, selector: S,
domain: Self::Domain, domain: Self::Domain,
) -> Result<Self::Key, Self::Error> ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where where
S: IntoSelector<Self::Key, Self::Domain>; S: IntoSelector<Self::Key, Self::Domain>;
async fn set_domains_for_key<S>( fn set_domains_for_key<S>(
&self, &self,
selector: S, selector: S,
domains: Vec<Self::Domain>, domains: Vec<Self::Domain>,
) -> Result<Self::Key, Self::Error> ) -> impl Future<Output = Result<Self::Key, Self::Error>> + Send
where
S: IntoSelector<Self::Key, Self::Domain>;
fn timeout_key<S>(
&self,
selector: S,
duration: Duration,
) -> impl Future<Output = Result<(), Self::Error>> + Send
where where
S: IntoSelector<Self::Key, Self::Domain>; S: IntoSelector<Self::Key, Self::Domain>;
} }
#[derive(Debug, Default)] #[derive(Default)]
pub struct PoolOptions { pub struct PoolOptions<S>
comment: Option<String>,
hooks_before: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
hooks_after: std::collections::HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
}
#[derive(Debug, Clone)]
pub struct KeyPoolExecutor<'a, C, S>
where where
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
storage: &'a S, comment: Option<String>,
options: Arc<PoolOptions>, #[allow(clippy::type_complexity)]
selector: KeySelector<S::Key, S::Domain>, error_hooks: HashMap<
_marker: std::marker::PhantomData<C>, u16,
Box<
dyn for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
+ Send
+ Sync,
>,
>,
} }
impl<'a, C, S> KeyPoolExecutor<'a, C, S> pub struct PoolBuilder<S>
where
S: KeyPoolStorage,
{
client: reqwest::Client,
storage: S,
options: crate::PoolOptions<S>,
}
impl<S> PoolBuilder<S>
where
S: KeyPoolStorage,
{
pub fn new(storage: S) -> Self {
Self {
client: reqwest::Client::builder()
.brotli(true)
.http2_keep_alive_timeout(Duration::from_secs(60))
.http2_keep_alive_interval(Duration::from_secs(5))
.https_only(true)
.build()
.unwrap(),
storage,
options: PoolOptions {
comment: None,
error_hooks: Default::default(),
},
}
}
pub fn comment(mut self, c: impl ToString) -> Self {
self.options.comment = Some(c.to_string());
self
}
pub fn error_hook<F>(mut self, code: u16, handler: F) -> Self
where
F: for<'a> Fn(&'a S, &'a S::Key) -> BoxFuture<'a, Result<bool, S::Error>>
+ Send
+ Sync
+ 'static,
{
self.options.error_hooks.insert(code, Box::new(handler));
self
}
pub fn use_default_hooks(self) -> Self {
self.error_hook(2, |storage, key| {
async move {
storage.remove_key(KeySelector::Id(key.id())).await?;
Ok(true)
}
.boxed()
})
.error_hook(5, |storage, key| {
async move {
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
.await?;
Ok(true)
}
.boxed()
})
.error_hook(10, |storage, key| {
async move {
storage.remove_key(KeySelector::Id(key.id())).await?;
Ok(true)
}
.boxed()
})
.error_hook(13, |storage, key| {
async move {
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
.await?;
Ok(true)
}
.boxed()
})
.error_hook(18, |storage, key| {
async move {
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(24 * 3_600))
.await?;
Ok(true)
}
.boxed()
})
}
pub fn build(self) -> KeyPool<S> {
KeyPool {
inner: Arc::new(KeyPoolInner {
client: self.client,
storage: self.storage,
options: self.options,
}),
}
}
}
pub struct KeyPoolInner<S>
where
S: KeyPoolStorage,
{
pub client: reqwest::Client,
pub storage: S,
pub options: PoolOptions<S>,
}
impl<S> KeyPoolInner<S>
where
S: KeyPoolStorage,
{
async fn execute_with_key(
&self,
key: &S::Key,
request: &ApiRequest,
) -> Result<RequestResult, S::Error> {
let mut headers = HeaderMap::with_capacity(1);
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("ApiKey {}", key.value())).unwrap(),
);
let resp = self
.client
.get(request.url())
.headers(headers)
.send()
.await?;
let status = resp.status();
let bytes = resp.bytes().await?;
if let Some(err) = decode_error(&bytes)? {
if let Some(handler) = self.options.error_hooks.get(&err.code()) {
let retry = (*handler)(&self.storage, key).await?;
if retry {
return Ok(RequestResult::Retry);
}
}
Err(err.into())
} else {
Ok(RequestResult::Response(ApiResponse {
body: Some(bytes),
status,
}))
}
}
async fn execute_request(
&self,
selector: KeySelector<S::Key, S::Domain>,
request: ApiRequest,
) -> Result<ApiResponse, S::Error> {
loop {
let key = self.storage.acquire_key(selector.clone()).await?;
match self.execute_with_key(&key, &request).await {
Ok(RequestResult::Response(resp)) => return Ok(resp),
Ok(RequestResult::Retry) => (),
Err(why) => return Err(why),
}
}
}
async fn execute_bulk_requests<D, T: IntoIterator<Item = (D, ApiRequest)>>(
&self,
selector: KeySelector<S::Key, S::Domain>,
requests: T,
) -> impl Stream<Item = (D, Result<ApiResponse, S::Error>)> + use<'_, D, S, T> {
let requests: Vec<_> = requests.into_iter().collect();
let keys: Vec<_> = match self
.storage
.acquire_many_keys(selector.clone(), requests.len() as i64)
.await
{
Ok(keys) => keys.into_iter().map(Ok).collect(),
Err(why) => {
let why = Arc::new(why);
std::iter::repeat_n(why, requests.len())
.map(|e| Err(S::Error::from(e)))
.collect()
}
};
StreamExt::map(
futures::stream::iter(std::iter::zip(requests, keys)),
move |((discriminant, request), mut maybe_key)| {
let selector = selector.clone();
async move {
loop {
let key = match maybe_key {
Ok(key) => key,
Err(why) => return (discriminant, Err(why)),
};
match self.execute_with_key(&key, &request).await {
Ok(RequestResult::Response(resp)) => return (discriminant, Ok(resp)),
Ok(RequestResult::Retry) => (),
Err(why) => return (discriminant, Err(why)),
}
maybe_key = self.storage.acquire_key(selector.clone()).await;
}
}
},
)
.buffer_unordered(25)
}
}
pub struct KeyPool<S>
where
S: KeyPoolStorage,
{
inner: Arc<KeyPoolInner<S>>,
}
impl<S> Deref for KeyPool<S>
where
S: KeyPoolStorage,
{
type Target = KeyPoolInner<S>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
enum RequestResult {
Response(ApiResponse),
Retry,
}
impl<S> KeyPool<S>
where
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn torn_api<I>(&self, selector: I) -> KeyPoolExecutor<S>
where
I: IntoSelector<S::Key, S::Domain>,
{
KeyPoolExecutor::new(self, selector.into_selector())
}
pub fn throttled_torn_api<I>(
&self,
selector: I,
distance: Duration,
) -> ThrottledKeyPoolExecutor<S>
where
I: IntoSelector<S::Key, S::Domain>,
{
ThrottledKeyPoolExecutor::new(self, selector.into_selector(), distance)
}
}
fn decode_error(buf: &[u8]) -> Result<Option<ApiError>, serde_json::Error> {
if buf.starts_with(br#"{"error":{"#) {
#[derive(Deserialize)]
struct ErrorBody<'a> {
code: u16,
error: &'a str,
}
#[derive(Deserialize)]
struct ErrorContainer<'a> {
#[serde(borrow)]
error: ErrorBody<'a>,
}
let error: ErrorContainer = serde_json::from_slice(buf)?;
Ok(Some(crate::ApiError::new(
error.error.code,
error.error.error,
)))
} else {
Ok(None)
}
}
pub struct KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pool: &'p KeyPoolInner<S>,
selector: KeySelector<S::Key, S::Domain>,
}
impl<'p, S> KeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pub fn new(pool: &'p KeyPool<S>, selector: KeySelector<S::Key, S::Domain>) -> Self {
Self {
pool: &pool.inner,
selector,
}
}
}
impl<S> Executor for KeyPoolExecutor<'_, S>
where
S: KeyPoolStorage + 'static,
{
type Error = S::Error;
async fn execute<R>(self, request: R) -> (R::Discriminant, Result<ApiResponse, Self::Error>)
where
R: torn_api::request::IntoRequest,
{
let (d, request) = request.into_request();
(d, self.pool.execute_request(self.selector, request).await)
}
}
impl<S> BulkExecutor for KeyPoolExecutor<'_, S>
where
S: KeyPoolStorage + 'static,
{
type Error = S::Error;
fn execute<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
where
R: torn_api::request::IntoRequest,
{
let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect();
self.pool
.execute_bulk_requests(self.selector.clone(), requests)
.into_stream()
.flatten()
.boxed()
}
}
pub struct ThrottledKeyPoolExecutor<'p, S>
where
S: KeyPoolStorage,
{
pool: &'p KeyPoolInner<S>,
selector: KeySelector<S::Key, S::Domain>,
distance: Duration,
}
impl<S> Clone for ThrottledKeyPoolExecutor<'_, S>
where
S: KeyPoolStorage,
{
fn clone(&self) -> Self {
Self {
pool: self.pool,
selector: self.selector.clone(),
distance: self.distance,
}
}
}
impl<S> ThrottledKeyPoolExecutor<'_, S>
where
S: KeyPoolStorage,
{
async fn execute_request(self, request: ApiRequest) -> Result<ApiResponse, S::Error> {
self.pool.execute_request(self.selector, request).await
}
}
impl<'p, S> ThrottledKeyPoolExecutor<'p, S>
where where
S: KeyPoolStorage, S: KeyPoolStorage,
{ {
pub fn new( pub fn new(
storage: &'a S, pool: &'p KeyPool<S>,
selector: KeySelector<S::Key, S::Domain>, selector: KeySelector<S::Key, S::Domain>,
options: Arc<PoolOptions>, distance: Duration,
) -> Self { ) -> Self {
Self { Self {
storage, pool: &pool.inner,
selector, selector,
options, distance,
_marker: std::marker::PhantomData,
} }
} }
} }
#[cfg(all(test, feature = "postgres"))] impl<S> BulkExecutor for ThrottledKeyPoolExecutor<'_, S>
mod test {} where
S: KeyPoolStorage + 'static,
{
type Error = S::Error;
fn execute<R>(
self,
requests: impl IntoIterator<Item = R>,
) -> impl futures::Stream<Item = (R::Discriminant, Result<ApiResponse, Self::Error>)> + Unpin
where
R: torn_api::request::IntoRequest,
{
let requests: Vec<_> = requests.into_iter().map(|r| r.into_request()).collect();
StreamExt::map(
futures::stream::iter(requests).throttle(self.distance),
move |(d, request)| {
let this = self.clone();
async move {
let result = this.execute_request(request).await;
(d, result)
}
},
)
.buffer_unordered(25)
.boxed()
}
}
#[cfg(test)]
#[cfg(feature = "postgres")]
mod test {
use torn_api::executor::{BulkExecutorExt, ExecutorExt};
use crate::postgres;
use super::*;
#[sqlx::test]
fn name(pool: sqlx::PgPool) {
let (storage, _) = postgres::test::setup(pool).await;
let pool = PoolBuilder::new(storage)
.use_default_hooks()
.comment("test_runner")
.build();
pool.torn_api(postgres::test::Domain::All)
.faction()
.basic(|b| b)
.await
.unwrap();
}
#[sqlx::test]
fn bulk(pool: sqlx::PgPool) {
let (storage, _) = postgres::test::setup(pool).await;
let pool = PoolBuilder::new(storage)
.use_default_hooks()
.comment("test_runner")
.build();
let responses = pool
.torn_api(postgres::test::Domain::All)
.faction_bulk()
.basic_for_id(vec![19.into(), 89.into()], |b| b);
let mut responses: Vec<_> = StreamExt::collect(responses).await;
let (_id1, basic1) = responses.pop().unwrap();
basic1.unwrap();
let (_id2, basic2) = responses.pop().unwrap();
basic2.unwrap();
}
#[sqlx::test]
fn bulk_trottled(pool: sqlx::PgPool) {
let (storage, _) = postgres::test::setup(pool).await;
let pool = PoolBuilder::new(storage)
.use_default_hooks()
.comment("test_runner")
.build();
let responses = pool
.throttled_torn_api(postgres::test::Domain::All, Duration::from_millis(500))
.faction_bulk()
.basic_for_id(vec![19.into(), 89.into()], |b| b);
let mut responses: Vec<_> = StreamExt::collect(responses).await;
let (_id1, basic1) = responses.pop().unwrap();
basic1.unwrap();
let (_id2, basic2) = responses.pop().unwrap();
basic2.unwrap();
}
}

View file

@ -1,206 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use torn_api::{
local::{ApiClient, ApiProvider, RequestExecutor},
ApiRequest, ApiResponse, ApiSelection, ResponseError,
};
use crate::{ApiKey, KeyPoolError, KeyPoolExecutor, KeyPoolStorage, IntoSelector};
#[async_trait(?Send)]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ApiClient,
S: KeyPoolStorage + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(
&self,
client: &C,
mut request: ApiRequest<A>,
id: Option<String>,
) -> Result<ApiResponse, Self::Error>
where
A: ApiSelection,
{
request.comment = self.comment.map(ToOwned::to_owned);
loop {
let key = self
.storage
.acquire_key(self.selector.clone())
.await
.map_err(|e| KeyPoolError::Storage(Arc::new(e)))?;
let url = request.url(key.value(), id.as_deref());
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(Arc::new)
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => return Ok(res),
};
}
}
async fn execute_many<A, I>(
&self,
client: &C,
mut request: ApiRequest<A>,
ids: Vec<I>,
) -> HashMap<I, Result<ApiResponse, Self::Error>>
where
A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq,
{
let keys = match self
.storage
.acquire_many_keys(self.selector.clone(), ids.len() as i64)
.await
{
Ok(keys) => keys,
Err(why) => {
let shared = Arc::new(why);
return ids
.into_iter()
.map(|i| (i, Err(Self::Error::Storage(shared.clone()))))
.collect();
}
};
request.comment = self.comment.map(ToOwned::to_owned);
let request_ref = &request;
let tuples =
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
let id_string = id.to_string();
loop {
let url = request_ref.url(key.value(), Some(&id_string));
let value = match client.request(url).await {
Ok(v) => v,
Err(why) => return (id, Err(Self::Error::Client(why))),
};
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
match self.storage.flag_key(key, code).await {
Ok(false) => {
return (
id,
Err(KeyPoolError::Response(ResponseError::Api {
code,
reason,
})),
)
}
Ok(true) => (),
Err(why) => return (id, Err(KeyPoolError::Storage(Arc::new(why)))),
}
}
Err(parsing_error) => {
return (id, Err(KeyPoolError::Response(parsing_error)))
}
Ok(res) => return (id, Ok(res)),
};
key = match self.storage.acquire_key(self.selector.clone()).await {
Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(Arc::new(why)))),
};
}
}))
.await;
HashMap::from_iter(tuples)
}
}
#[derive(Clone, Debug)]
pub struct KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: C,
pub storage: S,
comment: Option<String>,
}
impl<C, S> KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage + 'static,
{
pub fn new(client: C, storage: S, comment: Option<String>) -> Self {
Self {
client,
storage,
comment,
}
}
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>> where I: IntoSelector<S::Key, S::Domain> {
ApiProvider::new(
&self.client,
KeyPoolExecutor::new(&self.storage, selector.into_selector(), self.comment.as_deref()),
)
}
}
pub trait WithStorage {
fn with_storage<'a, S, I>(
&'a self,
storage: &'a S,
selector: I
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ApiClient + Sized,
S: KeyPoolStorage + 'static,
I: IntoSelector<S::Key, S::Domain>
{
ApiProvider::new(self, KeyPoolExecutor::new(storage, selector.into_selector(), None))
}
}
#[cfg(feature = "awc")]
impl WithStorage for awc::Client {}
#[cfg(all(test, feature = "postgres", feature = "awc"))]
mod test {
use tokio::test;
use super::*;
use crate::postgres::test::{setup, Domain};
#[test]
async fn test_pool_request() {
let storage = setup().await;
let pool = KeyPool::new(awc::Client::default(), storage);
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.profile().unwrap();
}
#[test]
async fn test_with_storage_request() {
let storage = setup().await;
let response = awc::Client::new()
.with_storage(&storage, Domain::All)
.user(|b| b)
.await
.unwrap();
_ = response.profile().unwrap();
}
}

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use futures::future::BoxFuture;
use indoc::indoc; use indoc::formatdoc;
use sqlx::{FromRow, PgPool, Postgres, QueryBuilder}; use sqlx::{FromRow, PgPool, Postgres, QueryBuilder};
use thiserror::Error; use thiserror::Error;
@ -17,28 +17,31 @@ impl<T> PgKeyDomain for T where
{ {
} }
#[derive(Debug, Error, Clone)] #[derive(Debug, Error)]
pub enum PgStorageError<D> pub enum PgKeyPoolError<D>
where where
D: PgKeyDomain, D: PgKeyDomain,
{ {
#[error(transparent)] #[error("Databank: {0}")]
Pg(Arc<sqlx::Error>), Pg(#[from] sqlx::Error),
#[error("Network: {0}")]
Network(#[from] reqwest::Error),
#[error("Parsing: {0}")]
Parsing(#[from] serde_json::Error),
#[error("Api: {0}")]
Api(#[from] torn_api::ApiError),
#[error("No key avalaible for domain {0:?}")] #[error("No key avalaible for domain {0:?}")]
Unavailable(KeySelector<PgKey<D>, D>), Unavailable(KeySelector<PgKey<D>, D>),
#[error("Key not found: '{0:?}'")] #[error("Key not found: '{0:?}'")]
KeyNotFound(KeySelector<PgKey<D>, D>), KeyNotFound(KeySelector<PgKey<D>, D>),
}
impl<D> From<sqlx::Error> for PgStorageError<D> #[error("Failed to acquire keys in bulk: {0}")]
where Bulk(#[from] Arc<Self>),
D: PgKeyDomain,
{
fn from(value: sqlx::Error) -> Self {
Self::Pg(Arc::new(value))
}
} }
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow)]
@ -95,6 +98,7 @@ where
{ {
pool: PgPool, pool: PgPool,
limit: i16, limit: i16,
schema: Option<String>,
_phantom: std::marker::PhantomData<D>, _phantom: std::marker::PhantomData<D>,
} }
@ -119,62 +123,91 @@ impl<D> PgKeyPoolStorage<D>
where where
D: PgKeyDomain, D: PgKeyDomain,
{ {
pub fn new(pool: PgPool, limit: i16) -> Self { pub fn new(pool: PgPool, limit: i16, schema: Option<String>) -> Self {
Self { Self {
pool, pool,
limit, limit,
schema,
_phantom: Default::default(), _phantom: Default::default(),
} }
} }
pub async fn initialise(&self) -> Result<(), PgStorageError<D>> { fn table_name(&self) -> String {
sqlx::query(indoc! {r#" match self.schema.as_ref() {
CREATE TABLE IF NOT EXISTS api_keys ( Some(schema) => format!("{schema}.api_keys"),
None => "api_keys".to_owned(),
}
}
fn unique_array_fn(&self) -> String {
match self.schema.as_ref() {
Some(schema) => format!("{schema}.__unique_jsonb_array"),
None => "__unique_jsonb_array".to_owned(),
}
}
fn filter_array_fn(&self) -> String {
match self.schema.as_ref() {
Some(schema) => format!("{schema}.__filter_jsonb_array"),
None => "__filter_jsonb_array".to_owned(),
}
}
pub async fn initialise(&self) -> Result<(), PgKeyPoolError<D>> {
if let Some(schema) = self.schema.as_ref() {
sqlx::query(&format!("create schema if not exists {}", schema))
.execute(&self.pool)
.await?;
}
sqlx::query(&formatdoc! {r#"
CREATE TABLE IF NOT EXISTS {} (
id serial primary key, id serial primary key,
user_id int4 not null, user_id int4 not null,
key char(16) not null, key char(16) not null,
uses int2 not null default 0, uses int2 not null default 0,
domains jsonb not null default '{}'::jsonb, domains jsonb not null default '{{}}'::jsonb,
last_used timestamptz not null default now(), last_used timestamptz not null default now(),
flag int2, flag int2,
cooldown timestamptz, cooldown timestamptz,
constraint "uq:api_keys.key" UNIQUE(key) constraint "uq:api_keys.key" UNIQUE(key)
)"# )"#,
self.table_name()
}) })
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
sqlx::query(indoc! {r#" sqlx::query(&formatdoc! {r#"
CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON api_keys USING GIN(domains jsonb_path_ops) CREATE INDEX IF NOT EXISTS "idx:api_keys.domains" ON {} USING GIN(domains jsonb_path_ops)
"#}) "#, self.table_name()})
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
sqlx::query(indoc! {r#" sqlx::query(&formatdoc! {r#"
CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON api_keys USING BTREE(user_id) CREATE INDEX IF NOT EXISTS "idx:api_keys.user_id" ON {} USING BTREE(user_id)
"#}) "#, self.table_name()})
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
sqlx::query(indoc! {r#" sqlx::query(&formatdoc! {r#"
create or replace function __unique_jsonb_array(jsonb) returns jsonb create or replace function {}(jsonb) returns jsonb
AS $$ AS $$
select jsonb_agg(d::jsonb) from ( select jsonb_agg(d::jsonb) from (
select distinct jsonb_array_elements_text($1) as d select distinct jsonb_array_elements_text($1) as d
) t ) t
$$ language sql; $$ language sql;
"#}) "#, self.unique_array_fn()})
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
sqlx::query(indoc! {r#" sqlx::query(&formatdoc! {r#"
create or replace function __filter_jsonb_array(jsonb, jsonb) returns jsonb create or replace function {}(jsonb, jsonb) returns jsonb
AS $$ AS $$
select jsonb_agg(d::jsonb) from ( select jsonb_agg(d::jsonb) from (
select distinct jsonb_array_elements_text($1) as d select distinct jsonb_array_elements_text($1) as d
) t where d<>$2::text ) t where d<>$2::text
$$ language sql; $$ language sql;
"#}) "#, self.filter_array_fn()})
.execute(&self.pool) .execute(&self.pool)
.await?; .await?;
@ -184,19 +217,11 @@ where
#[cfg(feature = "tokio-runtime")] #[cfg(feature = "tokio-runtime")]
async fn random_sleep() { async fn random_sleep() {
use rand::{thread_rng, Rng}; use rand::{rng, Rng};
let dur = tokio::time::Duration::from_millis(thread_rng().gen_range(1..50)); let dur = tokio::time::Duration::from_millis(rng().random_range(1..50));
tokio::time::sleep(dur).await; tokio::time::sleep(dur).await;
} }
#[cfg(all(not(feature = "tokio-runtime"), feature = "actix-runtime"))]
async fn random_sleep() {
use rand::{thread_rng, Rng};
let dur = std::time::Duration::from_millis(thread_rng().gen_range(1..50));
actix_rt::time::sleep(dur).await;
}
#[async_trait]
impl<D> KeyPoolStorage for PgKeyPoolStorage<D> impl<D> KeyPoolStorage for PgKeyPoolStorage<D>
where where
D: PgKeyDomain, D: PgKeyDomain,
@ -204,7 +229,7 @@ where
type Key = PgKey<D>; type Key = PgKey<D>;
type Domain = D; type Domain = D;
type Error = PgStorageError<D>; type Error = PgKeyPoolError<D>;
async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error> async fn acquire_key<S>(&self, selector: S) -> Result<Self::Key, Self::Error>
where where
@ -219,54 +244,52 @@ where
.execute(&mut *tx) .execute(&mut *tx)
.await?; .await?;
let mut qb = QueryBuilder::new(indoc::indoc! { let mut qb = QueryBuilder::new(&formatdoc! {
r#" r#"
with key as ( with key as (
select select
id, id,
0::int2 as uses 0::int2 as uses
from api_keys where last_used < date_trunc('minute', now()) from {} where last_used < date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and "# and "#,
self.table_name()
}); });
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! { qb.push(formatdoc! {
" "
\n union ( \n union (
select id, uses from api_keys select id, uses from {}
where last_used >= date_trunc('minute', now()) where last_used >= date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and " and ",
self.table_name()
}); });
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! { qb.push(formatdoc! {
" "
\n order by uses asc limit 1 \n order by uses asc limit 1
) )
order by uses asc limit 1 order by uses asc limit 1
) )
update api_keys set update {} as keys set
uses = key.uses + 1, uses = key.uses + 1,
cooldown = null, cooldown = null,
flag = null, flag = null,
last_used = now() last_used = now()
from key where from key where
api_keys.id=key.id and key.uses < " keys.id=key.id and key.uses < ",
self.table_name()
}); });
qb.push_bind(self.limit); qb.push_bind(self.limit);
qb.push(indoc::indoc! { " qb.push(indoc::indoc! { "
\nreturning \nreturning keys.id, keys.user_id, keys.key, keys.uses, keys.domains"
api_keys.id,
api_keys.user_id,
api_keys.key,
api_keys.uses,
api_keys.domains"
}); });
let key = qb.build_query_as().fetch_optional(&mut *tx).await?; let key = qb.build_query_as().fetch_optional(&mut *tx).await?;
@ -280,13 +303,23 @@ where
match attempt { match attempt {
Ok(Some(result)) => return Ok(result), Ok(Some(result)) => return Ok(result),
Ok(None) => { Ok(None) => {
return self fn recurse<D>(
.acquire_key( storage: &PgKeyPoolStorage<D>,
selector: KeySelector<PgKey<D>, D>,
) -> BoxFuture<Result<PgKey<D>, PgKeyPoolError<D>>>
where
D: PgKeyDomain,
{
Box::pin(storage.acquire_key(selector))
}
return recurse(
self,
selector selector
.fallback() .fallback()
.ok_or_else(|| PgStorageError::Unavailable(selector))?, .ok_or_else(|| PgKeyPoolError::Unavailable(selector))?,
) )
.await .await;
} }
Err(error) => { Err(error) => {
if let Some(db_error) = error.as_database_error() { if let Some(db_error) = error.as_database_error() {
@ -321,19 +354,20 @@ where
.execute(&mut *tx) .execute(&mut *tx)
.await?; .await?;
let mut qb = QueryBuilder::new(indoc::indoc! { let mut qb = QueryBuilder::new(&formatdoc! {
r#"select r#"select
id, id,
user_id, user_id,
key, key,
0::int2 as uses, 0::int2 as uses,
domains domains
from api_keys where last_used < date_trunc('minute', now()) from {} where last_used < date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and "# and "#,
self.table_name()
}); });
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.push(indoc::indoc! { qb.push(formatdoc! {
" "
\nunion \nunion
select select
@ -342,9 +376,10 @@ where
key, key,
uses, uses,
domains domains
from api_keys where last_used >= date_trunc('minute', now()) from {} where last_used >= date_trunc('minute', now())
and (cooldown is null or now() >= cooldown) and (cooldown is null or now() >= cooldown)
and " and ",
self.table_name()
}); });
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.push("\norder by uses limit "); qb.push("\norder by uses limit ");
@ -365,7 +400,7 @@ where
let available = max.uses - key.uses; let available = max.uses - key.uses;
let using = std::cmp::min(available, (number as i16) - (result.len() as i16)); let using = std::cmp::min(available, (number as i16) - (result.len() as i16));
key.uses += using; key.uses += using;
result.extend(std::iter::repeat(key.clone()).take(using as usize)); result.extend(std::iter::repeat_n(key.clone(), using as usize));
if result.len() == number as usize { if result.len() == number as usize {
break; break;
@ -383,15 +418,15 @@ where
result.extend_from_slice(slice); result.extend_from_slice(slice);
} }
sqlx::query(indoc! {r#" sqlx::query(&formatdoc! {r#"
update api_keys set update {} keys set
uses = tmp.uses, uses = tmp.uses,
cooldown = null, cooldown = null,
flag = null, flag = null,
last_used = now() last_used = now()
from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp from (select unnest($1::int4[]) as id, unnest($2::int2[]) as uses) as tmp
where api_keys.id = tmp.id where keys.id = tmp.id
"#}) "#, self.table_name()})
.bind(keys.iter().map(|k| k.id).collect::<Vec<_>>()) .bind(keys.iter().map(|k| k.id).collect::<Vec<_>>())
.bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>()) .bind(keys.iter().map(|k| k.uses).collect::<Vec<_>>())
.execute(&mut *tx) .execute(&mut *tx)
@ -406,14 +441,25 @@ where
match attempt { match attempt {
Ok(Some(result)) => return Ok(result), Ok(Some(result)) => return Ok(result),
Ok(None) => { Ok(None) => {
return self fn recurse<D>(
.acquire_many_keys( storage: &PgKeyPoolStorage<D>,
selector: KeySelector<PgKey<D>, D>,
number: i64,
) -> BoxFuture<Result<Vec<PgKey<D>>, PgKeyPoolError<D>>>
where
D: PgKeyDomain,
{
Box::pin(storage.acquire_many_keys(selector, number))
}
return recurse(
self,
selector selector
.fallback() .fallback()
.ok_or_else(|| Self::Error::Unavailable(selector))?, .ok_or_else(|| Self::Error::Unavailable(selector))?,
number, number,
) )
.await .await;
} }
Err(error) => { Err(error) => {
if let Some(db_error) = error.as_database_error() { if let Some(db_error) = error.as_database_error() {
@ -431,57 +477,27 @@ where
} }
} }
async fn flag_key(&self, key: Self::Key, code: u8) -> Result<bool, Self::Error> { async fn timeout_key<S>(
match code { &self,
2 | 10 | 13 => { selector: S,
// invalid key, owner fedded or owner inactive duration: std::time::Duration,
sqlx::query( ) -> Result<(), Self::Error>
"update api_keys set cooldown='infinity'::timestamptz, flag=$1 where id=$2", where
) S: IntoSelector<Self::Key, Self::Domain>,
.bind(code as i16) {
.bind(key.id) let selector = selector.into_selector();
.execute(&self.pool)
.await?; let mut qb = QueryBuilder::new(format!(
Ok(true) "update {} set cooldown=now() + ",
} self.table_name()
5 => { ));
// too many requests qb.push_bind(duration);
sqlx::query( qb.push(" where ");
"update api_keys set cooldown=date_trunc('min', now()) + interval '1 min', \ build_predicate(&mut qb, &selector);
flag=5 where id=$1",
) qb.build().fetch_optional(&self.pool).await?;
.bind(key.id)
.execute(&self.pool) Ok(())
.await?;
Ok(true)
}
8 => {
// IP block
sqlx::query("update api_keys set cooldown=now() + interval '5 min', flag=8")
.execute(&self.pool)
.await?;
Ok(false)
}
9 => {
// API disabled
sqlx::query("update api_keys set cooldown=now() + interval '1 min', flag=9")
.execute(&self.pool)
.await?;
Ok(false)
}
14 => {
// daily read limit reached
sqlx::query(
"update api_keys set cooldown=date_trunc('day', now()) + interval '1 day', \
flag=14 where id=$1",
)
.bind(key.id)
.execute(&self.pool)
.await?;
Ok(true)
}
_ => Ok(false),
}
} }
async fn store_key( async fn store_key(
@ -490,11 +506,13 @@ where
key: String, key: String,
domains: Vec<D>, domains: Vec<D>,
) -> Result<Self::Key, Self::Error> { ) -> Result<Self::Key, Self::Error> {
sqlx::query_as( sqlx::query_as(&dbg!(formatdoc!(
"insert into api_keys(user_id, key, domains) values ($1, $2, $3) on conflict on \ "insert into {} as api_keys(user_id, key, domains) values ($1, $2, $3)
constraint \"uq:api_keys.key\" do update set domains = \ on conflict(key) do update
__unique_jsonb_array(excluded.domains || api_keys.domains) returning *", set domains = {}(excluded.domains || api_keys.domains) returning *",
) self.table_name(),
self.unique_array_fn()
)))
.bind(user_id) .bind(user_id)
.bind(&key) .bind(&key)
.bind(sqlx::types::Json(domains)) .bind(sqlx::types::Json(domains))
@ -509,7 +527,7 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new("select * from api_keys where "); let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.build_query_as() qb.build_query_as()
@ -524,7 +542,7 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new("select * from api_keys where "); let mut qb = QueryBuilder::new(format!("select * from {} where ", self.table_name()));
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.build_query_as() qb.build_query_as()
@ -539,14 +557,14 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new("delete from api_keys where "); let mut qb = QueryBuilder::new(format!("delete from {} where ", self.table_name()));
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
qb.push(" returning *"); qb.push(" returning *");
qb.build_query_as() qb.build_query_as()
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await? .await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)) .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
} }
async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error> async fn add_domain_to_key<S>(&self, selector: S, domain: D) -> Result<Self::Key, Self::Error>
@ -555,9 +573,11 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new( let mut qb = QueryBuilder::new(format!(
"update api_keys set domains = __unique_jsonb_array(domains || jsonb_build_array(", "update {} set domains = {}(domains || jsonb_build_array(",
); self.table_name(),
self.unique_array_fn()
));
qb.push_bind(sqlx::types::Json(domain)); qb.push_bind(sqlx::types::Json(domain));
qb.push(")) where "); qb.push(")) where ");
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
@ -566,7 +586,7 @@ where
qb.build_query_as() qb.build_query_as()
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await? .await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)) .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
} }
async fn remove_domain_from_key<S>( async fn remove_domain_from_key<S>(
@ -579,9 +599,11 @@ where
{ {
let selector = selector.into_selector(); let selector = selector.into_selector();
let mut qb = QueryBuilder::new( let mut qb = QueryBuilder::new(format!(
"update api_keys set domains = coalesce(__filter_jsonb_array(domains, ", "update {} set domains = coalesce({}(domains, ",
); self.table_name(),
self.filter_array_fn()
));
qb.push_bind(sqlx::types::Json(domain)); qb.push_bind(sqlx::types::Json(domain));
qb.push("), '[]'::jsonb) where "); qb.push("), '[]'::jsonb) where ");
build_predicate(&mut qb, &selector); build_predicate(&mut qb, &selector);
@ -590,7 +612,7 @@ where
qb.build_query_as() qb.build_query_as()
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await? .await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)) .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
} }
async fn set_domains_for_key<S>( async fn set_domains_for_key<S>(
@ -612,13 +634,13 @@ where
qb.build_query_as() qb.build_query_as()
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await? .await?
.ok_or_else(|| PgStorageError::KeyNotFound(selector)) .ok_or_else(|| PgKeyPoolError::KeyNotFound(selector))
} }
} }
#[cfg(test)] #[cfg(test)]
pub(crate) mod test { pub(crate) mod test {
use std::sync::Arc; use std::{sync::Arc, time::Duration};
use sqlx::Row; use sqlx::Row;
@ -648,11 +670,11 @@ pub(crate) mod test {
.await .await
.unwrap(); .unwrap();
let storage = PgKeyPoolStorage::new(pool.clone(), 1000); let storage = PgKeyPoolStorage::new(pool.clone(), 1000, Some("test".to_owned()));
storage.initialise().await.unwrap(); storage.initialise().await.unwrap();
let key = storage let key = storage
.store_key(1, std::env::var("APIKEY").unwrap(), vec![Domain::All]) .store_key(1, std::env::var("API_KEY").unwrap(), vec![Domain::All])
.await .await
.unwrap(); .unwrap();
@ -816,34 +838,6 @@ pub(crate) mod test {
} }
} }
#[sqlx::test]
async fn test_flag_key_one(pool: PgPool) {
let (storage, key) = setup(pool).await;
assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_key(Domain::All).await.unwrap_err() {
PgStorageError::Unavailable(KeySelector::Has(domains)) => {
assert_eq!(domains, vec![Domain::All])
}
why => panic!("Expected domain unavailable error but found '{why}'"),
}
}
#[sqlx::test]
async fn test_flag_key_many(pool: PgPool) {
let (storage, key) = setup(pool).await;
assert!(storage.flag_key(key, 2).await.unwrap());
match storage.acquire_many_keys(Domain::All, 5).await.unwrap_err() {
PgStorageError::Unavailable(KeySelector::Has(domains)) => {
assert_eq!(domains, vec![Domain::All])
}
why => panic!("Expected domain unavailable error but found '{why}'"),
}
}
#[sqlx::test] #[sqlx::test]
async fn acquire_many(pool: PgPool) { async fn acquire_many(pool: PgPool) {
let (storage, _) = setup(pool).await; let (storage, _) = setup(pool).await;
@ -873,7 +867,7 @@ pub(crate) mod test {
set.join_next().await.unwrap().unwrap(); set.join_next().await.unwrap().unwrap();
} }
let uses: i16 = sqlx::query("select uses from api_keys") let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
.fetch_one(&storage.pool) .fetch_one(&storage.pool)
.await .await
.unwrap() .unwrap()
@ -881,7 +875,7 @@ pub(crate) mod test {
assert_eq!(uses, 100); assert_eq!(uses, 100);
sqlx::query("update api_keys set uses=0") sqlx::query(&format!("update {} set uses=0", storage.table_name()))
.execute(&storage.pool) .execute(&storage.pool)
.await .await
.unwrap(); .unwrap();
@ -921,7 +915,7 @@ pub(crate) mod test {
assert_eq!(key.uses, 2); assert_eq!(key.uses, 2);
} }
sqlx::query("update api_keys set uses=0") sqlx::query(&format!("update {} set uses=0", storage.table_name()))
.execute(&storage.pool) .execute(&storage.pool)
.await .await
.unwrap(); .unwrap();
@ -946,7 +940,7 @@ pub(crate) mod test {
set.join_next().await.unwrap().unwrap(); set.join_next().await.unwrap().unwrap();
} }
let uses: i16 = sqlx::query("select uses from api_keys") let uses: i16 = sqlx::query(&format!("select uses from {}", storage.table_name()))
.fetch_one(&storage.pool) .fetch_one(&storage.pool)
.await .await
.unwrap() .unwrap()
@ -954,7 +948,7 @@ pub(crate) mod test {
assert_eq!(uses, 500); assert_eq!(uses, 500);
sqlx::query("update api_keys set uses=0") sqlx::query(&format!("update {} set uses=0", storage.table_name()))
.execute(&storage.pool) .execute(&storage.pool)
.await .await
.unwrap(); .unwrap();
@ -1025,6 +1019,16 @@ pub(crate) mod test {
assert!(key.is_some()); assert!(key.is_some());
} }
#[sqlx::test]
async fn timeout(pool: PgPool) {
let (storage, key) = setup(pool).await;
storage
.timeout_key(KeySelector::Id(key.id()), Duration::from_secs(60))
.await
.unwrap();
}
#[sqlx::test] #[sqlx::test]
async fn query_by_set(pool: PgPool) { async fn query_by_set(pool: PgPool) {
let (storage, _key) = setup(pool).await; let (storage, _key) = setup(pool).await;

View file

@ -1,380 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use torn_api::{
send::{ApiClient, ApiProvider, RequestExecutor},
ApiRequest, ApiResponse, ApiSelection, ResponseError,
};
use crate::{
ApiKey, IntoSelector, KeyAction, KeyDomain, KeyPoolError, KeyPoolExecutor, KeyPoolStorage,
KeySelector, PoolOptions,
};
#[async_trait]
impl<'client, C, S> RequestExecutor<C> for KeyPoolExecutor<'client, C, S>
where
C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
type Error = KeyPoolError<S::Error, C::Error>;
async fn execute<A>(
&self,
client: &C,
mut request: ApiRequest<A>,
id: Option<String>,
) -> Result<A::Response, Self::Error>
where
A: ApiSelection,
{
if request.comment.is_none() {
request.comment = self.options.comment.clone();
}
if let Some(hook) = self.options.hooks_before.get(&std::any::TypeId::of::<A>()) {
let concrete = hook
.downcast_ref::<BeforeHook<A, S::Key, S::Domain>>()
.unwrap();
(concrete.body)(&mut request, &self.selector);
}
loop {
let key = self
.storage
.acquire_key(self.selector.clone())
.await
.map_err(KeyPoolError::Storage)?;
let url = request.url(key.value(), id.as_deref());
let value = client.request(url).await?;
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
if !self
.storage
.flag_key(key, code)
.await
.map_err(KeyPoolError::Storage)?
{
return Err(KeyPoolError::Response(ResponseError::Api { code, reason }));
}
}
Err(parsing_error) => return Err(KeyPoolError::Response(parsing_error)),
Ok(res) => {
let res = res.into();
if let Some(hook) = self.options.hooks_after.get(&std::any::TypeId::of::<A>()) {
let concrete = hook
.downcast_ref::<AfterHook<A, S::Key, S::Domain>>()
.unwrap();
match (concrete.body)(&res, &self.selector) {
Err(KeyAction::Delete) => {
self.storage
.remove_key(key.selector())
.await
.map_err(KeyPoolError::Storage)?;
continue;
}
Err(KeyAction::RemoveDomain(domain)) => {
self.storage
.remove_domain_from_key(key.selector(), domain)
.await
.map_err(KeyPoolError::Storage)?;
continue;
}
_ => (),
};
}
return Ok(res);
}
};
}
}
async fn execute_many<A, I>(
&self,
client: &C,
mut request: ApiRequest<A>,
ids: Vec<I>,
) -> HashMap<I, Result<A::Response, Self::Error>>
where
A: ApiSelection,
I: ToString + std::hash::Hash + std::cmp::Eq + Send + Sync,
{
let keys = match self
.storage
.acquire_many_keys(self.selector.clone(), ids.len() as i64)
.await
{
Ok(keys) => keys,
Err(why) => {
return ids
.into_iter()
.map(|i| (i, Err(Self::Error::Storage(why.clone()))))
.collect();
}
};
if request.comment.is_none() {
request.comment = self.options.comment.clone();
}
let request_ref = &request;
let tuples =
futures::future::join_all(std::iter::zip(ids, keys).map(|(id, mut key)| async move {
let id_string = id.to_string();
loop {
let url = request_ref.url(key.value(), Some(&id_string));
let value = match client.request(url).await {
Ok(v) => v,
Err(why) => return (id, Err(Self::Error::Client(why))),
};
match ApiResponse::from_value(value) {
Err(ResponseError::Api { code, reason }) => {
match self.storage.flag_key(key, code).await {
Ok(false) => {
return (
id,
Err(KeyPoolError::Response(ResponseError::Api {
code,
reason,
})),
)
}
Ok(true) => (),
Err(why) => return (id, Err(KeyPoolError::Storage(why))),
}
}
Err(parsing_error) => {
return (id, Err(KeyPoolError::Response(parsing_error)))
}
Ok(res) => return (id, Ok(res.into())),
};
key = match self.storage.acquire_key(self.selector.clone()).await {
Ok(k) => k,
Err(why) => return (id, Err(Self::Error::Storage(why))),
};
}
}))
.await;
HashMap::from_iter(tuples)
}
}
#[allow(clippy::type_complexity)]
pub struct BeforeHook<A, K, D>
where
A: ApiSelection,
K: ApiKey,
D: KeyDomain,
{
body: Box<dyn Fn(&mut ApiRequest<A>, &KeySelector<K, D>) + Send + Sync + 'static>,
}
#[allow(clippy::type_complexity)]
pub struct AfterHook<A, K, D>
where
A: ApiSelection,
K: ApiKey,
D: KeyDomain,
{
body: Box<
dyn Fn(&A::Response, &KeySelector<K, D>) -> Result<(), crate::KeyAction<D>>
+ Send
+ Sync
+ 'static,
>,
}
pub struct PoolBuilder<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
client: C,
storage: S,
options: crate::PoolOptions,
}
impl<C, S> PoolBuilder<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
pub fn new(client: C, storage: S) -> Self {
Self {
client,
storage,
options: Default::default(),
}
}
pub fn comment(mut self, c: impl ToString) -> Self {
self.options.comment = Some(c.to_string());
self
}
pub fn hook_before<A>(
mut self,
hook: impl Fn(&mut ApiRequest<A>, &KeySelector<S::Key, S::Domain>) + Send + Sync + 'static,
) -> Self
where
A: ApiSelection + 'static,
{
self.options.hooks_before.insert(
std::any::TypeId::of::<A>(),
Box::new(BeforeHook {
body: Box::new(hook),
}),
);
self
}
pub fn hook_after<A>(
mut self,
hook: impl Fn(&A::Response, &KeySelector<S::Key, S::Domain>) -> Result<(), KeyAction<S::Domain>>
+ Send
+ Sync
+ 'static,
) -> Self
where
A: ApiSelection + 'static,
{
self.options.hooks_after.insert(
std::any::TypeId::of::<A>(),
Box::new(AfterHook::<A, S::Key, S::Domain> {
body: Box::new(hook),
}),
);
self
}
pub fn build(self) -> KeyPool<C, S> {
KeyPool {
client: self.client,
storage: self.storage,
options: Arc::new(self.options),
}
}
}
#[derive(Clone, Debug)]
pub struct KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage,
{
pub client: C,
pub storage: S,
pub options: Arc<PoolOptions>,
}
impl<C, S> KeyPool<C, S>
where
C: ApiClient,
S: KeyPoolStorage + Send + Sync + 'static,
{
pub fn torn_api<I>(&self, selector: I) -> ApiProvider<C, KeyPoolExecutor<C, S>>
where
I: IntoSelector<S::Key, S::Domain>,
{
ApiProvider::new(
&self.client,
KeyPoolExecutor::new(
&self.storage,
selector.into_selector(),
self.options.clone(),
),
)
}
}
pub trait WithStorage {
fn with_storage<'a, S, I>(
&'a self,
storage: &'a S,
selector: I,
) -> ApiProvider<Self, KeyPoolExecutor<Self, S>>
where
Self: ApiClient + Sized,
S: KeyPoolStorage + Send + Sync + 'static,
I: IntoSelector<S::Key, S::Domain>,
{
ApiProvider::new(
self,
KeyPoolExecutor::new(storage, selector.into_selector(), Default::default()),
)
}
}
#[cfg(feature = "reqwest")]
impl WithStorage for reqwest::Client {}
#[cfg(all(test, feature = "postgres", feature = "reqwest"))]
mod test {
use sqlx::PgPool;
use super::*;
use crate::{
postgres::test::{setup, Domain},
KeySelector,
};
#[sqlx::test]
async fn test_pool_request(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.comment("api.rs")
.build();
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.profile().unwrap();
}
#[sqlx::test]
async fn test_with_storage_request(pool: PgPool) {
let (storage, _) = setup(pool).await;
let response = reqwest::Client::new()
.with_storage(&storage, Domain::All)
.user(|b| b)
.await
.unwrap();
_ = response.profile().unwrap();
}
#[sqlx::test]
async fn before_hook(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_before::<torn_api::user::UserSelection>(|req, _s| {
req.selections.push("crimes");
})
.build();
let response = pool.torn_api(Domain::All).user(|b| b).await.unwrap();
_ = response.crimes().unwrap();
}
#[sqlx::test]
async fn after_hook(pool: PgPool) {
let (storage, _) = setup(pool).await;
let pool = PoolBuilder::new(reqwest::Client::default(), storage)
.hook_after::<torn_api::user::UserSelection>(|_res, _s| Err(KeyAction::Delete))
.build();
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_some());
let response = pool.torn_api(Domain::All).user(|b| b).await;
assert!(matches!(response, Err(KeyPoolError::Storage(_))));
let key = pool.storage.read_key(KeySelector::Id(1)).await.unwrap();
assert!(key.is_none());
}
}