mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(versionable): impl Versionize for Arc
This commit is contained in:
committed by
Nicolas Sarlin
parent
8f72677fa6
commit
8ea647dc26
@@ -17,7 +17,7 @@ exclude = [
|
|||||||
"/js_on_wasm_tests/",
|
"/js_on_wasm_tests/",
|
||||||
"/web_wasm_parallel_tests/",
|
"/web_wasm_parallel_tests/",
|
||||||
]
|
]
|
||||||
rust-version = "1.75"
|
rust-version = "1.76"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use std::convert::Infallible;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
||||||
|
|
||||||
use crate::high_level_api::keys::*;
|
use crate::high_level_api::keys::*;
|
||||||
@@ -8,16 +10,27 @@ pub enum ClientKeyVersions {
|
|||||||
V0(ClientKey),
|
V0(ClientKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
// This type was previously versioned using a manual implementation with a conversion
|
||||||
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
|
// to a type where the inner key was name `integer_key`
|
||||||
pub enum ServerKeyVersioned<'vers> {
|
#[derive(Version)]
|
||||||
V0(ServerKeyVersion<'vers>),
|
pub struct ServerKeyV0 {
|
||||||
|
pub(crate) integer_key: Arc<IntegerServerKey>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
impl Upgrade<ServerKey> for ServerKeyV0 {
|
||||||
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
|
type Error = Infallible;
|
||||||
pub enum ServerKeyVersionedOwned {
|
|
||||||
V0(ServerKeyVersionOwned),
|
fn upgrade(self) -> Result<ServerKey, Self::Error> {
|
||||||
|
Ok(ServerKey {
|
||||||
|
key: self.integer_key,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(VersionsDispatch)]
|
||||||
|
pub enum ServerKeyVersions {
|
||||||
|
V0(ServerKeyV0),
|
||||||
|
V1(ServerKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(VersionsDispatch)]
|
#[derive(VersionsDispatch)]
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ pub use public::{CompactPublicKey, CompressedCompactPublicKey, CompressedPublicK
|
|||||||
#[cfg(feature = "gpu")]
|
#[cfg(feature = "gpu")]
|
||||||
pub use server::CudaServerKey;
|
pub use server::CudaServerKey;
|
||||||
pub use server::{CompressedServerKey, ServerKey};
|
pub use server::{CompressedServerKey, ServerKey};
|
||||||
pub(crate) use server::{InternalServerKey, ServerKeyVersion, ServerKeyVersionOwned};
|
pub(crate) use server::InternalServerKey;
|
||||||
|
|
||||||
pub(in crate::high_level_api) use inner::{
|
pub(in crate::high_level_api) use inner::{
|
||||||
IntegerClientKey, IntegerCompactPublicKey, IntegerCompressedCompactPublicKey,
|
IntegerClientKey, IntegerCompactPublicKey, IntegerCompressedCompactPublicKey,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use tfhe_versionable::{Unversionize, Versionize, VersionizeOwned};
|
use tfhe_versionable::Versionize;
|
||||||
|
|
||||||
use crate::backward_compatibility::keys::{
|
use crate::backward_compatibility::keys::{
|
||||||
CompressedServerKeyVersions, ServerKeyVersioned, ServerKeyVersionedOwned,
|
CompressedServerKeyVersions, ServerKeyVersions,
|
||||||
};
|
};
|
||||||
#[cfg(feature = "gpu")]
|
#[cfg(feature = "gpu")]
|
||||||
use crate::core_crypto::gpu::{synchronize_devices, CudaStreams};
|
use crate::core_crypto::gpu::{synchronize_devices, CudaStreams};
|
||||||
@@ -23,7 +23,8 @@ use super::ClientKey;
|
|||||||
// Keys are stored in an Arc, so that cloning them is cheap
|
// Keys are stored in an Arc, so that cloning them is cheap
|
||||||
// (compared to an actual clone hundreds of MB / GB), and cheap cloning is needed for
|
// (compared to an actual clone hundreds of MB / GB), and cheap cloning is needed for
|
||||||
// multithreading with less overhead)
|
// multithreading with less overhead)
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Versionize)]
|
||||||
|
#[versionize(ServerKeyVersions)]
|
||||||
pub struct ServerKey {
|
pub struct ServerKey {
|
||||||
pub(crate) key: Arc<IntegerServerKey>,
|
pub(crate) key: Arc<IntegerServerKey>,
|
||||||
}
|
}
|
||||||
@@ -136,52 +137,6 @@ impl<'de> serde::Deserialize<'de> for ServerKey {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Serialize)]
|
|
||||||
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
|
|
||||||
pub struct ServerKeyVersion<'vers> {
|
|
||||||
pub(crate) integer_key: <IntegerServerKey as Versionize>::Versioned<'vers>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Serialize, serde::Deserialize)]
|
|
||||||
#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))]
|
|
||||||
pub struct ServerKeyVersionOwned {
|
|
||||||
pub(crate) integer_key: <IntegerServerKey as VersionizeOwned>::VersionedOwned,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Versionize for ServerKey {
|
|
||||||
type Versioned<'vers> = ServerKeyVersioned<'vers>;
|
|
||||||
|
|
||||||
fn versionize(&self) -> Self::Versioned<'_> {
|
|
||||||
ServerKeyVersioned::V0(ServerKeyVersion {
|
|
||||||
integer_key: self.key.versionize(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VersionizeOwned for ServerKey {
|
|
||||||
type VersionedOwned = ServerKeyVersionedOwned;
|
|
||||||
|
|
||||||
fn versionize_owned(self) -> Self::VersionedOwned {
|
|
||||||
ServerKeyVersionedOwned::V0(ServerKeyVersionOwned {
|
|
||||||
integer_key: (*self.key).clone().versionize_owned(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Unversionize for ServerKey {
|
|
||||||
fn unversionize(
|
|
||||||
versioned: Self::VersionedOwned,
|
|
||||||
) -> Result<Self, tfhe_versionable::UnversionizeError> {
|
|
||||||
match versioned {
|
|
||||||
ServerKeyVersionedOwned::V0(v0) => {
|
|
||||||
IntegerServerKey::unversionize(v0.integer_key).map(|unversioned| Self {
|
|
||||||
key: Arc::new(unversioned),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Compressed ServerKey
|
/// Compressed ServerKey
|
||||||
///
|
///
|
||||||
/// A CompressedServerKey takes much less disk space / memory space than a
|
/// A CompressedServerKey takes much less disk space / memory space than a
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ documentation = "https://docs.rs/tfhe_versionable"
|
|||||||
repository = "https://github.com/zama-ai/tfhe-rs"
|
repository = "https://github.com/zama-ai/tfhe-rs"
|
||||||
license = "BSD-3-Clause-Clear"
|
license = "BSD-3-Clause-Clear"
|
||||||
description = "tfhe-versionable: Add versioning informations/backward compatibility on rust types used for serialization"
|
description = "tfhe-versionable: Add versioning informations/backward compatibility on rust types used for serialization"
|
||||||
|
rust-version = "1.76"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ use num_complex::Complex;
|
|||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
pub use derived_traits::{Version, VersionsDispatch};
|
pub use derived_traits::{Version, VersionsDispatch};
|
||||||
pub use upgrade::Upgrade;
|
pub use upgrade::Upgrade;
|
||||||
@@ -361,6 +362,32 @@ impl<T> Unversionize for PhantomData<T> {
|
|||||||
|
|
||||||
impl<T> NotVersioned for PhantomData<T> {}
|
impl<T> NotVersioned for PhantomData<T> {}
|
||||||
|
|
||||||
|
impl<T: Versionize> Versionize for Arc<T> {
|
||||||
|
type Versioned<'vers> = T::Versioned<'vers>
|
||||||
|
where
|
||||||
|
T: 'vers;
|
||||||
|
|
||||||
|
fn versionize(&self) -> Self::Versioned<'_> {
|
||||||
|
self.as_ref().versionize()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: VersionizeOwned + Clone> VersionizeOwned for Arc<T> {
|
||||||
|
type VersionedOwned = T::VersionedOwned;
|
||||||
|
|
||||||
|
fn versionize_owned(self) -> Self::VersionedOwned {
|
||||||
|
Arc::unwrap_or_clone(self).versionize_owned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Unversionize + Clone> Unversionize for Arc<T> {
|
||||||
|
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
|
||||||
|
Ok(Arc::new(T::unversionize(versioned)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: NotVersioned> NotVersioned for Arc<T> {}
|
||||||
|
|
||||||
impl<T: Versionize> Versionize for Complex<T> {
|
impl<T: Versionize> Versionize for Complex<T> {
|
||||||
type Versioned<'vers> = Complex<T::Versioned<'vers>> where T: 'vers;
|
type Versioned<'vers> = Complex<T::Versioned<'vers>> where T: 'vers;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user