Files
tfhe-rs/tfhe/src/integer/gpu/server_key/radix/aes256.rs
2025-12-12 22:01:49 +01:00

379 lines
15 KiB
Rust

use crate::core_crypto::gpu::{
check_valid_cuda_malloc, check_valid_cuda_malloc_assert_oom, CudaStreams,
};
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::{
CudaBootstrappingKey, CudaDynamicKeyswitchingKey, CudaServerKey,
};
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::{
cuda_backend_aes_key_expansion_256, cuda_backend_get_aes_key_expansion_256_size_on_gpu,
cuda_backend_unchecked_aes_ctr_256_encrypt, PBSType,
};
use crate::integer::{RadixCiphertext, RadixClientKey};
const NUM_BITS: usize = 128;
impl RadixClientKey {
pub fn encrypt_2u128_for_aes_ctr_256(&self, key_hi: u128, key_lo: u128) -> RadixCiphertext {
let ctxt_hi = self.encrypt_u128_for_aes_ctr(key_hi);
let ctxt_lo = self.encrypt_u128_for_aes_ctr(key_lo);
let mut combined_blocks = ctxt_hi.blocks;
combined_blocks.extend(ctxt_lo.blocks);
RadixCiphertext::from(combined_blocks)
}
}
impl CudaServerKey {
/// Computes homomorphically AES-256 encryption in CTR mode.
///
/// This function performs AES-256 encryption on an encrypted 128-bit IV
/// using an encrypted 256-bit key. It operates in Counter (CTR) mode, generating
/// `num_aes_inputs` encrypted ciphertexts starting from the `start_counter` value
/// (which is typically added to the IV).
///
/// The 256-bit key must be prepared using `encrypt_2u128_for_aes_ctr_256` and
/// the 128-bit IV using `encrypt_u128_for_aes_ctr`.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::gpu::CudaStreams;
/// use tfhe::GpuIndex;
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
///
/// // Generate the client key and the server key:
/// // AES bit-wise operations require 1-block ciphertexts (for encrypting single bits).
/// let num_blocks = 1;
/// let (cks, sks) = gen_keys_radix_gpu(
/// PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
/// num_blocks,
/// &streams,
/// );
///
/// let key_hi: u128 = 0x603deb1015ca71be2b73aef0857d7781;
/// let key_lo: u128 = 0x1f352c073b6108d72d9810a30914dff4;
/// let iv: u128 = 0xf0f1f2f3f4f5f6f7f8f9fafbfcfdfeff;
/// let num_aes_inputs = 2; // Produce 2 128-bits ciphertexts
/// let start_counter = 0u128;
///
/// // Encrypt the 256-bit key and 128-bit IV bit by bit
/// let ct_key = cks.encrypt_2u128_for_aes_ctr_256(key_hi, key_lo);
/// let ct_iv = cks.encrypt_u128_for_aes_ctr(iv);
///
/// let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
/// let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
///
/// let d_ct_res = sks.aes_ctr_256(&d_key, &d_iv, start_counter, num_aes_inputs, &streams);
///
/// let ct_res = d_ct_res.to_radix_ciphertext(&streams);
///
/// let fhe_results = cks.decrypt_u128_from_aes_ctr(&ct_res, num_aes_inputs);
///
/// // Verify:
/// let expected_results: Vec<u128> = vec![
/// 0xbdf7df1591716335e9a8b15c860c502,
/// 0x5a6e699d536119065433863c8f657b94,
/// ];
/// assert_eq!(fhe_results, expected_results);
/// ```
pub fn aes_ctr_256(
&self,
key: &CudaUnsignedRadixCiphertext,
iv: &CudaUnsignedRadixCiphertext,
start_counter: u128,
num_aes_inputs: usize,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext {
let gpu_index = streams.gpu_indexes[0];
let key_expansion_size = self.get_key_expansion_256_size_on_gpu(streams);
check_valid_cuda_malloc_assert_oom(key_expansion_size, gpu_index);
// `parallelism` refers to level of parallelization of the S-box.
// S-box should process 16 bytes of data: sequentially, or in groups of 2,
// or in groups of 4, or in groups of 8, or all 16 at the same time.
// More parallelization leads to higher memory usage. Therefore, we must find a way
// to maximize parallelization while ensuring that there is still enough memory remaining on
// the GPU.
//
let mut parallelism = 16;
while parallelism > 0 {
// `num_aes_inputs` refers to the number of 128-bit ciphertexts that AES will produce.
//
let aes_encrypt_size =
self.get_aes_encrypt_size_on_gpu(num_aes_inputs, parallelism, streams);
if check_valid_cuda_malloc(aes_encrypt_size, streams.gpu_indexes[0]) {
let round_keys = self.key_expansion_256(key, streams);
let res = self.aes_256_encrypt(
iv,
&round_keys,
start_counter,
num_aes_inputs,
parallelism,
streams,
);
return res;
}
parallelism /= 2;
}
panic!("Failed to allocate GPU memory for AES, even with the lowest parallelism setting.");
}
pub fn aes_ctr_256_with_fixed_parallelism(
&self,
key: &CudaUnsignedRadixCiphertext,
iv: &CudaUnsignedRadixCiphertext,
start_counter: u128,
num_aes_inputs: usize,
sbox_parallelism: usize,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext {
assert!(
[1, 2, 4, 8, 16].contains(&sbox_parallelism),
"Invalid S-Box parallelism: must be one of [1, 2, 4, 8, 16], got {sbox_parallelism}"
);
let gpu_index = streams.gpu_indexes[0];
let key_expansion_size = self.get_key_expansion_256_size_on_gpu(streams);
check_valid_cuda_malloc_assert_oom(key_expansion_size, gpu_index);
let aes_encrypt_size =
self.get_aes_encrypt_size_on_gpu(num_aes_inputs, sbox_parallelism, streams);
check_valid_cuda_malloc_assert_oom(aes_encrypt_size, gpu_index);
let round_keys = self.key_expansion_256(key, streams);
self.aes_256_encrypt(
iv,
&round_keys,
start_counter,
num_aes_inputs,
sbox_parallelism,
streams,
)
}
pub fn aes_256_encrypt(
&self,
iv: &CudaUnsignedRadixCiphertext,
round_keys: &CudaUnsignedRadixCiphertext,
start_counter: u128,
num_aes_inputs: usize,
sbox_parallelism: usize,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext {
let mut result: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(num_aes_inputs * 128, streams);
let num_round_key_blocks = 15 * NUM_BITS;
assert_eq!(
iv.as_ref().d_blocks.lwe_ciphertext_count().0,
NUM_BITS,
"AES IV must contain {NUM_BITS} encrypted bits, but contains {}",
iv.as_ref().d_blocks.lwe_ciphertext_count().0
);
assert_eq!(
round_keys.as_ref().d_blocks.lwe_ciphertext_count().0,
num_round_key_blocks,
"AES round_keys must contain {num_round_key_blocks} encrypted bits, but contains {}",
round_keys.as_ref().d_blocks.lwe_ciphertext_count().0
);
assert_eq!(
result.as_ref().d_blocks.lwe_ciphertext_count().0,
num_aes_inputs * 128,
"AES result must contain {} encrypted bits for {num_aes_inputs} blocks, but contains {}",
num_aes_inputs * 128,
result.as_ref().d_blocks.lwe_ciphertext_count().0
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_aes_ctr_256_encrypt(
streams,
result.as_mut(),
iv.as_ref(),
round_keys.as_ref(),
start_counter,
num_aes_inputs as u32,
sbox_parallelism as u32,
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_aes_ctr_256_encrypt(
streams,
result.as_mut(),
iv.as_ref(),
round_keys.as_ref(),
start_counter,
num_aes_inputs as u32,
sbox_parallelism as u32,
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
}
}
}
result
}
pub fn key_expansion_256(
&self,
key: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext {
let num_round_keys = 15;
let input_key_bits = 256;
let round_key_bits = 128;
let mut expanded_keys: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(num_round_keys * round_key_bits, streams);
assert_eq!(
key.as_ref().d_blocks.lwe_ciphertext_count().0,
input_key_bits,
"Input key must contain {} encrypted bits, but contains {}",
input_key_bits,
key.as_ref().d_blocks.lwe_ciphertext_count().0
);
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_aes_key_expansion_256(
streams,
expanded_keys.as_mut(),
key.as_ref(),
&d_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_aes_key_expansion_256(
streams,
expanded_keys.as_mut(),
key.as_ref(),
&d_multibit_bsk.d_vec,
&computing_ks_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
}
}
}
expanded_keys
}
pub fn get_key_expansion_256_size_on_gpu(&self, streams: &CudaStreams) -> u64 {
let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_get_aes_key_expansion_256_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
)
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_get_aes_key_expansion_256_size_on_gpu(
streams,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
computing_ks_key.decomposition_level_count(),
computing_ks_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
)
}
}
}
}