diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/wrappers.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/wrappers.cpp index 4f3dd8729..514045d02 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/wrappers.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/wrappers.cpp @@ -606,13 +606,78 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned, uint32_t output_dimension, uint32_t ksk_index, mlir::concretelang::RuntimeContext *context) { assert(out_stride == 1 && ct0_stride == 1); - // Get keyswitch key - const uint64_t *keyswitch_key = context->keyswitch_key_buffer(ksk_index); - // Get stack parameter - concrete_cpu_keyswitch_lwe_ciphertext_u64( - out_aligned + out_offset, ct0_aligned + ct0_offset, keyswitch_key, - decomposition_level_count, decomposition_base_log, input_dimension, - output_dimension); + + // The input ciphertext is a vector of u64 of size input_dimension + 1. It has the following + // memory repr: + // [a1, a2, ..., an, b] with ai the mask elements, and b the body. + uint64_t* input_ct = ct0_aligned + ct0_offset; + uint64_t input_size = input_dimension + 1; + + // The output ciphertext is a vector of u64 of size output_dimension + 1. It has the same kind of + // as the input ct. + uint64_t* output_ct = out_aligned + out_offset; + uint64_t output_size = output_dimension + 1; + + // We retrieve the ksk from the context. It is a three dimensional contiguous tensor with the + // following repr: + // [I1, I2, ..., Iv] with v the lwe dimension of the input key. + // + // Each Ii contains the encryptions of the decomposition of one bit of the input key. It has the + // following repr: + // [D1, D2, ..., Dl] with l the number of decomposition levels used for the ksk. + // + // Each Di is the encryption under the output key, of one level of the decomposition of the key + // bit. It has the same repr as athe : + // [D1, D2, ..., Dl] with l the number of decomposition levels used for the ksk. + const uint64_t *ksk = context->keyswitch_key_buffer(ksk_index); + + // We begin by zeroing out the output ct + for (size_t i = 0; i < output_size; i++){ + output_ct[i] = 0; + } + + // We copy the body of the input ct to the body of the output ct + output_ct[output_size-1] = input_ct[input_size-1]; + + // We loop through the mask elements of the input ct and the corresponding key bits in the ksk. + for (size_t i = 0; i < input_dimension; i ++){ + // We retrieve the block for the ith input key bit. + const uint64_t *ksk_block = ksk + i * decomposition_level_count * output_size; + // We retrieve the mask element + uint64_t mask_elm = input_ct[i]; + + // We compute the closest representable with the decomposition + uint64_t non_rep_bit_count = 64 - decomposition_level_count * decomposition_base_log; + uint64_t shift = non_rep_bit_count - 1; + uint64_t closest = mask_elm >> shift; + closest += (uint64_t) 1; + closest &= (uint64_t) - 2; + closest <<= shift; + + // We initialize the decomposition + uint64_t dec_state = closest >> non_rep_bit_count; + uint64_t dec_mod_b_mask = ((uint64_t) 1 << decomposition_base_log) - (uint64_t) 1; + + // We loop through the levels of the decomposition + for (size_t j = 0; j < decomposition_level_count; j++){ + + // We retrieve the encryption of the jth decomposition of ith bit of the input key. + const uint64_t *ksk_ct = ksk_block + j * output_size; + + // We get the decomposed iterate + uint64_t decomposed = dec_state & dec_mod_b_mask; + dec_state >>= decomposition_base_log; + uint64_t carry = ((decomposed - (uint64_t) 1) | dec_state) & decomposed; + carry >>= decomposition_base_log - 1; + dec_state += carry; + decomposed -= carry << decomposition_base_log; + + // We accumulate in the output ct + for (size_t k = 0; k < output_size; k++){ + output_ct[k] -= ksk_ct[k] * decomposed; + } + } + } } void memref_batched_add_lwe_ciphertexts_u64(