refactor(versionable)!: fix signature of versionize_owned

BREAKING CHANGE: `versionize_owned` now takes its argument by value.
This commit is contained in:
Nicolas Sarlin
2024-06-12 13:35:33 +02:00
committed by Nicolas Sarlin
parent ac37c3883d
commit e9051419cd
20 changed files with 240 additions and 148 deletions

View File

@@ -34,10 +34,10 @@ pub enum FourierLweMultiBitBootstrapKeyVersionedOwned {
V0(FourierLweMultiBitBootstrapKeyVersionOwned),
}
impl<C: Container<Element = c64>> From<&FourierLweMultiBitBootstrapKey<C>>
impl<C: Container<Element = c64>> From<FourierLweMultiBitBootstrapKey<C>>
for FourierLweMultiBitBootstrapKeyVersionedOwned
{
fn from(value: &FourierLweMultiBitBootstrapKey<C>) -> Self {
fn from(value: FourierLweMultiBitBootstrapKey<C>) -> Self {
Self::V0(value.into())
}
}

View File

@@ -33,10 +33,10 @@ pub enum FourierPolynomialListVersionedOwned {
V0(FourierPolynomialList<ABox<[c64]>>),
}
impl<C: Container<Element = c64>> From<&FourierPolynomialList<C>>
impl<C: Container<Element = c64>> From<FourierPolynomialList<C>>
for FourierPolynomialListVersionedOwned
{
fn from(value: &FourierPolynomialList<C>) -> Self {
fn from(value: FourierPolynomialList<C>) -> Self {
let owned_poly = FourierPolynomialList {
data: ABox::collect(value.data.as_ref().iter().copied()),
polynomial_size: value.polynomial_size,
@@ -76,10 +76,10 @@ pub enum FourierLweBootstrapKeyVersionedOwned {
V0(FourierLweBootstrapKeyVersionOwned),
}
impl<C: Container<Element = c64>> From<&FourierLweBootstrapKey<C>>
impl<C: Container<Element = c64>> From<FourierLweBootstrapKey<C>>
for FourierLweBootstrapKeyVersionedOwned
{
fn from(value: &FourierLweBootstrapKey<C>) -> Self {
fn from(value: FourierLweBootstrapKey<C>) -> Self {
Self::V0(value.into())
}
}

View File

@@ -31,7 +31,7 @@ pub mod serialization_proxy {
}
pub(crate) use serialization_proxy::*;
use tfhe_versionable::{Unversionize, Versionize};
use tfhe_versionable::{Unversionize, Versionize, VersionizeOwned};
#[derive(PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)]
/// New type to manage seeds used for compressed/seeded types.
@@ -46,11 +46,13 @@ impl Versionize for CompressionSeed {
fn versionize(&self) -> Self::Versioned<'_> {
self.into()
}
}
impl VersionizeOwned for CompressionSeed {
type VersionedOwned = CompressionSeedVersionedOwned;
fn versionize_owned(&self) -> Self::VersionedOwned {
(*self).into()
fn versionize_owned(self) -> Self::VersionedOwned {
self.into()
}
}

View File

@@ -15,7 +15,7 @@ use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::fft64::math::fft::FourierPolynomialList;
use aligned_vec::{avec, ABox};
use concrete_fft::c64;
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(LweMultiBitBootstrapKeyVersions)]
@@ -415,11 +415,11 @@ pub struct FourierLweMultiBitBootstrapKeyVersion<'vers> {
#[derive(serde::Serialize, serde::Deserialize)]
pub struct FourierLweMultiBitBootstrapKeyVersionOwned {
fourier: FourierPolynomialListVersionedOwned,
input_lwe_dimension: <LweDimension as Versionize>::VersionedOwned,
glwe_size: <GlweSize as Versionize>::VersionedOwned,
decomposition_base_log: <DecompositionBaseLog as Versionize>::VersionedOwned,
decomposition_level_count: <DecompositionLevelCount as Versionize>::VersionedOwned,
grouping_factor: <LweBskGroupingFactor as Versionize>::VersionedOwned,
input_lwe_dimension: <LweDimension as VersionizeOwned>::VersionedOwned,
glwe_size: <GlweSize as VersionizeOwned>::VersionedOwned,
decomposition_base_log: <DecompositionBaseLog as VersionizeOwned>::VersionedOwned,
decomposition_level_count: <DecompositionLevelCount as VersionizeOwned>::VersionedOwned,
grouping_factor: <LweBskGroupingFactor as VersionizeOwned>::VersionedOwned,
}
impl<'vers, C: Container<Element = c64>> From<&'vers FourierLweMultiBitBootstrapKey<C>>
@@ -437,10 +437,10 @@ impl<'vers, C: Container<Element = c64>> From<&'vers FourierLweMultiBitBootstrap
}
}
impl<C: Container<Element = c64>> From<&FourierLweMultiBitBootstrapKey<C>>
impl<C: Container<Element = c64>> From<FourierLweMultiBitBootstrapKey<C>>
for FourierLweMultiBitBootstrapKeyVersionOwned
{
fn from(value: &FourierLweMultiBitBootstrapKey<C>) -> Self {
fn from(value: FourierLweMultiBitBootstrapKey<C>) -> Self {
Self {
fourier: value.fourier.versionize_owned(),
input_lwe_dimension: value.input_lwe_dimension.versionize_owned(),
@@ -478,10 +478,12 @@ impl<C: Container<Element = c64>> Versionize for FourierLweMultiBitBootstrapKey<
fn versionize(&self) -> Self::Versioned<'_> {
self.into()
}
}
impl<C: Container<Element = c64>> VersionizeOwned for FourierLweMultiBitBootstrapKey<C> {
type VersionedOwned = FourierLweMultiBitBootstrapKeyVersionedOwned;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
self.into()
}
}

View File

@@ -24,7 +24,7 @@ use crate::core_crypto::prelude::ContainerMut;
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
use concrete_fft::c64;
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(bound(deserialize = "C: IntoContainerOwned"))]
@@ -48,10 +48,10 @@ pub struct FourierLweBootstrapKeyVersion<'vers> {
#[derive(serde::Serialize, serde::Deserialize)]
pub struct FourierLweBootstrapKeyVersionOwned {
fourier: FourierPolynomialListVersionedOwned,
input_lwe_dimension: <LweDimension as Versionize>::VersionedOwned,
glwe_size: <GlweSize as Versionize>::VersionedOwned,
decomposition_base_log: <DecompositionBaseLog as Versionize>::VersionedOwned,
decomposition_level_count: <DecompositionLevelCount as Versionize>::VersionedOwned,
input_lwe_dimension: <LweDimension as VersionizeOwned>::VersionedOwned,
glwe_size: <GlweSize as VersionizeOwned>::VersionedOwned,
decomposition_base_log: <DecompositionBaseLog as VersionizeOwned>::VersionedOwned,
decomposition_level_count: <DecompositionLevelCount as VersionizeOwned>::VersionedOwned,
}
impl<'vers, C: Container<Element = c64>> From<&'vers FourierLweBootstrapKey<C>>
@@ -68,10 +68,10 @@ impl<'vers, C: Container<Element = c64>> From<&'vers FourierLweBootstrapKey<C>>
}
}
impl<C: Container<Element = c64>> From<&FourierLweBootstrapKey<C>>
impl<C: Container<Element = c64>> From<FourierLweBootstrapKey<C>>
for FourierLweBootstrapKeyVersionOwned
{
fn from(value: &FourierLweBootstrapKey<C>) -> Self {
fn from(value: FourierLweBootstrapKey<C>) -> Self {
Self {
fourier: value.fourier.versionize_owned(),
input_lwe_dimension: value.input_lwe_dimension.versionize_owned(),
@@ -107,10 +107,12 @@ impl<C: Container<Element = c64>> Versionize for FourierLweBootstrapKey<C> {
fn versionize(&self) -> Self::Versioned<'_> {
self.into()
}
}
impl<C: Container<Element = c64>> VersionizeOwned for FourierLweBootstrapKey<C> {
type VersionedOwned = FourierLweBootstrapKeyVersionedOwned;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
self.into()
}
}

View File

@@ -20,7 +20,7 @@ use std::mem::{align_of, size_of};
use std::sync::{Arc, OnceLock, RwLock};
#[cfg(not(feature = "experimental-force_fft_algo_dif4"))]
use std::time::Duration;
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
mod x86;
@@ -595,10 +595,12 @@ impl<C: Container<Element = c64>> Versionize for FourierPolynomialList<C> {
fn versionize(&self) -> Self::Versioned<'_> {
self.into()
}
}
impl<C: Container<Element = c64>> VersionizeOwned for FourierPolynomialList<C> {
type VersionedOwned = FourierPolynomialListVersionedOwned;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
self.into()
}
}

View File

@@ -6,7 +6,7 @@ use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::integer::BooleanBlock;
use crate::Device;
use serde::{Deserializer, Serializer};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
/// Enum that manages the current inner representation of a boolean.
pub(in crate::high_level_api) enum InnerBoolean {
@@ -53,7 +53,7 @@ impl<'de> serde::Deserialize<'de> for InnerBoolean {
// Only CPU data are serialized so we only versionize the CPU type.
#[derive(serde::Serialize, serde::Deserialize)]
pub(crate) struct InnerBooleanVersionOwned(
<crate::integer::BooleanBlock as Versionize>::VersionedOwned,
<crate::integer::BooleanBlock as VersionizeOwned>::VersionedOwned,
);
impl Versionize for InnerBoolean {
@@ -61,15 +61,18 @@ impl Versionize for InnerBoolean {
fn versionize(&self) -> Self::Versioned<'_> {
let data = self.on_cpu();
let versioned = data.versionize_owned();
let versioned = data.into_owned().versionize_owned();
InnerBooleanVersionedOwned::V0(InnerBooleanVersionOwned(versioned))
}
}
impl VersionizeOwned for InnerBoolean {
type VersionedOwned = InnerBooleanVersionedOwned;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
let cpu_data = self.on_cpu();
InnerBooleanVersionedOwned::V0(InnerBooleanVersionOwned(cpu_data.versionize_owned()))
InnerBooleanVersionedOwned::V0(InnerBooleanVersionOwned(
cpu_data.into_owned().versionize_owned(),
))
}
}

View File

@@ -9,7 +9,7 @@ use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext;
use crate::Device;
use serde::{Deserializer, Serializer};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
pub(crate) enum RadixCiphertext {
Cpu(crate::integer::SignedRadixCiphertext),
@@ -68,7 +68,7 @@ impl<'de> serde::Deserialize<'de> for RadixCiphertext {
// Only CPU data are serialized so we only versionize the CPU type.
#[derive(serde::Serialize, serde::Deserialize)]
pub(crate) struct RadixCiphertextVersionOwned(
<crate::integer::SignedRadixCiphertext as Versionize>::VersionedOwned,
<crate::integer::SignedRadixCiphertext as VersionizeOwned>::VersionedOwned,
);
impl Versionize for RadixCiphertext {
@@ -76,16 +76,18 @@ impl Versionize for RadixCiphertext {
fn versionize(&self) -> Self::Versioned<'_> {
let data = self.on_cpu();
let versioned = data.versionize_owned();
let versioned = data.into_owned().versionize_owned();
SignedRadixCiphertextVersionedOwned::V0(RadixCiphertextVersionOwned(versioned))
}
}
impl VersionizeOwned for RadixCiphertext {
type VersionedOwned = SignedRadixCiphertextVersionedOwned;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
let cpu_data = self.on_cpu();
SignedRadixCiphertextVersionedOwned::V0(RadixCiphertextVersionOwned(
cpu_data.versionize_owned(),
cpu_data.into_owned().versionize_owned(),
))
}
}

View File

@@ -7,7 +7,7 @@ use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::Device;
use serde::{Deserializer, Serializer};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize};
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
pub(crate) enum RadixCiphertext {
Cpu(crate::integer::RadixCiphertext),
@@ -64,7 +64,7 @@ impl<'de> serde::Deserialize<'de> for RadixCiphertext {
// Only CPU data are serialized so we only version the CPU type.
#[derive(serde::Serialize, serde::Deserialize)]
pub(crate) struct RadixCiphertextVersionOwned(
<crate::integer::RadixCiphertext as Versionize>::VersionedOwned,
<crate::integer::RadixCiphertext as VersionizeOwned>::VersionedOwned,
);
impl Versionize for RadixCiphertext {
@@ -72,16 +72,18 @@ impl Versionize for RadixCiphertext {
fn versionize(&self) -> Self::Versioned<'_> {
let data = self.on_cpu();
let versioned = data.versionize_owned();
let versioned = data.into_owned().versionize_owned();
UnsignedRadixCiphertextVersionedOwned::V0(RadixCiphertextVersionOwned(versioned))
}
}
impl VersionizeOwned for RadixCiphertext {
type VersionedOwned = UnsignedRadixCiphertextVersionedOwned;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
let cpu_data = self.on_cpu();
UnsignedRadixCiphertextVersionedOwned::V0(RadixCiphertextVersionOwned(
cpu_data.versionize_owned(),
cpu_data.into_owned().versionize_owned(),
))
}
}

View File

@@ -1,4 +1,4 @@
use tfhe_versionable::{Unversionize, Versionize};
use tfhe_versionable::{Unversionize, Versionize, VersionizeOwned};
use crate::backward_compatibility::keys::{
CompressedServerKeyVersions, ServerKeyVersioned, ServerKeyVersionedOwned,
@@ -141,7 +141,7 @@ pub struct ServerKeyVersion<'vers> {
#[derive(serde::Serialize, serde::Deserialize)]
pub struct ServerKeyVersionOwned {
pub(crate) integer_key: <IntegerServerKey as Versionize>::VersionedOwned,
pub(crate) integer_key: <IntegerServerKey as VersionizeOwned>::VersionedOwned,
}
impl Versionize for ServerKey {
@@ -152,12 +152,14 @@ impl Versionize for ServerKey {
integer_key: self.key.versionize(),
})
}
}
impl VersionizeOwned for ServerKey {
type VersionedOwned = ServerKeyVersionedOwned;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
ServerKeyVersionedOwned::V0(ServerKeyVersionOwned {
integer_key: self.key.versionize_owned(),
integer_key: (*self.key).clone().versionize_owned(),
})
}
}

View File

@@ -23,10 +23,10 @@ pub enum SerializableShortintBootstrappingKeyVersionedOwned {
V0(SerializableShortintBootstrappingKeyVersionOwned),
}
impl<C: Container<Element = concrete_fft::c64>> From<&SerializableShortintBootstrappingKey<C>>
impl<C: Container<Element = concrete_fft::c64>> From<SerializableShortintBootstrappingKey<C>>
for SerializableShortintBootstrappingKeyVersionedOwned
{
fn from(value: &SerializableShortintBootstrappingKey<C>) -> Self {
fn from(value: SerializableShortintBootstrappingKey<C>) -> Self {
Self::V0(value.into())
}
}

View File

@@ -19,14 +19,13 @@ mod shift;
mod sub;
pub mod compressed;
use ::tfhe_versionable::{Unversionize, UnversionizeError, Versionize};
use ::tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
use aligned_vec::ABox;
pub use bivariate_pbs::{
BivariateLookupTableMutView, BivariateLookupTableOwned, BivariateLookupTableView,
};
pub use compressed::{CompressedServerKey, ShortintCompressedBootstrappingKey};
pub(crate) use scalar_mul::unchecked_scalar_mul_assign;
use serde::de::DeserializeOwned;
#[cfg(test)]
pub(crate) mod tests;
@@ -192,10 +191,10 @@ impl<'vers, C: Container<Element = concrete_fft::c64>>
}
}
impl<C: Container<Element = concrete_fft::c64>> From<&SerializableShortintBootstrappingKey<C>>
impl<C: Container<Element = concrete_fft::c64>> From<SerializableShortintBootstrappingKey<C>>
for SerializableShortintBootstrappingKeyVersionOwned
{
fn from(value: &SerializableShortintBootstrappingKey<C>) -> Self {
fn from(value: SerializableShortintBootstrappingKey<C>) -> Self {
match value {
SerializableShortintBootstrappingKey::Classic(bsk) => {
Self::Classic(bsk.versionize_owned())
@@ -205,7 +204,7 @@ impl<C: Container<Element = concrete_fft::c64>> From<&SerializableShortintBootst
deterministic_execution,
} => Self::MultiBit {
fourier_bsk: fourier_bsk.versionize_owned(),
deterministic_execution: *deterministic_execution,
deterministic_execution,
},
}
}
@@ -243,15 +242,19 @@ impl<C: Container<Element = concrete_fft::c64>> Versionize
fn versionize(&self) -> Self::Versioned<'_> {
self.into()
}
}
impl<C: Container<Element = concrete_fft::c64>> VersionizeOwned
for SerializableShortintBootstrappingKey<C>
{
type VersionedOwned = SerializableShortintBootstrappingKeyVersionedOwned;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
self.into()
}
}
impl<C: IntoContainerOwned<Element = concrete_fft::c64> + Serialize + DeserializeOwned> Unversionize
impl<C: IntoContainerOwned<Element = concrete_fft::c64>> Unversionize
for SerializableShortintBootstrappingKey<C>
{
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
@@ -320,12 +323,13 @@ impl Versionize for ShortintBootstrappingKey {
fn versionize(&self) -> Self::Versioned<'_> {
SerializableShortintBootstrappingKey::from(self).versionize_owned()
}
}
type VersionedOwned = <SerializableShortintBootstrappingKey<ABox<[concrete_fft::c64]>> as Versionize>::VersionedOwned;
impl VersionizeOwned for ShortintBootstrappingKey {
type VersionedOwned = <SerializableShortintBootstrappingKey<ABox<[concrete_fft::c64]>> as VersionizeOwned>::VersionedOwned;
fn versionize_owned(&self) -> Self::VersionedOwned {
todo!()
//SerializableShortintBootstrappingKey::from(self).versionize_owned()
fn versionize_owned(self) -> Self::VersionedOwned {
SerializableShortintBootstrappingKey::from(self).versionize_owned()
}
}

View File

@@ -145,7 +145,7 @@ impl AssociatedType for DispatchType {
// Wraps the highest version into the dispatch enum
let src_type = self.latest_version_type()?;
let src = parse_quote! { &#src_type };
let src = parse_quote! { #src_type };
let dest_ident = self.ident();
let dest = parse_quote! { #dest_ident #ty_generics };
let constructor = self.generate_conversion_constructor_ref("value")?;

View File

@@ -34,8 +34,10 @@ pub(crate) const LIFETIME_NAME: &str = "'vers";
pub(crate) const VERSION_TRAIT_NAME: &str = crate_full_path!("Version");
pub(crate) const DISPATCH_TRAIT_NAME: &str = crate_full_path!("VersionsDispatch");
pub(crate) const VERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Versionize");
pub(crate) const UNVERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Unversionize");
pub(crate) const VERSIONIZE_OWNED_TRAIT_NAME: &str = crate_full_path!("VersionizeOwned");
pub(crate) const VERSIONIZE_SLICE_TRAIT_NAME: &str = crate_full_path!("VersionizeSlice");
pub(crate) const VERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("VersionizeVec");
pub(crate) const UNVERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Unversionize");
pub(crate) const UNVERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("UnversionizeVec");
pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeError");
@@ -103,6 +105,7 @@ pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream {
&[
VERSIONIZE_TRAIT_NAME,
VERSIONIZE_VEC_TRAIT_NAME,
VERSIONIZE_SLICE_TRAIT_NAME,
UNVERSIONIZE_TRAIT_NAME,
UNVERSIONIZE_VEC_TRAIT_NAME,
SERIALIZE_TRAIT_NAME,
@@ -187,8 +190,10 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
));
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
let versionize_vec_trait: Path = parse_const_str(VERSIONIZE_VEC_TRAIT_NAME);
let versionize_slice_trait: Path = parse_const_str(VERSIONIZE_SLICE_TRAIT_NAME);
let unversionize_vec_trait: Path = parse_const_str(UNVERSIONIZE_VEC_TRAIT_NAME);
let mut versionize_generics = trait_generics.clone();
@@ -203,10 +208,16 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
}
// Add Generics for the `VersionizeVec` and `UnversionizeVec` traits
let mut versionize_slice_generics = versionize_generics.clone();
syn_unwrap!(add_trait_bound(
&mut versionize_slice_generics,
VERSIONIZE_TRAIT_NAME
));
let mut versionize_vec_generics = versionize_generics.clone();
syn_unwrap!(add_trait_bound(
&mut versionize_vec_generics,
VERSIONIZE_TRAIT_NAME
VERSIONIZE_OWNED_TRAIT_NAME
));
let mut unversionize_vec_generics = unversionize_generics.clone();
syn_unwrap!(add_trait_bound(
@@ -221,6 +232,8 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
let (unversionize_impl_generics, _, unversionize_where_clause) =
unversionize_generics.split_for_impl();
let (versionize_slice_impl_generics, _, versionize_slice_where_clause) =
versionize_slice_generics.split_for_impl();
let (versionize_vec_impl_generics, _, versionize_vec_where_clause) =
versionize_vec_generics.split_for_impl();
let (unversionize_vec_impl_generics, _, unversionize_vec_where_clause) =
@@ -253,15 +266,19 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
fn versionize(&self) -> Self::Versioned<'_> {
#versionize_body
}
}
fn versionize_owned(&self) -> Self::VersionedOwned {
#versionize_body
}
#[automatically_derived]
impl #versionize_impl_generics #versionize_owned_trait for #input_ident #ty_generics
#versionize_where_clause
{
type VersionedOwned =
<#dispatch_enum_path #dispatch_generics as
#dispatch_trait<#dispatch_target>>::Owned #owned_where_clause;
fn versionize_owned(self) -> Self::VersionedOwned {
#versionize_body
}
}
#[automatically_derived]
@@ -274,19 +291,24 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
}
#[automatically_derived]
impl #versionize_vec_impl_generics #versionize_vec_trait for #input_ident #ty_generics
#versionize_vec_where_clause
impl #versionize_slice_impl_generics #versionize_slice_trait for #input_ident #ty_generics
#versionize_slice_where_clause
{
type VersionedSlice<#lifetime> = Vec<<Self as #versionize_trait>::Versioned<#lifetime>> #ref_where_clause;
fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
slice.iter().map(|val| val.versionize()).collect()
slice.iter().map(|val| #versionize_trait::versionize(val)).collect()
}
}
type VersionedVec = Vec<<Self as #versionize_trait>::VersionedOwned> #owned_where_clause;
impl #versionize_vec_impl_generics #versionize_vec_trait for #input_ident #ty_generics
#versionize_vec_where_clause
{
fn versionize_vec(slice: &[Self]) -> Self::VersionedVec {
slice.iter().map(|val| val.versionize_owned()).collect()
type VersionedVec = Vec<<Self as #versionize_owned_trait>::VersionedOwned> #owned_where_clause;
fn versionize_vec(vec: Vec<Self>) -> Self::VersionedVec {
vec.into_iter().map(|val| #versionize_owned_trait::versionize_owned(val)).collect()
}
}
@@ -321,6 +343,7 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
let input_ident = &input.ident;
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME);
let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site());
@@ -329,14 +352,18 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
#[automatically_derived]
impl #impl_generics #versionize_trait for #input_ident #ty_generics #where_clause {
type Versioned<#lifetime> = &#lifetime Self;
type VersionedOwned = Self;
fn versionize(&self) -> Self::Versioned<'_> {
self
}
}
fn versionize_owned(&self) -> Self::VersionedOwned {
self.clone()
#[automatically_derived]
impl #impl_generics #versionize_owned_trait for #input_ident #ty_generics #where_clause {
type VersionedOwned = Self;
fn versionize_owned(self) -> Self::VersionedOwned {
self
}
}

View File

@@ -6,9 +6,9 @@ use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{
parse_quote, Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields,
FieldsNamed, FieldsUnnamed, Generics, Ident, Item, ItemEnum, ItemImpl, ItemStruct, ItemUnion,
Lifetime, Path, Type, Variant,
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, FieldsNamed,
FieldsUnnamed, Generics, Ident, Item, ItemEnum, ItemImpl, ItemStruct, ItemUnion, Lifetime,
Path, Type, Variant,
};
use crate::associated::{
@@ -18,7 +18,7 @@ use crate::associated::{
use crate::{
add_lifetime_bound, add_trait_where_clause, parse_const_str, parse_trait_bound,
punctuated_from_iter_result, LIFETIME_NAME, UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME,
VERSIONIZE_TRAIT_NAME,
VERSIONIZE_OWNED_TRAIT_NAME, VERSIONIZE_TRAIT_NAME,
};
/// The types generated for a specific version of a given exposed type. These types are identical to
@@ -132,7 +132,7 @@ impl AssociatedType for VersionType {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let src_ident = self.orig_type.ident.clone();
let src = parse_quote! { &#src_ident #orig_generics };
let src = parse_quote! { #src_ident #orig_generics };
let dest_ident = self.ident();
let dest = parse_quote! { #dest_ident #ty_generics };
let constructor = self.generate_conversion_constructor(
@@ -193,12 +193,19 @@ impl VersionType {
fn type_generics(&self) -> syn::Result<Generics> {
let mut generics = self.orig_type.generics.clone();
if let AssociatedTypeKind::Ref(Some(lifetime)) = &self.kind {
add_lifetime_bound(&mut generics, lifetime);
if let AssociatedTypeKind::Ref(opt_lifetime) = &self.kind {
if let Some(lifetime) = opt_lifetime {
add_lifetime_bound(&mut generics, lifetime);
}
add_trait_where_clause(&mut generics, self.inner_types()?, &[VERSIONIZE_TRAIT_NAME])?;
} else {
add_trait_where_clause(
&mut generics,
self.inner_types()?,
&[VERSIONIZE_OWNED_TRAIT_NAME],
)?;
}
add_trait_where_clause(&mut generics, self.inner_types()?, &[VERSIONIZE_TRAIT_NAME])?;
Ok(generics)
}
@@ -327,13 +334,14 @@ impl VersionType {
let unver_ty = field.ty.clone();
let versionize_trait = parse_trait_bound(VERSIONIZE_TRAIT_NAME)?;
let versionize_owned_trait = parse_trait_bound(VERSIONIZE_OWNED_TRAIT_NAME)?;
let ty: Type = match &kind {
AssociatedTypeKind::Ref(lifetime) => parse_quote! {
<#unver_ty as #versionize_trait>::Versioned<#lifetime>
},
AssociatedTypeKind::Owned => parse_quote! {
<#unver_ty as #versionize_trait>::VersionedOwned
<#unver_ty as #versionize_owned_trait>::VersionedOwned
},
};
@@ -632,22 +640,24 @@ impl VersionType {
direction: ConversionDirection,
) -> syn::Result<TokenStream> {
let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME);
let field_constructor = match direction {
ConversionDirection::OrigToAssociated => {
let param = if is_ref {
field_param
} else {
quote! {&#field_param}
};
match self.kind {
AssociatedTypeKind::Ref(_) => quote! {
#versionize_trait::versionize(#param)
},
AssociatedTypeKind::Ref(_) => {
let param = if is_ref {
field_param
} else {
quote! {&#field_param}
};
quote! {
#versionize_trait::versionize(#param)
}},
AssociatedTypeKind::Owned => quote! {
#versionize_trait::versionize_owned(#param)
#versionize_owned_trait::versionize_owned(#field_param)
},
}
}

View File

@@ -3,10 +3,11 @@ use quote::{quote, ToTokens};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Path, Token, Type, TypeParam,
Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Path, Token, TraitBound, Type,
TypeParam,
};
use crate::{parse_const_str, UNVERSIONIZE_ERROR_NAME};
use crate::{parse_const_str, UNVERSIONIZE_ERROR_NAME, VERSIONIZE_OWNED_TRAIT_NAME};
/// Name of the attribute used to give arguments to our macros
const VERSIONIZE_ATTR_NAME: &str = "versionize";
@@ -184,11 +185,12 @@ impl VersionizeAttribute {
}
pub(crate) fn versionize_method_body(&self) -> proc_macro2::TokenStream {
let versionize_owned_trait: TraitBound = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME);
self.into
.as_ref()
.map(|target| {
quote! {
#target::from(self.to_owned()).versionize_owned()
#versionize_owned_trait::versionize_owned(#target::from(self.to_owned()))
}
})
.unwrap_or_else(|| {

View File

@@ -3,23 +3,27 @@
use serde::de::DeserializeOwned;
use serde::Serialize;
use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionsDispatch};
use tfhe_versionable::{
Unversionize, UnversionizeError, Versionize, VersionizeOwned, VersionsDispatch,
};
// Example of a simple struct with a manual Versionize impl that requires a specific bound
struct MyStruct<T> {
val: T,
}
impl<T: Serialize + DeserializeOwned + ToOwned<Owned = T>> Versionize for MyStruct<T> {
impl<T: Serialize + DeserializeOwned> Versionize for MyStruct<T> {
type Versioned<'vers> = &'vers T where T: 'vers;
fn versionize(&self) -> Self::Versioned<'_> {
&self.val
}
}
impl<T: Serialize + DeserializeOwned + ToOwned<Owned = T>> VersionizeOwned for MyStruct<T> {
type VersionedOwned = T;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
self.val.to_owned()
}
}

View File

@@ -2,7 +2,7 @@
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tfhe_versionable::{Unversionize, UnversionizeError, Upgrade, Versionize};
use tfhe_versionable::{Unversionize, UnversionizeError, Upgrade, Versionize, VersionizeOwned};
struct MyStruct<T: Default> {
attr: T,
@@ -30,7 +30,7 @@ struct MyStructVersion<'vers, T: 'vers + Default + Versionize> {
}
#[derive(Serialize, Deserialize)]
struct MyStructVersionOwned<T: Default + Versionize> {
struct MyStructVersionOwned<T: Default + VersionizeOwned> {
attr: T::VersionedOwned,
builtin: u32,
}
@@ -47,10 +47,12 @@ impl<T: Default + Versionize + Serialize + DeserializeOwned> Versionize for MySt
};
MyStructVersionsDispatch::V1(ver)
}
}
impl<T: Default + VersionizeOwned + Serialize + DeserializeOwned> VersionizeOwned for MyStruct<T> {
type VersionedOwned = MyStructVersionsDispatchOwned<T>;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
let ver = MyStructVersionOwned {
attr: self.attr.versionize_owned(),
builtin: self.builtin,
@@ -59,7 +61,7 @@ impl<T: Default + Versionize + Serialize + DeserializeOwned> Versionize for MySt
}
}
impl<T: Default + Versionize + Unversionize + Serialize + DeserializeOwned> Unversionize
impl<T: Default + VersionizeOwned + Unversionize + Serialize + DeserializeOwned> Unversionize
for MyStruct<T>
{
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
@@ -83,7 +85,7 @@ enum MyStructVersionsDispatch<'vers, T: 'vers + Default + Versionize> {
}
#[derive(Serialize, Deserialize)]
enum MyStructVersionsDispatchOwned<T: Default + Versionize> {
enum MyStructVersionsDispatchOwned<T: Default + VersionizeOwned> {
V0(MyStructV0),
V1(MyStructVersionOwned<T>),
}

View File

@@ -11,10 +11,7 @@ pub trait Version: Sized {
type Ref<'vers>: From<&'vers Self> + Serialize
where
Self: 'vers;
type Owned: for<'vers> From<&'vers Self>
+ TryInto<Self, Error = UnversionizeError>
+ DeserializeOwned
+ Serialize;
type Owned: From<Self> + TryInto<Self, Error = UnversionizeError> + DeserializeOwned + Serialize;
}
/// This trait is implemented on the dispatch enum for a given type. The dispatch enum
@@ -24,7 +21,7 @@ pub trait VersionsDispatch<Unversioned>: Sized {
type Ref<'vers>: From<&'vers Unversioned> + Serialize
where
Unversioned: 'vers;
type Owned: for<'vers> From<&'vers Unversioned>
type Owned: From<Unversioned>
+ TryInto<Unversioned, Error = UnversionizeError>
+ DeserializeOwned
+ Serialize;

View File

@@ -34,28 +34,32 @@ pub trait Versionize {
/// Wraps the object into a versioned enum with a variant for each version. This will
/// use references on the underlying types if possible.
fn versionize(&self) -> Self::Versioned<'_>;
}
pub trait VersionizeOwned {
type VersionedOwned: Serialize + DeserializeOwned;
/// Wraps the object into a versioned enum with a variant for each version. This will
/// clone the underlying types.
fn versionize_owned(&self) -> Self::VersionedOwned;
fn versionize_owned(self) -> Self::VersionedOwned;
}
/// This trait is used as a proxy to be more felxible when deriving Versionize for Vec<T>.
/// This way, we can chose to skip versioning Vec<T> if T is a native types but still versionize in
/// a loop if T is a custom type.
/// This is used as a workaround for feature(specialization) and to bypass the orphan rule.
pub trait VersionizeVec: Sized {
pub trait VersionizeSlice: Sized {
type VersionedSlice<'vers>: Serialize
where
Self: 'vers;
fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_>;
}
pub trait VersionizeVec: Sized {
type VersionedVec: Serialize + DeserializeOwned;
fn versionize_vec(slice: &[Self]) -> Self::VersionedVec;
fn versionize_vec(vec: Vec<Self>) -> Self::VersionedVec;
}
#[derive(Debug)]
@@ -117,7 +121,7 @@ impl From<Infallible> for UnversionizeError {
/// This trait means that we can convert from a versioned enum into the target type. This trait
/// can only be implemented on Owned/static types, whereas `Versionize` can also be implemented
/// on reference types.
pub trait Unversionize: Versionize + Sized {
pub trait Unversionize: VersionizeOwned + Sized {
/// Creates an object from a versioned enum, and eventually upgrades from previous
/// variants.
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError>;
@@ -131,17 +135,19 @@ pub trait UnversionizeVec: VersionizeVec {
/// Self or &Self.
pub trait NotVersioned: Versionize {}
impl<T: NotVersioned + Serialize + DeserializeOwned + Clone> VersionizeVec for T {
impl<T: NotVersioned + Serialize + DeserializeOwned + Clone> VersionizeSlice for T {
type VersionedSlice<'vers> = &'vers [T] where T: 'vers;
fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
slice
}
}
impl<T: NotVersioned + Serialize + DeserializeOwned + Clone> VersionizeVec for T {
type VersionedVec = Vec<T>;
fn versionize_vec(slice: &[Self]) -> Self::VersionedVec {
slice.to_vec()
fn versionize_vec(vec: Vec<Self>) -> Self::VersionedVec {
vec
}
}
@@ -159,14 +165,15 @@ macro_rules! impl_scalar_versionize {
impl Versionize for $t {
type Versioned<'vers> = $t;
type VersionedOwned = $t;
fn versionize(&self) -> Self::Versioned<'_> {
*self
}
}
fn versionize_owned(&self) -> Self::VersionedOwned {
*self
impl VersionizeOwned for $t {
type VersionedOwned = $t;
fn versionize_owned(self) -> Self::VersionedOwned {
self
}
}
@@ -208,11 +215,13 @@ impl<T: Versionize> Versionize for Box<T> {
fn versionize(&self) -> Self::Versioned<'_> {
self.as_ref().versionize()
}
}
impl<T: VersionizeOwned> VersionizeOwned for Box<T> {
type VersionedOwned = Box<T::VersionedOwned>;
fn versionize_owned(&self) -> Self::VersionedOwned {
Box::new(T::versionize_owned(self))
fn versionize_owned(self) -> Self::VersionedOwned {
Box::new(T::versionize_owned(*self))
}
}
@@ -222,31 +231,35 @@ impl<T: Unversionize> Unversionize for Box<T> {
}
}
impl<T: VersionizeVec> Versionize for Vec<T> {
impl<T: VersionizeSlice> Versionize for Vec<T> {
type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers;
fn versionize(&self) -> Self::Versioned<'_> {
T::versionize_slice(self)
}
}
impl<T: VersionizeVec> VersionizeOwned for Vec<T> {
type VersionedOwned = T::VersionedVec;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
T::versionize_vec(self)
}
}
impl<T: VersionizeVec + Clone> Versionize for [T] {
impl<T: VersionizeSlice + Clone> Versionize for [T] {
type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers;
fn versionize(&self) -> Self::Versioned<'_> {
T::versionize_slice(self)
}
}
impl<T: VersionizeVec + Clone> VersionizeOwned for &[T] {
type VersionedOwned = T::VersionedVec;
fn versionize_owned(&self) -> Self::VersionedOwned {
T::versionize_vec(self)
fn versionize_owned(self) -> Self::VersionedOwned {
T::versionize_vec(self.to_vec())
}
}
@@ -262,11 +275,13 @@ impl Versionize for String {
fn versionize(&self) -> Self::Versioned<'_> {
self.as_ref()
}
}
impl VersionizeOwned for String {
type VersionedOwned = Self;
fn versionize_owned(&self) -> Self::VersionedOwned {
self.clone()
fn versionize_owned(self) -> Self::VersionedOwned {
self
}
}
@@ -284,10 +299,12 @@ impl Versionize for str {
fn versionize(&self) -> Self::Versioned<'_> {
self
}
}
impl VersionizeOwned for &str {
type VersionedOwned = String;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
self.to_string()
}
}
@@ -300,11 +317,13 @@ impl<T: Versionize> Versionize for Option<T> {
fn versionize(&self) -> Self::Versioned<'_> {
self.as_ref().map(|val| val.versionize())
}
}
impl<T: VersionizeOwned> VersionizeOwned for Option<T> {
type VersionedOwned = Option<T::VersionedOwned>;
fn versionize_owned(&self) -> Self::VersionedOwned {
self.as_ref().map(|val| val.versionize_owned())
fn versionize_owned(self) -> Self::VersionedOwned {
self.map(|val| val.versionize_owned())
}
}
@@ -324,11 +343,13 @@ impl<T> Versionize for PhantomData<T> {
fn versionize(&self) -> Self::Versioned<'_> {
*self
}
}
impl<T> VersionizeOwned for PhantomData<T> {
type VersionedOwned = Self;
fn versionize_owned(&self) -> Self::VersionedOwned {
*self
fn versionize_owned(self) -> Self::VersionedOwned {
self
}
}
@@ -349,10 +370,12 @@ impl<T: Versionize> Versionize for Complex<T> {
im: self.im.versionize(),
}
}
}
impl<T: VersionizeOwned> VersionizeOwned for Complex<T> {
type VersionedOwned = Complex<T::VersionedOwned>;
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
Complex {
re: self.re.versionize_owned(),
im: self.im.versionize_owned(),
@@ -377,16 +400,18 @@ impl<T: Versionize> Versionize for ABox<T> {
fn versionize(&self) -> Self::Versioned<'_> {
self.as_ref().versionize()
}
}
impl<T: VersionizeOwned + Copy> VersionizeOwned for ABox<T> {
// Alignment doesn't matter for versioned types
type VersionedOwned = Box<T::VersionedOwned>;
fn versionize_owned(&self) -> Self::VersionedOwned {
Box::new(T::versionize_owned(self))
fn versionize_owned(self) -> Self::VersionedOwned {
Box::new(T::versionize_owned(*self))
}
}
impl<T: Unversionize> Unversionize for ABox<T>
impl<T: Unversionize + Copy> Unversionize for ABox<T>
where
T::VersionedOwned: Clone,
{
@@ -395,22 +420,24 @@ where
}
}
impl<T: VersionizeVec> Versionize for AVec<T> {
impl<T: VersionizeSlice> Versionize for AVec<T> {
type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers;
fn versionize(&self) -> Self::Versioned<'_> {
T::versionize_slice(self)
}
}
// Alignment doesn't matter for versioned types
// Alignment doesn't matter for versioned types
impl<T: VersionizeVec + Clone> VersionizeOwned for AVec<T> {
type VersionedOwned = T::VersionedVec;
fn versionize_owned(&self) -> Self::VersionedOwned {
T::versionize_vec(self)
fn versionize_owned(self) -> Self::VersionedOwned {
T::versionize_vec(self.to_vec())
}
}
impl<T: UnversionizeVec> Unversionize for AVec<T> {
impl<T: UnversionizeVec + Clone> Unversionize for AVec<T> {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
T::unversionize_vec(versioned).map(|unver| AVec::from_iter(0, unver))
}
@@ -424,10 +451,12 @@ impl<T: Versionize, U: Versionize> Versionize for (T, U) {
fn versionize(&self) -> Self::Versioned<'_> {
(self.0.versionize(), self.1.versionize())
}
}
impl<T: VersionizeOwned, U: VersionizeOwned> VersionizeOwned for (T, U) {
type VersionedOwned = (T::VersionedOwned, U::VersionedOwned);
fn versionize_owned(&self) -> Self::VersionedOwned {
fn versionize_owned(self) -> Self::VersionedOwned {
(self.0.versionize_owned(), self.1.versionize_owned())
}
}