mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat(versionable): impl Versionize for Arc
This commit is contained in:
committed by
Nicolas Sarlin
parent
8f72677fa6
commit
8ea647dc26
@@ -17,7 +17,7 @@ exclude = [
|
||||
"/js_on_wasm_tests/",
|
||||
"/web_wasm_parallel_tests/",
|
||||
]
|
||||
rust-version = "1.75"
|
||||
rust-version = "1.76"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
||||
|
||||
use crate::high_level_api::keys::*;
|
||||
@@ -8,16 +10,27 @@ pub enum ClientKeyVersions {
|
||||
V0(ClientKey),
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
|
||||
pub enum ServerKeyVersioned<'vers> {
|
||||
V0(ServerKeyVersion<'vers>),
|
||||
// This type was previously versioned using a manual implementation with a conversion
|
||||
// to a type where the inner key was name `integer_key`
|
||||
#[derive(Version)]
|
||||
pub struct ServerKeyV0 {
|
||||
pub(crate) integer_key: Arc<IntegerServerKey>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
|
||||
pub enum ServerKeyVersionedOwned {
|
||||
V0(ServerKeyVersionOwned),
|
||||
impl Upgrade<ServerKey> for ServerKeyV0 {
|
||||
type Error = Infallible;
|
||||
|
||||
fn upgrade(self) -> Result<ServerKey, Self::Error> {
|
||||
Ok(ServerKey {
|
||||
key: self.integer_key,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
pub enum ServerKeyVersions {
|
||||
V0(ServerKeyV0),
|
||||
V1(ServerKey),
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
|
||||
@@ -13,7 +13,7 @@ pub use public::{CompactPublicKey, CompressedCompactPublicKey, CompressedPublicK
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use server::CudaServerKey;
|
||||
pub use server::{CompressedServerKey, ServerKey};
|
||||
pub(crate) use server::{InternalServerKey, ServerKeyVersion, ServerKeyVersionOwned};
|
||||
pub(crate) use server::InternalServerKey;
|
||||
|
||||
pub(in crate::high_level_api) use inner::{
|
||||
IntegerClientKey, IntegerCompactPublicKey, IntegerCompressedCompactPublicKey,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use tfhe_versionable::{Unversionize, Versionize, VersionizeOwned};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use crate::backward_compatibility::keys::{
|
||||
CompressedServerKeyVersions, ServerKeyVersioned, ServerKeyVersionedOwned,
|
||||
CompressedServerKeyVersions, ServerKeyVersions,
|
||||
};
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::core_crypto::gpu::{synchronize_devices, CudaStreams};
|
||||
@@ -23,7 +23,8 @@ use super::ClientKey;
|
||||
// Keys are stored in an Arc, so that cloning them is cheap
|
||||
// (compared to an actual clone hundreds of MB / GB), and cheap cloning is needed for
|
||||
// multithreading with less overhead)
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Versionize)]
|
||||
#[versionize(ServerKeyVersions)]
|
||||
pub struct ServerKey {
|
||||
pub(crate) key: Arc<IntegerServerKey>,
|
||||
}
|
||||
@@ -136,52 +137,6 @@ impl<'de> serde::Deserialize<'de> for ServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
|
||||
pub struct ServerKeyVersion<'vers> {
|
||||
pub(crate) integer_key: <IntegerServerKey as Versionize>::Versioned<'vers>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
|
||||
pub struct ServerKeyVersionOwned {
|
||||
pub(crate) integer_key: <IntegerServerKey as VersionizeOwned>::VersionedOwned,
|
||||
}
|
||||
|
||||
impl Versionize for ServerKey {
|
||||
type Versioned<'vers> = ServerKeyVersioned<'vers>;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
ServerKeyVersioned::V0(ServerKeyVersion {
|
||||
integer_key: self.key.versionize(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl VersionizeOwned for ServerKey {
|
||||
type VersionedOwned = ServerKeyVersionedOwned;
|
||||
|
||||
fn versionize_owned(self) -> Self::VersionedOwned {
|
||||
ServerKeyVersionedOwned::V0(ServerKeyVersionOwned {
|
||||
integer_key: (*self.key).clone().versionize_owned(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Unversionize for ServerKey {
|
||||
fn unversionize(
|
||||
versioned: Self::VersionedOwned,
|
||||
) -> Result<Self, tfhe_versionable::UnversionizeError> {
|
||||
match versioned {
|
||||
ServerKeyVersionedOwned::V0(v0) => {
|
||||
IntegerServerKey::unversionize(v0.integer_key).map(|unversioned| Self {
|
||||
key: Arc::new(unversioned),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compressed ServerKey
|
||||
///
|
||||
/// A CompressedServerKey takes much less disk space / memory space than a
|
||||
|
||||
@@ -8,6 +8,7 @@ documentation = "https://docs.rs/tfhe_versionable"
|
||||
repository = "https://github.com/zama-ai/tfhe-rs"
|
||||
license = "BSD-3-Clause-Clear"
|
||||
description = "tfhe-versionable: Add versioning informations/backward compatibility on rust types used for serialization"
|
||||
rust-version = "1.76"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ use num_complex::Complex;
|
||||
use std::convert::Infallible;
|
||||
use std::fmt::Display;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use derived_traits::{Version, VersionsDispatch};
|
||||
pub use upgrade::Upgrade;
|
||||
@@ -361,6 +362,32 @@ impl<T> Unversionize for PhantomData<T> {
|
||||
|
||||
impl<T> NotVersioned for PhantomData<T> {}
|
||||
|
||||
impl<T: Versionize> Versionize for Arc<T> {
|
||||
type Versioned<'vers> = T::Versioned<'vers>
|
||||
where
|
||||
T: 'vers;
|
||||
|
||||
fn versionize(&self) -> Self::Versioned<'_> {
|
||||
self.as_ref().versionize()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: VersionizeOwned + Clone> VersionizeOwned for Arc<T> {
|
||||
type VersionedOwned = T::VersionedOwned;
|
||||
|
||||
fn versionize_owned(self) -> Self::VersionedOwned {
|
||||
Arc::unwrap_or_clone(self).versionize_owned()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Unversionize + Clone> Unversionize for Arc<T> {
|
||||
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
|
||||
Ok(Arc::new(T::unversionize(versioned)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: NotVersioned> NotVersioned for Arc<T> {}
|
||||
|
||||
impl<T: Versionize> Versionize for Complex<T> {
|
||||
type Versioned<'vers> = Complex<T::Versioned<'vers>> where T: 'vers;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user