mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
feat(versionable): add skip attribute to skip field versioning
This commit is contained in:
committed by
Nicolas Sarlin
parent
e57b91eccd
commit
5f9ac48dbe
@@ -68,7 +68,7 @@ pub(crate) fn generate_try_from_trait_impl(
|
||||
})
|
||||
}
|
||||
|
||||
/// The ownership kind of the data for a associated type.
|
||||
/// The ownership kind of the data for an associated type.
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum AssociatedTypeKind {
|
||||
/// This version type use references to non-Copy rust underlying built-in types.
|
||||
|
||||
@@ -52,6 +52,7 @@ pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into";
|
||||
pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error";
|
||||
pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync";
|
||||
pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send";
|
||||
pub(crate) const DEFAULT_TRAIT_NAME: &str = "::core::default::Default";
|
||||
pub(crate) const STATIC_LIFETIME_NAME: &str = "'static";
|
||||
|
||||
use associated::AssociatingTrait;
|
||||
@@ -71,7 +72,7 @@ macro_rules! syn_unwrap {
|
||||
};
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Version)]
|
||||
#[proc_macro_derive(Version, attributes(versionize))]
|
||||
/// Implement the `Version` trait for the target type.
|
||||
pub fn derive_version(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
|
||||
@@ -2,9 +2,7 @@ use std::iter::zip;
|
||||
|
||||
use proc_macro2::{Literal, Span, TokenStream};
|
||||
use quote::{format_ident, quote};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::token::Comma;
|
||||
use syn::{
|
||||
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, FieldsNamed,
|
||||
FieldsUnnamed, Generics, Ident, Item, ItemEnum, ItemImpl, ItemStruct, ItemUnion, Lifetime,
|
||||
@@ -15,12 +13,12 @@ use crate::associated::{
|
||||
generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind,
|
||||
ConversionDirection,
|
||||
};
|
||||
use crate::versionize_attribute::is_transparent;
|
||||
use crate::versionize_attribute::{is_skipped, is_transparent, replace_versionize_skip_with_serde};
|
||||
use crate::{
|
||||
add_trait_where_clause, parse_const_str, parse_trait_bound, punctuated_from_iter_result,
|
||||
INTO_TRAIT_NAME, LIFETIME_NAME, TRY_INTO_TRAIT_NAME, UNVERSIONIZE_ERROR_NAME,
|
||||
UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, VERSIONIZE_TRAIT_NAME,
|
||||
VERSION_TRAIT_NAME,
|
||||
DEFAULT_TRAIT_NAME, INTO_TRAIT_NAME, LIFETIME_NAME, TRY_INTO_TRAIT_NAME,
|
||||
UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME,
|
||||
VERSIONIZE_TRAIT_NAME, VERSION_TRAIT_NAME,
|
||||
};
|
||||
|
||||
/// The types generated for a specific version of a given exposed type. These types are identical to
|
||||
@@ -199,9 +197,9 @@ impl AssociatedType for VersionType {
|
||||
}
|
||||
|
||||
fn inner_types(&self) -> syn::Result<Vec<&Type>> {
|
||||
self.orig_type_fields()
|
||||
.iter()
|
||||
.map(|field| Ok(&field.ty))
|
||||
self.orig_type_fields()?
|
||||
.filter_map(filter_skipped_field)
|
||||
.map(|field| Ok(&field?.ty))
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -232,6 +230,14 @@ impl AssociatedType for VersionType {
|
||||
self.inner_types()?,
|
||||
&[UNVERSIONIZE_TRAIT_NAME],
|
||||
)?;
|
||||
|
||||
// "skipped" types are not present in the Version types so we add a Default
|
||||
// bound to be able to reconstruct them.
|
||||
add_trait_where_clause(
|
||||
&mut generics,
|
||||
self.skipped_inner_types()?,
|
||||
&[DEFAULT_TRAIT_NAME],
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -242,10 +248,18 @@ impl AssociatedType for VersionType {
|
||||
|
||||
impl VersionType {
|
||||
/// Returns the fields of the original declaration.
|
||||
fn orig_type_fields(&self) -> Punctuated<&Field, Comma> {
|
||||
fn orig_type_fields(&self) -> syn::Result<Box<dyn Iterator<Item = &Field> + '_>> {
|
||||
derive_type_fields(&self.orig_type)
|
||||
}
|
||||
|
||||
/// Returns the list of types inside the original type that are skipped
|
||||
fn skipped_inner_types(&self) -> syn::Result<Vec<&Type>> {
|
||||
self.orig_type_fields()?
|
||||
.filter_map(keep_skipped_field)
|
||||
.map(|field| Ok(&field?.ty))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generates the declaration for the Version equivalent of the input struct
|
||||
fn generate_struct(&self, stru: &DataStruct) -> syn::Result<ItemStruct> {
|
||||
let fields = match &stru.fields {
|
||||
@@ -313,14 +327,24 @@ impl VersionType {
|
||||
|
||||
/// Converts an enum variant into its "Version" form
|
||||
fn convert_enum_variant(&self, variant: &Variant) -> syn::Result<Variant> {
|
||||
let fields = match &variant.fields {
|
||||
Fields::Named(fields) => Fields::Named(self.convert_fields_named(fields)?),
|
||||
Fields::Unnamed(fields) => Fields::Unnamed(self.convert_fields_unnamed(fields)?),
|
||||
Fields::Unit => Fields::Unit,
|
||||
let is_skipped = is_skipped(&variant.attrs)?;
|
||||
let fields = if is_skipped {
|
||||
// If the whole variant is skipped convert the variant to a unit. That way it still
|
||||
// compiles but the user gets an error at the serialization step
|
||||
Fields::Unit
|
||||
} else {
|
||||
match &variant.fields {
|
||||
Fields::Named(fields) => Fields::Named(self.convert_fields_named(fields)?),
|
||||
Fields::Unnamed(fields) => Fields::Unnamed(self.convert_fields_unnamed(fields)?),
|
||||
Fields::Unit => Fields::Unit,
|
||||
}
|
||||
};
|
||||
|
||||
// Copy the attributes from the initial variant and remove the ones that were meant for us
|
||||
let attrs = replace_versionize_skip_with_serde(&variant.attrs)?;
|
||||
|
||||
let versioned_variant = Variant {
|
||||
attrs: Vec::new(),
|
||||
attrs,
|
||||
ident: variant.ident.clone(),
|
||||
fields,
|
||||
discriminant: variant.discriminant.clone(),
|
||||
@@ -353,45 +377,49 @@ impl VersionType {
|
||||
let kind = self.kind.clone();
|
||||
let is_transparent = self.is_transparent;
|
||||
|
||||
fields_iter.into_iter().map(move |field| {
|
||||
let unver_ty = field.ty.clone();
|
||||
fields_iter
|
||||
.into_iter()
|
||||
.filter_map(filter_skipped_field)
|
||||
.map(move |field| {
|
||||
let field = field?;
|
||||
let unver_ty = field.ty.clone();
|
||||
|
||||
if is_transparent {
|
||||
// If the type is transparent, we reuse the "Version" impl of the inner type
|
||||
let version_trait = parse_trait_bound(VERSION_TRAIT_NAME)?;
|
||||
if is_transparent {
|
||||
// If the type is transparent, we reuse the "Version" impl of the inner type
|
||||
let version_trait = parse_trait_bound(VERSION_TRAIT_NAME)?;
|
||||
|
||||
let ty: Type = match &kind {
|
||||
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
|
||||
<#unver_ty as #version_trait>::Ref<#lifetime>
|
||||
},
|
||||
AssociatedTypeKind::Owned => parse_quote! {
|
||||
<#unver_ty as #version_trait>::Owned
|
||||
},
|
||||
};
|
||||
let ty: Type = match &kind {
|
||||
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
|
||||
<#unver_ty as #version_trait>::Ref<#lifetime>
|
||||
},
|
||||
AssociatedTypeKind::Owned => parse_quote! {
|
||||
<#unver_ty as #version_trait>::Owned
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Field {
|
||||
ty,
|
||||
..field.clone()
|
||||
})
|
||||
} else {
|
||||
let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?;
|
||||
let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?;
|
||||
Ok(Field {
|
||||
ty,
|
||||
..field.clone()
|
||||
})
|
||||
} else {
|
||||
let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?;
|
||||
let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?;
|
||||
|
||||
let ty: Type = match &kind {
|
||||
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
|
||||
<#unver_ty as #versionize_trait>::Versioned<#lifetime>
|
||||
},
|
||||
AssociatedTypeKind::Owned => parse_quote! {
|
||||
<#unver_ty as #versionize_owned_trait>::VersionedOwned
|
||||
},
|
||||
};
|
||||
let ty: Type = match &kind {
|
||||
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
|
||||
<#unver_ty as #versionize_trait>::Versioned<#lifetime>
|
||||
},
|
||||
AssociatedTypeKind::Owned => parse_quote! {
|
||||
<#unver_ty as #versionize_owned_trait>::VersionedOwned
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Field {
|
||||
ty,
|
||||
..field.clone()
|
||||
})
|
||||
}
|
||||
})
|
||||
Ok(Field {
|
||||
ty,
|
||||
..field.clone()
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor part of the conversion impl block. This will create the dest type
|
||||
@@ -495,7 +523,10 @@ impl VersionType {
|
||||
variant: &Variant,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let (param, fields) = match &variant.fields {
|
||||
let is_skipped = is_skipped(&variant.attrs)?;
|
||||
let variant_ident = &variant.ident;
|
||||
|
||||
Ok(match &variant.fields {
|
||||
Fields::Named(fields) => {
|
||||
let args_iter = fields
|
||||
.named
|
||||
@@ -504,35 +535,47 @@ impl VersionType {
|
||||
.map(|field| field.ident.as_ref().unwrap());
|
||||
let args = args_iter.clone();
|
||||
|
||||
(
|
||||
quote! { {
|
||||
#(#args),*
|
||||
}},
|
||||
self.generate_constructor_enum_variants_named(
|
||||
if is_skipped {
|
||||
self.generate_constructor_skipped_enum_variants(
|
||||
src_type,
|
||||
variant_ident,
|
||||
direction,
|
||||
)
|
||||
} else {
|
||||
let constructor = self.generate_constructor_enum_variants_named(
|
||||
args_iter.cloned(),
|
||||
fields.named.iter(),
|
||||
direction,
|
||||
)?,
|
||||
)
|
||||
)?;
|
||||
quote! {
|
||||
#src_type::#variant_ident {#(#args),*} =>
|
||||
Self::#variant_ident #constructor
|
||||
}
|
||||
}
|
||||
}
|
||||
Fields::Unnamed(fields) => {
|
||||
let args_iter = generate_args_list(fields.unnamed.len());
|
||||
let args = args_iter.clone();
|
||||
(
|
||||
quote! { (#(#args),*) },
|
||||
self.generate_constructor_enum_variants_unnamed(
|
||||
|
||||
if is_skipped {
|
||||
self.generate_constructor_skipped_enum_variants(
|
||||
src_type,
|
||||
variant_ident,
|
||||
direction,
|
||||
)
|
||||
} else {
|
||||
let constructor = self.generate_constructor_enum_variants_unnamed(
|
||||
args_iter,
|
||||
fields.unnamed.iter(),
|
||||
direction,
|
||||
)?,
|
||||
)
|
||||
)?;
|
||||
quote! {
|
||||
#src_type::#variant_ident (#(#args),*) =>
|
||||
Self::#variant_ident #constructor
|
||||
}
|
||||
}
|
||||
}
|
||||
Fields::Unit => (TokenStream::new(), TokenStream::new()),
|
||||
};
|
||||
let variant_ident = &variant.ident;
|
||||
|
||||
Ok(quote! {
|
||||
#src_type::#variant_ident #param => Self::#variant_ident #fields
|
||||
Fields::Unit => quote! { #src_type::#variant_ident => Self::#variant_ident },
|
||||
})
|
||||
}
|
||||
|
||||
@@ -545,7 +588,10 @@ impl VersionType {
|
||||
) -> syn::Result<TokenStream> {
|
||||
let fields: syn::Result<Vec<TokenStream>> = fields
|
||||
.into_iter()
|
||||
.map(move |field| self.generate_constructor_field_named(arg_name, field, direction))
|
||||
.filter_map(move |field| {
|
||||
self.generate_constructor_field_named(arg_name, field, direction)
|
||||
.transpose()
|
||||
})
|
||||
.collect();
|
||||
let fields = fields?;
|
||||
|
||||
@@ -562,7 +608,7 @@ impl VersionType {
|
||||
arg_name: &str,
|
||||
field: &Field,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
) -> syn::Result<Option<TokenStream>> {
|
||||
let arg_ident = Ident::new(arg_name, Span::call_site());
|
||||
// Ok to unwrap because the field is named so field.ident is Some
|
||||
let field_ident = field.ident.as_ref().unwrap();
|
||||
@@ -570,14 +616,23 @@ impl VersionType {
|
||||
let param = quote! { #arg_ident.#field_ident };
|
||||
|
||||
let rhs = if self.is_transparent() {
|
||||
self.generate_constructor_transparent_rhs(param, direction)?
|
||||
self.generate_constructor_transparent_rhs(param, direction)
|
||||
.map(Some)
|
||||
} else {
|
||||
self.generate_constructor_field_rhs(ty, param, false, direction)?
|
||||
};
|
||||
self.generate_constructor_field_rhs(
|
||||
ty,
|
||||
param,
|
||||
false,
|
||||
is_skipped(&field.attrs)?,
|
||||
direction,
|
||||
)
|
||||
}?;
|
||||
|
||||
Ok(quote! {
|
||||
#field_ident: #rhs
|
||||
})
|
||||
Ok(rhs.map(|rhs| {
|
||||
quote! {
|
||||
#field_ident: #rhs
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Generates the constructor for the fields of a named enum variant.
|
||||
@@ -592,22 +647,32 @@ impl VersionType {
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields)
|
||||
.map(move |(arg_name, field)| {
|
||||
.filter_map(move |(arg_name, field)| {
|
||||
// Ok to unwrap because the field is named so field.ident is Some
|
||||
let field_ident = field.ident.as_ref().unwrap();
|
||||
|
||||
let rhs = if self.is_transparent() {
|
||||
self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)?
|
||||
Some(self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction))
|
||||
} else {
|
||||
let skipped = match is_skipped(&field.attrs) {
|
||||
Ok(skipped) => skipped,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
self.generate_constructor_field_rhs(
|
||||
&field.ty,
|
||||
quote! {#arg_name},
|
||||
true,
|
||||
skipped,
|
||||
direction,
|
||||
)?
|
||||
};
|
||||
Ok(quote! {
|
||||
#field_ident: #rhs
|
||||
})
|
||||
)
|
||||
.transpose()
|
||||
}?;
|
||||
|
||||
Some(rhs.map(|rhs| {
|
||||
quote! {
|
||||
#field_ident: #rhs
|
||||
}
|
||||
}))
|
||||
})
|
||||
.collect();
|
||||
let fields = fields?;
|
||||
@@ -629,8 +694,9 @@ impl VersionType {
|
||||
let fields: syn::Result<Vec<TokenStream>> = fields
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(move |(idx, field)| {
|
||||
.filter_map(move |(idx, field)| {
|
||||
self.generate_constructor_field_unnamed(arg_name, field, idx, direction)
|
||||
.transpose()
|
||||
})
|
||||
.collect();
|
||||
let fields = fields?;
|
||||
@@ -647,7 +713,7 @@ impl VersionType {
|
||||
field: &Field,
|
||||
idx: usize,
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
) -> syn::Result<Option<TokenStream>> {
|
||||
let arg_ident = Ident::new(arg_name, Span::call_site());
|
||||
let idx = Literal::usize_unsuffixed(idx);
|
||||
let ty = &field.ty;
|
||||
@@ -655,8 +721,15 @@ impl VersionType {
|
||||
|
||||
if self.is_transparent {
|
||||
self.generate_constructor_transparent_rhs(param, direction)
|
||||
.map(Some)
|
||||
} else {
|
||||
self.generate_constructor_field_rhs(ty, param, false, direction)
|
||||
self.generate_constructor_field_rhs(
|
||||
ty,
|
||||
param,
|
||||
false,
|
||||
is_skipped(&field.attrs)?,
|
||||
direction,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -672,16 +745,22 @@ impl VersionType {
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields)
|
||||
.map(move |(arg_name, field)| {
|
||||
if self.is_transparent {
|
||||
self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)
|
||||
.filter_map(move |(arg_name, field)| {
|
||||
if self.is_transparent() {
|
||||
Some(self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction))
|
||||
} else {
|
||||
let skipped = match is_skipped(&field.attrs) {
|
||||
Ok(skipped) => skipped,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
self.generate_constructor_field_rhs(
|
||||
&field.ty,
|
||||
quote! {#arg_name},
|
||||
true,
|
||||
skipped,
|
||||
direction,
|
||||
)
|
||||
.transpose()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
@@ -692,6 +771,35 @@ impl VersionType {
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the constructor for a variant of an enum with the `skip` attribute.
|
||||
///
|
||||
/// This constructor is never supposed to be called, but we need to handle it anyways.
|
||||
///
|
||||
/// During a call to "versionize", the conversion will simply create a unit variant that will
|
||||
/// trigger an error at the "serialize" step. During a call to "unversionize", the conversion
|
||||
/// will raise an error.
|
||||
fn generate_constructor_skipped_enum_variants(
|
||||
&self,
|
||||
src_type: &Ident,
|
||||
variant_ident: &Ident,
|
||||
direction: ConversionDirection,
|
||||
) -> TokenStream {
|
||||
match direction {
|
||||
ConversionDirection::OrigToAssociated => quote! {
|
||||
#src_type::#variant_ident { .. } =>
|
||||
Self::#variant_ident
|
||||
},
|
||||
ConversionDirection::AssociatedToOrig => {
|
||||
let error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
|
||||
let variant_name = format!("{}::{}", self.orig_type.ident, variant_ident);
|
||||
|
||||
quote! {
|
||||
#src_type::#variant_ident => return Err(#error::skipped_variant(#variant_name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates the rhs part of a field constructor.
|
||||
/// For example, in `Self { count: value.count.versionize() }`, this is
|
||||
/// `value.count.versionize()`.
|
||||
@@ -699,15 +807,21 @@ impl VersionType {
|
||||
&self,
|
||||
ty: &Type,
|
||||
field_param: TokenStream,
|
||||
is_ref: bool, // True if the param is already a reference
|
||||
is_ref: bool, // True if the param is already a reference
|
||||
is_skipped: bool, // True if the field has the `skipped` attribute
|
||||
direction: ConversionDirection,
|
||||
) -> syn::Result<TokenStream> {
|
||||
) -> syn::Result<Option<TokenStream>> {
|
||||
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
|
||||
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
|
||||
let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
|
||||
let default_trait: Path = parse_const_str(DEFAULT_TRAIT_NAME);
|
||||
|
||||
let field_constructor = match direction {
|
||||
ConversionDirection::OrigToAssociated => {
|
||||
if is_skipped {
|
||||
// Skipped fields does not exist in the associated type so we return None
|
||||
return Ok(None);
|
||||
}
|
||||
match self.kind {
|
||||
AssociatedTypeKind::Ref(_) => {
|
||||
let param = if is_ref {
|
||||
@@ -727,12 +841,21 @@ impl VersionType {
|
||||
ConversionDirection::AssociatedToOrig => match self.kind {
|
||||
AssociatedTypeKind::Ref(_) =>
|
||||
panic!("No conversion should be generated between associated ref type to original type"),
|
||||
AssociatedTypeKind::Owned => quote! {
|
||||
<#ty as #unversionize_trait>::unversionize(#field_param)?
|
||||
AssociatedTypeKind::Owned => {
|
||||
if is_skipped {
|
||||
// If the field is skipped, we try to construct it from a Default impl (this is what serde does)
|
||||
quote! {
|
||||
<#ty as #default_trait>::default()
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
<#ty as #unversionize_trait>::unversionize(#field_param)?
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
};
|
||||
Ok(field_constructor)
|
||||
Ok(Some(field_constructor))
|
||||
}
|
||||
|
||||
fn generate_constructor_transparent_rhs(
|
||||
@@ -788,23 +911,64 @@ fn is_unit(input: &DeriveInput) -> bool {
|
||||
|
||||
/// Returns the fields of the input type. This is independent of the kind of type
|
||||
/// (enum, struct, ...)
|
||||
fn derive_type_fields(input: &DeriveInput) -> Punctuated<&Field, Comma> {
|
||||
match &input.data {
|
||||
Data::Struct(stru) => match &stru.fields {
|
||||
Fields::Named(fields) => Punctuated::from_iter(fields.named.iter()),
|
||||
Fields::Unnamed(fields) => Punctuated::from_iter(fields.unnamed.iter()),
|
||||
Fields::Unit => Punctuated::new(),
|
||||
},
|
||||
Data::Enum(enu) => Punctuated::<&Field, Comma>::from_iter(
|
||||
enu.variants
|
||||
///
|
||||
/// In the case of an enum, the fields for each variants that are not skipped are flattened into a
|
||||
/// single iterator.
|
||||
fn derive_type_fields(input: &DeriveInput) -> syn::Result<Box<dyn Iterator<Item = &Field> + '_>> {
|
||||
Ok(match &input.data {
|
||||
Data::Struct(stru) => Box::new(iter_fields(&stru.fields)),
|
||||
Data::Enum(enu) => {
|
||||
let filtered: Result<Vec<&Variant>, syn::Error> = enu
|
||||
.variants
|
||||
.iter()
|
||||
.filter_map(|variant| match &variant.fields {
|
||||
Fields::Named(fields) => Some(fields.named.iter()),
|
||||
Fields::Unnamed(fields) => Some(fields.unnamed.iter()),
|
||||
Fields::Unit => None,
|
||||
})
|
||||
.flatten(),
|
||||
),
|
||||
Data::Union(uni) => Punctuated::from_iter(uni.fields.named.iter()),
|
||||
.filter_map(filter_skipped_variant)
|
||||
.collect();
|
||||
|
||||
Box::new(
|
||||
filtered?
|
||||
.into_iter()
|
||||
.flat_map(|variant| iter_fields(&variant.fields)),
|
||||
)
|
||||
}
|
||||
Data::Union(uni) => Box::new(uni.fields.named.iter()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over the `Field`s in a `Fields` regardless of the fields type (named,
|
||||
/// unnamed or unit).
|
||||
fn iter_fields(fields: &Fields) -> Box<dyn Iterator<Item = &Field> + '_> {
|
||||
match fields {
|
||||
Fields::Named(fields) => Box::new(fields.named.iter()),
|
||||
Fields::Unnamed(fields) => Box::new(fields.unnamed.iter()),
|
||||
Fields::Unit => Box::new(std::iter::empty()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Can be used inside a field iterator to remove the fields with a `#[versionize(skip)]` attribute
|
||||
fn filter_skipped_field(field: &Field) -> Option<syn::Result<&Field>> {
|
||||
match is_skipped(&field.attrs) {
|
||||
Ok(true) => None,
|
||||
Ok(false) => Some(Ok(field)),
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Can be used inside a field iterator to only keep the fields with a `#[versionize(skip)]`
|
||||
/// attribute
|
||||
fn keep_skipped_field(field: &Field) -> Option<syn::Result<&Field>> {
|
||||
match is_skipped(&field.attrs) {
|
||||
Ok(true) => Some(Ok(field)),
|
||||
Ok(false) => None,
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Can be used inside a variant iterator to remove the variants with a `#[versionize(skip)]`
|
||||
/// attribute
|
||||
fn filter_skipped_variant(variant: &Variant) -> Option<syn::Result<&Variant>> {
|
||||
match is_skipped(&variant.attrs) {
|
||||
Ok(true) => None,
|
||||
Ok(false) => Some(Ok(variant)),
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,10 +5,13 @@ use proc_macro2::Span;
|
||||
use quote::ToTokens;
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::{Attribute, Expr, Lit, Meta, Path, Token};
|
||||
use syn::{parse_quote, Attribute, Expr, Lit, Meta, Path, Token};
|
||||
|
||||
/// Name of the attribute used to give arguments to the `Versionize` macro
|
||||
const VERSIONIZE_ATTR_NAME: &str = "versionize";
|
||||
pub(crate) const VERSIONIZE_ATTR_NAME: &str = "versionize";
|
||||
|
||||
/// Name of the attribute used to give arguments to serde macros
|
||||
pub(crate) const SERDE_ATTR_NAME: &str = "serde";
|
||||
|
||||
/// Transparent mode can also be activated using `#[repr(transparent)]`
|
||||
pub(crate) const REPR_ATTR_NAME: &str = "repr";
|
||||
@@ -297,11 +300,12 @@ fn parse_path_ignore_quotes(value: &Expr) -> syn::Result<Path> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the target type has the `#[repr(transparent)]` attribute in its attributes list
|
||||
/// Check if the target type has the `#[repr(transparent)]` or `#[serde(transparent)]` attribute in
|
||||
/// its attributes list
|
||||
pub(crate) fn is_transparent(attributes: &[Attribute]) -> syn::Result<bool> {
|
||||
if let Some(attr) = attributes
|
||||
.iter()
|
||||
.find(|attr| attr.path().is_ident(REPR_ATTR_NAME))
|
||||
.find(|attr| attr.path().is_ident(REPR_ATTR_NAME) || attr.path().is_ident(SERDE_ATTR_NAME))
|
||||
{
|
||||
let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
|
||||
|
||||
@@ -316,3 +320,48 @@ pub(crate) fn is_transparent(attributes: &[Attribute]) -> syn::Result<bool> {
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Check if a field has the `#[serde(skip)]` or `#[versionize(skip)]` attribute in
|
||||
/// its attributes list
|
||||
pub(crate) fn is_skipped(attributes: &[Attribute]) -> syn::Result<bool> {
|
||||
if let Some(attr) = attributes.iter().find(|attr| {
|
||||
attr.path().is_ident(VERSIONIZE_ATTR_NAME) || attr.path().is_ident(SERDE_ATTR_NAME)
|
||||
}) {
|
||||
let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
|
||||
|
||||
for meta in nested.iter() {
|
||||
if let Meta::Path(path) = meta {
|
||||
if path.is_ident("skip") {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Replace `#[versionize(skip)]` with `#[serde(skip)]` in an attributes list
|
||||
pub(crate) fn replace_versionize_skip_with_serde(
|
||||
attributes: &[Attribute],
|
||||
) -> syn::Result<Vec<Attribute>> {
|
||||
attributes
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|attr| {
|
||||
if attr.path().is_ident(VERSIONIZE_ATTR_NAME) {
|
||||
let nested =
|
||||
attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
|
||||
|
||||
for meta in nested.iter() {
|
||||
if let Meta::Path(path) = meta {
|
||||
if path.is_ident("skip") {
|
||||
return Ok(parse_quote! { #[serde(skip)] });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(attr)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -69,3 +69,7 @@ test = true
|
||||
[[example]]
|
||||
name = "associated_bounds"
|
||||
test = true
|
||||
|
||||
[[example]]
|
||||
name = "skip"
|
||||
test = true
|
||||
|
||||
96
utils/tfhe-versionable/examples/skip.rs
Normal file
96
utils/tfhe-versionable/examples/skip.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
//! Example of a struct that include a field that is not versionable and is skipped during
|
||||
//! versioning.
|
||||
//!
|
||||
//! This is similar to the `#[serde(skip)]` attribute of Serde.
|
||||
//! With this attribute, the field is not included in the definition of the associated Version type.
|
||||
//! During unversioning, the field is instantiated using a `Default` impl.
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::io::Cursor;
|
||||
|
||||
use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch};
|
||||
|
||||
// This type is not versionable/serializable and should not be stored.
|
||||
// It should however at least implement Default to be instantiable when the data is loaded.
|
||||
#[derive(Default)]
|
||||
struct NotVersionable(u64);
|
||||
|
||||
/// The previous version of our application
|
||||
mod v0 {
|
||||
use super::NotVersionable;
|
||||
use tfhe_versionable::{Versionize, VersionsDispatch};
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
pub(super) struct MyStruct {
|
||||
pub(super) val: u32,
|
||||
// This attribute is used to skip the versioning of the field
|
||||
// Also work with `#[serde(skip)]` if the field derives Serialize
|
||||
#[versionize(skip)]
|
||||
#[allow(dead_code)]
|
||||
pub(super) to_skip: NotVersionable,
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub(super) enum MyStructVersions {
|
||||
V0(MyStruct),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Version)]
|
||||
struct MyStructV0 {
|
||||
val: u32,
|
||||
#[versionize(skip)]
|
||||
to_skip: NotVersionable,
|
||||
}
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyStructVersions)]
|
||||
struct MyStruct {
|
||||
val: u64,
|
||||
#[versionize(skip)]
|
||||
to_skip: NotVersionable,
|
||||
}
|
||||
|
||||
impl Upgrade<MyStruct> for MyStructV0 {
|
||||
type Error = Infallible;
|
||||
|
||||
fn upgrade(self) -> Result<MyStruct, Self::Error> {
|
||||
let val = self.val as u64;
|
||||
|
||||
Ok(MyStruct {
|
||||
val,
|
||||
to_skip: self.to_skip,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
enum MyStructVersions {
|
||||
V0(MyStructV0),
|
||||
V1(MyStruct),
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let val = 64;
|
||||
let stru_v0 = v0::MyStruct {
|
||||
val,
|
||||
to_skip: NotVersionable(42), // The value will be lost during serialization
|
||||
};
|
||||
|
||||
let mut ser = Vec::new();
|
||||
ciborium::ser::into_writer(&stru_v0.versionize(), &mut ser).unwrap();
|
||||
|
||||
let unvers =
|
||||
MyStruct::unversionize(ciborium::de::from_reader(&mut Cursor::new(&ser)).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(unvers.val, val as u64);
|
||||
assert_eq!(unvers.to_skip.0, Default::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
main()
|
||||
}
|
||||
@@ -11,7 +11,7 @@ use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispa
|
||||
//
|
||||
// struct MyStructWrapper<T> { inner: MyStruct<T> };
|
||||
#[derive(Versionize)]
|
||||
#[versionize(transparent)] // Also works with `#[repr(transparent)]`
|
||||
#[versionize(transparent)] // Also works with `#[repr(transparent)]` or `#[serde(transparent)]`
|
||||
struct MyStructWrapper<T>(MyStruct<T>);
|
||||
|
||||
// The inner struct that is versioned.
|
||||
|
||||
@@ -94,6 +94,9 @@ pub enum UnversionizeError {
|
||||
|
||||
/// A deprecated version has been found
|
||||
DeprecatedVersion(DeprecatedVersionError),
|
||||
|
||||
/// User tried to unversionize an enum variant with the `#[versionize(skip)]` attribute
|
||||
SkippedVariant { variant_name: String },
|
||||
}
|
||||
|
||||
impl Display for UnversionizeError {
|
||||
@@ -120,6 +123,9 @@ impl Display for UnversionizeError {
|
||||
)
|
||||
}
|
||||
Self::DeprecatedVersion(deprecation_error) => deprecation_error.fmt(f),
|
||||
Self::SkippedVariant { variant_name } => write!(f,
|
||||
"Enum variant {variant_name} is marked with the `skip` attribute and cannot be unversioned"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -131,6 +137,7 @@ impl Error for UnversionizeError {
|
||||
UnversionizeError::Conversion { source, .. } => Some(source.as_ref()),
|
||||
UnversionizeError::ArrayLength { .. } => None,
|
||||
UnversionizeError::DeprecatedVersion(_) => None,
|
||||
UnversionizeError::SkippedVariant { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -154,6 +161,12 @@ impl UnversionizeError {
|
||||
source: Box::new(source),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn skipped_variant(variant_name: &str) -> Self {
|
||||
Self::SkippedVariant {
|
||||
variant_name: variant_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Infallible> for UnversionizeError {
|
||||
|
||||
88
utils/tfhe-versionable/tests/skip_enum.rs
Normal file
88
utils/tfhe-versionable/tests/skip_enum.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
//! Test the skip attribute in an enum. This attribute in a struct is already tested in
|
||||
//! `examples/skip.rs`
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::io::Cursor;
|
||||
|
||||
use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch};
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct NotVersionable(u64);
|
||||
|
||||
mod v0 {
|
||||
use super::NotVersionable;
|
||||
use tfhe_versionable::{Versionize, VersionsDispatch};
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyEnumVersions)]
|
||||
pub(super) enum MyEnum {
|
||||
Var0(u32),
|
||||
#[versionize(skip)]
|
||||
#[allow(dead_code)]
|
||||
Var1(NotVersionable),
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
pub(super) enum MyEnumVersions {
|
||||
V0(MyEnum),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Version)]
|
||||
enum MyEnumV0 {
|
||||
Var0(u32),
|
||||
#[versionize(skip)]
|
||||
#[allow(dead_code)]
|
||||
Var1(NotVersionable),
|
||||
}
|
||||
|
||||
#[derive(Versionize)]
|
||||
#[versionize(MyEnumVersions)]
|
||||
enum MyEnum {
|
||||
Var0(u64),
|
||||
#[versionize(skip)]
|
||||
#[allow(dead_code)]
|
||||
Var1(NotVersionable),
|
||||
}
|
||||
|
||||
impl Upgrade<MyEnum> for MyEnumV0 {
|
||||
type Error = Infallible;
|
||||
|
||||
fn upgrade(self) -> Result<MyEnum, Self::Error> {
|
||||
match self {
|
||||
MyEnumV0::Var0(val) => Ok(MyEnum::Var0(val as u64)),
|
||||
MyEnumV0::Var1(val) => Ok(MyEnum::Var1(val)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
#[allow(unused)]
|
||||
enum MyEnumVersions {
|
||||
V0(MyEnumV0),
|
||||
V1(MyEnum),
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
// Test the "normal" variant
|
||||
let val = 64;
|
||||
let enu_v0 = v0::MyEnum::Var0(val);
|
||||
|
||||
let mut ser = Vec::new();
|
||||
ciborium::ser::into_writer(&enu_v0.versionize(), &mut ser).unwrap();
|
||||
|
||||
let unvers =
|
||||
MyEnum::unversionize(ciborium::de::from_reader(&mut Cursor::new(&ser)).unwrap()).unwrap();
|
||||
|
||||
assert!(matches!(unvers, MyEnum::Var0(unvers_val) if unvers_val == val as u64));
|
||||
|
||||
// Test the skipped variant
|
||||
let val = 64;
|
||||
let enu_v0 = v0::MyEnum::Var1(NotVersionable(val));
|
||||
|
||||
let mut ser = Vec::new();
|
||||
// Serialization of the skipped variant must fail
|
||||
assert!(ciborium::ser::into_writer(&enu_v0.versionize(), &mut ser).is_err());
|
||||
}
|
||||
Reference in New Issue
Block a user