From 4a73b7bb4b2a3e4209f8210c64521a96f0f0b0c1 Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Fri, 19 Sep 2025 11:48:43 +0200 Subject: [PATCH] fix(versionable): use full type path in proc macro This avoids name clashes if user re-defines the type --- .../tfhe-versionable-derive/src/associated.rs | 13 ++++++++---- utils/tfhe-versionable-derive/src/lib.rs | 20 +++++++++++++------ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/utils/tfhe-versionable-derive/src/associated.rs b/utils/tfhe-versionable-derive/src/associated.rs index f31d95f5b..9de062b19 100644 --- a/utils/tfhe-versionable-derive/src/associated.rs +++ b/utils/tfhe-versionable-derive/src/associated.rs @@ -8,7 +8,7 @@ use syn::{ use crate::{ add_lifetime_param, add_trait_where_clause, add_where_lifetime_bound_to_generics, extend_where_clause, filter_unsized_bounds, parse_const_str, DESERIALIZE_TRAIT_NAME, - LIFETIME_NAME, SERIALIZE_TRAIT_NAME, + FROM_TRAIT_NAME, LIFETIME_NAME, RESULT_TYPE_NAME, SERIALIZE_TRAIT_NAME, TRY_FROM_TRAIT_NAME, }; /// Generates an impl block for the From trait. This will be: @@ -28,9 +28,11 @@ pub(crate) fn generate_from_trait_impl( from_variable_name: &str, ) -> syn::Result { let from_variable = Ident::new(from_variable_name, Span::call_site()); + let from_trait: Path = parse_const_str(FROM_TRAIT_NAME); + Ok(parse_quote! { #[automatically_derived] - impl #impl_generics From<#src> for #dest #where_clause { + impl #impl_generics #from_trait<#src> for #dest #where_clause { fn from(#from_variable: #src) -> Self { #constructor } @@ -57,11 +59,14 @@ pub(crate) fn generate_try_from_trait_impl( from_variable_name: &str, ) -> syn::Result { let from_variable = Ident::new(from_variable_name, Span::call_site()); + let result_type: Path = parse_const_str(RESULT_TYPE_NAME); + let try_from_trait: Path = parse_const_str(TRY_FROM_TRAIT_NAME); + Ok(parse_quote! { #[automatically_derived] - impl #impl_generics TryFrom<#src> for #dest #where_clause { + impl #impl_generics #try_from_trait<#src> for #dest #where_clause { type Error = #error; - fn try_from(#from_variable: #src) -> Result { + fn try_from(#from_variable: #src) -> #result_type { #constructor } } diff --git a/utils/tfhe-versionable-derive/src/lib.rs b/utils/tfhe-versionable-derive/src/lib.rs index 1c585ecca..b608c1d96 100644 --- a/utils/tfhe-versionable-derive/src/lib.rs +++ b/utils/tfhe-versionable-derive/src/lib.rs @@ -46,6 +46,7 @@ pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeE pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize"; pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize"; pub(crate) const DESERIALIZE_OWNED_TRAIT_NAME: &str = "::serde::de::DeserializeOwned"; +pub(crate) const TRY_FROM_TRAIT_NAME: &str = "::core::convert::TryFrom"; pub(crate) const FROM_TRAIT_NAME: &str = "::core::convert::From"; pub(crate) const TRY_INTO_TRAIT_NAME: &str = "::core::convert::TryInto"; pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into"; @@ -53,6 +54,8 @@ pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error"; pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync"; pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send"; pub(crate) const DEFAULT_TRAIT_NAME: &str = "::core::default::Default"; +pub(crate) const RESULT_TYPE_NAME: &str = "::core::result::Result"; +pub(crate) const VEC_TYPE_NAME: &str = "::std::vec::Vec"; pub(crate) const STATIC_LIFETIME_NAME: &str = "'static"; use associated::AssociatingTrait; @@ -240,6 +243,9 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { let unversionize_body = implementor.unversionize_method_body(&unversionize_arg_name); let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME); + let result_type: Path = parse_const_str(RESULT_TYPE_NAME); + let vec_type: Path = parse_const_str(VEC_TYPE_NAME); + quote! { #version_trait_impl @@ -269,7 +275,7 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { impl #trait_impl_generics #unversionize_trait for #input_ident #ty_generics #unversionize_trait_where_clause { - fn unversionize(#unversionize_arg_name: Self::VersionedOwned) -> Result { + fn unversionize(#unversionize_arg_name: Self::VersionedOwned) -> #result_type { #unversionize_body } } @@ -278,7 +284,7 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { impl #trait_impl_generics #versionize_slice_trait for #input_ident #ty_generics #versionize_trait_where_clause { - type VersionedSlice<#lifetime> = Vec<::Versioned<#lifetime>> #versioned_type_where_clause; + type VersionedSlice<#lifetime> = #vec_type<::Versioned<#lifetime>> #versioned_type_where_clause; fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> { slice.iter().map(|val| #versionize_trait::versionize(val)).collect() @@ -290,9 +296,9 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { #versionize_owned_trait_where_clause { - type VersionedVec = Vec<::VersionedOwned> #versioned_owned_type_where_clause; + type VersionedVec = #vec_type<::VersionedOwned> #versioned_owned_type_where_clause; - fn versionize_vec(vec: Vec) -> Self::VersionedVec { + fn versionize_vec(vec: #vec_type) -> Self::VersionedVec { vec.into_iter().map(|val| #versionize_owned_trait::versionize_owned(val)).collect() } } @@ -301,7 +307,7 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { impl #trait_impl_generics #unversionize_vec_trait for #input_ident #ty_generics #unversionize_trait_where_clause { - fn unversionize_vec(versioned: Self::VersionedVec) -> Result, #unversionize_error> { + fn unversionize_vec(versioned: Self::VersionedVec) -> #result_type<#vec_type, #unversionize_error> { versioned .into_iter() .map(|versioned| ::unversionize(versioned)) @@ -346,6 +352,8 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream { let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME); let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site()); + let result_type: Path = parse_const_str(RESULT_TYPE_NAME); + quote! { #[automatically_derived] impl #impl_generics #versionize_trait for #input_ident #ty_generics #versionize_where_clause { @@ -367,7 +375,7 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream { #[automatically_derived] impl #impl_generics #unversionize_trait for #input_ident #ty_generics #versionize_owned_where_clause { - fn unversionize(versioned: Self::VersionedOwned) -> Result { + fn unversionize(versioned: Self::VersionedOwned) -> #result_type { Ok(versioned) } }