mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
keyswitch
This commit is contained in:
@@ -606,13 +606,78 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint32_t output_dimension, uint32_t ksk_index,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
assert(out_stride == 1 && ct0_stride == 1);
|
||||
// Get keyswitch key
|
||||
const uint64_t *keyswitch_key = context->keyswitch_key_buffer(ksk_index);
|
||||
// Get stack parameter
|
||||
concrete_cpu_keyswitch_lwe_ciphertext_u64(
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset, keyswitch_key,
|
||||
decomposition_level_count, decomposition_base_log, input_dimension,
|
||||
output_dimension);
|
||||
|
||||
// The input ciphertext is a vector of u64 of size input_dimension + 1. It has the following
|
||||
// memory repr:
|
||||
// [a1, a2, ..., an, b] with ai the mask elements, and b the body.
|
||||
uint64_t* input_ct = ct0_aligned + ct0_offset;
|
||||
uint64_t input_size = input_dimension + 1;
|
||||
|
||||
// The output ciphertext is a vector of u64 of size output_dimension + 1. It has the same kind of
|
||||
// as the input ct.
|
||||
uint64_t* output_ct = out_aligned + out_offset;
|
||||
uint64_t output_size = output_dimension + 1;
|
||||
|
||||
// We retrieve the ksk from the context. It is a three dimensional contiguous tensor with the
|
||||
// following repr:
|
||||
// [I1, I2, ..., Iv] with v the lwe dimension of the input key.
|
||||
//
|
||||
// Each Ii contains the encryptions of the decomposition of one bit of the input key. It has the
|
||||
// following repr:
|
||||
// [D1, D2, ..., Dl] with l the number of decomposition levels used for the ksk.
|
||||
//
|
||||
// Each Di is the encryption under the output key, of one level of the decomposition of the key
|
||||
// bit. It has the same repr as athe :
|
||||
// [D1, D2, ..., Dl] with l the number of decomposition levels used for the ksk.
|
||||
const uint64_t *ksk = context->keyswitch_key_buffer(ksk_index);
|
||||
|
||||
// We begin by zeroing out the output ct
|
||||
for (size_t i = 0; i < output_size; i++){
|
||||
output_ct[i] = 0;
|
||||
}
|
||||
|
||||
// We copy the body of the input ct to the body of the output ct
|
||||
output_ct[output_size-1] = input_ct[input_size-1];
|
||||
|
||||
// We loop through the mask elements of the input ct and the corresponding key bits in the ksk.
|
||||
for (size_t i = 0; i < input_dimension; i ++){
|
||||
// We retrieve the block for the ith input key bit.
|
||||
const uint64_t *ksk_block = ksk + i * decomposition_level_count * output_size;
|
||||
// We retrieve the mask element
|
||||
uint64_t mask_elm = input_ct[i];
|
||||
|
||||
// We compute the closest representable with the decomposition
|
||||
uint64_t non_rep_bit_count = 64 - decomposition_level_count * decomposition_base_log;
|
||||
uint64_t shift = non_rep_bit_count - 1;
|
||||
uint64_t closest = mask_elm >> shift;
|
||||
closest += (uint64_t) 1;
|
||||
closest &= (uint64_t) - 2;
|
||||
closest <<= shift;
|
||||
|
||||
// We initialize the decomposition
|
||||
uint64_t dec_state = closest >> non_rep_bit_count;
|
||||
uint64_t dec_mod_b_mask = ((uint64_t) 1 << decomposition_base_log) - (uint64_t) 1;
|
||||
|
||||
// We loop through the levels of the decomposition
|
||||
for (size_t j = 0; j < decomposition_level_count; j++){
|
||||
|
||||
// We retrieve the encryption of the jth decomposition of ith bit of the input key.
|
||||
const uint64_t *ksk_ct = ksk_block + j * output_size;
|
||||
|
||||
// We get the decomposed iterate
|
||||
uint64_t decomposed = dec_state & dec_mod_b_mask;
|
||||
dec_state >>= decomposition_base_log;
|
||||
uint64_t carry = ((decomposed - (uint64_t) 1) | dec_state) & decomposed;
|
||||
carry >>= decomposition_base_log - 1;
|
||||
dec_state += carry;
|
||||
decomposed -= carry << decomposition_base_log;
|
||||
|
||||
// We accumulate in the output ct
|
||||
for (size_t k = 0; k < output_size; k++){
|
||||
output_ct[k] -= ksk_ct[k] * decomposed;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void memref_batched_add_lwe_ciphertexts_u64(
|
||||
|
||||
Reference in New Issue
Block a user