feat(versionable): "Version" macro now handles transparent attribute

This commit is contained in:
Nicolas Sarlin
2024-12-04 10:27:13 +01:00
committed by Nicolas Sarlin
parent e9c901b3a9
commit 3dcb982a0b
6 changed files with 394 additions and 57 deletions

View File

@@ -94,9 +94,9 @@ pub(crate) enum AssociatedTypeKind {
/// [`VersionType`]: crate::dispatch_type::VersionType
pub(crate) trait AssociatedType: Sized {
/// Bounds that will be added on the fields of the ref type definition
const REF_BOUNDS: &'static [&'static str];
fn ref_bounds(&self) -> &'static [&'static str];
/// Bounds that will be added on the fields of the owned type definition
const OWNED_BOUNDS: &'static [&'static str];
fn owned_bounds(&self) -> &'static [&'static str];
/// This will create the alternative of the type that holds a reference to the underlying data
fn new_ref(orig_type: &DeriveInput) -> syn::Result<Self>;
@@ -109,6 +109,10 @@ pub(crate) trait AssociatedType: Sized {
/// Returns the kind of associated type, a ref or an owned type
fn kind(&self) -> &AssociatedTypeKind;
/// Returns true if the type is transparent and trait implementation is actually deferred to the
/// inner type
fn is_transparent(&self) -> bool;
/// Returns the generics found in the original type definition
fn orig_type_generics(&self) -> &Generics;
@@ -119,9 +123,9 @@ pub(crate) trait AssociatedType: Sized {
if let Some(lifetime) = opt_lifetime {
add_lifetime_param(&mut generics, lifetime);
}
add_trait_where_clause(&mut generics, self.inner_types()?, Self::REF_BOUNDS)?;
add_trait_where_clause(&mut generics, self.inner_types()?, self.ref_bounds())?;
} else {
add_trait_where_clause(&mut generics, self.inner_types()?, Self::OWNED_BOUNDS)?;
add_trait_where_clause(&mut generics, self.inner_types()?, self.owned_bounds())?;
}
Ok(generics)
@@ -254,14 +258,27 @@ impl<T: AssociatedType> AssociatingTrait<T> {
)
]};
let owned_attributes = if self.owned_type.is_transparent() {
quote! {
#[derive(#serialize_trait, #deserialize_trait)]
#[repr(transparent)]
#[serde(bound = "")]
#ignored_lints
}
} else {
quote! {
#[derive(#serialize_trait, #deserialize_trait)]
#[serde(bound = "")]
#ignored_lints
}
};
// Creates the type declaration. These types are the output of the versioning process, so
// they should be serializable. Serde might try to add automatic bounds on the type generics
// even if we don't need them, so we use `#[serde(bound = "")]` to disable this. The bounds
// on the generated types should be sufficient.
let owned_tokens = quote! {
#[derive(#serialize_trait, #deserialize_trait)]
#[serde(bound = "")]
#ignored_lints
#owned_attributes
#owned_decla
#(#owned_conversion)*
@@ -271,10 +288,23 @@ impl<T: AssociatedType> AssociatingTrait<T> {
let ref_conversion = self.ref_type.generate_conversion()?;
let ref_attributes = if self.ref_type.is_transparent() {
quote! {
#[derive(#serialize_trait)]
#[repr(transparent)]
#[serde(bound = "")]
#ignored_lints
}
} else {
quote! {
#[derive(#serialize_trait)]
#[serde(bound = "")]
#ignored_lints
}
};
let ref_tokens = quote! {
#[derive(#serialize_trait)]
#[serde(bound = "")]
#ignored_lints
#ref_attributes
#ref_decla
#(#ref_conversion)*

View File

@@ -47,9 +47,13 @@ fn derive_input_to_enum(input: &DeriveInput) -> syn::Result<ItemEnum> {
}
impl AssociatedType for DispatchType {
const REF_BOUNDS: &'static [&'static str] = &[VERSION_TRAIT_NAME];
fn ref_bounds(&self) -> &'static [&'static str] {
&[VERSION_TRAIT_NAME]
}
const OWNED_BOUNDS: &'static [&'static str] = &[VERSION_TRAIT_NAME];
fn owned_bounds(&self) -> &'static [&'static str] {
&[VERSION_TRAIT_NAME]
}
fn new_ref(orig_type: &DeriveInput) -> syn::Result<Self> {
for lt in orig_type.generics.lifetimes() {
@@ -109,6 +113,10 @@ impl AssociatedType for DispatchType {
&self.kind
}
fn is_transparent(&self) -> bool {
false
}
fn orig_type_generics(&self) -> &Generics {
&self.orig_type.generics
}

View File

@@ -15,10 +15,12 @@ use crate::associated::{
generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind,
ConversionDirection,
};
use crate::versionize_attribute::is_transparent;
use crate::{
add_trait_where_clause, parse_const_str, parse_trait_bound, punctuated_from_iter_result,
LIFETIME_NAME, UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME,
VERSIONIZE_TRAIT_NAME,
INTO_TRAIT_NAME, LIFETIME_NAME, TRY_INTO_TRAIT_NAME, UNVERSIONIZE_ERROR_NAME,
UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, VERSIONIZE_TRAIT_NAME,
VERSION_TRAIT_NAME,
};
/// The types generated for a specific version of a given exposed type. These types are identical to
@@ -27,13 +29,29 @@ use crate::{
pub(crate) struct VersionType {
orig_type: DeriveInput,
kind: AssociatedTypeKind,
is_transparent: bool,
}
impl AssociatedType for VersionType {
const REF_BOUNDS: &'static [&'static str] = &[VERSIONIZE_TRAIT_NAME];
const OWNED_BOUNDS: &'static [&'static str] = &[VERSIONIZE_OWNED_TRAIT_NAME];
fn ref_bounds(&self) -> &'static [&'static str] {
if self.is_transparent {
&[VERSION_TRAIT_NAME]
} else {
&[VERSIONIZE_TRAIT_NAME]
}
}
fn owned_bounds(&self) -> &'static [&'static str] {
if self.is_transparent {
&[VERSION_TRAIT_NAME]
} else {
&[VERSIONIZE_OWNED_TRAIT_NAME]
}
}
fn new_ref(orig_type: &DeriveInput) -> syn::Result<VersionType> {
let is_transparent = is_transparent(&orig_type.attrs)?;
let lifetime = if is_unit(orig_type) {
None
} else {
@@ -54,13 +72,17 @@ impl AssociatedType for VersionType {
Ok(Self {
orig_type: orig_type.clone(),
kind: AssociatedTypeKind::Ref(lifetime),
is_transparent,
})
}
fn new_owned(orig_type: &DeriveInput) -> syn::Result<Self> {
let is_transparent = is_transparent(&orig_type.attrs)?;
Ok(Self {
orig_type: orig_type.clone(),
kind: AssociatedTypeKind::Owned,
is_transparent,
})
}
@@ -191,6 +213,10 @@ impl AssociatedType for VersionType {
&self.kind
}
fn is_transparent(&self) -> bool {
self.is_transparent
}
fn orig_type_generics(&self) -> &Generics {
&self.orig_type.generics
}
@@ -198,13 +224,15 @@ impl AssociatedType for VersionType {
fn conversion_generics(&self, direction: ConversionDirection) -> syn::Result<Generics> {
let mut generics = self.type_generics()?;
if let ConversionDirection::AssociatedToOrig = direction {
if let AssociatedTypeKind::Owned = &self.kind {
add_trait_where_clause(
&mut generics,
self.inner_types()?,
&[UNVERSIONIZE_TRAIT_NAME],
)?;
if !self.is_transparent {
if let ConversionDirection::AssociatedToOrig = direction {
if let AssociatedTypeKind::Owned = &self.kind {
add_trait_where_clause(
&mut generics,
self.inner_types()?,
&[UNVERSIONIZE_TRAIT_NAME],
)?;
}
}
}
@@ -323,25 +351,46 @@ impl VersionType {
fields_iter: I,
) -> impl IntoIterator<Item = syn::Result<Field>> + 'a {
let kind = self.kind.clone();
let is_transparent = self.is_transparent;
fields_iter.into_iter().map(move |field| {
let unver_ty = field.ty.clone();
let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?;
let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?;
if is_transparent {
// If the type is transparent, we reuse the "Version" impl of the inner type
let version_trait = parse_trait_bound(VERSION_TRAIT_NAME)?;
let ty: Type = match &kind {
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
<#unver_ty as #versionize_trait>::Versioned<#lifetime>
},
AssociatedTypeKind::Owned => parse_quote! {
<#unver_ty as #versionize_owned_trait>::VersionedOwned
},
};
let ty: Type = match &kind {
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
<#unver_ty as #version_trait>::Ref<#lifetime>
},
AssociatedTypeKind::Owned => parse_quote! {
<#unver_ty as #version_trait>::Owned
},
};
Ok(Field {
ty,
..field.clone()
})
Ok(Field {
ty,
..field.clone()
})
} else {
let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?;
let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?;
let ty: Type = match &kind {
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
<#unver_ty as #versionize_trait>::Versioned<#lifetime>
},
AssociatedTypeKind::Owned => parse_quote! {
<#unver_ty as #versionize_owned_trait>::VersionedOwned
},
};
Ok(Field {
ty,
..field.clone()
})
}
})
}
@@ -520,7 +569,11 @@ impl VersionType {
let ty = &field.ty;
let param = quote! { #arg_ident.#field_ident };
let rhs = self.generate_constructor_field_rhs(ty, param, false, direction)?;
let rhs = if self.is_transparent() {
self.generate_constructor_transparent_rhs(param, direction)?
} else {
self.generate_constructor_field_rhs(ty, param, false, direction)?
};
Ok(quote! {
#field_ident: #rhs
@@ -542,12 +595,16 @@ impl VersionType {
.map(move |(arg_name, field)| {
// Ok to unwrap because the field is named so field.ident is Some
let field_ident = field.ident.as_ref().unwrap();
let rhs = self.generate_constructor_field_rhs(
&field.ty,
quote! {#arg_name},
true,
direction,
)?;
let rhs = if self.is_transparent() {
self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)?
} else {
self.generate_constructor_field_rhs(
&field.ty,
quote! {#arg_name},
true,
direction,
)?
};
Ok(quote! {
#field_ident: #rhs
})
@@ -596,7 +653,11 @@ impl VersionType {
let ty = &field.ty;
let param = quote! { #arg_ident.#idx };
self.generate_constructor_field_rhs(ty, param, false, direction)
if self.is_transparent {
self.generate_constructor_transparent_rhs(param, direction)
} else {
self.generate_constructor_field_rhs(ty, param, false, direction)
}
}
/// Generates the constructor for the fields of an unnamed enum variant.
@@ -612,7 +673,16 @@ impl VersionType {
) -> syn::Result<TokenStream> {
let fields: syn::Result<Vec<TokenStream>> = zip(arg_names, fields)
.map(move |(arg_name, field)| {
self.generate_constructor_field_rhs(&field.ty, quote! {#arg_name}, true, direction)
if self.is_transparent {
self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)
} else {
self.generate_constructor_field_rhs(
&field.ty,
quote! {#arg_name},
true,
direction,
)
}
})
.collect();
let fields = fields?;
@@ -664,6 +734,41 @@ panic!("No conversion should be generated between associated ref type to origina
};
Ok(field_constructor)
}
fn generate_constructor_transparent_rhs(
&self,
field_param: TokenStream,
direction: ConversionDirection,
) -> syn::Result<TokenStream> {
let into_trait: Path = parse_const_str(INTO_TRAIT_NAME);
let try_into_trait: Path = parse_const_str(TRY_INTO_TRAIT_NAME);
let field_constructor = match direction {
ConversionDirection::OrigToAssociated => match self.kind {
AssociatedTypeKind::Ref(_) => {
quote! {
#into_trait::into(&#field_param)
}
}
AssociatedTypeKind::Owned => {
quote! {
#into_trait::into(#field_param)
}
}
},
ConversionDirection::AssociatedToOrig => match self.kind {
AssociatedTypeKind::Ref(_) => {
panic!("No conversion should be generated between associated ref type to original type");
}
AssociatedTypeKind::Owned => {
quote! {
#try_into_trait::try_into(#field_param)?
}
}
},
};
Ok(field_constructor)
}
}
/// Generates a list of argument names. This is used to create a pattern matching of a

View File

@@ -11,7 +11,7 @@ use syn::{Attribute, Expr, Lit, Meta, Path, Token};
const VERSIONIZE_ATTR_NAME: &str = "versionize";
/// Transparent mode can also be activated using `#[repr(transparent)]`
const REPR_ATTR_NAME: &str = "repr";
pub(crate) const REPR_ATTR_NAME: &str = "repr";
/// Represent the parsed `#[versionize(...)]` attribute
pub(crate) enum VersionizeAttribute {
@@ -167,16 +167,14 @@ impl VersionizeAttribute {
.filter(|attr| attr.path().is_ident(VERSIONIZE_ATTR_NAME))
.collect();
let repr_attributes: Vec<&Attribute> = attributes
.iter()
.filter(|attr| attr.path().is_ident(REPR_ATTR_NAME))
.collect();
// Check if transparent mode is enabled via repr(transparent). It can also be enabled with
// the versionize attribute.
let type_is_transparent = is_transparent(attributes)?;
match version_attributes.as_slice() {
[] => {
// transparent mode can also be enabled via `#[repr(transparent)]`
if let Some(attr) = repr_attributes.first() {
Self::parse_from_attribute(attr)
if type_is_transparent {
Ok(Self::Transparent)
} else {
Err(syn::Error::new(
Span::call_site(),
@@ -298,3 +296,26 @@ fn parse_path_ignore_quotes(value: &Expr) -> syn::Result<Path> {
)),
}
}
/// Check if the target type has the `#[repr(transparent)]` attribute in its attributes list
pub(crate) fn is_transparent(attributes: &[Attribute]) -> syn::Result<bool> {
if let Some(attr) = attributes
.iter()
.find(|attr| attr.path().is_ident(REPR_ATTR_NAME))
{
let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
for meta in nested.iter() {
match meta {
Meta::Path(path) => {
if path.is_ident("transparent") {
return Ok(true);
}
}
_ => {}
}
}
}
Ok(false)
}

View File

@@ -48,9 +48,9 @@ enum MyStructVersions<T> {
mod v0 {
use tfhe_versionable::{Versionize, VersionsDispatch};
// This struct cannot change as it is not itself versioned. If you ever make a change that
// should impact the serialized layout of the data, you need to update all the types that use
// it.
// If you ever change the layout of this struct to make it "not transparent", you should create
// a MyStructWrapperVersions enum where the first versions are the same than the ones of
// MyStructVersions. See `transparent_then_not.rs` for a full example.
#[derive(Versionize)]
#[versionize(transparent)]
pub(super) struct MyStructWrapper(pub(super) MyStruct);

View File

@@ -0,0 +1,173 @@
//! This example is similar to the "transparent" one, except that the wrapper type is transparent at
//! a point in time, then converted into its own type that is not transparent.
//!
//! Here we have a type, `MyStructWrapper`, that was a transparent wrapper for `MyStruct` in the v0
//! and v1 of the application. `MyStruct` has been upgraded between v0 and v1. In v2,
//! `MyStructWrapper` was transformed into an enum. Since it was transparent before, it has no
//! history (dispatch enum) before v2.
//!
//! To make this work, we consider that the inner and the wrapper type share the same history up to
//! the version where the transparent attribute has been removed.
use std::convert::Infallible;
use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch};
// This type was transparent before, but it has now been transformed to a full type, for example by
// adding a new kind of metadata.
#[derive(Versionize)]
#[versionize(MyStructWrapperVersions)]
struct MyStructWrapper<T> {
inner: MyStruct<T>,
count: u64,
}
// We need to create a dispatch enum that has the same history as the inner type until the point
// where the wrapper is not transparent anymore.
#[derive(VersionsDispatch)]
#[allow(unused)]
enum MyStructWrapperVersions<T> {
V0(MyStructWrapperV0),
V1(MyStructWrapperV1<T>),
V2(MyStructWrapper<T>),
}
// We copy the upgrade path of the internal struct for the wrapper for the first 2 versions. To do
// that, we recreate the "transparent" `MyStructWrapper` from v0 and v1 and upgrade them by calling
// the upgrade method of the inner type.
#[derive(Version)]
#[repr(transparent)]
struct MyStructWrapperV0(MyStructV0);
impl<T: Default> Upgrade<MyStructWrapperV1<T>> for MyStructWrapperV0 {
type Error = Infallible;
fn upgrade(self) -> Result<MyStructWrapperV1<T>, Self::Error> {
Ok(MyStructWrapperV1(self.0.upgrade()?))
}
}
// Then we define the upgrade from the last transparent version to the first "full" version
#[derive(Version)]
#[repr(transparent)]
struct MyStructWrapperV1<T>(MyStruct<T>);
impl<T> Upgrade<MyStructWrapper<T>> for MyStructWrapperV1<T> {
type Error = Infallible;
fn upgrade(self) -> Result<MyStructWrapper<T>, Self::Error> {
Ok(MyStructWrapper {
inner: self.0,
count: 0,
})
}
}
#[derive(Versionize)]
#[versionize(MyStructVersions)]
struct MyStruct<T> {
attr: T,
builtin: u32,
}
#[derive(Version)]
struct MyStructV0 {
builtin: u32,
}
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
type Error = Infallible;
fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
Ok(MyStruct {
attr: T::default(),
builtin: self.builtin,
})
}
}
#[derive(VersionsDispatch)]
#[allow(unused)]
enum MyStructVersions<T> {
V0(MyStructV0),
V1(MyStruct<T>),
}
// v0 of the app defined the type as a transparent wrapper
mod v0 {
use tfhe_versionable::{Versionize, VersionsDispatch};
#[derive(Versionize)]
#[versionize(transparent)]
pub(super) struct MyStructWrapper(pub(super) MyStruct);
#[derive(Versionize)]
#[versionize(MyStructVersions)]
pub(super) struct MyStruct {
pub(super) builtin: u32,
}
#[derive(VersionsDispatch)]
#[allow(unused)]
pub(super) enum MyStructVersions {
V0(MyStruct),
}
}
// In v1, MyStructWrapper is still transparent but MyStruct got an upgrade compared to v0.
mod v1 {
use std::convert::Infallible;
use tfhe_versionable::{Upgrade, Version, Versionize, VersionsDispatch};
#[derive(Versionize)]
#[repr(transparent)]
struct MyStructWrapper<T>(MyStruct<T>);
#[derive(Versionize)]
#[versionize(MyStructVersions)]
struct MyStruct<T> {
attr: T,
builtin: u32,
}
#[derive(Version)]
struct MyStructV0 {
builtin: u32,
}
impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
type Error = Infallible;
fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
Ok(MyStruct {
attr: T::default(),
builtin: self.builtin,
})
}
}
#[derive(VersionsDispatch)]
#[allow(unused)]
enum MyStructVersions<T> {
V0(MyStructV0),
V1(MyStruct<T>),
}
}
fn main() {
let value = 1234;
let ms = v0::MyStructWrapper(v0::MyStruct { builtin: value });
let serialized = bincode::serialize(&ms.versionize()).unwrap();
let unserialized =
MyStructWrapper::<u64>::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();
assert_eq!(unserialized.inner.builtin, value)
}
#[test]
fn test() {
main()
}