use heck::{ToSnakeCase, ToUpperCamelCase}; use proc_macro2::TokenStream; use quote::{format_ident, quote}; use syn::Ident; use crate::openapi::{ parameter::OpenApiParameterSchema, r#type::{OpenApiType, OpenApiVariants}, }; use super::{object::PrimitiveType, Model, ResolvedSchema}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum EnumRepr { U8, U32, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum EnumVariantTupleValue { Ref { ty_name: String }, ArrayOfRefs { ty_name: String }, Primitive(PrimitiveType), Enum { name: String, inner: Enum }, } impl EnumVariantTupleValue { pub fn from_schema(name: &str, schema: &OpenApiType) -> Option { match schema { OpenApiType { ref_path: Some(path), .. } => Some(Self::Ref { ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(), }), OpenApiType { r#type: Some("array"), items: Some(items), .. } => { let OpenApiType { ref_path: Some(path), .. } = items.as_ref() else { return None; }; Some(Self::ArrayOfRefs { ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(), }) } OpenApiType { r#type: Some("string"), format: None, r#enum: None, .. } => Some(Self::Primitive(PrimitiveType::String)), OpenApiType { r#type: Some("string"), format: None, r#enum: Some(_), .. } => { let name = format!("{name}Variant"); Some(Self::Enum { inner: Enum::from_schema(&name, schema)?, name, }) } OpenApiType { r#type: Some("integer"), format: Some("int64"), .. } => Some(Self::Primitive(PrimitiveType::I64)), OpenApiType { r#type: Some("integer"), format: Some("int32"), .. } => Some(Self::Primitive(PrimitiveType::I32)), _ => None, } } pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream { match self { Self::Ref { ty_name } => { let ty = format_ident!("{ty_name}"); quote! { crate::models::#ty } } Self::ArrayOfRefs { ty_name } => { let ty = format_ident!("{ty_name}"); quote! { Vec } } Self::Primitive(PrimitiveType::I64) => quote! { i64 }, Self::Primitive(PrimitiveType::I32) => quote! { i32 }, Self::Primitive(PrimitiveType::Float) => quote! { f32 }, Self::Primitive(PrimitiveType::String) => quote! { String }, Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime }, Self::Primitive(PrimitiveType::Bool) => quote! { bool }, Self::Enum { name, .. } => { let path = ns.get_ident(); let ty_name = format_ident!("{name}"); quote! { #path::#ty_name, } } } } pub fn name(&self) -> String { match self { Self::Ref { ty_name } => ty_name.clone(), Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"), Self::Primitive(PrimitiveType::I64) => "I64".to_owned(), Self::Primitive(PrimitiveType::I32) => "I32".to_owned(), Self::Primitive(PrimitiveType::Float) => "Float".to_owned(), Self::Primitive(PrimitiveType::String) => "String".to_owned(), Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(), Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(), Self::Enum { .. } => "Variant".to_owned(), } } pub fn is_display(&self, resolved: &ResolvedSchema) -> bool { match self { Self::Primitive(_) => true, Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved .models .get(ty_name) .map(|f| f.is_display(resolved)) .unwrap_or_default(), Self::Enum { inner, .. } => inner.is_display(resolved), } } pub fn codegen_display(&self) -> TokenStream { match self { Self::ArrayOfRefs { .. } => quote! { write!(f, "{}", value.iter().map(ToString::to_string).collect::>().join(",")) }, _ => quote! { write!(f, "{}", value) }, } } pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool { match self { Self::Primitive(_) => true, Self::Enum { inner, .. } => inner.is_comparable(resolved), Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved .models .get(ty_name) .map(|m| matches!(m, Model::Newtype(_))) .unwrap_or_default(), } } } #[derive(Debug, Clone, PartialEq, Eq)] pub enum EnumVariantValue { Repr(u32), String { rename: Option }, Tuple(Vec), } impl Default for EnumVariantValue { fn default() -> Self { Self::String { rename: None } } } impl EnumVariantValue { pub fn is_display(&self, resolved: &ResolvedSchema) -> bool { match self { Self::Repr(_) | Self::String { .. } => true, Self::Tuple(val) => { val.len() == 1 && val .iter() .next() .map(|v| v.is_display(resolved)) .unwrap_or_default() } } } pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool { match self { Self::Repr(_) | Self::String { .. } => true, Self::Tuple(values) => values.iter().all(|v| v.is_comparable(resolved)), } } pub fn codegen_display(&self, name: &str) -> Option { let variant = format_ident!("{name}"); match self { Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }), Self::String { rename } => { let name = rename.as_deref().unwrap_or(name); Some(quote! { Self::#variant => write!(f, #name) }) } Self::Tuple(values) if values.len() == 1 => { let rhs = values.first().unwrap().codegen_display(); Some(quote! { Self::#variant(value) => #rhs }) } _ => None, } } } #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct EnumVariant { pub name: String, pub description: Option, pub value: EnumVariantValue, } pub struct EnumNamespace<'e> { r#enum: &'e Enum, ident: Option, elements: Vec, top_level_elements: Vec, } impl EnumNamespace<'_> { pub fn get_ident(&mut self) -> Ident { self.ident .get_or_insert_with(|| { let name = self.r#enum.name.to_snake_case(); format_ident!("{name}") }) .clone() } pub fn push_element(&mut self, el: TokenStream) { self.elements.push(el); } pub fn push_top_level(&mut self, el: TokenStream) { self.top_level_elements.push(el); } pub fn codegen(mut self) -> Option { if self.elements.is_empty() && self.top_level_elements.is_empty() { None } else { let top_level = &self.top_level_elements; let mut output = quote! { #(#top_level)* }; if !self.elements.is_empty() { let ident = self.get_ident(); let elements = self.elements; output.extend(quote! { pub mod #ident { #(#elements)* } }); } Some(output) } } } impl EnumVariant { pub fn codegen( &self, ns: &mut EnumNamespace, resolved: &ResolvedSchema, ) -> Option { let doc = self.description.as_ref().map(|d| { quote! { #[doc = #d] } }); let name = format_ident!("{}", self.name); match &self.value { EnumVariantValue::Repr(repr) => Some(quote! { #doc #name = #repr }), EnumVariantValue::String { rename } => { let serde_attr = rename.as_ref().map(|r| { quote! { #[serde(rename = #r)] } }); Some(quote! { #doc #serde_attr #name }) } EnumVariantValue::Tuple(values) => { let mut val_tys = Vec::with_capacity(values.len()); if let [value] = values.as_slice() { let enum_name = format_ident!("{}", ns.r#enum.name); let ty_name = value.type_name(ns); ns.push_top_level(quote! { impl From<#ty_name> for #enum_name { fn from(value: #ty_name) -> Self { Self::#name(value) } } }); } for value in values { let ty_name = value.type_name(ns); if let EnumVariantTupleValue::Enum { inner, .. } = &value { ns.push_element(inner.codegen(resolved)?); } val_tys.push(ty_name); } Some(quote! { #name(#(#val_tys),*) }) } } } pub fn codegen_display(&self) -> Option { self.value.codegen_display(&self.name) } } #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct Enum { pub name: String, pub description: Option, pub repr: Option, pub copy: bool, pub untagged: bool, pub variants: Vec, } impl Enum { pub fn from_schema(name: &str, schema: &OpenApiType) -> Option { let mut result = Enum { name: name.to_owned(), description: schema.description.as_deref().map(ToOwned::to_owned), copy: true, ..Default::default() }; match &schema.r#enum { Some(OpenApiVariants::Int(int_variants)) => { result.repr = Some(EnumRepr::U32); result.variants = int_variants .iter() .copied() .map(|i| EnumVariant { name: format!("Variant{i}"), value: EnumVariantValue::Repr(i as u32), ..Default::default() }) .collect(); } Some(OpenApiVariants::Str(str_variants)) => { result.variants = str_variants .iter() .copied() .map(|s| { let transformed = s.replace('&', "And").to_upper_camel_case(); EnumVariant { value: EnumVariantValue::String { rename: (transformed != s).then(|| s.to_owned()), }, name: transformed, ..Default::default() } }) .collect(); } None => return None, } Some(result) } pub fn from_parameter_schema(name: &str, schema: &OpenApiParameterSchema) -> Option { let mut result = Self { name: name.to_owned(), copy: true, ..Default::default() }; for var in schema.r#enum.as_ref()? { let transformed = var.to_upper_camel_case(); result.variants.push(EnumVariant { value: EnumVariantValue::String { rename: (transformed != *var).then(|| transformed.clone()), }, name: transformed, ..Default::default() }); } Some(result) } pub fn from_one_of(name: &str, schemas: &[OpenApiType]) -> Option { let mut result = Self { name: name.to_owned(), untagged: true, ..Default::default() }; for schema in schemas { let value = EnumVariantTupleValue::from_schema(name, schema)?; let name = value.name(); result.variants.push(EnumVariant { name, value: EnumVariantValue::Tuple(vec![value]), ..Default::default() }); } // HACK: idk let shared: Vec<_> = result .variants .iter_mut() .filter(|v| v.name == "Variant") .collect(); if shared.len() >= 2 { for (idx, variant) in shared.into_iter().enumerate() { let label = idx + 1; variant.name = format!("Variant{}", label); if let EnumVariantValue::Tuple(values) = &mut variant.value { if let [EnumVariantTupleValue::Enum { name, inner, .. }] = values.as_mut_slice() { inner.name.push_str(&label.to_string()); name.push_str(&label.to_string()); } } } } Some(result) } pub fn is_display(&self, resolved: &ResolvedSchema) -> bool { self.variants.iter().all(|v| v.value.is_display(resolved)) } pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool { self.variants .iter() .all(|v| v.value.is_comparable(resolved)) } pub fn codegen(&self, resolved: &ResolvedSchema) -> Option { let repr = self.repr.map(|r| match r { EnumRepr::U8 => quote! { #[repr(u8)] }, EnumRepr::U32 => quote! { #[repr(u32)] }, }); let name = format_ident!("{}", self.name); let desc = self.description.as_ref().map(|d| { quote! { #repr #[doc = #d] } }); let mut ns = EnumNamespace { r#enum: self, ident: None, elements: Default::default(), top_level_elements: Default::default(), }; let is_display = self.is_display(resolved); let mut display = Vec::with_capacity(self.variants.len()); let mut variants = Vec::with_capacity(self.variants.len()); for variant in &self.variants { variants.push(variant.codegen(&mut ns, resolved)?); if is_display { display.push(variant.codegen_display()?); } } let mut derives = vec![]; if self.repr.is_some() { derives.push(quote! { serde_repr::Deserialize_repr }); } else { derives.push(quote! { serde::Deserialize }); } if self.copy { derives.push(quote! { Copy }); } if self.is_comparable(resolved) { derives.push(quote! { Eq, Hash }); } let serde_attr = self.untagged.then(|| { quote! { #[serde(untagged)] } }); let display = is_display.then(|| { quote! { impl std::fmt::Display for #name { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { #(#display),* } } } } }); let module = ns.codegen(); Some(quote! { #desc #[derive(Debug, Clone, PartialEq, #(#derives),*)] #[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))] #serde_attr pub enum #name { #(#variants),* } #display #module }) } } #[cfg(test)] mod test { use super::*; use crate::openapi::schema::test::get_schema; #[test] fn is_display() { let schema = get_schema(); let resolved = ResolvedSchema::from_open_api(&schema); let torn_selection_name = resolved.models.get("TornSelectionName").unwrap(); assert!(torn_selection_name.is_display(&resolved)); } }