keyswitch

This commit is contained in:
Alexandre Péré
2024-03-18 16:11:26 +01:00
parent 312f505063
commit 957b592cda

View File

@@ -606,13 +606,78 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint32_t output_dimension, uint32_t ksk_index,
mlir::concretelang::RuntimeContext *context) {
assert(out_stride == 1 && ct0_stride == 1);
// Get keyswitch key
const uint64_t *keyswitch_key = context->keyswitch_key_buffer(ksk_index);
// Get stack parameter
concrete_cpu_keyswitch_lwe_ciphertext_u64(
out_aligned + out_offset, ct0_aligned + ct0_offset, keyswitch_key,
decomposition_level_count, decomposition_base_log, input_dimension,
output_dimension);
// The input ciphertext is a vector of u64 of size input_dimension + 1. It has the following
// memory repr:
// [a1, a2, ..., an, b] with ai the mask elements, and b the body.
uint64_t* input_ct = ct0_aligned + ct0_offset;
uint64_t input_size = input_dimension + 1;
// The output ciphertext is a vector of u64 of size output_dimension + 1. It has the same kind of
// as the input ct.
uint64_t* output_ct = out_aligned + out_offset;
uint64_t output_size = output_dimension + 1;
// We retrieve the ksk from the context. It is a three dimensional contiguous tensor with the
// following repr:
// [I1, I2, ..., Iv] with v the lwe dimension of the input key.
//
// Each Ii contains the encryptions of the decomposition of one bit of the input key. It has the
// following repr:
// [D1, D2, ..., Dl] with l the number of decomposition levels used for the ksk.
//
// Each Di is the encryption under the output key, of one level of the decomposition of the key
// bit. It has the same repr as athe :
// [D1, D2, ..., Dl] with l the number of decomposition levels used for the ksk.
const uint64_t *ksk = context->keyswitch_key_buffer(ksk_index);
// We begin by zeroing out the output ct
for (size_t i = 0; i < output_size; i++){
output_ct[i] = 0;
}
// We copy the body of the input ct to the body of the output ct
output_ct[output_size-1] = input_ct[input_size-1];
// We loop through the mask elements of the input ct and the corresponding key bits in the ksk.
for (size_t i = 0; i < input_dimension; i ++){
// We retrieve the block for the ith input key bit.
const uint64_t *ksk_block = ksk + i * decomposition_level_count * output_size;
// We retrieve the mask element
uint64_t mask_elm = input_ct[i];
// We compute the closest representable with the decomposition
uint64_t non_rep_bit_count = 64 - decomposition_level_count * decomposition_base_log;
uint64_t shift = non_rep_bit_count - 1;
uint64_t closest = mask_elm >> shift;
closest += (uint64_t) 1;
closest &= (uint64_t) - 2;
closest <<= shift;
// We initialize the decomposition
uint64_t dec_state = closest >> non_rep_bit_count;
uint64_t dec_mod_b_mask = ((uint64_t) 1 << decomposition_base_log) - (uint64_t) 1;
// We loop through the levels of the decomposition
for (size_t j = 0; j < decomposition_level_count; j++){
// We retrieve the encryption of the jth decomposition of ith bit of the input key.
const uint64_t *ksk_ct = ksk_block + j * output_size;
// We get the decomposed iterate
uint64_t decomposed = dec_state & dec_mod_b_mask;
dec_state >>= decomposition_base_log;
uint64_t carry = ((decomposed - (uint64_t) 1) | dec_state) & decomposed;
carry >>= decomposition_base_log - 1;
dec_state += carry;
decomposed -= carry << decomposition_base_log;
// We accumulate in the output ct
for (size_t k = 0; k < output_size; k++){
output_ct[k] -= ksk_ct[k] * decomposed;
}
}
}
}
void memref_batched_add_lwe_ciphertexts_u64(