From 5f9ac48dbea70ce30dcc9453c0ab2acf1ee39e65 Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Tue, 1 Apr 2025 10:02:22 +0200 Subject: [PATCH] feat(versionable): add `skip` attribute to skip field versioning --- .../tfhe-versionable-derive/src/associated.rs | 2 +- utils/tfhe-versionable-derive/src/lib.rs | 3 +- .../src/version_type.rs | 386 +++++++++++++----- .../src/versionize_attribute.rs | 57 ++- utils/tfhe-versionable/Cargo.toml | 4 + utils/tfhe-versionable/examples/skip.rs | 96 +++++ .../tfhe-versionable/examples/transparent.rs | 2 +- utils/tfhe-versionable/src/lib.rs | 13 + utils/tfhe-versionable/tests/skip_enum.rs | 88 ++++ 9 files changed, 533 insertions(+), 118 deletions(-) create mode 100644 utils/tfhe-versionable/examples/skip.rs create mode 100644 utils/tfhe-versionable/tests/skip_enum.rs diff --git a/utils/tfhe-versionable-derive/src/associated.rs b/utils/tfhe-versionable-derive/src/associated.rs index a07842e00..f31d95f5b 100644 --- a/utils/tfhe-versionable-derive/src/associated.rs +++ b/utils/tfhe-versionable-derive/src/associated.rs @@ -68,7 +68,7 @@ pub(crate) fn generate_try_from_trait_impl( }) } -/// The ownership kind of the data for a associated type. +/// The ownership kind of the data for an associated type. #[derive(Clone)] pub(crate) enum AssociatedTypeKind { /// This version type use references to non-Copy rust underlying built-in types. diff --git a/utils/tfhe-versionable-derive/src/lib.rs b/utils/tfhe-versionable-derive/src/lib.rs index e3dfbd92a..1c585ecca 100644 --- a/utils/tfhe-versionable-derive/src/lib.rs +++ b/utils/tfhe-versionable-derive/src/lib.rs @@ -52,6 +52,7 @@ pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into"; pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error"; pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync"; pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send"; +pub(crate) const DEFAULT_TRAIT_NAME: &str = "::core::default::Default"; pub(crate) const STATIC_LIFETIME_NAME: &str = "'static"; use associated::AssociatingTrait; @@ -71,7 +72,7 @@ macro_rules! syn_unwrap { }; } -#[proc_macro_derive(Version)] +#[proc_macro_derive(Version, attributes(versionize))] /// Implement the `Version` trait for the target type. pub fn derive_version(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); diff --git a/utils/tfhe-versionable-derive/src/version_type.rs b/utils/tfhe-versionable-derive/src/version_type.rs index dc4d745b9..42104a082 100644 --- a/utils/tfhe-versionable-derive/src/version_type.rs +++ b/utils/tfhe-versionable-derive/src/version_type.rs @@ -2,9 +2,7 @@ use std::iter::zip; use proc_macro2::{Literal, Span, TokenStream}; use quote::{format_ident, quote}; -use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::token::Comma; use syn::{ parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, Generics, Ident, Item, ItemEnum, ItemImpl, ItemStruct, ItemUnion, Lifetime, @@ -15,12 +13,12 @@ use crate::associated::{ generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind, ConversionDirection, }; -use crate::versionize_attribute::is_transparent; +use crate::versionize_attribute::{is_skipped, is_transparent, replace_versionize_skip_with_serde}; use crate::{ add_trait_where_clause, parse_const_str, parse_trait_bound, punctuated_from_iter_result, - INTO_TRAIT_NAME, LIFETIME_NAME, TRY_INTO_TRAIT_NAME, UNVERSIONIZE_ERROR_NAME, - UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, VERSIONIZE_TRAIT_NAME, - VERSION_TRAIT_NAME, + DEFAULT_TRAIT_NAME, INTO_TRAIT_NAME, LIFETIME_NAME, TRY_INTO_TRAIT_NAME, + UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, + VERSIONIZE_TRAIT_NAME, VERSION_TRAIT_NAME, }; /// The types generated for a specific version of a given exposed type. These types are identical to @@ -199,9 +197,9 @@ impl AssociatedType for VersionType { } fn inner_types(&self) -> syn::Result> { - self.orig_type_fields() - .iter() - .map(|field| Ok(&field.ty)) + self.orig_type_fields()? + .filter_map(filter_skipped_field) + .map(|field| Ok(&field?.ty)) .collect() } @@ -232,6 +230,14 @@ impl AssociatedType for VersionType { self.inner_types()?, &[UNVERSIONIZE_TRAIT_NAME], )?; + + // "skipped" types are not present in the Version types so we add a Default + // bound to be able to reconstruct them. + add_trait_where_clause( + &mut generics, + self.skipped_inner_types()?, + &[DEFAULT_TRAIT_NAME], + )?; } } } @@ -242,10 +248,18 @@ impl AssociatedType for VersionType { impl VersionType { /// Returns the fields of the original declaration. - fn orig_type_fields(&self) -> Punctuated<&Field, Comma> { + fn orig_type_fields(&self) -> syn::Result + '_>> { derive_type_fields(&self.orig_type) } + /// Returns the list of types inside the original type that are skipped + fn skipped_inner_types(&self) -> syn::Result> { + self.orig_type_fields()? + .filter_map(keep_skipped_field) + .map(|field| Ok(&field?.ty)) + .collect() + } + /// Generates the declaration for the Version equivalent of the input struct fn generate_struct(&self, stru: &DataStruct) -> syn::Result { let fields = match &stru.fields { @@ -313,14 +327,24 @@ impl VersionType { /// Converts an enum variant into its "Version" form fn convert_enum_variant(&self, variant: &Variant) -> syn::Result { - let fields = match &variant.fields { - Fields::Named(fields) => Fields::Named(self.convert_fields_named(fields)?), - Fields::Unnamed(fields) => Fields::Unnamed(self.convert_fields_unnamed(fields)?), - Fields::Unit => Fields::Unit, + let is_skipped = is_skipped(&variant.attrs)?; + let fields = if is_skipped { + // If the whole variant is skipped convert the variant to a unit. That way it still + // compiles but the user gets an error at the serialization step + Fields::Unit + } else { + match &variant.fields { + Fields::Named(fields) => Fields::Named(self.convert_fields_named(fields)?), + Fields::Unnamed(fields) => Fields::Unnamed(self.convert_fields_unnamed(fields)?), + Fields::Unit => Fields::Unit, + } }; + // Copy the attributes from the initial variant and remove the ones that were meant for us + let attrs = replace_versionize_skip_with_serde(&variant.attrs)?; + let versioned_variant = Variant { - attrs: Vec::new(), + attrs, ident: variant.ident.clone(), fields, discriminant: variant.discriminant.clone(), @@ -353,45 +377,49 @@ impl VersionType { let kind = self.kind.clone(); let is_transparent = self.is_transparent; - fields_iter.into_iter().map(move |field| { - let unver_ty = field.ty.clone(); + fields_iter + .into_iter() + .filter_map(filter_skipped_field) + .map(move |field| { + let field = field?; + let unver_ty = field.ty.clone(); - if is_transparent { - // If the type is transparent, we reuse the "Version" impl of the inner type - let version_trait = parse_trait_bound(VERSION_TRAIT_NAME)?; + if is_transparent { + // If the type is transparent, we reuse the "Version" impl of the inner type + let version_trait = parse_trait_bound(VERSION_TRAIT_NAME)?; - let ty: Type = match &kind { - AssociatedTypeKind::Ref(lifetime) => parse_quote! { - <#unver_ty as #version_trait>::Ref<#lifetime> - }, - AssociatedTypeKind::Owned => parse_quote! { - <#unver_ty as #version_trait>::Owned - }, - }; + let ty: Type = match &kind { + AssociatedTypeKind::Ref(lifetime) => parse_quote! { + <#unver_ty as #version_trait>::Ref<#lifetime> + }, + AssociatedTypeKind::Owned => parse_quote! { + <#unver_ty as #version_trait>::Owned + }, + }; - Ok(Field { - ty, - ..field.clone() - }) - } else { - let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?; - let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?; + Ok(Field { + ty, + ..field.clone() + }) + } else { + let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?; + let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?; - let ty: Type = match &kind { - AssociatedTypeKind::Ref(lifetime) => parse_quote! { - <#unver_ty as #versionize_trait>::Versioned<#lifetime> - }, - AssociatedTypeKind::Owned => parse_quote! { - <#unver_ty as #versionize_owned_trait>::VersionedOwned - }, - }; + let ty: Type = match &kind { + AssociatedTypeKind::Ref(lifetime) => parse_quote! { + <#unver_ty as #versionize_trait>::Versioned<#lifetime> + }, + AssociatedTypeKind::Owned => parse_quote! { + <#unver_ty as #versionize_owned_trait>::VersionedOwned + }, + }; - Ok(Field { - ty, - ..field.clone() - }) - } - }) + Ok(Field { + ty, + ..field.clone() + }) + } + }) } /// Generates the constructor part of the conversion impl block. This will create the dest type @@ -495,7 +523,10 @@ impl VersionType { variant: &Variant, direction: ConversionDirection, ) -> syn::Result { - let (param, fields) = match &variant.fields { + let is_skipped = is_skipped(&variant.attrs)?; + let variant_ident = &variant.ident; + + Ok(match &variant.fields { Fields::Named(fields) => { let args_iter = fields .named @@ -504,35 +535,47 @@ impl VersionType { .map(|field| field.ident.as_ref().unwrap()); let args = args_iter.clone(); - ( - quote! { { - #(#args),* - }}, - self.generate_constructor_enum_variants_named( + if is_skipped { + self.generate_constructor_skipped_enum_variants( + src_type, + variant_ident, + direction, + ) + } else { + let constructor = self.generate_constructor_enum_variants_named( args_iter.cloned(), fields.named.iter(), direction, - )?, - ) + )?; + quote! { + #src_type::#variant_ident {#(#args),*} => + Self::#variant_ident #constructor + } + } } Fields::Unnamed(fields) => { let args_iter = generate_args_list(fields.unnamed.len()); let args = args_iter.clone(); - ( - quote! { (#(#args),*) }, - self.generate_constructor_enum_variants_unnamed( + + if is_skipped { + self.generate_constructor_skipped_enum_variants( + src_type, + variant_ident, + direction, + ) + } else { + let constructor = self.generate_constructor_enum_variants_unnamed( args_iter, fields.unnamed.iter(), direction, - )?, - ) + )?; + quote! { + #src_type::#variant_ident (#(#args),*) => + Self::#variant_ident #constructor + } + } } - Fields::Unit => (TokenStream::new(), TokenStream::new()), - }; - let variant_ident = &variant.ident; - - Ok(quote! { - #src_type::#variant_ident #param => Self::#variant_ident #fields + Fields::Unit => quote! { #src_type::#variant_ident => Self::#variant_ident }, }) } @@ -545,7 +588,10 @@ impl VersionType { ) -> syn::Result { let fields: syn::Result> = fields .into_iter() - .map(move |field| self.generate_constructor_field_named(arg_name, field, direction)) + .filter_map(move |field| { + self.generate_constructor_field_named(arg_name, field, direction) + .transpose() + }) .collect(); let fields = fields?; @@ -562,7 +608,7 @@ impl VersionType { arg_name: &str, field: &Field, direction: ConversionDirection, - ) -> syn::Result { + ) -> syn::Result> { let arg_ident = Ident::new(arg_name, Span::call_site()); // Ok to unwrap because the field is named so field.ident is Some let field_ident = field.ident.as_ref().unwrap(); @@ -570,14 +616,23 @@ impl VersionType { let param = quote! { #arg_ident.#field_ident }; let rhs = if self.is_transparent() { - self.generate_constructor_transparent_rhs(param, direction)? + self.generate_constructor_transparent_rhs(param, direction) + .map(Some) } else { - self.generate_constructor_field_rhs(ty, param, false, direction)? - }; + self.generate_constructor_field_rhs( + ty, + param, + false, + is_skipped(&field.attrs)?, + direction, + ) + }?; - Ok(quote! { - #field_ident: #rhs - }) + Ok(rhs.map(|rhs| { + quote! { + #field_ident: #rhs + } + })) } /// Generates the constructor for the fields of a named enum variant. @@ -592,22 +647,32 @@ impl VersionType { direction: ConversionDirection, ) -> syn::Result { let fields: syn::Result> = zip(arg_names, fields) - .map(move |(arg_name, field)| { + .filter_map(move |(arg_name, field)| { // Ok to unwrap because the field is named so field.ident is Some let field_ident = field.ident.as_ref().unwrap(); + let rhs = if self.is_transparent() { - self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)? + Some(self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)) } else { + let skipped = match is_skipped(&field.attrs) { + Ok(skipped) => skipped, + Err(e) => return Some(Err(e)), + }; self.generate_constructor_field_rhs( &field.ty, quote! {#arg_name}, true, + skipped, direction, - )? - }; - Ok(quote! { - #field_ident: #rhs - }) + ) + .transpose() + }?; + + Some(rhs.map(|rhs| { + quote! { + #field_ident: #rhs + } + })) }) .collect(); let fields = fields?; @@ -629,8 +694,9 @@ impl VersionType { let fields: syn::Result> = fields .into_iter() .enumerate() - .map(move |(idx, field)| { + .filter_map(move |(idx, field)| { self.generate_constructor_field_unnamed(arg_name, field, idx, direction) + .transpose() }) .collect(); let fields = fields?; @@ -647,7 +713,7 @@ impl VersionType { field: &Field, idx: usize, direction: ConversionDirection, - ) -> syn::Result { + ) -> syn::Result> { let arg_ident = Ident::new(arg_name, Span::call_site()); let idx = Literal::usize_unsuffixed(idx); let ty = &field.ty; @@ -655,8 +721,15 @@ impl VersionType { if self.is_transparent { self.generate_constructor_transparent_rhs(param, direction) + .map(Some) } else { - self.generate_constructor_field_rhs(ty, param, false, direction) + self.generate_constructor_field_rhs( + ty, + param, + false, + is_skipped(&field.attrs)?, + direction, + ) } } @@ -672,16 +745,22 @@ impl VersionType { direction: ConversionDirection, ) -> syn::Result { let fields: syn::Result> = zip(arg_names, fields) - .map(move |(arg_name, field)| { - if self.is_transparent { - self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction) + .filter_map(move |(arg_name, field)| { + if self.is_transparent() { + Some(self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)) } else { + let skipped = match is_skipped(&field.attrs) { + Ok(skipped) => skipped, + Err(e) => return Some(Err(e)), + }; self.generate_constructor_field_rhs( &field.ty, quote! {#arg_name}, true, + skipped, direction, ) + .transpose() } }) .collect(); @@ -692,6 +771,35 @@ impl VersionType { }) } + /// Generates the constructor for a variant of an enum with the `skip` attribute. + /// + /// This constructor is never supposed to be called, but we need to handle it anyways. + /// + /// During a call to "versionize", the conversion will simply create a unit variant that will + /// trigger an error at the "serialize" step. During a call to "unversionize", the conversion + /// will raise an error. + fn generate_constructor_skipped_enum_variants( + &self, + src_type: &Ident, + variant_ident: &Ident, + direction: ConversionDirection, + ) -> TokenStream { + match direction { + ConversionDirection::OrigToAssociated => quote! { + #src_type::#variant_ident { .. } => + Self::#variant_ident + }, + ConversionDirection::AssociatedToOrig => { + let error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME); + let variant_name = format!("{}::{}", self.orig_type.ident, variant_ident); + + quote! { + #src_type::#variant_ident => return Err(#error::skipped_variant(#variant_name)) + } + } + } + } + /// Generates the rhs part of a field constructor. /// For example, in `Self { count: value.count.versionize() }`, this is /// `value.count.versionize()`. @@ -699,15 +807,21 @@ impl VersionType { &self, ty: &Type, field_param: TokenStream, - is_ref: bool, // True if the param is already a reference + is_ref: bool, // True if the param is already a reference + is_skipped: bool, // True if the field has the `skipped` attribute direction: ConversionDirection, - ) -> syn::Result { + ) -> syn::Result> { let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME); let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME); let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME); + let default_trait: Path = parse_const_str(DEFAULT_TRAIT_NAME); let field_constructor = match direction { ConversionDirection::OrigToAssociated => { + if is_skipped { + // Skipped fields does not exist in the associated type so we return None + return Ok(None); + } match self.kind { AssociatedTypeKind::Ref(_) => { let param = if is_ref { @@ -727,12 +841,21 @@ impl VersionType { ConversionDirection::AssociatedToOrig => match self.kind { AssociatedTypeKind::Ref(_) => panic!("No conversion should be generated between associated ref type to original type"), - AssociatedTypeKind::Owned => quote! { - <#ty as #unversionize_trait>::unversionize(#field_param)? + AssociatedTypeKind::Owned => { + if is_skipped { + // If the field is skipped, we try to construct it from a Default impl (this is what serde does) + quote! { + <#ty as #default_trait>::default() + } + } else { + quote! { + <#ty as #unversionize_trait>::unversionize(#field_param)? + } + } }, }, }; - Ok(field_constructor) + Ok(Some(field_constructor)) } fn generate_constructor_transparent_rhs( @@ -788,23 +911,64 @@ fn is_unit(input: &DeriveInput) -> bool { /// Returns the fields of the input type. This is independent of the kind of type /// (enum, struct, ...) -fn derive_type_fields(input: &DeriveInput) -> Punctuated<&Field, Comma> { - match &input.data { - Data::Struct(stru) => match &stru.fields { - Fields::Named(fields) => Punctuated::from_iter(fields.named.iter()), - Fields::Unnamed(fields) => Punctuated::from_iter(fields.unnamed.iter()), - Fields::Unit => Punctuated::new(), - }, - Data::Enum(enu) => Punctuated::<&Field, Comma>::from_iter( - enu.variants +/// +/// In the case of an enum, the fields for each variants that are not skipped are flattened into a +/// single iterator. +fn derive_type_fields(input: &DeriveInput) -> syn::Result + '_>> { + Ok(match &input.data { + Data::Struct(stru) => Box::new(iter_fields(&stru.fields)), + Data::Enum(enu) => { + let filtered: Result, syn::Error> = enu + .variants .iter() - .filter_map(|variant| match &variant.fields { - Fields::Named(fields) => Some(fields.named.iter()), - Fields::Unnamed(fields) => Some(fields.unnamed.iter()), - Fields::Unit => None, - }) - .flatten(), - ), - Data::Union(uni) => Punctuated::from_iter(uni.fields.named.iter()), + .filter_map(filter_skipped_variant) + .collect(); + + Box::new( + filtered? + .into_iter() + .flat_map(|variant| iter_fields(&variant.fields)), + ) + } + Data::Union(uni) => Box::new(uni.fields.named.iter()), + }) +} + +/// Returns an iterator over the `Field`s in a `Fields` regardless of the fields type (named, +/// unnamed or unit). +fn iter_fields(fields: &Fields) -> Box + '_> { + match fields { + Fields::Named(fields) => Box::new(fields.named.iter()), + Fields::Unnamed(fields) => Box::new(fields.unnamed.iter()), + Fields::Unit => Box::new(std::iter::empty()), + } +} + +/// Can be used inside a field iterator to remove the fields with a `#[versionize(skip)]` attribute +fn filter_skipped_field(field: &Field) -> Option> { + match is_skipped(&field.attrs) { + Ok(true) => None, + Ok(false) => Some(Ok(field)), + Err(e) => Some(Err(e)), + } +} + +/// Can be used inside a field iterator to only keep the fields with a `#[versionize(skip)]` +/// attribute +fn keep_skipped_field(field: &Field) -> Option> { + match is_skipped(&field.attrs) { + Ok(true) => Some(Ok(field)), + Ok(false) => None, + Err(e) => Some(Err(e)), + } +} + +/// Can be used inside a variant iterator to remove the variants with a `#[versionize(skip)]` +/// attribute +fn filter_skipped_variant(variant: &Variant) -> Option> { + match is_skipped(&variant.attrs) { + Ok(true) => None, + Ok(false) => Some(Ok(variant)), + Err(e) => Some(Err(e)), } } diff --git a/utils/tfhe-versionable-derive/src/versionize_attribute.rs b/utils/tfhe-versionable-derive/src/versionize_attribute.rs index e89f72da5..e2e0d5e03 100644 --- a/utils/tfhe-versionable-derive/src/versionize_attribute.rs +++ b/utils/tfhe-versionable-derive/src/versionize_attribute.rs @@ -5,10 +5,13 @@ use proc_macro2::Span; use quote::ToTokens; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::{Attribute, Expr, Lit, Meta, Path, Token}; +use syn::{parse_quote, Attribute, Expr, Lit, Meta, Path, Token}; /// Name of the attribute used to give arguments to the `Versionize` macro -const VERSIONIZE_ATTR_NAME: &str = "versionize"; +pub(crate) const VERSIONIZE_ATTR_NAME: &str = "versionize"; + +/// Name of the attribute used to give arguments to serde macros +pub(crate) const SERDE_ATTR_NAME: &str = "serde"; /// Transparent mode can also be activated using `#[repr(transparent)]` pub(crate) const REPR_ATTR_NAME: &str = "repr"; @@ -297,11 +300,12 @@ fn parse_path_ignore_quotes(value: &Expr) -> syn::Result { } } -/// Check if the target type has the `#[repr(transparent)]` attribute in its attributes list +/// Check if the target type has the `#[repr(transparent)]` or `#[serde(transparent)]` attribute in +/// its attributes list pub(crate) fn is_transparent(attributes: &[Attribute]) -> syn::Result { if let Some(attr) = attributes .iter() - .find(|attr| attr.path().is_ident(REPR_ATTR_NAME)) + .find(|attr| attr.path().is_ident(REPR_ATTR_NAME) || attr.path().is_ident(SERDE_ATTR_NAME)) { let nested = attr.parse_args_with(Punctuated::::parse_terminated)?; @@ -316,3 +320,48 @@ pub(crate) fn is_transparent(attributes: &[Attribute]) -> syn::Result { Ok(false) } + +/// Check if a field has the `#[serde(skip)]` or `#[versionize(skip)]` attribute in +/// its attributes list +pub(crate) fn is_skipped(attributes: &[Attribute]) -> syn::Result { + if let Some(attr) = attributes.iter().find(|attr| { + attr.path().is_ident(VERSIONIZE_ATTR_NAME) || attr.path().is_ident(SERDE_ATTR_NAME) + }) { + let nested = attr.parse_args_with(Punctuated::::parse_terminated)?; + + for meta in nested.iter() { + if let Meta::Path(path) = meta { + if path.is_ident("skip") { + return Ok(true); + } + } + } + } + + Ok(false) +} + +/// Replace `#[versionize(skip)]` with `#[serde(skip)]` in an attributes list +pub(crate) fn replace_versionize_skip_with_serde( + attributes: &[Attribute], +) -> syn::Result> { + attributes + .iter() + .cloned() + .map(|attr| { + if attr.path().is_ident(VERSIONIZE_ATTR_NAME) { + let nested = + attr.parse_args_with(Punctuated::::parse_terminated)?; + + for meta in nested.iter() { + if let Meta::Path(path) = meta { + if path.is_ident("skip") { + return Ok(parse_quote! { #[serde(skip)] }); + } + } + } + } + Ok(attr) + }) + .collect() +} diff --git a/utils/tfhe-versionable/Cargo.toml b/utils/tfhe-versionable/Cargo.toml index fca83857d..a877ac1cc 100644 --- a/utils/tfhe-versionable/Cargo.toml +++ b/utils/tfhe-versionable/Cargo.toml @@ -69,3 +69,7 @@ test = true [[example]] name = "associated_bounds" test = true + +[[example]] +name = "skip" +test = true diff --git a/utils/tfhe-versionable/examples/skip.rs b/utils/tfhe-versionable/examples/skip.rs new file mode 100644 index 000000000..ac6c44204 --- /dev/null +++ b/utils/tfhe-versionable/examples/skip.rs @@ -0,0 +1,96 @@ +//! Example of a struct that include a field that is not versionable and is skipped during +//! versioning. +//! +//! This is similar to the `#[serde(skip)]` attribute of Serde. +//! With this attribute, the field is not included in the definition of the associated Version type. +//! During unversioning, the field is instantiated using a `Default` impl. + +use std::convert::Infallible; +use std::io::Cursor; + +use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch}; + +// This type is not versionable/serializable and should not be stored. +// It should however at least implement Default to be instantiable when the data is loaded. +#[derive(Default)] +struct NotVersionable(u64); + +/// The previous version of our application +mod v0 { + use super::NotVersionable; + use tfhe_versionable::{Versionize, VersionsDispatch}; + + #[derive(Versionize)] + #[versionize(MyStructVersions)] + pub(super) struct MyStruct { + pub(super) val: u32, + // This attribute is used to skip the versioning of the field + // Also work with `#[serde(skip)]` if the field derives Serialize + #[versionize(skip)] + #[allow(dead_code)] + pub(super) to_skip: NotVersionable, + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub(super) enum MyStructVersions { + V0(MyStruct), + } +} + +#[derive(Version)] +struct MyStructV0 { + val: u32, + #[versionize(skip)] + to_skip: NotVersionable, +} + +#[derive(Versionize)] +#[versionize(MyStructVersions)] +struct MyStruct { + val: u64, + #[versionize(skip)] + to_skip: NotVersionable, +} + +impl Upgrade for MyStructV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + let val = self.val as u64; + + Ok(MyStruct { + val, + to_skip: self.to_skip, + }) + } +} + +#[derive(VersionsDispatch)] +#[allow(unused)] +enum MyStructVersions { + V0(MyStructV0), + V1(MyStruct), +} + +fn main() { + let val = 64; + let stru_v0 = v0::MyStruct { + val, + to_skip: NotVersionable(42), // The value will be lost during serialization + }; + + let mut ser = Vec::new(); + ciborium::ser::into_writer(&stru_v0.versionize(), &mut ser).unwrap(); + + let unvers = + MyStruct::unversionize(ciborium::de::from_reader(&mut Cursor::new(&ser)).unwrap()).unwrap(); + + assert_eq!(unvers.val, val as u64); + assert_eq!(unvers.to_skip.0, Default::default()); +} + +#[test] +fn test() { + main() +} diff --git a/utils/tfhe-versionable/examples/transparent.rs b/utils/tfhe-versionable/examples/transparent.rs index f2869dc9c..2e88e1198 100644 --- a/utils/tfhe-versionable/examples/transparent.rs +++ b/utils/tfhe-versionable/examples/transparent.rs @@ -11,7 +11,7 @@ use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispa // // struct MyStructWrapper { inner: MyStruct }; #[derive(Versionize)] -#[versionize(transparent)] // Also works with `#[repr(transparent)]` +#[versionize(transparent)] // Also works with `#[repr(transparent)]` or `#[serde(transparent)]` struct MyStructWrapper(MyStruct); // The inner struct that is versioned. diff --git a/utils/tfhe-versionable/src/lib.rs b/utils/tfhe-versionable/src/lib.rs index e7ee2d3a3..26473c4de 100644 --- a/utils/tfhe-versionable/src/lib.rs +++ b/utils/tfhe-versionable/src/lib.rs @@ -94,6 +94,9 @@ pub enum UnversionizeError { /// A deprecated version has been found DeprecatedVersion(DeprecatedVersionError), + + /// User tried to unversionize an enum variant with the `#[versionize(skip)]` attribute + SkippedVariant { variant_name: String }, } impl Display for UnversionizeError { @@ -120,6 +123,9 @@ impl Display for UnversionizeError { ) } Self::DeprecatedVersion(deprecation_error) => deprecation_error.fmt(f), + Self::SkippedVariant { variant_name } => write!(f, + "Enum variant {variant_name} is marked with the `skip` attribute and cannot be unversioned" + ), } } } @@ -131,6 +137,7 @@ impl Error for UnversionizeError { UnversionizeError::Conversion { source, .. } => Some(source.as_ref()), UnversionizeError::ArrayLength { .. } => None, UnversionizeError::DeprecatedVersion(_) => None, + UnversionizeError::SkippedVariant { .. } => None, } } } @@ -154,6 +161,12 @@ impl UnversionizeError { source: Box::new(source), } } + + pub fn skipped_variant(variant_name: &str) -> Self { + Self::SkippedVariant { + variant_name: variant_name.to_string(), + } + } } impl From for UnversionizeError { diff --git a/utils/tfhe-versionable/tests/skip_enum.rs b/utils/tfhe-versionable/tests/skip_enum.rs new file mode 100644 index 000000000..d04c68cca --- /dev/null +++ b/utils/tfhe-versionable/tests/skip_enum.rs @@ -0,0 +1,88 @@ +//! Test the skip attribute in an enum. This attribute in a struct is already tested in +//! `examples/skip.rs` + +use std::convert::Infallible; +use std::io::Cursor; + +use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch}; + +#[allow(dead_code)] +struct NotVersionable(u64); + +mod v0 { + use super::NotVersionable; + use tfhe_versionable::{Versionize, VersionsDispatch}; + + #[derive(Versionize)] + #[versionize(MyEnumVersions)] + pub(super) enum MyEnum { + Var0(u32), + #[versionize(skip)] + #[allow(dead_code)] + Var1(NotVersionable), + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub(super) enum MyEnumVersions { + V0(MyEnum), + } +} + +#[derive(Version)] +enum MyEnumV0 { + Var0(u32), + #[versionize(skip)] + #[allow(dead_code)] + Var1(NotVersionable), +} + +#[derive(Versionize)] +#[versionize(MyEnumVersions)] +enum MyEnum { + Var0(u64), + #[versionize(skip)] + #[allow(dead_code)] + Var1(NotVersionable), +} + +impl Upgrade for MyEnumV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + match self { + MyEnumV0::Var0(val) => Ok(MyEnum::Var0(val as u64)), + MyEnumV0::Var1(val) => Ok(MyEnum::Var1(val)), + } + } +} + +#[derive(VersionsDispatch)] +#[allow(unused)] +enum MyEnumVersions { + V0(MyEnumV0), + V1(MyEnum), +} + +#[test] +fn test() { + // Test the "normal" variant + let val = 64; + let enu_v0 = v0::MyEnum::Var0(val); + + let mut ser = Vec::new(); + ciborium::ser::into_writer(&enu_v0.versionize(), &mut ser).unwrap(); + + let unvers = + MyEnum::unversionize(ciborium::de::from_reader(&mut Cursor::new(&ser)).unwrap()).unwrap(); + + assert!(matches!(unvers, MyEnum::Var0(unvers_val) if unvers_val == val as u64)); + + // Test the skipped variant + let val = 64; + let enu_v0 = v0::MyEnum::Var1(NotVersionable(val)); + + let mut ser = Vec::new(); + // Serialization of the skipped variant must fail + assert!(ciborium::ser::into_writer(&enu_v0.versionize(), &mut ser).is_err()); +}