feat(versionable): add skip attribute to skip field versioning

This commit is contained in:
Nicolas Sarlin
2025-04-01 10:02:22 +02:00
committed by Nicolas Sarlin
parent e57b91eccd
commit 5f9ac48dbe
9 changed files with 533 additions and 118 deletions

View File

@@ -68,7 +68,7 @@ pub(crate) fn generate_try_from_trait_impl(
}) })
} }
/// The ownership kind of the data for a associated type. /// The ownership kind of the data for an associated type.
#[derive(Clone)] #[derive(Clone)]
pub(crate) enum AssociatedTypeKind { pub(crate) enum AssociatedTypeKind {
/// This version type use references to non-Copy rust underlying built-in types. /// This version type use references to non-Copy rust underlying built-in types.

View File

@@ -52,6 +52,7 @@ pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into";
pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error"; pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error";
pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync"; pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync";
pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send"; pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send";
pub(crate) const DEFAULT_TRAIT_NAME: &str = "::core::default::Default";
pub(crate) const STATIC_LIFETIME_NAME: &str = "'static"; pub(crate) const STATIC_LIFETIME_NAME: &str = "'static";
use associated::AssociatingTrait; use associated::AssociatingTrait;
@@ -71,7 +72,7 @@ macro_rules! syn_unwrap {
}; };
} }
#[proc_macro_derive(Version)] #[proc_macro_derive(Version, attributes(versionize))]
/// Implement the `Version` trait for the target type. /// Implement the `Version` trait for the target type.
pub fn derive_version(input: TokenStream) -> TokenStream { pub fn derive_version(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);

View File

@@ -2,9 +2,7 @@ use std::iter::zip;
use proc_macro2::{Literal, Span, TokenStream}; use proc_macro2::{Literal, Span, TokenStream};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned; use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{ use syn::{
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, FieldsNamed, parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, FieldsNamed,
FieldsUnnamed, Generics, Ident, Item, ItemEnum, ItemImpl, ItemStruct, ItemUnion, Lifetime, FieldsUnnamed, Generics, Ident, Item, ItemEnum, ItemImpl, ItemStruct, ItemUnion, Lifetime,
@@ -15,12 +13,12 @@ use crate::associated::{
generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind, generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind,
ConversionDirection, ConversionDirection,
}; };
use crate::versionize_attribute::is_transparent; use crate::versionize_attribute::{is_skipped, is_transparent, replace_versionize_skip_with_serde};
use crate::{ use crate::{
add_trait_where_clause, parse_const_str, parse_trait_bound, punctuated_from_iter_result, add_trait_where_clause, parse_const_str, parse_trait_bound, punctuated_from_iter_result,
INTO_TRAIT_NAME, LIFETIME_NAME, TRY_INTO_TRAIT_NAME, UNVERSIONIZE_ERROR_NAME, DEFAULT_TRAIT_NAME, INTO_TRAIT_NAME, LIFETIME_NAME, TRY_INTO_TRAIT_NAME,
UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, VERSIONIZE_TRAIT_NAME, UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME,
VERSION_TRAIT_NAME, VERSIONIZE_TRAIT_NAME, VERSION_TRAIT_NAME,
}; };
/// The types generated for a specific version of a given exposed type. These types are identical to /// The types generated for a specific version of a given exposed type. These types are identical to
@@ -199,9 +197,9 @@ impl AssociatedType for VersionType {
} }
fn inner_types(&self) -> syn::Result<Vec<&Type>> { fn inner_types(&self) -> syn::Result<Vec<&Type>> {
self.orig_type_fields() self.orig_type_fields()?
.iter() .filter_map(filter_skipped_field)
.map(|field| Ok(&field.ty)) .map(|field| Ok(&field?.ty))
.collect() .collect()
} }
@@ -232,6 +230,14 @@ impl AssociatedType for VersionType {
self.inner_types()?, self.inner_types()?,
&[UNVERSIONIZE_TRAIT_NAME], &[UNVERSIONIZE_TRAIT_NAME],
)?; )?;
// "skipped" types are not present in the Version types so we add a Default
// bound to be able to reconstruct them.
add_trait_where_clause(
&mut generics,
self.skipped_inner_types()?,
&[DEFAULT_TRAIT_NAME],
)?;
} }
} }
} }
@@ -242,10 +248,18 @@ impl AssociatedType for VersionType {
impl VersionType { impl VersionType {
/// Returns the fields of the original declaration. /// Returns the fields of the original declaration.
fn orig_type_fields(&self) -> Punctuated<&Field, Comma> { fn orig_type_fields(&self) -> syn::Result<Box<dyn Iterator<Item = &Field> + '_>> {
derive_type_fields(&self.orig_type) derive_type_fields(&self.orig_type)
} }
/// Returns the list of types inside the original type that are skipped
fn skipped_inner_types(&self) -> syn::Result<Vec<&Type>> {
self.orig_type_fields()?
.filter_map(keep_skipped_field)
.map(|field| Ok(&field?.ty))
.collect()
}
/// Generates the declaration for the Version equivalent of the input struct /// Generates the declaration for the Version equivalent of the input struct
fn generate_struct(&self, stru: &DataStruct) -> syn::Result<ItemStruct> { fn generate_struct(&self, stru: &DataStruct) -> syn::Result<ItemStruct> {
let fields = match &stru.fields { let fields = match &stru.fields {
@@ -313,14 +327,24 @@ impl VersionType {
/// Converts an enum variant into its "Version" form /// Converts an enum variant into its "Version" form
fn convert_enum_variant(&self, variant: &Variant) -> syn::Result<Variant> { fn convert_enum_variant(&self, variant: &Variant) -> syn::Result<Variant> {
let fields = match &variant.fields { let is_skipped = is_skipped(&variant.attrs)?;
Fields::Named(fields) => Fields::Named(self.convert_fields_named(fields)?), let fields = if is_skipped {
Fields::Unnamed(fields) => Fields::Unnamed(self.convert_fields_unnamed(fields)?), // If the whole variant is skipped convert the variant to a unit. That way it still
Fields::Unit => Fields::Unit, // compiles but the user gets an error at the serialization step
Fields::Unit
} else {
match &variant.fields {
Fields::Named(fields) => Fields::Named(self.convert_fields_named(fields)?),
Fields::Unnamed(fields) => Fields::Unnamed(self.convert_fields_unnamed(fields)?),
Fields::Unit => Fields::Unit,
}
}; };
// Copy the attributes from the initial variant and remove the ones that were meant for us
let attrs = replace_versionize_skip_with_serde(&variant.attrs)?;
let versioned_variant = Variant { let versioned_variant = Variant {
attrs: Vec::new(), attrs,
ident: variant.ident.clone(), ident: variant.ident.clone(),
fields, fields,
discriminant: variant.discriminant.clone(), discriminant: variant.discriminant.clone(),
@@ -353,45 +377,49 @@ impl VersionType {
let kind = self.kind.clone(); let kind = self.kind.clone();
let is_transparent = self.is_transparent; let is_transparent = self.is_transparent;
fields_iter.into_iter().map(move |field| { fields_iter
let unver_ty = field.ty.clone(); .into_iter()
.filter_map(filter_skipped_field)
.map(move |field| {
let field = field?;
let unver_ty = field.ty.clone();
if is_transparent { if is_transparent {
// If the type is transparent, we reuse the "Version" impl of the inner type // If the type is transparent, we reuse the "Version" impl of the inner type
let version_trait = parse_trait_bound(VERSION_TRAIT_NAME)?; let version_trait = parse_trait_bound(VERSION_TRAIT_NAME)?;
let ty: Type = match &kind { let ty: Type = match &kind {
AssociatedTypeKind::Ref(lifetime) => parse_quote! { AssociatedTypeKind::Ref(lifetime) => parse_quote! {
<#unver_ty as #version_trait>::Ref<#lifetime> <#unver_ty as #version_trait>::Ref<#lifetime>
}, },
AssociatedTypeKind::Owned => parse_quote! { AssociatedTypeKind::Owned => parse_quote! {
<#unver_ty as #version_trait>::Owned <#unver_ty as #version_trait>::Owned
}, },
}; };
Ok(Field { Ok(Field {
ty, ty,
..field.clone() ..field.clone()
}) })
} else { } else {
let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?; let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?;
let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?; let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?;
let ty: Type = match &kind { let ty: Type = match &kind {
AssociatedTypeKind::Ref(lifetime) => parse_quote! { AssociatedTypeKind::Ref(lifetime) => parse_quote! {
<#unver_ty as #versionize_trait>::Versioned<#lifetime> <#unver_ty as #versionize_trait>::Versioned<#lifetime>
}, },
AssociatedTypeKind::Owned => parse_quote! { AssociatedTypeKind::Owned => parse_quote! {
<#unver_ty as #versionize_owned_trait>::VersionedOwned <#unver_ty as #versionize_owned_trait>::VersionedOwned
}, },
}; };
Ok(Field { Ok(Field {
ty, ty,
..field.clone() ..field.clone()
}) })
} }
}) })
} }
/// Generates the constructor part of the conversion impl block. This will create the dest type /// Generates the constructor part of the conversion impl block. This will create the dest type
@@ -495,7 +523,10 @@ impl VersionType {
variant: &Variant, variant: &Variant,
direction: ConversionDirection, direction: ConversionDirection,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let (param, fields) = match &variant.fields { let is_skipped = is_skipped(&variant.attrs)?;
let variant_ident = &variant.ident;
Ok(match &variant.fields {
Fields::Named(fields) => { Fields::Named(fields) => {
let args_iter = fields let args_iter = fields
.named .named
@@ -504,35 +535,47 @@ impl VersionType {
.map(|field| field.ident.as_ref().unwrap()); .map(|field| field.ident.as_ref().unwrap());
let args = args_iter.clone(); let args = args_iter.clone();
( if is_skipped {
quote! { { self.generate_constructor_skipped_enum_variants(
#(#args),* src_type,
}}, variant_ident,
self.generate_constructor_enum_variants_named( direction,
)
} else {
let constructor = self.generate_constructor_enum_variants_named(
args_iter.cloned(), args_iter.cloned(),
fields.named.iter(), fields.named.iter(),
direction, direction,
)?, )?;
) quote! {
#src_type::#variant_ident {#(#args),*} =>
Self::#variant_ident #constructor
}
}
} }
Fields::Unnamed(fields) => { Fields::Unnamed(fields) => {
let args_iter = generate_args_list(fields.unnamed.len()); let args_iter = generate_args_list(fields.unnamed.len());
let args = args_iter.clone(); let args = args_iter.clone();
(
quote! { (#(#args),*) }, if is_skipped {
self.generate_constructor_enum_variants_unnamed( self.generate_constructor_skipped_enum_variants(
src_type,
variant_ident,
direction,
)
} else {
let constructor = self.generate_constructor_enum_variants_unnamed(
args_iter, args_iter,
fields.unnamed.iter(), fields.unnamed.iter(),
direction, direction,
)?, )?;
) quote! {
#src_type::#variant_ident (#(#args),*) =>
Self::#variant_ident #constructor
}
}
} }
Fields::Unit => (TokenStream::new(), TokenStream::new()), Fields::Unit => quote! { #src_type::#variant_ident => Self::#variant_ident },
};
let variant_ident = &variant.ident;
Ok(quote! {
#src_type::#variant_ident #param => Self::#variant_ident #fields
}) })
} }
@@ -545,7 +588,10 @@ impl VersionType {
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let fields: syn::Result<Vec<TokenStream>> = fields let fields: syn::Result<Vec<TokenStream>> = fields
.into_iter() .into_iter()
.map(move |field| self.generate_constructor_field_named(arg_name, field, direction)) .filter_map(move |field| {
self.generate_constructor_field_named(arg_name, field, direction)
.transpose()
})
.collect(); .collect();
let fields = fields?; let fields = fields?;
@@ -562,7 +608,7 @@ impl VersionType {
arg_name: &str, arg_name: &str,
field: &Field, field: &Field,
direction: ConversionDirection, direction: ConversionDirection,
) -> syn::Result<TokenStream> { ) -> syn::Result<Option<TokenStream>> {
let arg_ident = Ident::new(arg_name, Span::call_site()); let arg_ident = Ident::new(arg_name, Span::call_site());
// Ok to unwrap because the field is named so field.ident is Some // Ok to unwrap because the field is named so field.ident is Some
let field_ident = field.ident.as_ref().unwrap(); let field_ident = field.ident.as_ref().unwrap();
@@ -570,14 +616,23 @@ impl VersionType {
let param = quote! { #arg_ident.#field_ident }; let param = quote! { #arg_ident.#field_ident };
let rhs = if self.is_transparent() { let rhs = if self.is_transparent() {
self.generate_constructor_transparent_rhs(param, direction)? self.generate_constructor_transparent_rhs(param, direction)
.map(Some)
} else { } else {
self.generate_constructor_field_rhs(ty, param, false, direction)? self.generate_constructor_field_rhs(
}; ty,
param,
false,
is_skipped(&field.attrs)?,
direction,
)
}?;
Ok(quote! { Ok(rhs.map(|rhs| {
#field_ident: #rhs quote! {
}) #field_ident: #rhs
}
}))
} }
/// Generates the constructor for the fields of a named enum variant. /// Generates the constructor for the fields of a named enum variant.
@@ -592,22 +647,32 @@ impl VersionType {
direction: ConversionDirection, direction: ConversionDirection,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields) let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields)
.map(move |(arg_name, field)| { .filter_map(move |(arg_name, field)| {
// Ok to unwrap because the field is named so field.ident is Some // Ok to unwrap because the field is named so field.ident is Some
let field_ident = field.ident.as_ref().unwrap(); let field_ident = field.ident.as_ref().unwrap();
let rhs = if self.is_transparent() { let rhs = if self.is_transparent() {
self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)? Some(self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction))
} else { } else {
let skipped = match is_skipped(&field.attrs) {
Ok(skipped) => skipped,
Err(e) => return Some(Err(e)),
};
self.generate_constructor_field_rhs( self.generate_constructor_field_rhs(
&field.ty, &field.ty,
quote! {#arg_name}, quote! {#arg_name},
true, true,
skipped,
direction, direction,
)? )
}; .transpose()
Ok(quote! { }?;
#field_ident: #rhs
}) Some(rhs.map(|rhs| {
quote! {
#field_ident: #rhs
}
}))
}) })
.collect(); .collect();
let fields = fields?; let fields = fields?;
@@ -629,8 +694,9 @@ impl VersionType {
let fields: syn::Result<Vec<TokenStream>> = fields let fields: syn::Result<Vec<TokenStream>> = fields
.into_iter() .into_iter()
.enumerate() .enumerate()
.map(move |(idx, field)| { .filter_map(move |(idx, field)| {
self.generate_constructor_field_unnamed(arg_name, field, idx, direction) self.generate_constructor_field_unnamed(arg_name, field, idx, direction)
.transpose()
}) })
.collect(); .collect();
let fields = fields?; let fields = fields?;
@@ -647,7 +713,7 @@ impl VersionType {
field: &Field, field: &Field,
idx: usize, idx: usize,
direction: ConversionDirection, direction: ConversionDirection,
) -> syn::Result<TokenStream> { ) -> syn::Result<Option<TokenStream>> {
let arg_ident = Ident::new(arg_name, Span::call_site()); let arg_ident = Ident::new(arg_name, Span::call_site());
let idx = Literal::usize_unsuffixed(idx); let idx = Literal::usize_unsuffixed(idx);
let ty = &field.ty; let ty = &field.ty;
@@ -655,8 +721,15 @@ impl VersionType {
if self.is_transparent { if self.is_transparent {
self.generate_constructor_transparent_rhs(param, direction) self.generate_constructor_transparent_rhs(param, direction)
.map(Some)
} else { } else {
self.generate_constructor_field_rhs(ty, param, false, direction) self.generate_constructor_field_rhs(
ty,
param,
false,
is_skipped(&field.attrs)?,
direction,
)
} }
} }
@@ -672,16 +745,22 @@ impl VersionType {
direction: ConversionDirection, direction: ConversionDirection,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields) let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields)
.map(move |(arg_name, field)| { .filter_map(move |(arg_name, field)| {
if self.is_transparent { if self.is_transparent() {
self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction) Some(self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction))
} else { } else {
let skipped = match is_skipped(&field.attrs) {
Ok(skipped) => skipped,
Err(e) => return Some(Err(e)),
};
self.generate_constructor_field_rhs( self.generate_constructor_field_rhs(
&field.ty, &field.ty,
quote! {#arg_name}, quote! {#arg_name},
true, true,
skipped,
direction, direction,
) )
.transpose()
} }
}) })
.collect(); .collect();
@@ -692,6 +771,35 @@ impl VersionType {
}) })
} }
/// Generates the constructor for a variant of an enum with the `skip` attribute.
///
/// This constructor is never supposed to be called, but we need to handle it anyways.
///
/// During a call to "versionize", the conversion will simply create a unit variant that will
/// trigger an error at the "serialize" step. During a call to "unversionize", the conversion
/// will raise an error.
fn generate_constructor_skipped_enum_variants(
&self,
src_type: &Ident,
variant_ident: &Ident,
direction: ConversionDirection,
) -> TokenStream {
match direction {
ConversionDirection::OrigToAssociated => quote! {
#src_type::#variant_ident { .. } =>
Self::#variant_ident
},
ConversionDirection::AssociatedToOrig => {
let error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
let variant_name = format!("{}::{}", self.orig_type.ident, variant_ident);
quote! {
#src_type::#variant_ident => return Err(#error::skipped_variant(#variant_name))
}
}
}
}
/// Generates the rhs part of a field constructor. /// Generates the rhs part of a field constructor.
/// For example, in `Self { count: value.count.versionize() }`, this is /// For example, in `Self { count: value.count.versionize() }`, this is
/// `value.count.versionize()`. /// `value.count.versionize()`.
@@ -699,15 +807,21 @@ impl VersionType {
&self, &self,
ty: &Type, ty: &Type,
field_param: TokenStream, field_param: TokenStream,
is_ref: bool, // True if the param is already a reference is_ref: bool, // True if the param is already a reference
is_skipped: bool, // True if the field has the `skipped` attribute
direction: ConversionDirection, direction: ConversionDirection,
) -> syn::Result<TokenStream> { ) -> syn::Result<Option<TokenStream>> {
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME); let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME); let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME); let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
let default_trait: Path = parse_const_str(DEFAULT_TRAIT_NAME);
let field_constructor = match direction { let field_constructor = match direction {
ConversionDirection::OrigToAssociated => { ConversionDirection::OrigToAssociated => {
if is_skipped {
// Skipped fields does not exist in the associated type so we return None
return Ok(None);
}
match self.kind { match self.kind {
AssociatedTypeKind::Ref(_) => { AssociatedTypeKind::Ref(_) => {
let param = if is_ref { let param = if is_ref {
@@ -727,12 +841,21 @@ impl VersionType {
ConversionDirection::AssociatedToOrig => match self.kind { ConversionDirection::AssociatedToOrig => match self.kind {
AssociatedTypeKind::Ref(_) => AssociatedTypeKind::Ref(_) =>
panic!("No conversion should be generated between associated ref type to original type"), panic!("No conversion should be generated between associated ref type to original type"),
AssociatedTypeKind::Owned => quote! { AssociatedTypeKind::Owned => {
<#ty as #unversionize_trait>::unversionize(#field_param)? if is_skipped {
// If the field is skipped, we try to construct it from a Default impl (this is what serde does)
quote! {
<#ty as #default_trait>::default()
}
} else {
quote! {
<#ty as #unversionize_trait>::unversionize(#field_param)?
}
}
}, },
}, },
}; };
Ok(field_constructor) Ok(Some(field_constructor))
} }
fn generate_constructor_transparent_rhs( fn generate_constructor_transparent_rhs(
@@ -788,23 +911,64 @@ fn is_unit(input: &DeriveInput) -> bool {
/// Returns the fields of the input type. This is independent of the kind of type /// Returns the fields of the input type. This is independent of the kind of type
/// (enum, struct, ...) /// (enum, struct, ...)
fn derive_type_fields(input: &DeriveInput) -> Punctuated<&Field, Comma> { ///
match &input.data { /// In the case of an enum, the fields for each variants that are not skipped are flattened into a
Data::Struct(stru) => match &stru.fields { /// single iterator.
Fields::Named(fields) => Punctuated::from_iter(fields.named.iter()), fn derive_type_fields(input: &DeriveInput) -> syn::Result<Box<dyn Iterator<Item = &Field> + '_>> {
Fields::Unnamed(fields) => Punctuated::from_iter(fields.unnamed.iter()), Ok(match &input.data {
Fields::Unit => Punctuated::new(), Data::Struct(stru) => Box::new(iter_fields(&stru.fields)),
}, Data::Enum(enu) => {
Data::Enum(enu) => Punctuated::<&Field, Comma>::from_iter( let filtered: Result<Vec<&Variant>, syn::Error> = enu
enu.variants .variants
.iter() .iter()
.filter_map(|variant| match &variant.fields { .filter_map(filter_skipped_variant)
Fields::Named(fields) => Some(fields.named.iter()), .collect();
Fields::Unnamed(fields) => Some(fields.unnamed.iter()),
Fields::Unit => None, Box::new(
}) filtered?
.flatten(), .into_iter()
), .flat_map(|variant| iter_fields(&variant.fields)),
Data::Union(uni) => Punctuated::from_iter(uni.fields.named.iter()), )
}
Data::Union(uni) => Box::new(uni.fields.named.iter()),
})
}
/// Returns an iterator over the `Field`s in a `Fields` regardless of the fields type (named,
/// unnamed or unit).
fn iter_fields(fields: &Fields) -> Box<dyn Iterator<Item = &Field> + '_> {
match fields {
Fields::Named(fields) => Box::new(fields.named.iter()),
Fields::Unnamed(fields) => Box::new(fields.unnamed.iter()),
Fields::Unit => Box::new(std::iter::empty()),
}
}
/// Can be used inside a field iterator to remove the fields with a `#[versionize(skip)]` attribute
fn filter_skipped_field(field: &Field) -> Option<syn::Result<&Field>> {
match is_skipped(&field.attrs) {
Ok(true) => None,
Ok(false) => Some(Ok(field)),
Err(e) => Some(Err(e)),
}
}
/// Can be used inside a field iterator to only keep the fields with a `#[versionize(skip)]`
/// attribute
fn keep_skipped_field(field: &Field) -> Option<syn::Result<&Field>> {
match is_skipped(&field.attrs) {
Ok(true) => Some(Ok(field)),
Ok(false) => None,
Err(e) => Some(Err(e)),
}
}
/// Can be used inside a variant iterator to remove the variants with a `#[versionize(skip)]`
/// attribute
fn filter_skipped_variant(variant: &Variant) -> Option<syn::Result<&Variant>> {
match is_skipped(&variant.attrs) {
Ok(true) => None,
Ok(false) => Some(Ok(variant)),
Err(e) => Some(Err(e)),
} }
} }

View File

@@ -5,10 +5,13 @@ use proc_macro2::Span;
use quote::ToTokens; use quote::ToTokens;
use syn::punctuated::Punctuated; use syn::punctuated::Punctuated;
use syn::spanned::Spanned; use syn::spanned::Spanned;
use syn::{Attribute, Expr, Lit, Meta, Path, Token}; use syn::{parse_quote, Attribute, Expr, Lit, Meta, Path, Token};
/// Name of the attribute used to give arguments to the `Versionize` macro /// Name of the attribute used to give arguments to the `Versionize` macro
const VERSIONIZE_ATTR_NAME: &str = "versionize"; pub(crate) const VERSIONIZE_ATTR_NAME: &str = "versionize";
/// Name of the attribute used to give arguments to serde macros
pub(crate) const SERDE_ATTR_NAME: &str = "serde";
/// Transparent mode can also be activated using `#[repr(transparent)]` /// Transparent mode can also be activated using `#[repr(transparent)]`
pub(crate) const REPR_ATTR_NAME: &str = "repr"; pub(crate) const REPR_ATTR_NAME: &str = "repr";
@@ -297,11 +300,12 @@ fn parse_path_ignore_quotes(value: &Expr) -> syn::Result<Path> {
} }
} }
/// Check if the target type has the `#[repr(transparent)]` attribute in its attributes list /// Check if the target type has the `#[repr(transparent)]` or `#[serde(transparent)]` attribute in
/// its attributes list
pub(crate) fn is_transparent(attributes: &[Attribute]) -> syn::Result<bool> { pub(crate) fn is_transparent(attributes: &[Attribute]) -> syn::Result<bool> {
if let Some(attr) = attributes if let Some(attr) = attributes
.iter() .iter()
.find(|attr| attr.path().is_ident(REPR_ATTR_NAME)) .find(|attr| attr.path().is_ident(REPR_ATTR_NAME) || attr.path().is_ident(SERDE_ATTR_NAME))
{ {
let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?; let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
@@ -316,3 +320,48 @@ pub(crate) fn is_transparent(attributes: &[Attribute]) -> syn::Result<bool> {
Ok(false) Ok(false)
} }
/// Check if a field has the `#[serde(skip)]` or `#[versionize(skip)]` attribute in
/// its attributes list
pub(crate) fn is_skipped(attributes: &[Attribute]) -> syn::Result<bool> {
if let Some(attr) = attributes.iter().find(|attr| {
attr.path().is_ident(VERSIONIZE_ATTR_NAME) || attr.path().is_ident(SERDE_ATTR_NAME)
}) {
let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
for meta in nested.iter() {
if let Meta::Path(path) = meta {
if path.is_ident("skip") {
return Ok(true);
}
}
}
}
Ok(false)
}
/// Replace `#[versionize(skip)]` with `#[serde(skip)]` in an attributes list
pub(crate) fn replace_versionize_skip_with_serde(
attributes: &[Attribute],
) -> syn::Result<Vec<Attribute>> {
attributes
.iter()
.cloned()
.map(|attr| {
if attr.path().is_ident(VERSIONIZE_ATTR_NAME) {
let nested =
attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
for meta in nested.iter() {
if let Meta::Path(path) = meta {
if path.is_ident("skip") {
return Ok(parse_quote! { #[serde(skip)] });
}
}
}
}
Ok(attr)
})
.collect()
}

View File

@@ -69,3 +69,7 @@ test = true
[[example]] [[example]]
name = "associated_bounds" name = "associated_bounds"
test = true test = true
[[example]]
name = "skip"
test = true

View File

@@ -0,0 +1,96 @@
//! Example of a struct that include a field that is not versionable and is skipped during
//! versioning.
//!
//! This is similar to the `#[serde(skip)]` attribute of Serde.
//! With this attribute, the field is not included in the definition of the associated Version type.
//! During unversioning, the field is instantiated using a `Default` impl.
use std::convert::Infallible;
use std::io::Cursor;
use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch};
// This type is not versionable/serializable and should not be stored.
// It should however at least implement Default to be instantiable when the data is loaded.
#[derive(Default)]
struct NotVersionable(u64);
/// The previous version of our application
mod v0 {
use super::NotVersionable;
use tfhe_versionable::{Versionize, VersionsDispatch};
#[derive(Versionize)]
#[versionize(MyStructVersions)]
pub(super) struct MyStruct {
pub(super) val: u32,
// This attribute is used to skip the versioning of the field
// Also work with `#[serde(skip)]` if the field derives Serialize
#[versionize(skip)]
#[allow(dead_code)]
pub(super) to_skip: NotVersionable,
}
#[derive(VersionsDispatch)]
#[allow(unused)]
pub(super) enum MyStructVersions {
V0(MyStruct),
}
}
#[derive(Version)]
struct MyStructV0 {
val: u32,
#[versionize(skip)]
to_skip: NotVersionable,
}
#[derive(Versionize)]
#[versionize(MyStructVersions)]
struct MyStruct {
val: u64,
#[versionize(skip)]
to_skip: NotVersionable,
}
impl Upgrade<MyStruct> for MyStructV0 {
type Error = Infallible;
fn upgrade(self) -> Result<MyStruct, Self::Error> {
let val = self.val as u64;
Ok(MyStruct {
val,
to_skip: self.to_skip,
})
}
}
#[derive(VersionsDispatch)]
#[allow(unused)]
enum MyStructVersions {
V0(MyStructV0),
V1(MyStruct),
}
fn main() {
let val = 64;
let stru_v0 = v0::MyStruct {
val,
to_skip: NotVersionable(42), // The value will be lost during serialization
};
let mut ser = Vec::new();
ciborium::ser::into_writer(&stru_v0.versionize(), &mut ser).unwrap();
let unvers =
MyStruct::unversionize(ciborium::de::from_reader(&mut Cursor::new(&ser)).unwrap()).unwrap();
assert_eq!(unvers.val, val as u64);
assert_eq!(unvers.to_skip.0, Default::default());
}
#[test]
fn test() {
main()
}

View File

@@ -11,7 +11,7 @@ use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispa
// //
// struct MyStructWrapper<T> { inner: MyStruct<T> }; // struct MyStructWrapper<T> { inner: MyStruct<T> };
#[derive(Versionize)] #[derive(Versionize)]
#[versionize(transparent)] // Also works with `#[repr(transparent)]` #[versionize(transparent)] // Also works with `#[repr(transparent)]` or `#[serde(transparent)]`
struct MyStructWrapper<T>(MyStruct<T>); struct MyStructWrapper<T>(MyStruct<T>);
// The inner struct that is versioned. // The inner struct that is versioned.

View File

@@ -94,6 +94,9 @@ pub enum UnversionizeError {
/// A deprecated version has been found /// A deprecated version has been found
DeprecatedVersion(DeprecatedVersionError), DeprecatedVersion(DeprecatedVersionError),
/// User tried to unversionize an enum variant with the `#[versionize(skip)]` attribute
SkippedVariant { variant_name: String },
} }
impl Display for UnversionizeError { impl Display for UnversionizeError {
@@ -120,6 +123,9 @@ impl Display for UnversionizeError {
) )
} }
Self::DeprecatedVersion(deprecation_error) => deprecation_error.fmt(f), Self::DeprecatedVersion(deprecation_error) => deprecation_error.fmt(f),
Self::SkippedVariant { variant_name } => write!(f,
"Enum variant {variant_name} is marked with the `skip` attribute and cannot be unversioned"
),
} }
} }
} }
@@ -131,6 +137,7 @@ impl Error for UnversionizeError {
UnversionizeError::Conversion { source, .. } => Some(source.as_ref()), UnversionizeError::Conversion { source, .. } => Some(source.as_ref()),
UnversionizeError::ArrayLength { .. } => None, UnversionizeError::ArrayLength { .. } => None,
UnversionizeError::DeprecatedVersion(_) => None, UnversionizeError::DeprecatedVersion(_) => None,
UnversionizeError::SkippedVariant { .. } => None,
} }
} }
} }
@@ -154,6 +161,12 @@ impl UnversionizeError {
source: Box::new(source), source: Box::new(source),
} }
} }
pub fn skipped_variant(variant_name: &str) -> Self {
Self::SkippedVariant {
variant_name: variant_name.to_string(),
}
}
} }
impl From<Infallible> for UnversionizeError { impl From<Infallible> for UnversionizeError {

View File

@@ -0,0 +1,88 @@
//! Test the skip attribute in an enum. This attribute in a struct is already tested in
//! `examples/skip.rs`
use std::convert::Infallible;
use std::io::Cursor;
use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch};
#[allow(dead_code)]
struct NotVersionable(u64);
mod v0 {
use super::NotVersionable;
use tfhe_versionable::{Versionize, VersionsDispatch};
#[derive(Versionize)]
#[versionize(MyEnumVersions)]
pub(super) enum MyEnum {
Var0(u32),
#[versionize(skip)]
#[allow(dead_code)]
Var1(NotVersionable),
}
#[derive(VersionsDispatch)]
#[allow(unused)]
pub(super) enum MyEnumVersions {
V0(MyEnum),
}
}
#[derive(Version)]
enum MyEnumV0 {
Var0(u32),
#[versionize(skip)]
#[allow(dead_code)]
Var1(NotVersionable),
}
#[derive(Versionize)]
#[versionize(MyEnumVersions)]
enum MyEnum {
Var0(u64),
#[versionize(skip)]
#[allow(dead_code)]
Var1(NotVersionable),
}
impl Upgrade<MyEnum> for MyEnumV0 {
type Error = Infallible;
fn upgrade(self) -> Result<MyEnum, Self::Error> {
match self {
MyEnumV0::Var0(val) => Ok(MyEnum::Var0(val as u64)),
MyEnumV0::Var1(val) => Ok(MyEnum::Var1(val)),
}
}
}
#[derive(VersionsDispatch)]
#[allow(unused)]
enum MyEnumVersions {
V0(MyEnumV0),
V1(MyEnum),
}
#[test]
fn test() {
// Test the "normal" variant
let val = 64;
let enu_v0 = v0::MyEnum::Var0(val);
let mut ser = Vec::new();
ciborium::ser::into_writer(&enu_v0.versionize(), &mut ser).unwrap();
let unvers =
MyEnum::unversionize(ciborium::de::from_reader(&mut Cursor::new(&ser)).unwrap()).unwrap();
assert!(matches!(unvers, MyEnum::Var0(unvers_val) if unvers_val == val as u64));
// Test the skipped variant
let val = 64;
let enu_v0 = v0::MyEnum::Var1(NotVersionable(val));
let mut ser = Vec::new();
// Serialization of the skipped variant must fail
assert!(ciborium::ser::into_writer(&enu_v0.versionize(), &mut ser).is_err());
}