diff --git a/tfhe/src/core_crypto/backward_compatibility/entities/lwe_multi_bit_bootstrap_key.rs b/tfhe/src/core_crypto/backward_compatibility/entities/lwe_multi_bit_bootstrap_key.rs index ce6dd0ab0..32429125a 100644 --- a/tfhe/src/core_crypto/backward_compatibility/entities/lwe_multi_bit_bootstrap_key.rs +++ b/tfhe/src/core_crypto/backward_compatibility/entities/lwe_multi_bit_bootstrap_key.rs @@ -34,10 +34,10 @@ pub enum FourierLweMultiBitBootstrapKeyVersionedOwned { V0(FourierLweMultiBitBootstrapKeyVersionOwned), } -impl> From<&FourierLweMultiBitBootstrapKey> +impl> From> for FourierLweMultiBitBootstrapKeyVersionedOwned { - fn from(value: &FourierLweMultiBitBootstrapKey) -> Self { + fn from(value: FourierLweMultiBitBootstrapKey) -> Self { Self::V0(value.into()) } } diff --git a/tfhe/src/core_crypto/backward_compatibility/fft_impl/mod.rs b/tfhe/src/core_crypto/backward_compatibility/fft_impl/mod.rs index f3a95efbd..76ed030f6 100644 --- a/tfhe/src/core_crypto/backward_compatibility/fft_impl/mod.rs +++ b/tfhe/src/core_crypto/backward_compatibility/fft_impl/mod.rs @@ -33,10 +33,10 @@ pub enum FourierPolynomialListVersionedOwned { V0(FourierPolynomialList>), } -impl> From<&FourierPolynomialList> +impl> From> for FourierPolynomialListVersionedOwned { - fn from(value: &FourierPolynomialList) -> Self { + fn from(value: FourierPolynomialList) -> Self { let owned_poly = FourierPolynomialList { data: ABox::collect(value.data.as_ref().iter().copied()), polynomial_size: value.polynomial_size, @@ -76,10 +76,10 @@ pub enum FourierLweBootstrapKeyVersionedOwned { V0(FourierLweBootstrapKeyVersionOwned), } -impl> From<&FourierLweBootstrapKey> +impl> From> for FourierLweBootstrapKeyVersionedOwned { - fn from(value: &FourierLweBootstrapKey) -> Self { + fn from(value: FourierLweBootstrapKey) -> Self { Self::V0(value.into()) } } diff --git a/tfhe/src/core_crypto/commons/math/random/generator.rs b/tfhe/src/core_crypto/commons/math/random/generator.rs index d714cca01..e75e71b92 100644 --- a/tfhe/src/core_crypto/commons/math/random/generator.rs +++ b/tfhe/src/core_crypto/commons/math/random/generator.rs @@ -31,7 +31,7 @@ pub mod serialization_proxy { } pub(crate) use serialization_proxy::*; -use tfhe_versionable::{Unversionize, Versionize}; +use tfhe_versionable::{Unversionize, Versionize, VersionizeOwned}; #[derive(PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)] /// New type to manage seeds used for compressed/seeded types. @@ -46,11 +46,13 @@ impl Versionize for CompressionSeed { fn versionize(&self) -> Self::Versioned<'_> { self.into() } +} +impl VersionizeOwned for CompressionSeed { type VersionedOwned = CompressionSeedVersionedOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { - (*self).into() + fn versionize_owned(self) -> Self::VersionedOwned { + self.into() } } diff --git a/tfhe/src/core_crypto/entities/lwe_multi_bit_bootstrap_key.rs b/tfhe/src/core_crypto/entities/lwe_multi_bit_bootstrap_key.rs index 26439ffc5..2a8883948 100644 --- a/tfhe/src/core_crypto/entities/lwe_multi_bit_bootstrap_key.rs +++ b/tfhe/src/core_crypto/entities/lwe_multi_bit_bootstrap_key.rs @@ -15,7 +15,7 @@ use crate::core_crypto::entities::*; use crate::core_crypto::fft_impl::fft64::math::fft::FourierPolynomialList; use aligned_vec::{avec, ABox}; use concrete_fft::c64; -use tfhe_versionable::{Unversionize, UnversionizeError, Versionize}; +use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(LweMultiBitBootstrapKeyVersions)] @@ -415,11 +415,11 @@ pub struct FourierLweMultiBitBootstrapKeyVersion<'vers> { #[derive(serde::Serialize, serde::Deserialize)] pub struct FourierLweMultiBitBootstrapKeyVersionOwned { fourier: FourierPolynomialListVersionedOwned, - input_lwe_dimension: ::VersionedOwned, - glwe_size: ::VersionedOwned, - decomposition_base_log: ::VersionedOwned, - decomposition_level_count: ::VersionedOwned, - grouping_factor: ::VersionedOwned, + input_lwe_dimension: ::VersionedOwned, + glwe_size: ::VersionedOwned, + decomposition_base_log: ::VersionedOwned, + decomposition_level_count: ::VersionedOwned, + grouping_factor: ::VersionedOwned, } impl<'vers, C: Container> From<&'vers FourierLweMultiBitBootstrapKey> @@ -437,10 +437,10 @@ impl<'vers, C: Container> From<&'vers FourierLweMultiBitBootstrap } } -impl> From<&FourierLweMultiBitBootstrapKey> +impl> From> for FourierLweMultiBitBootstrapKeyVersionOwned { - fn from(value: &FourierLweMultiBitBootstrapKey) -> Self { + fn from(value: FourierLweMultiBitBootstrapKey) -> Self { Self { fourier: value.fourier.versionize_owned(), input_lwe_dimension: value.input_lwe_dimension.versionize_owned(), @@ -478,10 +478,12 @@ impl> Versionize for FourierLweMultiBitBootstrapKey< fn versionize(&self) -> Self::Versioned<'_> { self.into() } +} +impl> VersionizeOwned for FourierLweMultiBitBootstrapKey { type VersionedOwned = FourierLweMultiBitBootstrapKeyVersionedOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { self.into() } } diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs index cdd556fa7..a556a7d72 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs @@ -24,7 +24,7 @@ use crate::core_crypto::prelude::ContainerMut; use aligned_vec::{avec, ABox, CACHELINE_ALIGN}; use concrete_fft::c64; use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; -use tfhe_versionable::{Unversionize, UnversionizeError, Versionize}; +use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; #[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[serde(bound(deserialize = "C: IntoContainerOwned"))] @@ -48,10 +48,10 @@ pub struct FourierLweBootstrapKeyVersion<'vers> { #[derive(serde::Serialize, serde::Deserialize)] pub struct FourierLweBootstrapKeyVersionOwned { fourier: FourierPolynomialListVersionedOwned, - input_lwe_dimension: ::VersionedOwned, - glwe_size: ::VersionedOwned, - decomposition_base_log: ::VersionedOwned, - decomposition_level_count: ::VersionedOwned, + input_lwe_dimension: ::VersionedOwned, + glwe_size: ::VersionedOwned, + decomposition_base_log: ::VersionedOwned, + decomposition_level_count: ::VersionedOwned, } impl<'vers, C: Container> From<&'vers FourierLweBootstrapKey> @@ -68,10 +68,10 @@ impl<'vers, C: Container> From<&'vers FourierLweBootstrapKey> } } -impl> From<&FourierLweBootstrapKey> +impl> From> for FourierLweBootstrapKeyVersionOwned { - fn from(value: &FourierLweBootstrapKey) -> Self { + fn from(value: FourierLweBootstrapKey) -> Self { Self { fourier: value.fourier.versionize_owned(), input_lwe_dimension: value.input_lwe_dimension.versionize_owned(), @@ -107,10 +107,12 @@ impl> Versionize for FourierLweBootstrapKey { fn versionize(&self) -> Self::Versioned<'_> { self.into() } +} +impl> VersionizeOwned for FourierLweBootstrapKey { type VersionedOwned = FourierLweBootstrapKeyVersionedOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { self.into() } } diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs index cb8e4660c..7cdbaac02 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs @@ -20,7 +20,7 @@ use std::mem::{align_of, size_of}; use std::sync::{Arc, OnceLock, RwLock}; #[cfg(not(feature = "experimental-force_fft_algo_dif4"))] use std::time::Duration; -use tfhe_versionable::{Unversionize, UnversionizeError, Versionize}; +use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] mod x86; @@ -595,10 +595,12 @@ impl> Versionize for FourierPolynomialList { fn versionize(&self) -> Self::Versioned<'_> { self.into() } +} +impl> VersionizeOwned for FourierPolynomialList { type VersionedOwned = FourierPolynomialListVersionedOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { self.into() } } diff --git a/tfhe/src/high_level_api/booleans/inner.rs b/tfhe/src/high_level_api/booleans/inner.rs index 67d557e9d..4bd92e1e2 100644 --- a/tfhe/src/high_level_api/booleans/inner.rs +++ b/tfhe/src/high_level_api/booleans/inner.rs @@ -6,7 +6,7 @@ use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::integer::BooleanBlock; use crate::Device; use serde::{Deserializer, Serializer}; -use tfhe_versionable::{Unversionize, UnversionizeError, Versionize}; +use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; /// Enum that manages the current inner representation of a boolean. pub(in crate::high_level_api) enum InnerBoolean { @@ -53,7 +53,7 @@ impl<'de> serde::Deserialize<'de> for InnerBoolean { // Only CPU data are serialized so we only versionize the CPU type. #[derive(serde::Serialize, serde::Deserialize)] pub(crate) struct InnerBooleanVersionOwned( - ::VersionedOwned, + ::VersionedOwned, ); impl Versionize for InnerBoolean { @@ -61,15 +61,18 @@ impl Versionize for InnerBoolean { fn versionize(&self) -> Self::Versioned<'_> { let data = self.on_cpu(); - let versioned = data.versionize_owned(); + let versioned = data.into_owned().versionize_owned(); InnerBooleanVersionedOwned::V0(InnerBooleanVersionOwned(versioned)) } - +} +impl VersionizeOwned for InnerBoolean { type VersionedOwned = InnerBooleanVersionedOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { let cpu_data = self.on_cpu(); - InnerBooleanVersionedOwned::V0(InnerBooleanVersionOwned(cpu_data.versionize_owned())) + InnerBooleanVersionedOwned::V0(InnerBooleanVersionOwned( + cpu_data.into_owned().versionize_owned(), + )) } } diff --git a/tfhe/src/high_level_api/integers/signed/inner.rs b/tfhe/src/high_level_api/integers/signed/inner.rs index d883fcdd0..78c177619 100644 --- a/tfhe/src/high_level_api/integers/signed/inner.rs +++ b/tfhe/src/high_level_api/integers/signed/inner.rs @@ -9,7 +9,7 @@ use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; use crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext; use crate::Device; use serde::{Deserializer, Serializer}; -use tfhe_versionable::{Unversionize, UnversionizeError, Versionize}; +use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; pub(crate) enum RadixCiphertext { Cpu(crate::integer::SignedRadixCiphertext), @@ -68,7 +68,7 @@ impl<'de> serde::Deserialize<'de> for RadixCiphertext { // Only CPU data are serialized so we only versionize the CPU type. #[derive(serde::Serialize, serde::Deserialize)] pub(crate) struct RadixCiphertextVersionOwned( - ::VersionedOwned, + ::VersionedOwned, ); impl Versionize for RadixCiphertext { @@ -76,16 +76,18 @@ impl Versionize for RadixCiphertext { fn versionize(&self) -> Self::Versioned<'_> { let data = self.on_cpu(); - let versioned = data.versionize_owned(); + let versioned = data.into_owned().versionize_owned(); SignedRadixCiphertextVersionedOwned::V0(RadixCiphertextVersionOwned(versioned)) } +} +impl VersionizeOwned for RadixCiphertext { type VersionedOwned = SignedRadixCiphertextVersionedOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { let cpu_data = self.on_cpu(); SignedRadixCiphertextVersionedOwned::V0(RadixCiphertextVersionOwned( - cpu_data.versionize_owned(), + cpu_data.into_owned().versionize_owned(), )) } } diff --git a/tfhe/src/high_level_api/integers/unsigned/inner.rs b/tfhe/src/high_level_api/integers/unsigned/inner.rs index 7926e3613..e932c9d2c 100644 --- a/tfhe/src/high_level_api/integers/unsigned/inner.rs +++ b/tfhe/src/high_level_api/integers/unsigned/inner.rs @@ -7,7 +7,7 @@ use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; use crate::Device; use serde::{Deserializer, Serializer}; -use tfhe_versionable::{Unversionize, UnversionizeError, Versionize}; +use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; pub(crate) enum RadixCiphertext { Cpu(crate::integer::RadixCiphertext), @@ -64,7 +64,7 @@ impl<'de> serde::Deserialize<'de> for RadixCiphertext { // Only CPU data are serialized so we only version the CPU type. #[derive(serde::Serialize, serde::Deserialize)] pub(crate) struct RadixCiphertextVersionOwned( - ::VersionedOwned, + ::VersionedOwned, ); impl Versionize for RadixCiphertext { @@ -72,16 +72,18 @@ impl Versionize for RadixCiphertext { fn versionize(&self) -> Self::Versioned<'_> { let data = self.on_cpu(); - let versioned = data.versionize_owned(); + let versioned = data.into_owned().versionize_owned(); UnsignedRadixCiphertextVersionedOwned::V0(RadixCiphertextVersionOwned(versioned)) } +} +impl VersionizeOwned for RadixCiphertext { type VersionedOwned = UnsignedRadixCiphertextVersionedOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { let cpu_data = self.on_cpu(); UnsignedRadixCiphertextVersionedOwned::V0(RadixCiphertextVersionOwned( - cpu_data.versionize_owned(), + cpu_data.into_owned().versionize_owned(), )) } } diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index 40c05583d..f5db7070f 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -1,4 +1,4 @@ -use tfhe_versionable::{Unversionize, Versionize}; +use tfhe_versionable::{Unversionize, Versionize, VersionizeOwned}; use crate::backward_compatibility::keys::{ CompressedServerKeyVersions, ServerKeyVersioned, ServerKeyVersionedOwned, @@ -141,7 +141,7 @@ pub struct ServerKeyVersion<'vers> { #[derive(serde::Serialize, serde::Deserialize)] pub struct ServerKeyVersionOwned { - pub(crate) integer_key: ::VersionedOwned, + pub(crate) integer_key: ::VersionedOwned, } impl Versionize for ServerKey { @@ -152,12 +152,14 @@ impl Versionize for ServerKey { integer_key: self.key.versionize(), }) } +} +impl VersionizeOwned for ServerKey { type VersionedOwned = ServerKeyVersionedOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { ServerKeyVersionedOwned::V0(ServerKeyVersionOwned { - integer_key: self.key.versionize_owned(), + integer_key: (*self.key).clone().versionize_owned(), }) } } diff --git a/tfhe/src/shortint/backward_compatibility/server_key/mod.rs b/tfhe/src/shortint/backward_compatibility/server_key/mod.rs index 315aa3abb..21cbcff4b 100644 --- a/tfhe/src/shortint/backward_compatibility/server_key/mod.rs +++ b/tfhe/src/shortint/backward_compatibility/server_key/mod.rs @@ -23,10 +23,10 @@ pub enum SerializableShortintBootstrappingKeyVersionedOwned { V0(SerializableShortintBootstrappingKeyVersionOwned), } -impl> From<&SerializableShortintBootstrappingKey> +impl> From> for SerializableShortintBootstrappingKeyVersionedOwned { - fn from(value: &SerializableShortintBootstrappingKey) -> Self { + fn from(value: SerializableShortintBootstrappingKey) -> Self { Self::V0(value.into()) } } diff --git a/tfhe/src/shortint/server_key/mod.rs b/tfhe/src/shortint/server_key/mod.rs index 0d4cc8607..d6898ef15 100644 --- a/tfhe/src/shortint/server_key/mod.rs +++ b/tfhe/src/shortint/server_key/mod.rs @@ -19,14 +19,13 @@ mod shift; mod sub; pub mod compressed; -use ::tfhe_versionable::{Unversionize, UnversionizeError, Versionize}; +use ::tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; use aligned_vec::ABox; pub use bivariate_pbs::{ BivariateLookupTableMutView, BivariateLookupTableOwned, BivariateLookupTableView, }; pub use compressed::{CompressedServerKey, ShortintCompressedBootstrappingKey}; pub(crate) use scalar_mul::unchecked_scalar_mul_assign; -use serde::de::DeserializeOwned; #[cfg(test)] pub(crate) mod tests; @@ -192,10 +191,10 @@ impl<'vers, C: Container> } } -impl> From<&SerializableShortintBootstrappingKey> +impl> From> for SerializableShortintBootstrappingKeyVersionOwned { - fn from(value: &SerializableShortintBootstrappingKey) -> Self { + fn from(value: SerializableShortintBootstrappingKey) -> Self { match value { SerializableShortintBootstrappingKey::Classic(bsk) => { Self::Classic(bsk.versionize_owned()) @@ -205,7 +204,7 @@ impl> From<&SerializableShortintBootst deterministic_execution, } => Self::MultiBit { fourier_bsk: fourier_bsk.versionize_owned(), - deterministic_execution: *deterministic_execution, + deterministic_execution, }, } } @@ -243,15 +242,19 @@ impl> Versionize fn versionize(&self) -> Self::Versioned<'_> { self.into() } +} +impl> VersionizeOwned + for SerializableShortintBootstrappingKey +{ type VersionedOwned = SerializableShortintBootstrappingKeyVersionedOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { self.into() } } -impl + Serialize + DeserializeOwned> Unversionize +impl> Unversionize for SerializableShortintBootstrappingKey { fn unversionize(versioned: Self::VersionedOwned) -> Result { @@ -320,12 +323,13 @@ impl Versionize for ShortintBootstrappingKey { fn versionize(&self) -> Self::Versioned<'_> { SerializableShortintBootstrappingKey::from(self).versionize_owned() } +} - type VersionedOwned = > as Versionize>::VersionedOwned; +impl VersionizeOwned for ShortintBootstrappingKey { + type VersionedOwned = > as VersionizeOwned>::VersionedOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { - todo!() - //SerializableShortintBootstrappingKey::from(self).versionize_owned() + fn versionize_owned(self) -> Self::VersionedOwned { + SerializableShortintBootstrappingKey::from(self).versionize_owned() } } diff --git a/utils/tfhe-versionable-derive/src/dispatch_type.rs b/utils/tfhe-versionable-derive/src/dispatch_type.rs index 15da57eba..e6e77f370 100644 --- a/utils/tfhe-versionable-derive/src/dispatch_type.rs +++ b/utils/tfhe-versionable-derive/src/dispatch_type.rs @@ -145,7 +145,7 @@ impl AssociatedType for DispatchType { // Wraps the highest version into the dispatch enum let src_type = self.latest_version_type()?; - let src = parse_quote! { &#src_type }; + let src = parse_quote! { #src_type }; let dest_ident = self.ident(); let dest = parse_quote! { #dest_ident #ty_generics }; let constructor = self.generate_conversion_constructor_ref("value")?; diff --git a/utils/tfhe-versionable-derive/src/lib.rs b/utils/tfhe-versionable-derive/src/lib.rs index e63d9bdc2..1b9dead2f 100644 --- a/utils/tfhe-versionable-derive/src/lib.rs +++ b/utils/tfhe-versionable-derive/src/lib.rs @@ -34,8 +34,10 @@ pub(crate) const LIFETIME_NAME: &str = "'vers"; pub(crate) const VERSION_TRAIT_NAME: &str = crate_full_path!("Version"); pub(crate) const DISPATCH_TRAIT_NAME: &str = crate_full_path!("VersionsDispatch"); pub(crate) const VERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Versionize"); -pub(crate) const UNVERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Unversionize"); +pub(crate) const VERSIONIZE_OWNED_TRAIT_NAME: &str = crate_full_path!("VersionizeOwned"); +pub(crate) const VERSIONIZE_SLICE_TRAIT_NAME: &str = crate_full_path!("VersionizeSlice"); pub(crate) const VERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("VersionizeVec"); +pub(crate) const UNVERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Unversionize"); pub(crate) const UNVERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("UnversionizeVec"); pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeError"); @@ -103,6 +105,7 @@ pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream { &[ VERSIONIZE_TRAIT_NAME, VERSIONIZE_VEC_TRAIT_NAME, + VERSIONIZE_SLICE_TRAIT_NAME, UNVERSIONIZE_TRAIT_NAME, UNVERSIONIZE_VEC_TRAIT_NAME, SERIALIZE_TRAIT_NAME, @@ -187,8 +190,10 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { )); let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME); + let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME); let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME); let versionize_vec_trait: Path = parse_const_str(VERSIONIZE_VEC_TRAIT_NAME); + let versionize_slice_trait: Path = parse_const_str(VERSIONIZE_SLICE_TRAIT_NAME); let unversionize_vec_trait: Path = parse_const_str(UNVERSIONIZE_VEC_TRAIT_NAME); let mut versionize_generics = trait_generics.clone(); @@ -203,10 +208,16 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { } // Add Generics for the `VersionizeVec` and `UnversionizeVec` traits + let mut versionize_slice_generics = versionize_generics.clone(); + syn_unwrap!(add_trait_bound( + &mut versionize_slice_generics, + VERSIONIZE_TRAIT_NAME + )); + let mut versionize_vec_generics = versionize_generics.clone(); syn_unwrap!(add_trait_bound( &mut versionize_vec_generics, - VERSIONIZE_TRAIT_NAME + VERSIONIZE_OWNED_TRAIT_NAME )); let mut unversionize_vec_generics = unversionize_generics.clone(); syn_unwrap!(add_trait_bound( @@ -221,6 +232,8 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { let (unversionize_impl_generics, _, unversionize_where_clause) = unversionize_generics.split_for_impl(); + let (versionize_slice_impl_generics, _, versionize_slice_where_clause) = + versionize_slice_generics.split_for_impl(); let (versionize_vec_impl_generics, _, versionize_vec_where_clause) = versionize_vec_generics.split_for_impl(); let (unversionize_vec_impl_generics, _, unversionize_vec_where_clause) = @@ -253,15 +266,19 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { fn versionize(&self) -> Self::Versioned<'_> { #versionize_body } + } - fn versionize_owned(&self) -> Self::VersionedOwned { - #versionize_body - } - + #[automatically_derived] + impl #versionize_impl_generics #versionize_owned_trait for #input_ident #ty_generics + #versionize_where_clause + { type VersionedOwned = <#dispatch_enum_path #dispatch_generics as #dispatch_trait<#dispatch_target>>::Owned #owned_where_clause; + fn versionize_owned(self) -> Self::VersionedOwned { + #versionize_body + } } #[automatically_derived] @@ -274,19 +291,24 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { } #[automatically_derived] - impl #versionize_vec_impl_generics #versionize_vec_trait for #input_ident #ty_generics - #versionize_vec_where_clause + impl #versionize_slice_impl_generics #versionize_slice_trait for #input_ident #ty_generics + #versionize_slice_where_clause { type VersionedSlice<#lifetime> = Vec<::Versioned<#lifetime>> #ref_where_clause; fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> { - slice.iter().map(|val| val.versionize()).collect() + slice.iter().map(|val| #versionize_trait::versionize(val)).collect() } + } - type VersionedVec = Vec<::VersionedOwned> #owned_where_clause; + impl #versionize_vec_impl_generics #versionize_vec_trait for #input_ident #ty_generics + #versionize_vec_where_clause + { - fn versionize_vec(slice: &[Self]) -> Self::VersionedVec { - slice.iter().map(|val| val.versionize_owned()).collect() + type VersionedVec = Vec<::VersionedOwned> #owned_where_clause; + + fn versionize_vec(vec: Vec) -> Self::VersionedVec { + vec.into_iter().map(|val| #versionize_owned_trait::versionize_owned(val)).collect() } } @@ -321,6 +343,7 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream { let input_ident = &input.ident; let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME); + let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME); let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME); let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME); let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site()); @@ -329,14 +352,18 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream { #[automatically_derived] impl #impl_generics #versionize_trait for #input_ident #ty_generics #where_clause { type Versioned<#lifetime> = &#lifetime Self; - type VersionedOwned = Self; fn versionize(&self) -> Self::Versioned<'_> { self } + } - fn versionize_owned(&self) -> Self::VersionedOwned { - self.clone() + #[automatically_derived] + impl #impl_generics #versionize_owned_trait for #input_ident #ty_generics #where_clause { + type VersionedOwned = Self; + + fn versionize_owned(self) -> Self::VersionedOwned { + self } } diff --git a/utils/tfhe-versionable-derive/src/version_type.rs b/utils/tfhe-versionable-derive/src/version_type.rs index edff7fb3e..1c31877ff 100644 --- a/utils/tfhe-versionable-derive/src/version_type.rs +++ b/utils/tfhe-versionable-derive/src/version_type.rs @@ -6,9 +6,9 @@ use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::token::Comma; use syn::{ - parse_quote, Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, - FieldsNamed, FieldsUnnamed, Generics, Ident, Item, ItemEnum, ItemImpl, ItemStruct, ItemUnion, - Lifetime, Path, Type, Variant, + parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, FieldsNamed, + FieldsUnnamed, Generics, Ident, Item, ItemEnum, ItemImpl, ItemStruct, ItemUnion, Lifetime, + Path, Type, Variant, }; use crate::associated::{ @@ -18,7 +18,7 @@ use crate::associated::{ use crate::{ add_lifetime_bound, add_trait_where_clause, parse_const_str, parse_trait_bound, punctuated_from_iter_result, LIFETIME_NAME, UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, - VERSIONIZE_TRAIT_NAME, + VERSIONIZE_OWNED_TRAIT_NAME, VERSIONIZE_TRAIT_NAME, }; /// The types generated for a specific version of a given exposed type. These types are identical to @@ -132,7 +132,7 @@ impl AssociatedType for VersionType { let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let src_ident = self.orig_type.ident.clone(); - let src = parse_quote! { &#src_ident #orig_generics }; + let src = parse_quote! { #src_ident #orig_generics }; let dest_ident = self.ident(); let dest = parse_quote! { #dest_ident #ty_generics }; let constructor = self.generate_conversion_constructor( @@ -193,12 +193,19 @@ impl VersionType { fn type_generics(&self) -> syn::Result { let mut generics = self.orig_type.generics.clone(); - if let AssociatedTypeKind::Ref(Some(lifetime)) = &self.kind { - add_lifetime_bound(&mut generics, lifetime); + if let AssociatedTypeKind::Ref(opt_lifetime) = &self.kind { + if let Some(lifetime) = opt_lifetime { + add_lifetime_bound(&mut generics, lifetime); + } + add_trait_where_clause(&mut generics, self.inner_types()?, &[VERSIONIZE_TRAIT_NAME])?; + } else { + add_trait_where_clause( + &mut generics, + self.inner_types()?, + &[VERSIONIZE_OWNED_TRAIT_NAME], + )?; } - add_trait_where_clause(&mut generics, self.inner_types()?, &[VERSIONIZE_TRAIT_NAME])?; - Ok(generics) } @@ -327,13 +334,14 @@ impl VersionType { let unver_ty = field.ty.clone(); let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?; + let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?; let ty: Type = match &kind { AssociatedTypeKind::Ref(lifetime) => parse_quote! { <#unver_ty as #versionize_trait>::Versioned<#lifetime> }, AssociatedTypeKind::Owned => parse_quote! { - <#unver_ty as #versionize_trait>::VersionedOwned + <#unver_ty as #versionize_owned_trait>::VersionedOwned }, }; @@ -632,22 +640,24 @@ impl VersionType { direction: ConversionDirection, ) -> syn::Result { let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME); + let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME); let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME); let field_constructor = match direction { ConversionDirection::OrigToAssociated => { - let param = if is_ref { - field_param - } else { - quote! {&#field_param} - }; - match self.kind { - AssociatedTypeKind::Ref(_) => quote! { - #versionize_trait::versionize(#param) - }, + AssociatedTypeKind::Ref(_) => { + let param = if is_ref { + field_param + } else { + quote! {&#field_param} + }; + + quote! { + #versionize_trait::versionize(#param) + }}, AssociatedTypeKind::Owned => quote! { - #versionize_trait::versionize_owned(#param) + #versionize_owned_trait::versionize_owned(#field_param) }, } } diff --git a/utils/tfhe-versionable-derive/src/versionize_attribute.rs b/utils/tfhe-versionable-derive/src/versionize_attribute.rs index 63c26620b..5accc750c 100644 --- a/utils/tfhe-versionable-derive/src/versionize_attribute.rs +++ b/utils/tfhe-versionable-derive/src/versionize_attribute.rs @@ -3,10 +3,11 @@ use quote::{quote, ToTokens}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ - Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Path, Token, Type, TypeParam, + Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Path, Token, TraitBound, Type, + TypeParam, }; -use crate::{parse_const_str, UNVERSIONIZE_ERROR_NAME}; +use crate::{parse_const_str, UNVERSIONIZE_ERROR_NAME, VERSIONIZE_OWNED_TRAIT_NAME}; /// Name of the attribute used to give arguments to our macros const VERSIONIZE_ATTR_NAME: &str = "versionize"; @@ -184,11 +185,12 @@ impl VersionizeAttribute { } pub(crate) fn versionize_method_body(&self) -> proc_macro2::TokenStream { + let versionize_owned_trait: TraitBound = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME); self.into .as_ref() .map(|target| { quote! { - #target::from(self.to_owned()).versionize_owned() + #versionize_owned_trait::versionize_owned(#target::from(self.to_owned())) } }) .unwrap_or_else(|| { diff --git a/utils/tfhe-versionable/examples/bounds.rs b/utils/tfhe-versionable/examples/bounds.rs index 18e853b04..dddfdfa08 100644 --- a/utils/tfhe-versionable/examples/bounds.rs +++ b/utils/tfhe-versionable/examples/bounds.rs @@ -3,23 +3,27 @@ use serde::de::DeserializeOwned; use serde::Serialize; -use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionsDispatch}; +use tfhe_versionable::{ + Unversionize, UnversionizeError, Versionize, VersionizeOwned, VersionsDispatch, +}; // Example of a simple struct with a manual Versionize impl that requires a specific bound struct MyStruct { val: T, } -impl> Versionize for MyStruct { +impl Versionize for MyStruct { type Versioned<'vers> = &'vers T where T: 'vers; fn versionize(&self) -> Self::Versioned<'_> { &self.val } +} +impl> VersionizeOwned for MyStruct { type VersionedOwned = T; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { self.val.to_owned() } } diff --git a/utils/tfhe-versionable/examples/manual_impl.rs b/utils/tfhe-versionable/examples/manual_impl.rs index e80683b2a..4d31e6a94 100644 --- a/utils/tfhe-versionable/examples/manual_impl.rs +++ b/utils/tfhe-versionable/examples/manual_impl.rs @@ -2,7 +2,7 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use tfhe_versionable::{Unversionize, UnversionizeError, Upgrade, Versionize}; +use tfhe_versionable::{Unversionize, UnversionizeError, Upgrade, Versionize, VersionizeOwned}; struct MyStruct { attr: T, @@ -30,7 +30,7 @@ struct MyStructVersion<'vers, T: 'vers + Default + Versionize> { } #[derive(Serialize, Deserialize)] -struct MyStructVersionOwned { +struct MyStructVersionOwned { attr: T::VersionedOwned, builtin: u32, } @@ -47,10 +47,12 @@ impl Versionize for MySt }; MyStructVersionsDispatch::V1(ver) } +} +impl VersionizeOwned for MyStruct { type VersionedOwned = MyStructVersionsDispatchOwned; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { let ver = MyStructVersionOwned { attr: self.attr.versionize_owned(), builtin: self.builtin, @@ -59,7 +61,7 @@ impl Versionize for MySt } } -impl Unversionize +impl Unversionize for MyStruct { fn unversionize(versioned: Self::VersionedOwned) -> Result { @@ -83,7 +85,7 @@ enum MyStructVersionsDispatch<'vers, T: 'vers + Default + Versionize> { } #[derive(Serialize, Deserialize)] -enum MyStructVersionsDispatchOwned { +enum MyStructVersionsDispatchOwned { V0(MyStructV0), V1(MyStructVersionOwned), } diff --git a/utils/tfhe-versionable/src/derived_traits.rs b/utils/tfhe-versionable/src/derived_traits.rs index dedd37b25..84372a712 100644 --- a/utils/tfhe-versionable/src/derived_traits.rs +++ b/utils/tfhe-versionable/src/derived_traits.rs @@ -11,10 +11,7 @@ pub trait Version: Sized { type Ref<'vers>: From<&'vers Self> + Serialize where Self: 'vers; - type Owned: for<'vers> From<&'vers Self> - + TryInto - + DeserializeOwned - + Serialize; + type Owned: From + TryInto + DeserializeOwned + Serialize; } /// This trait is implemented on the dispatch enum for a given type. The dispatch enum @@ -24,7 +21,7 @@ pub trait VersionsDispatch: Sized { type Ref<'vers>: From<&'vers Unversioned> + Serialize where Unversioned: 'vers; - type Owned: for<'vers> From<&'vers Unversioned> + type Owned: From + TryInto + DeserializeOwned + Serialize; diff --git a/utils/tfhe-versionable/src/lib.rs b/utils/tfhe-versionable/src/lib.rs index f7101d09d..db81cf1c5 100644 --- a/utils/tfhe-versionable/src/lib.rs +++ b/utils/tfhe-versionable/src/lib.rs @@ -34,28 +34,32 @@ pub trait Versionize { /// Wraps the object into a versioned enum with a variant for each version. This will /// use references on the underlying types if possible. fn versionize(&self) -> Self::Versioned<'_>; +} +pub trait VersionizeOwned { type VersionedOwned: Serialize + DeserializeOwned; /// Wraps the object into a versioned enum with a variant for each version. This will /// clone the underlying types. - fn versionize_owned(&self) -> Self::VersionedOwned; + fn versionize_owned(self) -> Self::VersionedOwned; } /// This trait is used as a proxy to be more felxible when deriving Versionize for Vec. /// This way, we can chose to skip versioning Vec if T is a native types but still versionize in /// a loop if T is a custom type. /// This is used as a workaround for feature(specialization) and to bypass the orphan rule. -pub trait VersionizeVec: Sized { +pub trait VersionizeSlice: Sized { type VersionedSlice<'vers>: Serialize where Self: 'vers; fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_>; +} +pub trait VersionizeVec: Sized { type VersionedVec: Serialize + DeserializeOwned; - fn versionize_vec(slice: &[Self]) -> Self::VersionedVec; + fn versionize_vec(vec: Vec) -> Self::VersionedVec; } #[derive(Debug)] @@ -117,7 +121,7 @@ impl From for UnversionizeError { /// This trait means that we can convert from a versioned enum into the target type. This trait /// can only be implemented on Owned/static types, whereas `Versionize` can also be implemented /// on reference types. -pub trait Unversionize: Versionize + Sized { +pub trait Unversionize: VersionizeOwned + Sized { /// Creates an object from a versioned enum, and eventually upgrades from previous /// variants. fn unversionize(versioned: Self::VersionedOwned) -> Result; @@ -131,17 +135,19 @@ pub trait UnversionizeVec: VersionizeVec { /// Self or &Self. pub trait NotVersioned: Versionize {} -impl VersionizeVec for T { +impl VersionizeSlice for T { type VersionedSlice<'vers> = &'vers [T] where T: 'vers; fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> { slice } +} +impl VersionizeVec for T { type VersionedVec = Vec; - fn versionize_vec(slice: &[Self]) -> Self::VersionedVec { - slice.to_vec() + fn versionize_vec(vec: Vec) -> Self::VersionedVec { + vec } } @@ -159,14 +165,15 @@ macro_rules! impl_scalar_versionize { impl Versionize for $t { type Versioned<'vers> = $t; - type VersionedOwned = $t; - fn versionize(&self) -> Self::Versioned<'_> { *self } + } - fn versionize_owned(&self) -> Self::VersionedOwned { - *self + impl VersionizeOwned for $t { + type VersionedOwned = $t; + fn versionize_owned(self) -> Self::VersionedOwned { + self } } @@ -208,11 +215,13 @@ impl Versionize for Box { fn versionize(&self) -> Self::Versioned<'_> { self.as_ref().versionize() } +} +impl VersionizeOwned for Box { type VersionedOwned = Box; - fn versionize_owned(&self) -> Self::VersionedOwned { - Box::new(T::versionize_owned(self)) + fn versionize_owned(self) -> Self::VersionedOwned { + Box::new(T::versionize_owned(*self)) } } @@ -222,31 +231,35 @@ impl Unversionize for Box { } } -impl Versionize for Vec { +impl Versionize for Vec { type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers; fn versionize(&self) -> Self::Versioned<'_> { T::versionize_slice(self) } +} +impl VersionizeOwned for Vec { type VersionedOwned = T::VersionedVec; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { T::versionize_vec(self) } } -impl Versionize for [T] { +impl Versionize for [T] { type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers; fn versionize(&self) -> Self::Versioned<'_> { T::versionize_slice(self) } +} +impl VersionizeOwned for &[T] { type VersionedOwned = T::VersionedVec; - fn versionize_owned(&self) -> Self::VersionedOwned { - T::versionize_vec(self) + fn versionize_owned(self) -> Self::VersionedOwned { + T::versionize_vec(self.to_vec()) } } @@ -262,11 +275,13 @@ impl Versionize for String { fn versionize(&self) -> Self::Versioned<'_> { self.as_ref() } +} +impl VersionizeOwned for String { type VersionedOwned = Self; - fn versionize_owned(&self) -> Self::VersionedOwned { - self.clone() + fn versionize_owned(self) -> Self::VersionedOwned { + self } } @@ -284,10 +299,12 @@ impl Versionize for str { fn versionize(&self) -> Self::Versioned<'_> { self } +} +impl VersionizeOwned for &str { type VersionedOwned = String; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { self.to_string() } } @@ -300,11 +317,13 @@ impl Versionize for Option { fn versionize(&self) -> Self::Versioned<'_> { self.as_ref().map(|val| val.versionize()) } +} +impl VersionizeOwned for Option { type VersionedOwned = Option; - fn versionize_owned(&self) -> Self::VersionedOwned { - self.as_ref().map(|val| val.versionize_owned()) + fn versionize_owned(self) -> Self::VersionedOwned { + self.map(|val| val.versionize_owned()) } } @@ -324,11 +343,13 @@ impl Versionize for PhantomData { fn versionize(&self) -> Self::Versioned<'_> { *self } +} +impl VersionizeOwned for PhantomData { type VersionedOwned = Self; - fn versionize_owned(&self) -> Self::VersionedOwned { - *self + fn versionize_owned(self) -> Self::VersionedOwned { + self } } @@ -349,10 +370,12 @@ impl Versionize for Complex { im: self.im.versionize(), } } +} +impl VersionizeOwned for Complex { type VersionedOwned = Complex; - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { Complex { re: self.re.versionize_owned(), im: self.im.versionize_owned(), @@ -377,16 +400,18 @@ impl Versionize for ABox { fn versionize(&self) -> Self::Versioned<'_> { self.as_ref().versionize() } +} +impl VersionizeOwned for ABox { // Alignment doesn't matter for versioned types type VersionedOwned = Box; - fn versionize_owned(&self) -> Self::VersionedOwned { - Box::new(T::versionize_owned(self)) + fn versionize_owned(self) -> Self::VersionedOwned { + Box::new(T::versionize_owned(*self)) } } -impl Unversionize for ABox +impl Unversionize for ABox where T::VersionedOwned: Clone, { @@ -395,22 +420,24 @@ where } } -impl Versionize for AVec { +impl Versionize for AVec { type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers; fn versionize(&self) -> Self::Versioned<'_> { T::versionize_slice(self) } +} - // Alignment doesn't matter for versioned types +// Alignment doesn't matter for versioned types +impl VersionizeOwned for AVec { type VersionedOwned = T::VersionedVec; - fn versionize_owned(&self) -> Self::VersionedOwned { - T::versionize_vec(self) + fn versionize_owned(self) -> Self::VersionedOwned { + T::versionize_vec(self.to_vec()) } } -impl Unversionize for AVec { +impl Unversionize for AVec { fn unversionize(versioned: Self::VersionedOwned) -> Result { T::unversionize_vec(versioned).map(|unver| AVec::from_iter(0, unver)) } @@ -424,10 +451,12 @@ impl Versionize for (T, U) { fn versionize(&self) -> Self::Versioned<'_> { (self.0.versionize(), self.1.versionize()) } +} +impl VersionizeOwned for (T, U) { type VersionedOwned = (T::VersionedOwned, U::VersionedOwned); - fn versionize_owned(&self) -> Self::VersionedOwned { + fn versionize_owned(self) -> Self::VersionedOwned { (self.0.versionize_owned(), self.1.versionize_owned()) } }