Files
concrete/compiler/lib/Runtime/wrappers.cpp
2022-06-16 17:04:25 +02:00

117 lines
5.4 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Runtime/wrappers.h"
#include <assert.h>
#include <stdio.h>
#include <string.h>
void memref_expand_lut_in_trivial_glwe_ct_u64(
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision,
uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset,
uint64_t lut_size, uint64_t lut_stride) {
assert(lut_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_expand_lut_in_trivial_glwe_ct_u64");
assert(glwe_ct_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_expand_lut_in_trivial_glwe_ct_u64");
expand_lut_in_trivial_glwe_ct_u64(glwe_ct_aligned, poly_size, glwe_dimension,
out_precision, lut_aligned, lut_size);
return;
}
void memref_add_lwe_ciphertexts_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t *ct1_allocated, uint64_t *ct1_aligned,
uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride) {
assert(out_size == ct0_size && out_size == ct1_size &&
"size of lwe buffer are incompatible");
size_t lwe_dimension = {out_size - 1};
add_two_lwe_ciphertexts_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset,
ct1_aligned + ct1_offset, lwe_dimension);
}
void memref_add_plaintext_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t plaintext) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
size_t lwe_dimension = {out_size - 1};
add_plaintext_to_lwe_ciphertext_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset, plaintext,
lwe_dimension);
}
void memref_mul_cleartext_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t cleartext) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
size_t lwe_dimension = {out_size - 1};
mul_cleartext_lwe_ciphertext_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset, cleartext,
lwe_dimension);
}
void memref_negate_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
size_t lwe_dimension = {out_size - 1};
neg_lwe_ciphertext_u64(out_aligned + out_offset, ct0_aligned + ct0_offset,
lwe_dimension);
}
void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride,
mlir::concretelang::RuntimeContext *context) {
keyswitch_lwe_u64(get_engine(context), get_keyswitch_key_u64(context),
out_aligned + out_offset, ct0_aligned + ct0_offset);
}
void memref_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
mlir::concretelang::RuntimeContext *context) {
bootstrap_lwe_u64(get_engine(context), get_bootstrap_key_u64(context),
out_aligned + out_offset, ct0_aligned + ct0_offset,
glwe_ct_aligned + glwe_ct_offset);
}
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
uint64_t src_offset, uint64_t src_size,
uint64_t src_stride, uint64_t *dst_allocated,
uint64_t *dst_aligned, uint64_t dst_offset,
uint64_t dst_size, uint64_t dst_stride) {
assert(src_size == dst_size && "memref_copy_one_rank size differs");
if (src_stride == dst_stride) {
memcpy(dst_aligned + dst_offset, src_aligned + src_offset,
src_size * sizeof(uint64_t));
return;
}
for (size_t i = 0; i < src_size; i++) {
dst_aligned[dst_offset + i * dst_stride] =
src_aligned[src_offset + i * src_stride];
}
}