feat(codegen): derive Eq and Hash for most enum types

This commit is contained in:
pyrite 2025-05-27 19:56:03 +02:00
parent 1af37bea89
commit 3ad92fb8c8
Signed by: pyrite
GPG key ID: 7F1BA9170CD35D15
4 changed files with 583 additions and 277 deletions

View file

@ -8,7 +8,7 @@ use crate::openapi::{
r#type::{OpenApiType, OpenApiVariants},
};
use super::{object::PrimitiveType, ResolvedSchema};
use super::{object::PrimitiveType, Model, ResolvedSchema};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnumRepr {
@ -143,6 +143,18 @@ impl EnumVariantTupleValue {
},
}
}
pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
match self {
Self::Primitive(_) => true,
Self::Enum { inner, .. } => inner.is_comparable(resolved),
Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
.models
.get(ty_name)
.map(|m| matches!(m, Model::Newtype(_)))
.unwrap_or_default(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
@ -173,6 +185,13 @@ impl EnumVariantValue {
}
}
pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
match self {
Self::Repr(_) | Self::String { .. } => true,
Self::Tuple(values) => values.iter().all(|v| v.is_comparable(resolved)),
}
}
pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
let variant = format_ident!("{name}");
@ -417,6 +436,12 @@ impl Enum {
self.variants.iter().all(|v| v.value.is_display(resolved))
}
pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
self.variants
.iter()
.all(|v| v.value.is_comparable(resolved))
}
pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
let repr = self.repr.map(|r| match r {
EnumRepr::U8 => quote! { #[repr(u8)]},
@ -458,7 +483,11 @@ impl Enum {
}
if self.copy {
derives.push(quote! { Copy, Hash });
derives.push(quote! { Copy });
}
if self.is_comparable(resolved) {
derives.push(quote! { Eq, Hash });
}
let serde_attr = self.untagged.then(|| {

View file

@ -315,7 +315,7 @@ The default value [Self::{}](self::{}#variant.{})"#,
let inner_ty = items.codegen_type_name(&inner_name);
code.extend(quote! {
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct #name(pub Vec<#inner_ty>);
impl std::fmt::Display for #name {