From 639dbd221c272352a100e3084e48da0e310bd5cf Mon Sep 17 00:00:00 2001 From: andrepd Date: Mon, 17 Mar 2025 15:45:46 +0000 Subject: [PATCH] speedy-derive: support derive for structs with const generics --- speedy-derive/src/lib.rs | 20 ++++++++++++-------- tests/serialization_tests.rs | 28 ++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/speedy-derive/src/lib.rs b/speedy-derive/src/lib.rs index 6dd2f61..e7fac16 100644 --- a/speedy-derive/src/lib.rs +++ b/speedy-derive/src/lib.rs @@ -16,7 +16,7 @@ use proc_macro2::{Span, TokenStream}; use quote::ToTokens; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::TypeParam; +use syn::{ConstParam, GenericParam, TypeParam}; trait IterExt: Iterator + Sized { fn collect_vec( self ) -> Vec< Self::Item > { @@ -285,19 +285,23 @@ fn parse_primitive_ty( ty: &syn::Type ) -> Option< PrimitiveTy > { fn common_tokens( ast: &syn::DeriveInput, types: &[syn::Type], trait_variant: Trait ) -> (TokenStream, TokenStream, TokenStream) { let impl_params = { - let lifetime_params = ast.generics.lifetimes().map( |alpha| quote! { #alpha } ); - let type_params = ast.generics.type_params().map( |ty| { let ty_without_default = TypeParam { default: None, ..ty.clone() }; quote! { #ty_without_default } }); - let params = lifetime_params.chain( type_params ).collect_vec(); + let params = ast.generics.params.iter().map(|x| match x { + GenericParam::Lifetime(a) => quote! { #a }, + GenericParam::Type(t) => { let t_without_default = TypeParam { default: None, ..t.clone() }; quote! { #t_without_default } }, + GenericParam::Const(c) => { let c_without_default = ConstParam { default: None, ..c.clone() }; quote! { #c_without_default } }, + }); quote! { #(#params,)* } }; let ty_params = { - let lifetime_params = ast.generics.lifetimes().map( |alpha| quote! { #alpha } ); - let type_params = ast.generics.type_params().map( |ty| { let ident = &ty.ident; quote! { #ident } } ); - let params = lifetime_params.chain( type_params ).collect_vec(); - if params.is_empty() { + let params = ast.generics.params.iter().map(|x| match x { + GenericParam::Lifetime(a) => quote! { #a }, + GenericParam::Type(t) => { let ident = t.ident.clone(); quote! { #ident } }, + GenericParam::Const(c) => { let ident = c.ident.clone(); quote! { #ident } }, + }); + if ast.generics.params.is_empty() { quote! {} } else { quote! { < #(#params),* > } diff --git a/tests/serialization_tests.rs b/tests/serialization_tests.rs index 720fad4..78b070e 100644 --- a/tests/serialization_tests.rs +++ b/tests/serialization_tests.rs @@ -516,6 +516,16 @@ struct DerivedStructWithGenericDefault< T = u8 > { inner: T } +#[derive(PartialEq, Debug, Readable, Writable)] +struct DerivedStructWithGenericAndConst< const N: usize, T > { + inner: [T; N] +} + +#[derive(PartialEq, Debug, Readable, Writable)] +struct DerivedStructWithGenericDefaultAndConst< const N: usize, T = u8 > { + inner: [T; N] +} + #[derive(PartialEq, Debug, Readable, Writable)] struct DerivedStructWithDefaultOnEof { a: u8, @@ -1674,6 +1684,24 @@ symmetric_tests! { be = [0, 1], minimum_bytes = 2 } + derived_struct_with_generic_and_const for DerivedStructWithGenericAndConst<4, i32> { + in = DerivedStructWithGenericAndConst { inner: [1_i32, 2_i32, 3_i32, 4_i32] }, + le = [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0], + be = [0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4], + minimum_bytes = 16 + } + derived_struct_with_generic_default_and_const for DerivedStructWithGenericDefaultAndConst<4> { + in = DerivedStructWithGenericDefaultAndConst { inner: [0x0a_u8, 0x0b_u8, 0x0c_u8, 0x0d_u8] }, + le = [0x0a, 0x0b, 0x0c, 0x0d], + be = [0x0a, 0x0b, 0x0c, 0x0d], + minimum_bytes = 4 + } + derived_struct_with_generic_default_and_const_provided for DerivedStructWithGenericDefaultAndConst<2, u16> { + in = DerivedStructWithGenericDefaultAndConst { inner: [0x0a0b_u16, 0x0c0d_u16] }, + le = [0x0b, 0x0a, 0x0d, 0x0c], + be = [0x0a, 0x0b, 0x0c, 0x0d], + minimum_bytes = 4 + } derived_recursive_struct_empty for DerivedRecursiveStruct { in = DerivedRecursiveStruct { inner: Vec::new() }, le = [0, 0, 0, 0],