torn-api.rs/torn-api-codegen/src/model/enum.rs

543 lines
16 KiB
Rust

use heck::{ToSnakeCase, ToUpperCamelCase};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::Ident;
use crate::openapi::{
parameter::OpenApiParameterSchema,
r#type::{OpenApiType, OpenApiVariants},
};
use super::{object::PrimitiveType, Model, ResolvedSchema};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnumRepr {
U8,
U32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EnumVariantTupleValue {
Ref { ty_name: String },
ArrayOfRefs { ty_name: String },
Primitive(PrimitiveType),
Enum { name: String, inner: Enum },
}
impl EnumVariantTupleValue {
pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
match schema {
OpenApiType {
ref_path: Some(path),
..
} => Some(Self::Ref {
ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
}),
OpenApiType {
r#type: Some("array"),
items: Some(items),
..
} => {
let OpenApiType {
ref_path: Some(path),
..
} = items.as_ref()
else {
return None;
};
Some(Self::ArrayOfRefs {
ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
})
}
OpenApiType {
r#type: Some("string"),
format: None,
r#enum: None,
..
} => Some(Self::Primitive(PrimitiveType::String)),
OpenApiType {
r#type: Some("string"),
format: None,
r#enum: Some(_),
..
} => {
let name = format!("{name}Variant");
Some(Self::Enum {
inner: Enum::from_schema(&name, schema)?,
name,
})
}
OpenApiType {
r#type: Some("integer"),
format: Some("int64"),
..
} => Some(Self::Primitive(PrimitiveType::I64)),
OpenApiType {
r#type: Some("integer"),
format: Some("int32"),
..
} => Some(Self::Primitive(PrimitiveType::I32)),
_ => None,
}
}
pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream {
match self {
Self::Ref { ty_name } => {
let ty = format_ident!("{ty_name}");
quote! { crate::models::#ty }
}
Self::ArrayOfRefs { ty_name } => {
let ty = format_ident!("{ty_name}");
quote! { Vec<crate::models::#ty> }
}
Self::Primitive(PrimitiveType::I64) => quote! { i64 },
Self::Primitive(PrimitiveType::I32) => quote! { i32 },
Self::Primitive(PrimitiveType::Float) => quote! { f32 },
Self::Primitive(PrimitiveType::String) => quote! { String },
Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime<chrono::Utc> },
Self::Primitive(PrimitiveType::Bool) => quote! { bool },
Self::Enum { name, .. } => {
let path = ns.get_ident();
let ty_name = format_ident!("{name}");
quote! {
#path::#ty_name,
}
}
}
}
pub fn name(&self) -> String {
match self {
Self::Ref { ty_name } => ty_name.clone(),
Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"),
Self::Primitive(PrimitiveType::I64) => "I64".to_owned(),
Self::Primitive(PrimitiveType::I32) => "I32".to_owned(),
Self::Primitive(PrimitiveType::Float) => "Float".to_owned(),
Self::Primitive(PrimitiveType::String) => "String".to_owned(),
Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(),
Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(),
Self::Enum { .. } => "Variant".to_owned(),
}
}
pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
match self {
Self::Primitive(_) => true,
Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
.models
.get(ty_name)
.map(|f| f.is_display(resolved))
.unwrap_or_default(),
Self::Enum { inner, .. } => inner.is_display(resolved),
}
}
pub fn codegen_display(&self) -> TokenStream {
match self {
Self::ArrayOfRefs { .. } => quote! {
write!(f, "{}", value.iter().map(ToString::to_string).collect::<Vec<_>>().join(","))
},
_ => quote! {
write!(f, "{}", value)
},
}
}
pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
match self {
Self::Primitive(_) => true,
Self::Enum { inner, .. } => inner.is_comparable(resolved),
Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
.models
.get(ty_name)
.map(|m| matches!(m, Model::Newtype(_)))
.unwrap_or_default(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EnumVariantValue {
Repr(u32),
String { rename: Option<String> },
Tuple(Vec<EnumVariantTupleValue>),
}
impl Default for EnumVariantValue {
fn default() -> Self {
Self::String { rename: None }
}
}
impl EnumVariantValue {
pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
match self {
Self::Repr(_) | Self::String { .. } => true,
Self::Tuple(val) => {
val.len() == 1
&& val
.iter()
.next()
.map(|v| v.is_display(resolved))
.unwrap_or_default()
}
}
}
pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
match self {
Self::Repr(_) | Self::String { .. } => true,
Self::Tuple(values) => values.iter().all(|v| v.is_comparable(resolved)),
}
}
pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
let variant = format_ident!("{name}");
match self {
Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }),
Self::String { rename } => {
let name = rename.as_deref().unwrap_or(name);
Some(quote! { Self::#variant => write!(f, #name) })
}
Self::Tuple(values) if values.len() == 1 => {
let rhs = values.first().unwrap().codegen_display();
Some(quote! { Self::#variant(value) => #rhs })
}
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct EnumVariant {
pub name: String,
pub description: Option<String>,
pub value: EnumVariantValue,
}
pub struct EnumNamespace<'e> {
r#enum: &'e Enum,
ident: Option<Ident>,
elements: Vec<TokenStream>,
top_level_elements: Vec<TokenStream>,
}
impl EnumNamespace<'_> {
pub fn get_ident(&mut self) -> Ident {
self.ident
.get_or_insert_with(|| {
let name = self.r#enum.name.to_snake_case();
format_ident!("{name}")
})
.clone()
}
pub fn push_element(&mut self, el: TokenStream) {
self.elements.push(el);
}
pub fn push_top_level(&mut self, el: TokenStream) {
self.top_level_elements.push(el);
}
pub fn codegen(mut self) -> Option<TokenStream> {
if self.elements.is_empty() && self.top_level_elements.is_empty() {
None
} else {
let top_level = &self.top_level_elements;
let mut output = quote! {
#(#top_level)*
};
if !self.elements.is_empty() {
let ident = self.get_ident();
let elements = self.elements;
output.extend(quote! {
pub mod #ident {
#(#elements)*
}
});
}
Some(output)
}
}
}
impl EnumVariant {
pub fn codegen(
&self,
ns: &mut EnumNamespace,
resolved: &ResolvedSchema,
) -> Option<TokenStream> {
let doc = self.description.as_ref().map(|d| {
quote! {
#[doc = #d]
}
});
let name = format_ident!("{}", self.name);
match &self.value {
EnumVariantValue::Repr(repr) => Some(quote! {
#doc
#name = #repr
}),
EnumVariantValue::String { rename } => {
let serde_attr = rename.as_ref().map(|r| {
quote! {
#[serde(rename = #r)]
}
});
Some(quote! {
#doc
#serde_attr
#name
})
}
EnumVariantValue::Tuple(values) => {
let mut val_tys = Vec::with_capacity(values.len());
if let [value] = values.as_slice() {
let enum_name = format_ident!("{}", ns.r#enum.name);
let ty_name = value.type_name(ns);
ns.push_top_level(quote! {
impl From<#ty_name> for #enum_name {
fn from(value: #ty_name) -> Self {
Self::#name(value)
}
}
});
}
for value in values {
let ty_name = value.type_name(ns);
if let EnumVariantTupleValue::Enum { inner, .. } = &value {
ns.push_element(inner.codegen(resolved)?);
}
val_tys.push(ty_name);
}
Some(quote! {
#name(#(#val_tys),*)
})
}
}
}
pub fn codegen_display(&self) -> Option<TokenStream> {
self.value.codegen_display(&self.name)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Enum {
pub name: String,
pub description: Option<String>,
pub repr: Option<EnumRepr>,
pub copy: bool,
pub untagged: bool,
pub variants: Vec<EnumVariant>,
}
impl Enum {
pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
let mut result = Enum {
name: name.to_owned(),
description: schema.description.as_deref().map(ToOwned::to_owned),
copy: true,
..Default::default()
};
match &schema.r#enum {
Some(OpenApiVariants::Int(int_variants)) => {
result.repr = Some(EnumRepr::U32);
result.variants = int_variants
.iter()
.copied()
.map(|i| EnumVariant {
name: format!("Variant{i}"),
value: EnumVariantValue::Repr(i as u32),
..Default::default()
})
.collect();
}
Some(OpenApiVariants::Str(str_variants)) => {
result.variants = str_variants
.iter()
.copied()
.map(|s| {
let transformed = s.replace('&', "And").to_upper_camel_case();
EnumVariant {
value: EnumVariantValue::String {
rename: (transformed != s).then(|| s.to_owned()),
},
name: transformed,
..Default::default()
}
})
.collect();
}
None => return None,
}
Some(result)
}
pub fn from_parameter_schema(name: &str, schema: &OpenApiParameterSchema) -> Option<Self> {
let mut result = Self {
name: name.to_owned(),
copy: true,
..Default::default()
};
for var in schema.r#enum.as_ref()? {
let transformed = var.to_upper_camel_case();
result.variants.push(EnumVariant {
value: EnumVariantValue::String {
rename: (transformed != *var).then(|| transformed.clone()),
},
name: transformed,
..Default::default()
});
}
Some(result)
}
pub fn from_one_of(name: &str, schemas: &[OpenApiType]) -> Option<Self> {
let mut result = Self {
name: name.to_owned(),
untagged: true,
..Default::default()
};
for schema in schemas {
let value = EnumVariantTupleValue::from_schema(name, schema)?;
let name = value.name();
result.variants.push(EnumVariant {
name,
value: EnumVariantValue::Tuple(vec![value]),
..Default::default()
});
}
Some(result)
}
pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
self.variants.iter().all(|v| v.value.is_display(resolved))
}
pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
self.variants
.iter()
.all(|v| v.value.is_comparable(resolved))
}
pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
let repr = self.repr.map(|r| match r {
EnumRepr::U8 => quote! { #[repr(u8)]},
EnumRepr::U32 => quote! { #[repr(u32)]},
});
let name = format_ident!("{}", self.name);
let desc = self.description.as_ref().map(|d| {
quote! {
#repr
#[doc = #d]
}
});
let mut ns = EnumNamespace {
r#enum: self,
ident: None,
elements: Default::default(),
top_level_elements: Default::default(),
};
let is_display = self.is_display(resolved);
let mut display = Vec::with_capacity(self.variants.len());
let mut variants = Vec::with_capacity(self.variants.len());
for variant in &self.variants {
variants.push(variant.codegen(&mut ns, resolved)?);
if is_display {
display.push(variant.codegen_display()?);
}
}
let mut derives = vec![];
if self.repr.is_some() {
derives.push(quote! { serde_repr::Deserialize_repr });
} else {
derives.push(quote! { serde::Deserialize });
}
if self.copy {
derives.push(quote! { Copy });
}
if self.is_comparable(resolved) {
derives.push(quote! { Eq, Hash });
}
let serde_attr = self.untagged.then(|| {
quote! {
#[serde(untagged)]
}
});
let display = is_display.then(|| {
quote! {
impl std::fmt::Display for #name {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
#(#display),*
}
}
}
}
});
let module = ns.codegen();
Some(quote! {
#desc
#[derive(Debug, Clone, PartialEq, #(#derives),*)]
#[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))]
#serde_attr
pub enum #name {
#(#variants),*
}
#display
#module
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::openapi::schema::test::get_schema;
#[test]
fn is_display() {
let schema = get_schema();
let resolved = ResolvedSchema::from_open_api(&schema);
let torn_selection_name = resolved.models.get("TornSelectionName").unwrap();
assert!(torn_selection_name.is_display(&resolved));
}
}