diff --git a/utils/tfhe-versionable-derive/src/associated.rs b/utils/tfhe-versionable-derive/src/associated.rs index d9e86e323..a07842e00 100644 --- a/utils/tfhe-versionable-derive/src/associated.rs +++ b/utils/tfhe-versionable-derive/src/associated.rs @@ -94,9 +94,9 @@ pub(crate) enum AssociatedTypeKind { /// [`VersionType`]: crate::dispatch_type::VersionType pub(crate) trait AssociatedType: Sized { /// Bounds that will be added on the fields of the ref type definition - const REF_BOUNDS: &'static [&'static str]; + fn ref_bounds(&self) -> &'static [&'static str]; /// Bounds that will be added on the fields of the owned type definition - const OWNED_BOUNDS: &'static [&'static str]; + fn owned_bounds(&self) -> &'static [&'static str]; /// This will create the alternative of the type that holds a reference to the underlying data fn new_ref(orig_type: &DeriveInput) -> syn::Result; @@ -109,6 +109,10 @@ pub(crate) trait AssociatedType: Sized { /// Returns the kind of associated type, a ref or an owned type fn kind(&self) -> &AssociatedTypeKind; + /// Returns true if the type is transparent and trait implementation is actually deferred to the + /// inner type + fn is_transparent(&self) -> bool; + /// Returns the generics found in the original type definition fn orig_type_generics(&self) -> &Generics; @@ -119,9 +123,9 @@ pub(crate) trait AssociatedType: Sized { if let Some(lifetime) = opt_lifetime { add_lifetime_param(&mut generics, lifetime); } - add_trait_where_clause(&mut generics, self.inner_types()?, Self::REF_BOUNDS)?; + add_trait_where_clause(&mut generics, self.inner_types()?, self.ref_bounds())?; } else { - add_trait_where_clause(&mut generics, self.inner_types()?, Self::OWNED_BOUNDS)?; + add_trait_where_clause(&mut generics, self.inner_types()?, self.owned_bounds())?; } Ok(generics) @@ -254,14 +258,27 @@ impl AssociatingTrait { ) ]}; + let owned_attributes = if self.owned_type.is_transparent() { + quote! { + #[derive(#serialize_trait, #deserialize_trait)] + #[repr(transparent)] + #[serde(bound = "")] + #ignored_lints + } + } else { + quote! { + #[derive(#serialize_trait, #deserialize_trait)] + #[serde(bound = "")] + #ignored_lints + } + }; + // Creates the type declaration. These types are the output of the versioning process, so // they should be serializable. Serde might try to add automatic bounds on the type generics // even if we don't need them, so we use `#[serde(bound = "")]` to disable this. The bounds // on the generated types should be sufficient. let owned_tokens = quote! { - #[derive(#serialize_trait, #deserialize_trait)] - #[serde(bound = "")] - #ignored_lints + #owned_attributes #owned_decla #(#owned_conversion)* @@ -271,10 +288,23 @@ impl AssociatingTrait { let ref_conversion = self.ref_type.generate_conversion()?; + let ref_attributes = if self.ref_type.is_transparent() { + quote! { + #[derive(#serialize_trait)] + #[repr(transparent)] + #[serde(bound = "")] + #ignored_lints + } + } else { + quote! { + #[derive(#serialize_trait)] + #[serde(bound = "")] + #ignored_lints + } + }; + let ref_tokens = quote! { - #[derive(#serialize_trait)] - #[serde(bound = "")] - #ignored_lints + #ref_attributes #ref_decla #(#ref_conversion)* diff --git a/utils/tfhe-versionable-derive/src/dispatch_type.rs b/utils/tfhe-versionable-derive/src/dispatch_type.rs index b2f936362..c47572afb 100644 --- a/utils/tfhe-versionable-derive/src/dispatch_type.rs +++ b/utils/tfhe-versionable-derive/src/dispatch_type.rs @@ -47,9 +47,13 @@ fn derive_input_to_enum(input: &DeriveInput) -> syn::Result { } impl AssociatedType for DispatchType { - const REF_BOUNDS: &'static [&'static str] = &[VERSION_TRAIT_NAME]; + fn ref_bounds(&self) -> &'static [&'static str] { + &[VERSION_TRAIT_NAME] + } - const OWNED_BOUNDS: &'static [&'static str] = &[VERSION_TRAIT_NAME]; + fn owned_bounds(&self) -> &'static [&'static str] { + &[VERSION_TRAIT_NAME] + } fn new_ref(orig_type: &DeriveInput) -> syn::Result { for lt in orig_type.generics.lifetimes() { @@ -109,6 +113,10 @@ impl AssociatedType for DispatchType { &self.kind } + fn is_transparent(&self) -> bool { + false + } + fn orig_type_generics(&self) -> &Generics { &self.orig_type.generics } diff --git a/utils/tfhe-versionable-derive/src/version_type.rs b/utils/tfhe-versionable-derive/src/version_type.rs index 672587886..dc4d745b9 100644 --- a/utils/tfhe-versionable-derive/src/version_type.rs +++ b/utils/tfhe-versionable-derive/src/version_type.rs @@ -15,10 +15,12 @@ use crate::associated::{ generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind, ConversionDirection, }; +use crate::versionize_attribute::is_transparent; use crate::{ add_trait_where_clause, parse_const_str, parse_trait_bound, punctuated_from_iter_result, - LIFETIME_NAME, UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, - VERSIONIZE_TRAIT_NAME, + INTO_TRAIT_NAME, LIFETIME_NAME, TRY_INTO_TRAIT_NAME, UNVERSIONIZE_ERROR_NAME, + UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, VERSIONIZE_TRAIT_NAME, + VERSION_TRAIT_NAME, }; /// The types generated for a specific version of a given exposed type. These types are identical to @@ -27,13 +29,29 @@ use crate::{ pub(crate) struct VersionType { orig_type: DeriveInput, kind: AssociatedTypeKind, + is_transparent: bool, } impl AssociatedType for VersionType { - const REF_BOUNDS: &'static [&'static str] = &[VERSIONIZE_TRAIT_NAME]; - const OWNED_BOUNDS: &'static [&'static str] = &[VERSIONIZE_OWNED_TRAIT_NAME]; + fn ref_bounds(&self) -> &'static [&'static str] { + if self.is_transparent { + &[VERSION_TRAIT_NAME] + } else { + &[VERSIONIZE_TRAIT_NAME] + } + } + + fn owned_bounds(&self) -> &'static [&'static str] { + if self.is_transparent { + &[VERSION_TRAIT_NAME] + } else { + &[VERSIONIZE_OWNED_TRAIT_NAME] + } + } fn new_ref(orig_type: &DeriveInput) -> syn::Result { + let is_transparent = is_transparent(&orig_type.attrs)?; + let lifetime = if is_unit(orig_type) { None } else { @@ -54,13 +72,17 @@ impl AssociatedType for VersionType { Ok(Self { orig_type: orig_type.clone(), kind: AssociatedTypeKind::Ref(lifetime), + is_transparent, }) } fn new_owned(orig_type: &DeriveInput) -> syn::Result { + let is_transparent = is_transparent(&orig_type.attrs)?; + Ok(Self { orig_type: orig_type.clone(), kind: AssociatedTypeKind::Owned, + is_transparent, }) } @@ -191,6 +213,10 @@ impl AssociatedType for VersionType { &self.kind } + fn is_transparent(&self) -> bool { + self.is_transparent + } + fn orig_type_generics(&self) -> &Generics { &self.orig_type.generics } @@ -198,13 +224,15 @@ impl AssociatedType for VersionType { fn conversion_generics(&self, direction: ConversionDirection) -> syn::Result { let mut generics = self.type_generics()?; - if let ConversionDirection::AssociatedToOrig = direction { - if let AssociatedTypeKind::Owned = &self.kind { - add_trait_where_clause( - &mut generics, - self.inner_types()?, - &[UNVERSIONIZE_TRAIT_NAME], - )?; + if !self.is_transparent { + if let ConversionDirection::AssociatedToOrig = direction { + if let AssociatedTypeKind::Owned = &self.kind { + add_trait_where_clause( + &mut generics, + self.inner_types()?, + &[UNVERSIONIZE_TRAIT_NAME], + )?; + } } } @@ -323,25 +351,46 @@ impl VersionType { fields_iter: I, ) -> impl IntoIterator> + 'a { let kind = self.kind.clone(); + let is_transparent = self.is_transparent; + fields_iter.into_iter().map(move |field| { let unver_ty = field.ty.clone(); - let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?; - let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?; + if is_transparent { + // If the type is transparent, we reuse the "Version" impl of the inner type + let version_trait = parse_trait_bound(VERSION_TRAIT_NAME)?; - let ty: Type = match &kind { - AssociatedTypeKind::Ref(lifetime) => parse_quote! { - <#unver_ty as #versionize_trait>::Versioned<#lifetime> - }, - AssociatedTypeKind::Owned => parse_quote! { - <#unver_ty as #versionize_owned_trait>::VersionedOwned - }, - }; + let ty: Type = match &kind { + AssociatedTypeKind::Ref(lifetime) => parse_quote! { + <#unver_ty as #version_trait>::Ref<#lifetime> + }, + AssociatedTypeKind::Owned => parse_quote! { + <#unver_ty as #version_trait>::Owned + }, + }; - Ok(Field { - ty, - ..field.clone() - }) + Ok(Field { + ty, + ..field.clone() + }) + } else { + let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?; + let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?; + + let ty: Type = match &kind { + AssociatedTypeKind::Ref(lifetime) => parse_quote! { + <#unver_ty as #versionize_trait>::Versioned<#lifetime> + }, + AssociatedTypeKind::Owned => parse_quote! { + <#unver_ty as #versionize_owned_trait>::VersionedOwned + }, + }; + + Ok(Field { + ty, + ..field.clone() + }) + } }) } @@ -520,7 +569,11 @@ impl VersionType { let ty = &field.ty; let param = quote! { #arg_ident.#field_ident }; - let rhs = self.generate_constructor_field_rhs(ty, param, false, direction)?; + let rhs = if self.is_transparent() { + self.generate_constructor_transparent_rhs(param, direction)? + } else { + self.generate_constructor_field_rhs(ty, param, false, direction)? + }; Ok(quote! { #field_ident: #rhs @@ -542,12 +595,16 @@ impl VersionType { .map(move |(arg_name, field)| { // Ok to unwrap because the field is named so field.ident is Some let field_ident = field.ident.as_ref().unwrap(); - let rhs = self.generate_constructor_field_rhs( - &field.ty, - quote! {#arg_name}, - true, - direction, - )?; + let rhs = if self.is_transparent() { + self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction)? + } else { + self.generate_constructor_field_rhs( + &field.ty, + quote! {#arg_name}, + true, + direction, + )? + }; Ok(quote! { #field_ident: #rhs }) @@ -596,7 +653,11 @@ impl VersionType { let ty = &field.ty; let param = quote! { #arg_ident.#idx }; - self.generate_constructor_field_rhs(ty, param, false, direction) + if self.is_transparent { + self.generate_constructor_transparent_rhs(param, direction) + } else { + self.generate_constructor_field_rhs(ty, param, false, direction) + } } /// Generates the constructor for the fields of an unnamed enum variant. @@ -612,7 +673,16 @@ impl VersionType { ) -> syn::Result { let fields: syn::Result> = zip(arg_names, fields) .map(move |(arg_name, field)| { - self.generate_constructor_field_rhs(&field.ty, quote! {#arg_name}, true, direction) + if self.is_transparent { + self.generate_constructor_transparent_rhs(quote! {#arg_name}, direction) + } else { + self.generate_constructor_field_rhs( + &field.ty, + quote! {#arg_name}, + true, + direction, + ) + } }) .collect(); let fields = fields?; @@ -664,6 +734,41 @@ panic!("No conversion should be generated between associated ref type to origina }; Ok(field_constructor) } + + fn generate_constructor_transparent_rhs( + &self, + field_param: TokenStream, + direction: ConversionDirection, + ) -> syn::Result { + let into_trait: Path = parse_const_str(INTO_TRAIT_NAME); + let try_into_trait: Path = parse_const_str(TRY_INTO_TRAIT_NAME); + + let field_constructor = match direction { + ConversionDirection::OrigToAssociated => match self.kind { + AssociatedTypeKind::Ref(_) => { + quote! { + #into_trait::into(&#field_param) + } + } + AssociatedTypeKind::Owned => { + quote! { + #into_trait::into(#field_param) + } + } + }, + ConversionDirection::AssociatedToOrig => match self.kind { + AssociatedTypeKind::Ref(_) => { + panic!("No conversion should be generated between associated ref type to original type"); + } + AssociatedTypeKind::Owned => { + quote! { + #try_into_trait::try_into(#field_param)? + } + } + }, + }; + Ok(field_constructor) + } } /// Generates a list of argument names. This is used to create a pattern matching of a diff --git a/utils/tfhe-versionable-derive/src/versionize_attribute.rs b/utils/tfhe-versionable-derive/src/versionize_attribute.rs index f077f3b27..4629a4a85 100644 --- a/utils/tfhe-versionable-derive/src/versionize_attribute.rs +++ b/utils/tfhe-versionable-derive/src/versionize_attribute.rs @@ -11,7 +11,7 @@ use syn::{Attribute, Expr, Lit, Meta, Path, Token}; const VERSIONIZE_ATTR_NAME: &str = "versionize"; /// Transparent mode can also be activated using `#[repr(transparent)]` -const REPR_ATTR_NAME: &str = "repr"; +pub(crate) const REPR_ATTR_NAME: &str = "repr"; /// Represent the parsed `#[versionize(...)]` attribute pub(crate) enum VersionizeAttribute { @@ -167,16 +167,14 @@ impl VersionizeAttribute { .filter(|attr| attr.path().is_ident(VERSIONIZE_ATTR_NAME)) .collect(); - let repr_attributes: Vec<&Attribute> = attributes - .iter() - .filter(|attr| attr.path().is_ident(REPR_ATTR_NAME)) - .collect(); + // Check if transparent mode is enabled via repr(transparent). It can also be enabled with + // the versionize attribute. + let type_is_transparent = is_transparent(attributes)?; match version_attributes.as_slice() { [] => { - // transparent mode can also be enabled via `#[repr(transparent)]` - if let Some(attr) = repr_attributes.first() { - Self::parse_from_attribute(attr) + if type_is_transparent { + Ok(Self::Transparent) } else { Err(syn::Error::new( Span::call_site(), @@ -298,3 +296,26 @@ fn parse_path_ignore_quotes(value: &Expr) -> syn::Result { )), } } + +/// Check if the target type has the `#[repr(transparent)]` attribute in its attributes list +pub(crate) fn is_transparent(attributes: &[Attribute]) -> syn::Result { + if let Some(attr) = attributes + .iter() + .find(|attr| attr.path().is_ident(REPR_ATTR_NAME)) + { + let nested = attr.parse_args_with(Punctuated::::parse_terminated)?; + + for meta in nested.iter() { + match meta { + Meta::Path(path) => { + if path.is_ident("transparent") { + return Ok(true); + } + } + _ => {} + } + } + } + + Ok(false) +} diff --git a/utils/tfhe-versionable/examples/transparent.rs b/utils/tfhe-versionable/examples/transparent.rs index 007b35539..f2869dc9c 100644 --- a/utils/tfhe-versionable/examples/transparent.rs +++ b/utils/tfhe-versionable/examples/transparent.rs @@ -48,9 +48,9 @@ enum MyStructVersions { mod v0 { use tfhe_versionable::{Versionize, VersionsDispatch}; - // This struct cannot change as it is not itself versioned. If you ever make a change that - // should impact the serialized layout of the data, you need to update all the types that use - // it. + // If you ever change the layout of this struct to make it "not transparent", you should create + // a MyStructWrapperVersions enum where the first versions are the same than the ones of + // MyStructVersions. See `transparent_then_not.rs` for a full example. #[derive(Versionize)] #[versionize(transparent)] pub(super) struct MyStructWrapper(pub(super) MyStruct); diff --git a/utils/tfhe-versionable/examples/transparent_then_not.rs b/utils/tfhe-versionable/examples/transparent_then_not.rs new file mode 100644 index 000000000..b01f4d052 --- /dev/null +++ b/utils/tfhe-versionable/examples/transparent_then_not.rs @@ -0,0 +1,173 @@ +//! This example is similar to the "transparent" one, except that the wrapper type is transparent at +//! a point in time, then converted into its own type that is not transparent. +//! +//! Here we have a type, `MyStructWrapper`, that was a transparent wrapper for `MyStruct` in the v0 +//! and v1 of the application. `MyStruct` has been upgraded between v0 and v1. In v2, +//! `MyStructWrapper` was transformed into an enum. Since it was transparent before, it has no +//! history (dispatch enum) before v2. +//! +//! To make this work, we consider that the inner and the wrapper type share the same history up to +//! the version where the transparent attribute has been removed. + +use std::convert::Infallible; + +use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch}; + +// This type was transparent before, but it has now been transformed to a full type, for example by +// adding a new kind of metadata. +#[derive(Versionize)] +#[versionize(MyStructWrapperVersions)] +struct MyStructWrapper { + inner: MyStruct, + count: u64, +} + +// We need to create a dispatch enum that has the same history as the inner type until the point +// where the wrapper is not transparent anymore. +#[derive(VersionsDispatch)] +#[allow(unused)] +enum MyStructWrapperVersions { + V0(MyStructWrapperV0), + V1(MyStructWrapperV1), + V2(MyStructWrapper), +} + +// We copy the upgrade path of the internal struct for the wrapper for the first 2 versions. To do +// that, we recreate the "transparent" `MyStructWrapper` from v0 and v1 and upgrade them by calling +// the upgrade method of the inner type. +#[derive(Version)] +#[repr(transparent)] +struct MyStructWrapperV0(MyStructV0); + +impl Upgrade> for MyStructWrapperV0 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(MyStructWrapperV1(self.0.upgrade()?)) + } +} + +// Then we define the upgrade from the last transparent version to the first "full" version +#[derive(Version)] +#[repr(transparent)] +struct MyStructWrapperV1(MyStruct); + +impl Upgrade> for MyStructWrapperV1 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(MyStructWrapper { + inner: self.0, + count: 0, + }) + } +} + +#[derive(Versionize)] +#[versionize(MyStructVersions)] +struct MyStruct { + attr: T, + builtin: u32, +} + +#[derive(Version)] +struct MyStructV0 { + builtin: u32, +} + +impl Upgrade> for MyStructV0 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(MyStruct { + attr: T::default(), + builtin: self.builtin, + }) + } +} + +#[derive(VersionsDispatch)] +#[allow(unused)] +enum MyStructVersions { + V0(MyStructV0), + V1(MyStruct), +} + +// v0 of the app defined the type as a transparent wrapper +mod v0 { + use tfhe_versionable::{Versionize, VersionsDispatch}; + + #[derive(Versionize)] + #[versionize(transparent)] + pub(super) struct MyStructWrapper(pub(super) MyStruct); + + #[derive(Versionize)] + #[versionize(MyStructVersions)] + pub(super) struct MyStruct { + pub(super) builtin: u32, + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub(super) enum MyStructVersions { + V0(MyStruct), + } +} + +// In v1, MyStructWrapper is still transparent but MyStruct got an upgrade compared to v0. +mod v1 { + use std::convert::Infallible; + + use tfhe_versionable::{Upgrade, Version, Versionize, VersionsDispatch}; + + #[derive(Versionize)] + #[repr(transparent)] + struct MyStructWrapper(MyStruct); + + #[derive(Versionize)] + #[versionize(MyStructVersions)] + struct MyStruct { + attr: T, + builtin: u32, + } + + #[derive(Version)] + struct MyStructV0 { + builtin: u32, + } + + impl Upgrade> for MyStructV0 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(MyStruct { + attr: T::default(), + builtin: self.builtin, + }) + } + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + enum MyStructVersions { + V0(MyStructV0), + V1(MyStruct), + } +} + +fn main() { + let value = 1234; + let ms = v0::MyStructWrapper(v0::MyStruct { builtin: value }); + + let serialized = bincode::serialize(&ms.versionize()).unwrap(); + + let unserialized = + MyStructWrapper::::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap(); + + assert_eq!(unserialized.inner.builtin, value) +} + +#[test] +fn test() { + main() +}