feat(versionable): impl Versionize for Arc

This commit is contained in:
Nicolas Sarlin
2024-07-12 16:27:06 +02:00
committed by Nicolas Sarlin
parent 8f72677fa6
commit 8ea647dc26
6 changed files with 56 additions and 60 deletions

View File

@@ -17,7 +17,7 @@ exclude = [
"/js_on_wasm_tests/",
"/web_wasm_parallel_tests/",
]
rust-version = "1.75"
rust-version = "1.76"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@@ -1,4 +1,6 @@
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::sync::Arc;
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
use crate::high_level_api::keys::*;
@@ -8,16 +10,27 @@ pub enum ClientKeyVersions {
V0(ClientKey),
}
#[derive(Serialize)]
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
pub enum ServerKeyVersioned<'vers> {
V0(ServerKeyVersion<'vers>),
// This type was previously versioned using a manual implementation with a conversion
// to a type where the inner key was name `integer_key`
#[derive(Version)]
pub struct ServerKeyV0 {
pub(crate) integer_key: Arc<IntegerServerKey>,
}
#[derive(Serialize, Deserialize)]
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
pub enum ServerKeyVersionedOwned {
V0(ServerKeyVersionOwned),
impl Upgrade<ServerKey> for ServerKeyV0 {
type Error = Infallible;
fn upgrade(self) -> Result<ServerKey, Self::Error> {
Ok(ServerKey {
key: self.integer_key,
})
}
}
#[derive(VersionsDispatch)]
pub enum ServerKeyVersions {
V0(ServerKeyV0),
V1(ServerKey),
}
#[derive(VersionsDispatch)]

View File

@@ -13,7 +13,7 @@ pub use public::{CompactPublicKey, CompressedCompactPublicKey, CompressedPublicK
#[cfg(feature = "gpu")]
pub use server::CudaServerKey;
pub use server::{CompressedServerKey, ServerKey};
pub(crate) use server::{InternalServerKey, ServerKeyVersion, ServerKeyVersionOwned};
pub(crate) use server::InternalServerKey;
pub(in crate::high_level_api) use inner::{
IntegerClientKey, IntegerCompactPublicKey, IntegerCompressedCompactPublicKey,

View File

@@ -1,7 +1,7 @@
use tfhe_versionable::{Unversionize, Versionize, VersionizeOwned};
use tfhe_versionable::Versionize;
use crate::backward_compatibility::keys::{
CompressedServerKeyVersions, ServerKeyVersioned, ServerKeyVersionedOwned,
CompressedServerKeyVersions, ServerKeyVersions,
};
#[cfg(feature = "gpu")]
use crate::core_crypto::gpu::{synchronize_devices, CudaStreams};
@@ -23,7 +23,8 @@ use super::ClientKey;
// Keys are stored in an Arc, so that cloning them is cheap
// (compared to an actual clone hundreds of MB / GB), and cheap cloning is needed for
// multithreading with less overhead)
#[derive(Clone)]
#[derive(Clone, Versionize)]
#[versionize(ServerKeyVersions)]
pub struct ServerKey {
pub(crate) key: Arc<IntegerServerKey>,
}
@@ -136,52 +137,6 @@ impl<'de> serde::Deserialize<'de> for ServerKey {
}
}
#[derive(serde::Serialize)]
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
pub struct ServerKeyVersion<'vers> {
pub(crate) integer_key: <IntegerServerKey as Versionize>::Versioned<'vers>,
}
#[derive(serde::Serialize, serde::Deserialize)]
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
pub struct ServerKeyVersionOwned {
pub(crate) integer_key: <IntegerServerKey as VersionizeOwned>::VersionedOwned,
}
impl Versionize for ServerKey {
type Versioned<'vers> = ServerKeyVersioned<'vers>;
fn versionize(&self) -> Self::Versioned<'_> {
ServerKeyVersioned::V0(ServerKeyVersion {
integer_key: self.key.versionize(),
})
}
}
impl VersionizeOwned for ServerKey {
type VersionedOwned = ServerKeyVersionedOwned;
fn versionize_owned(self) -> Self::VersionedOwned {
ServerKeyVersionedOwned::V0(ServerKeyVersionOwned {
integer_key: (*self.key).clone().versionize_owned(),
})
}
}
impl Unversionize for ServerKey {
fn unversionize(
versioned: Self::VersionedOwned,
) -> Result<Self, tfhe_versionable::UnversionizeError> {
match versioned {
ServerKeyVersionedOwned::V0(v0) => {
IntegerServerKey::unversionize(v0.integer_key).map(|unversioned| Self {
key: Arc::new(unversioned),
})
}
}
}
}
/// Compressed ServerKey
///
/// A CompressedServerKey takes much less disk space / memory space than a

View File

@@ -8,6 +8,7 @@ documentation = "https://docs.rs/tfhe_versionable"
repository = "https://github.com/zama-ai/tfhe-rs"
license = "BSD-3-Clause-Clear"
description = "tfhe-versionable: Add versioning informations/backward compatibility on rust types used for serialization"
rust-version = "1.76"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@@ -14,6 +14,7 @@ use num_complex::Complex;
use std::convert::Infallible;
use std::fmt::Display;
use std::marker::PhantomData;
use std::sync::Arc;
pub use derived_traits::{Version, VersionsDispatch};
pub use upgrade::Upgrade;
@@ -361,6 +362,32 @@ impl<T> Unversionize for PhantomData<T> {
impl<T> NotVersioned for PhantomData<T> {}
impl<T: Versionize> Versionize for Arc<T> {
type Versioned<'vers> = T::Versioned<'vers>
where
T: 'vers;
fn versionize(&self) -> Self::Versioned<'_> {
self.as_ref().versionize()
}
}
impl<T: VersionizeOwned + Clone> VersionizeOwned for Arc<T> {
type VersionedOwned = T::VersionedOwned;
fn versionize_owned(self) -> Self::VersionedOwned {
Arc::unwrap_or_clone(self).versionize_owned()
}
}
impl<T: Unversionize + Clone> Unversionize for Arc<T> {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
Ok(Arc::new(T::unversionize(versioned)?))
}
}
impl<T: NotVersioned> NotVersioned for Arc<T> {}
impl<T: Versionize> Versionize for Complex<T> {
type Versioned<'vers> = Complex<T::Versioned<'vers>> where T: 'vers;