Files
tfhe-rs/tfhe/src/integer/gpu/server_key/mod.rs
2025-02-18 13:19:28 +01:00

283 lines
11 KiB
Rust

use crate::core_crypto::gpu::lwe_bootstrap_key::CudaLweBootstrapKey;
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
use crate::core_crypto::gpu::lwe_multi_bit_bootstrap_key::CudaLweMultiBitBootstrapKey;
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::{
allocate_and_generate_new_lwe_keyswitch_key, par_allocate_and_generate_new_lwe_bootstrap_key,
par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key, LweBootstrapKeyOwned,
LweMultiBitBootstrapKeyOwned,
};
use crate::integer::gpu::UnsignedInteger;
use crate::integer::ClientKey;
use crate::shortint::ciphertext::{MaxDegree, MaxNoiseLevel};
use crate::shortint::engine::ShortintEngine;
use crate::shortint::{CarryModulus, CiphertextModulus, MessageModulus, PBSOrder};
mod radix;
pub enum CudaBootstrappingKey {
Classic(CudaLweBootstrapKey),
MultiBit(CudaLweMultiBitBootstrapKey),
}
/// A structure containing the server public key.
///
/// The server key is generated by the client and is meant to be published: the client
/// sends it to the server so it can compute homomorphic circuits.
// #[derive(PartialEq, Serialize, Deserialize)]
pub struct CudaServerKey {
pub key_switching_key: CudaLweKeyswitchKey<u64>,
pub bootstrapping_key: CudaBootstrappingKey,
// Size of the message buffer
pub message_modulus: MessageModulus,
// Size of the carry buffer
pub carry_modulus: CarryModulus,
// Maximum number of operations that can be done before emptying the operation buffer
pub max_degree: MaxDegree,
pub max_noise_level: MaxNoiseLevel,
// Modulus use for computations on the ciphertext
pub ciphertext_modulus: CiphertextModulus,
pub pbs_order: PBSOrder,
}
impl CudaServerKey {
/// Generates a server key that stores keys in the device memory.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::gpu::CudaStreams;
/// use tfhe::core_crypto::gpu::vec::GpuIndex;
/// use tfhe::integer::gpu::CudaServerKey;
/// use tfhe::integer::ClientKey;
/// # // TODO GPU DRIFT UPDATE
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
///
/// # // TODO GPU DRIFT UPDATE
/// // Generate the client key:
/// let cks = ClientKey::new(PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
///
/// // Generate the server key:
/// let sks = CudaServerKey::new(&cks, &streams);
/// ```
pub fn new<C>(cks: C, streams: &CudaStreams) -> Self
where
C: AsRef<ClientKey>,
{
// It should remain just enough space to add a carry
let client_key = cks.as_ref();
let max_degree = MaxDegree::integer_radix_server_key(
client_key.key.parameters.message_modulus(),
client_key.key.parameters.carry_modulus(),
);
Self::new_server_key_with_max_degree(client_key, max_degree, streams)
}
pub(crate) fn new_server_key_with_max_degree(
cks: &ClientKey,
max_degree: MaxDegree,
streams: &CudaStreams,
) -> Self {
let mut engine = ShortintEngine::new();
// Generate a regular keyset and convert to the GPU
let pbs_params_base = &cks.parameters();
let d_bootstrapping_key = match pbs_params_base {
crate::shortint::PBSParameters::PBS(pbs_params) => {
let h_bootstrap_key: LweBootstrapKeyOwned<u64> =
par_allocate_and_generate_new_lwe_bootstrap_key(
&cks.key.small_lwe_secret_key(),
&cks.key.glwe_secret_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.glwe_noise_distribution,
pbs_params.ciphertext_modulus,
&mut engine.encryption_generator,
);
let d_bootstrap_key =
CudaLweBootstrapKey::from_lwe_bootstrap_key(&h_bootstrap_key, streams);
CudaBootstrappingKey::Classic(d_bootstrap_key)
}
crate::shortint::PBSParameters::MultiBitPBS(pbs_params) => {
let h_bootstrap_key: LweMultiBitBootstrapKeyOwned<u64> =
par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key(
&cks.key.small_lwe_secret_key(),
&cks.key.glwe_secret_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.grouping_factor,
pbs_params.glwe_noise_distribution,
pbs_params.ciphertext_modulus,
&mut engine.encryption_generator,
);
let d_bootstrap_key = CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(
&h_bootstrap_key,
streams,
);
CudaBootstrappingKey::MultiBit(d_bootstrap_key)
}
};
// Creation of the key switching key
let h_key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
&cks.key.large_lwe_secret_key(),
&cks.key.small_lwe_secret_key(),
cks.parameters().ks_base_log(),
cks.parameters().ks_level(),
cks.parameters().lwe_noise_distribution(),
cks.parameters().ciphertext_modulus(),
&mut engine.encryption_generator,
);
let d_key_switching_key =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
assert!(matches!(
cks.parameters().encryption_key_choice().into(),
PBSOrder::KeyswitchBootstrap
));
// Pack the keys in the server key set:
Self {
key_switching_key: d_key_switching_key,
bootstrapping_key: d_bootstrapping_key,
message_modulus: cks.parameters().message_modulus(),
carry_modulus: cks.parameters().carry_modulus(),
max_degree,
max_noise_level: cks.parameters().max_noise_level(),
ciphertext_modulus: cks.parameters().ciphertext_modulus(),
pbs_order: cks.parameters().encryption_key_choice().into(),
}
}
/// Decompress a CompressedServerKey to a CudaServerKey
///
/// This is useful in particular for debugging purposes, as it allows to compare the result of
/// CPU & GPU computations. When using trivial encryption it is then possible to track
/// intermediate and final result values easily between CPU and GPU.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::gpu::CudaStreams;
/// use tfhe::core_crypto::gpu::vec::GpuIndex;
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::CudaServerKey;
/// use tfhe::integer::{ClientKey, CompressedServerKey, ServerKey};
/// # // TODO GPU DRIFT UPDATE
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
/// let size = 4;
/// # // TODO GPU DRIFT UPDATE
/// let cks = ClientKey::new(PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let compressed_sks = CompressedServerKey::new_radix_compressed_server_key(&cks);
/// let cuda_sks = CudaServerKey::decompress_from_cpu(&compressed_sks, &streams);
/// let cpu_sks = compressed_sks.decompress();
/// let msg = 1;
/// let scalar = 3;
/// let ct = cpu_sks.create_trivial_radix(msg, size);
/// let d_ct = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams);
/// // Compute homomorphically a scalar multiplication:
/// let d_ct_res = cuda_sks.unchecked_scalar_add(&d_ct, scalar, &streams);
/// let ct_res = d_ct_res.to_radix_ciphertext(&streams);
/// let ct_res_cpu = cpu_sks.unchecked_scalar_add(&ct, scalar);
/// let clear: u64 = cks.decrypt_radix(&ct_res);
/// let clear_cpu: u64 = cks.decrypt_radix(&ct_res_cpu);
/// assert_eq!((scalar + msg) % (4_u64.pow(size as u32)), clear_cpu);
/// assert_eq!((scalar + msg) % (4_u64.pow(size as u32)), clear);
/// ```
pub fn decompress_from_cpu(
cpu_key: &crate::integer::CompressedServerKey,
streams: &CudaStreams,
) -> Self {
let crate::shortint::CompressedServerKey {
key_switching_key,
bootstrapping_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order,
} = cpu_key.key.clone();
let h_key_switching_key = key_switching_key.par_decompress_into_lwe_keyswitch_key();
let key_switching_key =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);
let bootstrapping_key = match bootstrapping_key {
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::Classic{ bsk: h_bootstrap_key, modulus_switch_noise_reduction_key } => {
assert!(modulus_switch_noise_reduction_key.is_none(), "Modulus Switch Noise Reduction is not yet support on GPU");
let standard_bootstrapping_key =
h_bootstrap_key.par_decompress_into_lwe_bootstrap_key();
let d_bootstrap_key =
CudaLweBootstrapKey::from_lwe_bootstrap_key(&standard_bootstrapping_key, streams);
CudaBootstrappingKey::Classic(d_bootstrap_key)
}
crate::shortint::server_key::compressed::ShortintCompressedBootstrappingKey::MultiBit {
seeded_bsk: bootstrapping_key,
deterministic_execution: _,
} => {
let standard_bootstrapping_key =
bootstrapping_key.par_decompress_into_lwe_multi_bit_bootstrap_key();
let d_bootstrap_key =
CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(
&standard_bootstrapping_key, streams);
CudaBootstrappingKey::MultiBit(d_bootstrap_key)
}
};
Self {
key_switching_key,
bootstrapping_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order,
}
}
#[allow(clippy::unused_self)]
pub(crate) fn num_bits_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
if clear == Clear::MAX {
Clear::BITS
} else {
let bits = (clear + Clear::ONE).ceil_ilog2() as usize;
if bits == 0 {
1
} else {
bits
}
}
}
/// Returns how many blocks a radix ciphertext should have to
/// be able to represent the given unsigned integer
pub(crate) fn num_blocks_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
let num_bits_to_represent_output_value = self.num_bits_to_represent_unsigned_value(clear);
let num_bits_in_message = self.message_modulus.0.ilog2();
num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize)
}
}