feat(versionable): "Version" macro now handles transparent attribute

This commit is contained in:
Nicolas Sarlin
2024-12-04 10:27:13 +01:00
committed by Nicolas Sarlin
parent e9c901b3a9
commit 3dcb982a0b
6 changed files with 394 additions and 57 deletions

View File

@@ -94,9 +94,9 @@ pub(crate) enum AssociatedTypeKind {
/// [`VersionType`]: crate::dispatch_type::VersionType /// [`VersionType`]: crate::dispatch_type::VersionType
pub(crate) trait AssociatedType: Sized { pub(crate) trait AssociatedType: Sized {
/// Bounds that will be added on the fields of the ref type definition /// Bounds that will be added on the fields of the ref type definition
const REF_BOUNDS: &'static [&'static str]; fn ref_bounds(&self) -> &'static [&'static str];
/// Bounds that will be added on the fields of the owned type definition /// Bounds that will be added on the fields of the owned type definition
const OWNED_BOUNDS: &'static [&'static str]; fn owned_bounds(&self) -> &'static [&'static str];
/// This will create the alternative of the type that holds a reference to the underlying data /// This will create the alternative of the type that holds a reference to the underlying data
fn new_ref(orig_type: &DeriveInput) -> syn::Result<Self>; fn new_ref(orig_type: &DeriveInput) -> syn::Result<Self>;
@@ -109,6 +109,10 @@ pub(crate) trait AssociatedType: Sized {
/// Returns the kind of associated type, a ref or an owned type /// Returns the kind of associated type, a ref or an owned type
fn kind(&self) -> &AssociatedTypeKind; fn kind(&self) -> &AssociatedTypeKind;
/// Returns true if the type is transparent and trait implementation is actually deferred to the
/// inner type
fn is_transparent(&self) -> bool;
/// Returns the generics found in the original type definition /// Returns the generics found in the original type definition
fn orig_type_generics(&self) -> &Generics; fn orig_type_generics(&self) -> &Generics;
@@ -119,9 +123,9 @@ pub(crate) trait AssociatedType: Sized {
if let Some(lifetime) = opt_lifetime { if let Some(lifetime) = opt_lifetime {
add_lifetime_param(&mut generics, lifetime); add_lifetime_param(&mut generics, lifetime);
} }
add_trait_where_clause(&mut generics, self.inner_types()?, Self::REF_BOUNDS)?; add_trait_where_clause(&mut generics, self.inner_types()?, self.ref_bounds())?;
} else { } else {
add_trait_where_clause(&mut generics, self.inner_types()?, Self::OWNED_BOUNDS)?; add_trait_where_clause(&mut generics, self.inner_types()?, self.owned_bounds())?;
} }
Ok(generics) Ok(generics)
@@ -254,14 +258,27 @@ impl<T: AssociatedType> AssociatingTrait<T> {
) )
]}; ]};
let owned_attributes = if self.owned_type.is_transparent() {
quote! {
#[derive(#serialize_trait, #deserialize_trait)]
#[repr(transparent)]
#[serde(bound = "")]
#ignored_lints
}
} else {
quote! {
#[derive(#serialize_trait, #deserialize_trait)]
#[serde(bound = "")]
#ignored_lints
}
};
// Creates the type declaration. These types are the output of the versioning process, so // Creates the type declaration. These types are the output of the versioning process, so
// they should be serializable. Serde might try to add automatic bounds on the type generics // they should be serializable. Serde might try to add automatic bounds on the type generics
// even if we don't need them, so we use `#[serde(bound = "")]` to disable this. The bounds // even if we don't need them, so we use `#[serde(bound = "")]` to disable this. The bounds
// on the generated types should be sufficient. // on the generated types should be sufficient.
let owned_tokens = quote! { let owned_tokens = quote! {
#[derive(#serialize_trait, #deserialize_trait)] #owned_attributes
#[serde(bound = "")]
#ignored_lints
#owned_decla #owned_decla
#(#owned_conversion)* #(#owned_conversion)*
@@ -271,10 +288,23 @@ impl<T: AssociatedType> AssociatingTrait<T> {
let ref_conversion = self.ref_type.generate_conversion()?; let ref_conversion = self.ref_type.generate_conversion()?;
let ref_attributes = if self.ref_type.is_transparent() {
quote! {
#[derive(#serialize_trait)]
#[repr(transparent)]
#[serde(bound = "")]
#ignored_lints
}
} else {
quote! {
#[derive(#serialize_trait)]
#[serde(bound = "")]
#ignored_lints
}
};
let ref_tokens = quote! { let ref_tokens = quote! {
#[derive(#serialize_trait)] #ref_attributes
#[serde(bound = "")]
#ignored_lints
#ref_decla #ref_decla
#(#ref_conversion)* #(#ref_conversion)*

View File

@@ -47,9 +47,13 @@ fn derive_input_to_enum(input: &DeriveInput) -> syn::Result<ItemEnum> {
} }
impl AssociatedType for DispatchType { impl AssociatedType for DispatchType {
const REF_BOUNDS: &'static [&'static str] = &[VERSION_TRAIT_NAME]; fn ref_bounds(&self) -> &'static [&'static str] {
&[VERSION_TRAIT_NAME]
}
const OWNED_BOUNDS: &'static [&'static str] = &[VERSION_TRAIT_NAME]; fn owned_bounds(&self) -> &'static [&'static str] {
&[VERSION_TRAIT_NAME]
}
fn new_ref(orig_type: &DeriveInput) -> syn::Result<Self> { fn new_ref(orig_type: &DeriveInput) -> syn::Result<Self> {
for lt in orig_type.generics.lifetimes() { for lt in orig_type.generics.lifetimes() {
@@ -109,6 +113,10 @@ impl AssociatedType for DispatchType {
&self.kind &self.kind
} }
fn is_transparent(&self) -> bool {
false
}
fn orig_type_generics(&self) -> &Generics { fn orig_type_generics(&self) -> &Generics {
&self.orig_type.generics &self.orig_type.generics
} }

View File

@@ -15,10 +15,12 @@ use crate::associated::{
generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind, generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind,
ConversionDirection, ConversionDirection,
}; };
use crate::versionize_attribute::is_transparent;
use crate::{ use crate::{
add_trait_where_clause, parse_const_str, parse_trait_bound, punctuated_from_iter_result, add_trait_where_clause, parse_const_str, parse_trait_bound, punctuated_from_iter_result,
LIFETIME_NAME, UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, INTO_TRAIT_NAME, LIFETIME_NAME, TRY_INTO_TRAIT_NAME, UNVERSIONIZE_ERROR_NAME,
VERSIONIZE_TRAIT_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, VERSIONIZE_TRAIT_NAME,
VERSION_TRAIT_NAME,
}; };
/// The types generated for a specific version of a given exposed type. These types are identical to /// The types generated for a specific version of a given exposed type. These types are identical to
@@ -27,13 +29,29 @@ use crate::{
pub(crate) struct VersionType { pub(crate) struct VersionType {
orig_type: DeriveInput, orig_type: DeriveInput,
kind: AssociatedTypeKind, kind: AssociatedTypeKind,
is_transparent: bool,
} }
impl AssociatedType for VersionType { impl AssociatedType for VersionType {
const REF_BOUNDS: &'static [&'static str] = &[VERSIONIZE_TRAIT_NAME]; fn ref_bounds(&self) -> &'static [&'static str] {
const OWNED_BOUNDS: &'static [&'static str] = &[VERSIONIZE_OWNED_TRAIT_NAME]; if self.is_transparent {
&[VERSION_TRAIT_NAME]
} else {
&[VERSIONIZE_TRAIT_NAME]
}
}
fn owned_bounds(&self) -> &'static [&'static str] {
if self.is_transparent {
&[VERSION_TRAIT_NAME]
} else {
&[VERSIONIZE_OWNED_TRAIT_NAME]
}
}
fn new_ref(orig_type: &DeriveInput) -> syn::Result<VersionType> { fn new_ref(orig_type: &DeriveInput) -> syn::Result<VersionType> {
let is_transparent = is_transparent(&orig_type.attrs)?;
let lifetime = if is_unit(orig_type) { let lifetime = if is_unit(orig_type) {
None None
} else { } else {
@@ -54,13 +72,17 @@ impl AssociatedType for VersionType {
Ok(Self { Ok(Self {
orig_type: orig_type.clone(), orig_type: orig_type.clone(),
kind: AssociatedTypeKind::Ref(lifetime), kind: AssociatedTypeKind::Ref(lifetime),
is_transparent,
}) })
} }
fn new_owned(orig_type: &DeriveInput) -> syn::Result<Self> { fn new_owned(orig_type: &DeriveInput) -> syn::Result<Self> {
let is_transparent = is_transparent(&orig_type.attrs)?;
Ok(Self { Ok(Self {
orig_type: orig_type.clone(), orig_type: orig_type.clone(),
kind: AssociatedTypeKind::Owned, kind: AssociatedTypeKind::Owned,
is_transparent,
}) })
} }
@@ -191,6 +213,10 @@ impl AssociatedType for VersionType {
&self.kind &self.kind
} }
fn is_transparent(&self) -> bool {
self.is_transparent
}
fn orig_type_generics(&self) -> &Generics { fn orig_type_generics(&self) -> &Generics {
&self.orig_type.generics &self.orig_type.generics
} }
@@ -198,13 +224,15 @@ impl AssociatedType for VersionType {
fn conversion_generics(&self, direction: ConversionDirection) -> syn::Result<Generics> { fn conversion_generics(&self, direction: ConversionDirection) -> syn::Result<Generics> {
let mut generics = self.type_generics()?; let mut generics = self.type_generics()?;
if let ConversionDirection::AssociatedToOrig = direction { if !self.is_transparent {
if let AssociatedTypeKind::Owned = &self.kind { if let ConversionDirection::AssociatedToOrig = direction {
add_trait_where_clause( if let AssociatedTypeKind::Owned = &self.kind {
&mut generics, add_trait_where_clause(
self.inner_types()?, &mut generics,
&[UNVERSIONIZE_TRAIT_NAME], self.inner_types()?,
)?; &[UNVERSIONIZE_TRAIT_NAME],
)?;
}
} }
} }
@@ -323,25 +351,46 @@ impl VersionType {
fields_iter: I, fields_iter: I,
) -> impl IntoIterator<Item = syn::Result<Field>> + 'a { ) -> impl IntoIterator<Item = syn::Result<Field>> + 'a {
let kind = self.kind.clone(); let kind = self.kind.clone();
let is_transparent = self.is_transparent;
fields_iter.into_iter().map(move |field| { fields_iter.into_iter().map(move |field| {
let unver_ty = field.ty.clone(); let unver_ty = field.ty.clone();
let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?; if is_transparent {
let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?; // If the type is transparent, we reuse the "Version" impl of the inner type
let version_trait = parse_trait_bound(VERSION_TRAIT_NAME)?;
let ty: Type = match &kind { let ty: Type = match &kind {
AssociatedTypeKind::Ref(lifetime) => parse_quote! { AssociatedTypeKind::Ref(lifetime) => parse_quote! {
<#unver_ty as #versionize_trait>::Versioned<#lifetime> <#unver_ty as #version_trait>::Ref<#lifetime>
}, },
AssociatedTypeKind::Owned => parse_quote! { AssociatedTypeKind::Owned => parse_quote! {
<#unver_ty as #versionize_owned_trait>::VersionedOwned <#unver_ty as #version_trait>::Owned
}, },
}; };
Ok(Field { Ok(Field {
ty, ty,
..field.clone() ..field.clone()
}) })
} else {
let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?;
let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?;
let ty: Type = match &kind {
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
<#unver_ty as #versionize_trait>::Versioned<#lifetime>
},
AssociatedTypeKind::Owned => parse_quote! {
<#unver_ty as #versionize_owned_trait>::VersionedOwned
},
};
Ok(Field {
ty,
..field.clone()
})
}
}) })
} }
@@ -520,7 +569,11 @@ impl VersionType {
let ty = &field.ty; let ty = &field.ty;
let param = quote! { #arg_ident.#field_ident }; let param = quote! { #arg_ident.#field_ident };
let rhs = self.generate_constructor_field_rhs(ty, param, false, direction)?; let rhs = if self.is_transparent() {
self.generate_constructor_transparent_rhs(param, direction)?
} else {
self.generate_constructor_field_rhs(ty, param, false, direction)?
};
Ok(quote! { Ok(quote! {
#field_ident: #rhs #field_ident: #rhs
@@ -542,12 +595,16 @@ impl VersionType {
.map(move |(arg_name, field)| { .map(move |(arg_name, field)| {
// Ok to unwrap because the field is named so field.ident is Some // Ok to unwrap because the field is named so field.ident is Some
let field_ident = field.ident.as_ref().unwrap(); let field_ident = field.ident.as_ref().unwrap();
let rhs = self.generate_constructor_field_rhs( let rhs = if self.is_transparent() {
&field.ty, self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)?
quote! {#arg_name}, } else {
true, self.generate_constructor_field_rhs(
direction, &field.ty,
)?; quote! {#arg_name},
true,
direction,
)?
};
Ok(quote! { Ok(quote! {
#field_ident: #rhs #field_ident: #rhs
}) })
@@ -596,7 +653,11 @@ impl VersionType {
let ty = &field.ty; let ty = &field.ty;
let param = quote! { #arg_ident.#idx }; let param = quote! { #arg_ident.#idx };
self.generate_constructor_field_rhs(ty, param, false, direction) if self.is_transparent {
self.generate_constructor_transparent_rhs(param, direction)
} else {
self.generate_constructor_field_rhs(ty, param, false, direction)
}
} }
/// Generates the constructor for the fields of an unnamed enum variant. /// Generates the constructor for the fields of an unnamed enum variant.
@@ -612,7 +673,16 @@ impl VersionType {
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields) let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields)
.map(move |(arg_name, field)| { .map(move |(arg_name, field)| {
self.generate_constructor_field_rhs(&field.ty, quote! {#arg_name}, true, direction) if self.is_transparent {
self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)
} else {
self.generate_constructor_field_rhs(
&field.ty,
quote! {#arg_name},
true,
direction,
)
}
}) })
.collect(); .collect();
let fields = fields?; let fields = fields?;
@@ -664,6 +734,41 @@ panic!("No conversion should be generated between associated ref type to origina
}; };
Ok(field_constructor) Ok(field_constructor)
} }
fn generate_constructor_transparent_rhs(
&self,
field_param: TokenStream,
direction: ConversionDirection,
) -> syn::Result<TokenStream> {
let into_trait: Path = parse_const_str(INTO_TRAIT_NAME);
let try_into_trait: Path = parse_const_str(TRY_INTO_TRAIT_NAME);
let field_constructor = match direction {
ConversionDirection::OrigToAssociated => match self.kind {
AssociatedTypeKind::Ref(_) => {
quote! {
#into_trait::into(&#field_param)
}
}
AssociatedTypeKind::Owned => {
quote! {
#into_trait::into(#field_param)
}
}
},
ConversionDirection::AssociatedToOrig => match self.kind {
AssociatedTypeKind::Ref(_) => {
panic!("No conversion should be generated between associated ref type to original type");
}
AssociatedTypeKind::Owned => {
quote! {
#try_into_trait::try_into(#field_param)?
}
}
},
};
Ok(field_constructor)
}
} }
/// Generates a list of argument names. This is used to create a pattern matching of a /// Generates a list of argument names. This is used to create a pattern matching of a

View File

@@ -11,7 +11,7 @@ use syn::{Attribute, Expr, Lit, Meta, Path, Token};
const VERSIONIZE_ATTR_NAME: &str = "versionize"; const VERSIONIZE_ATTR_NAME: &str = "versionize";
/// Transparent mode can also be activated using `#[repr(transparent)]` /// Transparent mode can also be activated using `#[repr(transparent)]`
const REPR_ATTR_NAME: &str = "repr"; pub(crate) const REPR_ATTR_NAME: &str = "repr";
/// Represent the parsed `#[versionize(...)]` attribute /// Represent the parsed `#[versionize(...)]` attribute
pub(crate) enum VersionizeAttribute { pub(crate) enum VersionizeAttribute {
@@ -167,16 +167,14 @@ impl VersionizeAttribute {
.filter(|attr| attr.path().is_ident(VERSIONIZE_ATTR_NAME)) .filter(|attr| attr.path().is_ident(VERSIONIZE_ATTR_NAME))
.collect(); .collect();
let repr_attributes: Vec<&Attribute> = attributes // Check if transparent mode is enabled via repr(transparent). It can also be enabled with
.iter() // the versionize attribute.
.filter(|attr| attr.path().is_ident(REPR_ATTR_NAME)) let type_is_transparent = is_transparent(attributes)?;
.collect();
match version_attributes.as_slice() { match version_attributes.as_slice() {
[] => { [] => {
// transparent mode can also be enabled via `#[repr(transparent)]` if type_is_transparent {
if let Some(attr) = repr_attributes.first() { Ok(Self::Transparent)
Self::parse_from_attribute(attr)
} else { } else {
Err(syn::Error::new( Err(syn::Error::new(
Span::call_site(), Span::call_site(),
@@ -298,3 +296,26 @@ fn parse_path_ignore_quotes(value: &Expr) -> syn::Result<Path> {
)), )),
} }
} }
/// Check if the target type has the `#[repr(transparent)]` attribute in its attributes list
pub(crate) fn is_transparent(attributes: &[Attribute]) -> syn::Result<bool> {
if let Some(attr) = attributes
.iter()
.find(|attr| attr.path().is_ident(REPR_ATTR_NAME))
{
let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
for meta in nested.iter() {
match meta {
Meta::Path(path) => {
if path.is_ident("transparent") {
return Ok(true);
}
}
_ => {}
}
}
}
Ok(false)
}

View File

@@ -48,9 +48,9 @@ enum MyStructVersions<T> {
mod v0 { mod v0 {
use tfhe_versionable::{Versionize, VersionsDispatch}; use tfhe_versionable::{Versionize, VersionsDispatch};
// This struct cannot change as it is not itself versioned. If you ever make a change that // If you ever change the layout of this struct to make it "not transparent", you should create
// should impact the serialized layout of the data, you need to update all the types that use // a MyStructWrapperVersions enum where the first versions are the same than the ones of
// it. // MyStructVersions. See `transparent_then_not.rs` for a full example.
#[derive(Versionize)] #[derive(Versionize)]
#[versionize(transparent)] #[versionize(transparent)]
pub(super) struct MyStructWrapper(pub(super) MyStruct); pub(super) struct MyStructWrapper(pub(super) MyStruct);

View File

@@ -0,0 +1,173 @@
//! This example is similar to the "transparent" one, except that the wrapper type is transparent at
//! a point in time, then converted into its own type that is not transparent.
//!
//! Here we have a type, `MyStructWrapper`, that was a transparent wrapper for `MyStruct` in the v0
//! and v1 of the application. `MyStruct` has been upgraded between v0 and v1. In v2,
//! `MyStructWrapper` was transformed into an enum. Since it was transparent before, it has no
//! history (dispatch enum) before v2.
//!
//! To make this work, we consider that the inner and the wrapper type share the same history up to
//! the version where the transparent attribute has been removed.
use std::convert::Infallible;
use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch};
// This type was transparent before, but it has now been transformed to a full type, for example by
// adding a new kind of metadata.
#[derive(Versionize)]
#[versionize(MyStructWrapperVersions)]
struct MyStructWrapper<T> {
inner: MyStruct<T>,
count: u64,
}
// We need to create a dispatch enum that has the same history as the inner type until the point
// where the wrapper is not transparent anymore.
#[derive(VersionsDispatch)]
#[allow(unused)]
enum MyStructWrapperVersions<T> {
V0(MyStructWrapperV0),
V1(MyStructWrapperV1<T>),
V2(MyStructWrapper<T>),
}
// We copy the upgrade path of the internal struct for the wrapper for the first 2 versions. To do
// that, we recreate the "transparent" `MyStructWrapper` from v0 and v1 and upgrade them by calling
// the upgrade method of the inner type.
#[derive(Version)]
#[repr(transparent)]
struct MyStructWrapperV0(MyStructV0);
impl<T: Default> Upgrade<MyStructWrapperV1<T>> for MyStructWrapperV0 {
type Error = Infallible;
fn upgrade(self) -> Result<MyStructWrapperV1<T>, Self::Error> {
Ok(MyStructWrapperV1(self.0.upgrade()?))
}
}
// Then we define the upgrade from the last transparent version to the first "full" version
#[derive(Version)]
#[repr(transparent)]
struct MyStructWrapperV1<T>(MyStruct<T>);
impl<T> Upgrade<MyStructWrapper<T>> for MyStructWrapperV1<T> {
type Error = Infallible;
fn upgrade(self) -> Result<MyStructWrapper<T>, Self::Error> {
Ok(MyStructWrapper {
inner: self.0,
count: 0,
})
}
}
#[derive(Versionize)]
#[versionize(MyStructVersions)]
struct MyStruct<T> {
attr: T,
builtin: u32,
}
#[derive(Version)]
struct MyStructV0 {
builtin: u32,
}
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
type Error = Infallible;
fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
Ok(MyStruct {
attr: T::default(),
builtin: self.builtin,
})
}
}
#[derive(VersionsDispatch)]
#[allow(unused)]
enum MyStructVersions<T> {
V0(MyStructV0),
V1(MyStruct<T>),
}
// v0 of the app defined the type as a transparent wrapper
mod v0 {
use tfhe_versionable::{Versionize, VersionsDispatch};
#[derive(Versionize)]
#[versionize(transparent)]
pub(super) struct MyStructWrapper(pub(super) MyStruct);
#[derive(Versionize)]
#[versionize(MyStructVersions)]
pub(super) struct MyStruct {
pub(super) builtin: u32,
}
#[derive(VersionsDispatch)]
#[allow(unused)]
pub(super) enum MyStructVersions {
V0(MyStruct),
}
}
// In v1, MyStructWrapper is still transparent but MyStruct got an upgrade compared to v0.
mod v1 {
use std::convert::Infallible;
use tfhe_versionable::{Upgrade, Version, Versionize, VersionsDispatch};
#[derive(Versionize)]
#[repr(transparent)]
struct MyStructWrapper<T>(MyStruct<T>);
#[derive(Versionize)]
#[versionize(MyStructVersions)]
struct MyStruct<T> {
attr: T,
builtin: u32,
}
#[derive(Version)]
struct MyStructV0 {
builtin: u32,
}
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
type Error = Infallible;
fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
Ok(MyStruct {
attr: T::default(),
builtin: self.builtin,
})
}
}
#[derive(VersionsDispatch)]
#[allow(unused)]
enum MyStructVersions<T> {
V0(MyStructV0),
V1(MyStruct<T>),
}
}
fn main() {
let value = 1234;
let ms = v0::MyStructWrapper(v0::MyStruct { builtin: value });
let serialized = bincode::serialize(&ms.versionize()).unwrap();
let unserialized =
MyStructWrapper::<u64>::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
assert_eq!(unserialized.inner.builtin, value)
}
#[test]
fn test() {
main()
}