enhance(runtime): Fix the runtime tools to handle the ciphertext bufferization and the new compiler concrete bufferized API

This commit is contained in:
Quentin Bourgerie
2022-02-11 14:29:28 +01:00
committed by Quentin Bourgerie
parent 8a9cce64e3
commit 4e8a9d1077
12 changed files with 312 additions and 223 deletions

View File

@@ -12,9 +12,9 @@ pip install pybind11
Build concrete library:
```sh
git clone https://github.com/zama-ai/concrete
cd concrete
git checkout feature/core_c_api
git clone https://github.com/zama-ai/concrete_internal
cd concrete_internal
git checkout compiler_c_api
cd concrete-ffi
RUSTFLAGS="-C target-cpu=native" cargo build --release
```
@@ -23,7 +23,7 @@ Generate the compiler build system, in the `build` directory
```sh
export LLVM_PROJECT="PATH_TO_LLVM_PROJECT"
export CONCRETE_PROJECT="PATH_TO_CONCRETE_PROJECT"
export CONCRETE_PROJECT="PATH_TO_CONCRETE_INTERNAL_PROJECT"
make build-initialized
```

View File

@@ -37,16 +37,36 @@ public:
// isInputEncrypted return true if the input at the given pos is encrypted.
bool isInputEncrypted(size_t pos);
// allocate a lwe ciphertext for the argument at argPos.
llvm::Error allocate_lwe(size_t argPos, LweCiphertext_u64 **ciphertext);
// getInputLweSecretKeyParam returns the parameters of the lwe secret key for
// the input at the given `pos`.
// The input must be encrupted
LweSecretKeyParam getInputLweSecretKeyParam(size_t pos) {
auto gate = inputGate(pos);
auto inputSk = this->secretKeys.find(gate.encryption->secretKeyID);
return inputSk->second.first;
}
// getOutputLweSecretKeyParam returns the parameters of the lwe secret key for
// the given output.
LweSecretKeyParam getOutputLweSecretKeyParam(size_t pos) {
auto gate = outputGate(pos);
auto outputSk = this->secretKeys.find(gate.encryption->secretKeyID);
return outputSk->second.first;
}
// allocate a lwe ciphertext buffer for the argument at argPos, set the size
// of the allocated buffer.
llvm::Error allocate_lwe(size_t argPos, uint64_t **ciphertext,
uint64_t &size);
// encrypt the input to the ciphertext for the argument at argPos.
llvm::Error encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
uint64_t input);
llvm::Error encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input);
// isOuputEncrypted return true if the output at the given pos is encrypted.
bool isOutputEncrypted(size_t pos);
// decrypt the ciphertext to the output for the argument at argPos.
llvm::Error decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
llvm::Error decrypt_lwe(size_t argPos, uint64_t *ciphertext,
uint64_t &output);
size_t numInputs() { return inputs.size(); }

View File

@@ -20,10 +20,9 @@ typedef struct RuntimeContext {
std::map<std::string, LweBootstrapKey_u64 *> bsk;
~RuntimeContext() {
int err;
for (const auto &key : bsk) {
if (key.first != "_concretelang_base_context_bsk")
free_lwe_bootstrap_key_u64(&err, key.second);
free_lwe_bootstrap_key_u64(key.second);
}
}
} RuntimeContext;

View File

@@ -8,10 +8,48 @@
#include "concrete-ffi.h"
ForeignPlaintextList_u64 *
runtime_foreign_plaintext_list_u64(int *err, uint64_t *allocated,
uint64_t *aligned, uint64_t offset,
uint64_t size_dim0, uint64_t stride_dim0,
uint64_t size, uint32_t precision);
struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64(
uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size,
uint64_t stride, uint32_t precision);
void memref_add_lwe_ciphertexts_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t *ct1_allocated, uint64_t *ct1_aligned,
uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride);
void memref_add_plaintext_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t plaintext);
void memref_mul_cleartext_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t cleartext);
void memref_negate_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride);
void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key,
uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride);
void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key,
uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride,
struct GlweCiphertext_u64 *accumulator);
#endif

View File

@@ -96,13 +96,13 @@ public:
// Store the values of outputs
std::vector<void *> outputs;
// Store the input gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> inputGates;
std::vector<std::tuple<CircuitGate, size_t /*offset*/>> inputGates;
// Store the outputs gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> outputGates;
std::vector<std::tuple<CircuitGate, size_t /*offset*/>> outputGates;
// Store allocated lwe ciphertexts (for free)
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
std::vector<uint64_t *> allocatedCiphertexts;
// Store buffers of ciphertexts
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
std::vector<uint64_t *> ciphertextBuffers;
KeySet &keySet;
RuntimeContext context;

View File

@@ -6,31 +6,20 @@
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/Support/Error.h"
#define CAPI_ERR_TO_LLVM_ERROR(s, msg) \
{ \
int err; \
s; \
if (err != 0) { \
return llvm::make_error<llvm::StringError>( \
msg, llvm::inconvertibleErrorCode()); \
} \
}
namespace mlir {
namespace concretelang {
KeySet::~KeySet() {
int err;
for (auto it : secretKeys) {
free_lwe_secret_key_u64(&err, it.second.second);
free_lwe_secret_key_u64(it.second.second);
}
for (auto it : bootstrapKeys) {
free_lwe_bootstrap_key_u64(&err, it.second.second);
free_lwe_bootstrap_key_u64(it.second.second);
}
for (auto it : keyswitchKeys) {
free_lwe_keyswitch_key_u64(&err, it.second.second);
free_lwe_keyswitch_key_u64(it.second.second);
}
free_encryption_generator(&err, encryptionRandomGenerator);
free_encryption_generator(encryptionRandomGenerator);
}
llvm::Expected<std::unique_ptr<KeySet>>
@@ -97,10 +86,8 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters &params,
}
}
CAPI_ERR_TO_LLVM_ERROR(
this->encryptionRandomGenerator =
allocate_encryption_generator(&err, seed_msb, seed_lsb),
"cannot allocate encryption generator");
this->encryptionRandomGenerator =
allocate_encryption_generator(seed_msb, seed_lsb);
return llvm::Error::success();
}
@@ -108,13 +95,11 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters &params,
llvm::Error KeySet::generateKeysFromParams(ClientParameters &params,
uint64_t seed_msb,
uint64_t seed_lsb) {
{
// Generate LWE secret keys
SecretRandomGenerator *generator;
CAPI_ERR_TO_LLVM_ERROR(
generator = allocate_secret_generator(&err, seed_msb, seed_lsb),
"cannot allocate random generator");
generator = allocate_secret_generator(seed_msb, seed_lsb);
for (auto secretKeyParam : params.secretKeys) {
auto e = this->generateSecretKey(secretKeyParam.first,
secretKeyParam.second, generator);
@@ -122,14 +107,12 @@ llvm::Error KeySet::generateKeysFromParams(ClientParameters &params,
return std::move(e);
}
}
CAPI_ERR_TO_LLVM_ERROR(free_secret_generator(&err, generator),
"cannot free random generator");
free_secret_generator(generator);
}
// Allocate the encryption random generator
CAPI_ERR_TO_LLVM_ERROR(
this->encryptionRandomGenerator =
allocate_encryption_generator(&err, seed_msb, seed_lsb),
"cannot allocate encryption generator");
this->encryptionRandomGenerator =
allocate_encryption_generator(seed_msb, seed_lsb);
// Generate bootstrap and keyswitch keys
{
for (auto bootstrapKeyParam : params.bootstrapKeys) {
@@ -170,12 +153,9 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
LweSecretKeyParam param,
SecretRandomGenerator *generator) {
LweSecretKey_u64 *sk;
CAPI_ERR_TO_LLVM_ERROR(
sk = allocate_lwe_secret_key_u64(&err, {param.size + 1}),
"cannot allocate secret key");
sk = allocate_lwe_secret_key_u64({param.size});
CAPI_ERR_TO_LLVM_ERROR(fill_lwe_secret_key_u64(&err, sk, generator),
"cannot fill secret key with random generator");
fill_lwe_secret_key_u64(sk, generator);
secretKeys[id] = {param, sk};
@@ -207,11 +187,9 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
uint64_t polynomialSize = total_dimension / param.glweDimension;
CAPI_ERR_TO_LLVM_ERROR(
bsk = allocate_lwe_bootstrap_key_u64(
&err, {param.level}, {param.baseLog}, {param.glweDimension + 1},
{inputSk->second.first.size + 1}, {polynomialSize}),
"cannot allocate bootstrap key");
bsk = allocate_lwe_bootstrap_key_u64(
{param.level}, {param.baseLog}, {param.glweDimension},
{inputSk->second.first.size}, {polynomialSize});
// Store the bootstrap key
bootstrapKeys[id] = {param, bsk};
@@ -219,23 +197,16 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
// Convert the output lwe key to glwe key
GlweSecretKey_u64 *glwe_sk;
CAPI_ERR_TO_LLVM_ERROR(
glwe_sk = allocate_glwe_secret_key_u64(&err, {param.glweDimension + 1},
{polynomialSize}),
"cannot allocate glwe key for initiliazation of bootstrap key");
glwe_sk =
allocate_glwe_secret_key_u64({param.glweDimension}, {polynomialSize});
CAPI_ERR_TO_LLVM_ERROR(fill_glwe_secret_key_with_lwe_secret_key_u64(
&err, glwe_sk, outputSk->second.second),
"cannot fill glwe key with big key");
fill_glwe_secret_key_with_lwe_secret_key_u64(glwe_sk,
outputSk->second.second);
// Initialize the bootstrap key
CAPI_ERR_TO_LLVM_ERROR(
fill_lwe_bootstrap_key_u64(&err, bsk, inputSk->second.second, glwe_sk,
generator, {param.variance}),
"cannot fill bootstrap key");
CAPI_ERR_TO_LLVM_ERROR(
free_glwe_secret_key_u64(&err, glwe_sk),
"cannot free glwe key for initiliazation of bootstrap key")
fill_lwe_bootstrap_key_u64(bsk, inputSk->second.second, glwe_sk, generator,
{param.variance});
free_glwe_secret_key_u64(glwe_sk);
return llvm::Error::success();
}
@@ -257,33 +228,32 @@ llvm::Error KeySet::generateKeyswitchKey(KeyswitchKeyID id,
}
// Allocate the keyswitch key
LweKeyswitchKey_u64 *ksk;
CAPI_ERR_TO_LLVM_ERROR(
ksk = allocate_lwe_keyswitch_key_u64(&err, {param.level}, {param.baseLog},
{inputSk->second.first.size + 1},
{outputSk->second.first.size + 1}),
"cannot allocate keyswitch key");
ksk = allocate_lwe_keyswitch_key_u64({param.level}, {param.baseLog},
{inputSk->second.first.size},
{outputSk->second.first.size});
// Store the keyswitch key
keyswitchKeys[id] = {param, ksk};
// Initialize the keyswitch key
CAPI_ERR_TO_LLVM_ERROR(
fill_lwe_keyswitch_key_u64(&err, ksk, inputSk->second.second,
outputSk->second.second, generator,
{param.variance}),
"cannot fill bootsrap key");
fill_lwe_keyswitch_key_u64(ksk, inputSk->second.second,
outputSk->second.second, generator,
{param.variance});
return llvm::Error::success();
}
llvm::Error KeySet::allocate_lwe(size_t argPos,
LweCiphertext_u64 **ciphertext) {
llvm::Error KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext,
uint64_t &size) {
if (argPos >= inputs.size()) {
return llvm::make_error<llvm::StringError>(
"allocate_lwe position of argument is too high",
llvm::inconvertibleErrorCode());
}
auto inputSk = inputs[argPos];
CAPI_ERR_TO_LLVM_ERROR(*ciphertext = allocate_lwe_ciphertext_u64(
&err, {std::get<1>(inputSk).size + 1}),
"cannot allocate ciphertext");
size = std::get<1>(inputSk).size + 1;
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size);
return llvm::Error::success();
}
@@ -297,7 +267,7 @@ bool KeySet::isOutputEncrypted(size_t argPos) {
std::get<0>(outputs[argPos]).encryption.hasValue();
}
llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
llvm::Error KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext,
uint64_t input) {
if (argPos >= inputs.size()) {
return llvm::make_error<llvm::StringError>(
@@ -311,19 +281,15 @@ llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
llvm::inconvertibleErrorCode());
}
// Encode - TODO we could check if the input value is in the right range
Plaintext_u64 plaintext = {
input << (64 -
(std::get<0>(inputSk).encryption->encoding.precision + 1))};
// Encrypt
CAPI_ERR_TO_LLVM_ERROR(
encrypt_lwe_u64(&err, std::get<2>(inputSk), ciphertext, plaintext,
encryptionRandomGenerator,
{std::get<0>(inputSk).encryption->variance}),
"cannot encrypt");
uint64_t plaintext =
input << (64 - (std::get<0>(inputSk).encryption->encoding.precision + 1));
encrypt_lwe_u64(std::get<2>(inputSk), ciphertext, plaintext,
encryptionRandomGenerator,
{std::get<0>(inputSk).encryption->variance});
return llvm::Error::success();
}
llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
llvm::Error KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext,
uint64_t &output) {
if (argPos >= outputs.size()) {
@@ -337,14 +303,10 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
"decrypt_lwe: the positional argument is not encrypted",
llvm::inconvertibleErrorCode());
}
// Decrypt
Plaintext_u64 plaintext = {0};
CAPI_ERR_TO_LLVM_ERROR(
decrypt_lwe_u64(&err, std::get<2>(outputSk), ciphertext, &plaintext),
"cannot decrypt");
uint64_t plaintext = decrypt_lwe_u64(std::get<2>(outputSk), ciphertext);
// Decode
size_t precision = std::get<0>(outputSk).encryption->encoding.precision;
output = plaintext._0 >> (64 - precision - 2);
output = plaintext >> (64 - precision - 2);
size_t carry = output % 2;
output = ((output >> 1) + carry) % (1 << (precision + 1));
return llvm::Error::success();

View File

@@ -142,13 +142,10 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
auto contextType =
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
auto errType = mlir::IndexType::get(rewriter.getContext());
// Insert forward declaration of allocate lwe ciphertext
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
rewriter.getIndexType(),
},
@@ -163,7 +160,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
genericLweCiphertextType,
genericLweCiphertextType,
genericLweCiphertextType,
@@ -179,7 +175,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
genericLweCiphertextType,
genericLweCiphertextType,
genericPlaintextType,
@@ -195,7 +190,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
genericLweCiphertextType,
genericLweCiphertextType,
genericCleartextType,
@@ -211,7 +205,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType, genericLweCiphertextType, genericLweCiphertextType}, {});
{genericLweCiphertextType, genericLweCiphertextType}, {});
if (insertForwardDeclaration(op, rewriter, "negate_lwe_ciphertext_u64",
funcType)
.failed()) {
@@ -231,7 +225,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
genericBSKType,
genericLweCiphertextType,
genericLweCiphertextType,
@@ -256,7 +249,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
// ksk
genericKSKType,
// output ct
@@ -274,7 +266,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
rewriter.getI32Type(),
rewriter.getI32Type(),
},
@@ -287,9 +278,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
}
// Insert forward declaration of the alloc_plaintext_list function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{errType, rewriter.getI32Type()},
{genericPlaintextListType});
auto funcType =
mlir::FunctionType::get(rewriter.getContext(), {rewriter.getI32Type()},
{genericPlaintextListType});
if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64",
funcType)
.failed()) {
@@ -300,7 +291,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType, genericPlaintextListType, genericForeignPlaintextList}, {});
{genericPlaintextListType, genericForeignPlaintextList}, {});
if (insertForwardDeclaration(
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
.failed()) {
@@ -310,7 +301,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
// Insert forward declaration of the add_plaintext_list_glwe function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{errType, genericGlweCiphertextType,
{genericGlweCiphertextType,
genericGlweCiphertextType,
genericPlaintextListType},
{});
@@ -356,24 +347,20 @@ struct ConcreteOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
// Replace the operation with a call to the `funcName`
{
// Create the err value
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(0));
// Get the size from the dimension
int64_t lweDimension = lweResultType.getDimension();
int64_t lweSize = lweDimension + 1;
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweSize));
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweDimension));
// Add the call to the allocation
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
mlir::SmallVector<mlir::Value> allocOperands{lweDimensionOp};
auto allocGeneric = rewriter.create<mlir::CallOp>(
op.getLoc(), allocName,
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
// Construct operands for the operation.
// errOp doesn't need to be casted to something generic, allocGeneric
// already is. All the rest will be converted if needed
mlir::SmallVector<mlir::Value, 4> newOperands{errOp,
allocGeneric.getResult(0)};
mlir::SmallVector<mlir::Value, 4> newOperands{allocGeneric.getResult(0)};
for (mlir::Value operand : op->getOperands()) {
mlir::Type operandType = operand.getType();
mlir::Type castedType = getGenericType(operandType);
@@ -420,16 +407,13 @@ struct ConcreteZeroOpPattern
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
// Create the err value
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(0));
// Get the size from the dimension
int64_t lweDimension = lweResultType.getDimension();
int64_t lweSize = lweDimension + 1;
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweSize));
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweDimension));
// Allocate a fresh new ciphertext
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
mlir::SmallVector<mlir::Value> allocOperands{lweDimensionOp};
auto allocGeneric = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
@@ -506,7 +490,6 @@ struct GlweFromTableOpPattern
matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op,
mlir::PatternRewriter &rewriter) const override {
ConcreteToConcreteCAPITypeConverter typeConverter;
auto errType = mlir::IndexType::get(rewriter.getContext());
// TODO: move this to insertForwardDeclarations
// issue: can't define function with tensor<*xtype> that accept ranked
@@ -516,28 +499,22 @@ struct GlweFromTableOpPattern
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType, op->getOperandTypes().front(), rewriter.getI64Type(),
rewriter.getI32Type()},
{op->getOperandTypes().front(), rewriter.getI32Type()},
{getGenericForeignPlaintextListType(rewriter.getContext())});
if (insertForwardDeclaration(
op, rewriter, "runtime_foreign_plaintext_list_u64", funcType)
if (insertForwardDeclaration(op, rewriter,
"memref_runtime_foreign_plaintext_list_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(0));
// Get the size from the dimension
int64_t glweDimension =
op->getAttr("glweDimension").cast<mlir::IntegerAttr>().getInt();
int64_t glweSize = glweDimension + 1;
mlir::Value glweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(glweSize));
// allocate two glwe to build accumulator
auto polySizeOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op->getAttr("polynomialSize"));
mlir::SmallVector<mlir::Value> allocGlweOperands{errOp, glweSizeOp,
auto glweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op->getAttr("glweDimension"));
mlir::SmallVector<mlir::Value> allocGlweOperands{glweDimensionOp,
polySizeOp};
// first accumulator would replace the op since it's the returned value
auto accumulatorOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
@@ -548,8 +525,7 @@ struct GlweFromTableOpPattern
op.getLoc(), "allocate_glwe_ciphertext_u64",
getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands);
// allocate plaintext list
mlir::SmallVector<mlir::Value> allocPlaintextListOperands{errOp,
polySizeOp};
mlir::SmallVector<mlir::Value> allocPlaintextListOperands{polySizeOp};
auto plaintextListOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_plaintext_list_u64",
getGenericPlaintextListType(rewriter.getContext()),
@@ -559,27 +535,23 @@ struct GlweFromTableOpPattern
op->getOperandTypes().front().cast<mlir::RankedTensorType>();
assert(rankedTensorType.getRank() == 1 &&
"table lookup must be of a single dimension");
auto sizeOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(),
rewriter.getI64IntegerAttr(rankedTensorType.getDimSize(0)));
auto precisionOp =
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op->getAttr("p"));
mlir::SmallVector<mlir::Value> ForeignPlaintextListOperands{
errOp, op->getOperand(0), sizeOp, precisionOp};
op->getOperand(0), precisionOp};
auto foreignPlaintextListOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "runtime_foreign_plaintext_list_u64",
op.getLoc(), "memref_runtime_foreign_plaintext_list_u64",
getGenericForeignPlaintextListType(rewriter.getContext()),
ForeignPlaintextListOperands);
// fill plaintext list
mlir::SmallVector<mlir::Value> FillPlaintextListOperands{
errOp, plaintextListOp.getResult(0),
foreignPlaintextListOp.getResult(0)};
plaintextListOp.getResult(0), foreignPlaintextListOp.getResult(0)};
rewriter.create<mlir::CallOp>(
op.getLoc(), "fill_plaintext_list_with_expansion_u64",
mlir::TypeRange({}), FillPlaintextListOperands);
// add plaintext list and glwe to build final accumulator for pbs
mlir::SmallVector<mlir::Value> AddPlaintextListGlweOperands{
errOp, accumulatorOp.getResult(0), _accumulatorOp.getResult(0),
accumulatorOp.getResult(0), _accumulatorOp.getResult(0),
plaintextListOp.getResult(0)};
rewriter.create<mlir::CallOp>(
op.getLoc(), "add_plaintext_list_glwe_ciphertext_u64",
@@ -626,18 +598,15 @@ struct ConcreteBootstrapLweOpPattern
matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultType = op->getResultTypes().front();
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(0));
// Get the size from the dimension
int64_t outputLweDimension =
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>()
.getDimension();
int64_t outputLweSize = outputLweDimension + 1;
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(outputLweSize));
op.getLoc(), rewriter.getIndexAttr(outputLweDimension));
// allocate the result lwe ciphertext, should be of a generic type, to cast
// before return
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
mlir::SmallVector<mlir::Value> allocLweCtOperands{lweSizeOp};
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
@@ -662,7 +631,7 @@ struct ConcreteBootstrapLweOpPattern
op.getOperand(1))
.getResult(0);
mlir::SmallVector<mlir::Value> bootstrapOperands{
errOp, getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
lweToBootstrap, accumulator};
rewriter.create<mlir::CallOp>(op.getLoc(), "bootstrap_lwe_u64",
mlir::TypeRange({}), bootstrapOperands);
@@ -690,20 +659,17 @@ struct ConcreteKeySwitchLweOpPattern
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(0));
// Get the size from the dimension
int64_t lweDimension =
op.getResult()
.getType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>()
.getDimension();
int64_t lweSize = lweDimension + 1;
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweSize));
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweDimension));
// allocate the result lwe ciphertext, should be of a generic type, to cast
// before return
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
mlir::SmallVector<mlir::Value> allocLweCtOperands{lweDimensionOp};
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
@@ -721,7 +687,7 @@ struct ConcreteKeySwitchLweOpPattern
op.getOperand())
.getResult(0);
mlir::SmallVector<mlir::Value> keyswitchOperands{
errOp, getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
lweToKeyswitch};
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
mlir::TypeRange({}), keyswitchOperands);

View File

@@ -14,9 +14,9 @@ if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED)
install(TARGETS DFRuntime EXPORT DFRuntime)
install(EXPORT DFRuntime DESTINATION "./")
target_link_libraries(ConcretelangRuntime Concrete pthread m dl HPX::hpx)
target_link_libraries(ConcretelangRuntime Concrete pthread m dl HPX::hpx $<TARGET_OBJECTS:mlir_c_runner_utils>)
else()
target_link_libraries(ConcretelangRuntime Concrete pthread m dl)
target_link_libraries(ConcretelangRuntime Concrete pthread m dl $<TARGET_OBJECTS:mlir_c_runner_utils>)
endif()
install(TARGETS ConcretelangRuntime EXPORT ConcretelangRuntime)

View File

@@ -19,7 +19,6 @@ get_keyswitch_key(mlir::concretelang::RuntimeContext *context) {
LweBootstrapKey_u64 *
get_bootstrap_key(mlir::concretelang::RuntimeContext *context) {
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
int err;
std::string threadName = hpx::get_thread_name();
auto bskIt = context->bsk.find(threadName);
if (bskIt == context->bsk.end()) {
@@ -27,10 +26,8 @@ get_bootstrap_key(mlir::concretelang::RuntimeContext *context) {
.insert(std::pair<std::string, LweBootstrapKey_u64 *>(
threadName,
clone_lwe_bootstrap_key_u64(
&err, context->bsk["_concretelang_base_context_bsk"])))
context->bsk["_concretelang_base_context_bsk"])))
.first;
if (err != 0)
fprintf(stderr, "Runtime: cloning bootstrap key failed.\n");
}
#else
std::string baseName = "_concretelang_base_context_bsk";

View File

@@ -1,20 +1,89 @@
#include "concretelang/Runtime/wrappers.h"
#include <assert.h>
#include <stdio.h>
ForeignPlaintextList_u64 *
runtime_foreign_plaintext_list_u64(int *err, uint64_t *allocated,
uint64_t *aligned, uint64_t offset,
uint64_t size_dim0, uint64_t stride_dim0,
uint64_t size, uint32_t precision) {
if (stride_dim0 != 1) {
fprintf(stderr, "Runtime: stride not equal to 1, check "
"runtime_foreign_plaintext_list_u64");
}
struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64(
uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size,
uint64_t stride, uint32_t precision) {
assert(stride == 1 && "Runtime: stride not equal to 1, check "
"runtime_foreign_plaintext_list_u64");
// Encode table values in u64
uint64_t *encoded_table = malloc(size * sizeof(uint64_t));
for (uint64_t i = 0; i < size; i++) {
encoded_table[i] = (aligned + offset)[i] << (64 - precision - 1);
}
return foreign_plaintext_list_u64(err, encoded_table, size);
return foreign_plaintext_list_u64(encoded_table, size);
// TODO: is it safe to free after creating plaintext_list?
}
void memref_add_lwe_ciphertexts_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t *ct1_allocated, uint64_t *ct1_aligned,
uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride) {
assert(out_size == ct0_size && out_size == ct1_size &&
"size of lwe buffer are incompatible");
LweDimension lwe_dimension = {out_size - 1};
add_two_lwe_ciphertexts_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset,
ct1_aligned + ct1_offset, lwe_dimension);
}
void memref_add_plaintext_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t plaintext) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
LweDimension lwe_dimension = {out_size - 1};
add_plaintext_to_lwe_ciphertext_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset, plaintext,
lwe_dimension);
}
void memref_mul_cleartext_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t cleartext) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
LweDimension lwe_dimension = {out_size - 1};
mul_cleartext_lwe_ciphertext_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset, cleartext,
lwe_dimension);
}
void memref_negate_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
LweDimension lwe_dimension = {out_size - 1};
neg_lwe_ciphertext_u64(out_aligned + out_offset, ct0_aligned + ct0_offset,
lwe_dimension);
}
void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key,
uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride) {
bufferized_keyswitch_lwe_u64(keyswitch_key, out_aligned + out_offset,
ct0_aligned + ct0_offset);
}
void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key,
uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride,
struct GlweCiphertext_u64 *accumulator) {
bufferized_bootstrap_lwe_u64(bootstrap_key, out_aligned + out_offset,
ct0_aligned + ct0_offset, accumulator);
}

View File

@@ -78,6 +78,12 @@ llvm::Error JITLambda::invoke(Argument &args) {
<< actualInputs << "arguments instead of " << expectedInputs;
}
// memref is a struct which is flattened aligned, allocated pointers, offset,
// and two array of rank size for sizes and strides.
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
return 3 + 2 * rank;
}
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
// Setting the inputs
auto numInputs = 0;
@@ -86,16 +92,20 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
auto offset = numInputs;
auto gate = keySet.inputGate(i);
inputGates.push_back({gate, offset});
if (keySet.inputGate(i).shape.dimensions.empty()) {
if (gate.shape.dimensions.empty()) {
// scalar gate
numInputs = numInputs + 1;
if (gate.encryption.hasValue()) {
// encrypted is a memref<lweSizexi64>
numInputs = numInputs + numArgOfRankedMemrefCallingConvention(1);
} else {
numInputs = numInputs + 1;
}
continue;
}
// memref gate, as we follow the standard calling convention
numInputs = numInputs + 3;
// Offsets and strides are array of size N where N is the number of
// dimension of the tensor.
numInputs = numInputs + 2 * keySet.inputGate(i).shape.dimensions.size();
auto rank = keySet.inputGate(i).shape.dimensions.size() +
(keySet.isInputEncrypted(i) ? 1 /* for lwe size */ : 0);
numInputs = numInputs + numArgOfRankedMemrefCallingConvention(rank);
}
// Reserve for the context argument
numInputs = numInputs + 1;
@@ -111,19 +121,21 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
outputGates.push_back({gate, offset});
if (gate.shape.dimensions.empty()) {
// scalar gate
numOutputs = numOutputs + 1;
if (gate.encryption.hasValue()) {
// encrypted is a memref<lweSizexi64>
numOutputs = numOutputs + numArgOfRankedMemrefCallingConvention(1);
} else {
numOutputs = numOutputs + 1;
}
continue;
}
// memref gate, as we follow the standard calling convention
numOutputs = numOutputs + 3;
// Offsets and strides are array of size N where N is the number of
// dimension of the tensor.
numOutputs =
numOutputs + 2 * keySet.outputGate(i).shape.dimensions.size();
auto rank = keySet.outputGate(i).shape.dimensions.size() +
(keySet.isOutputEncrypted(i) ? 1 /* for lwe size */ : 0);
numOutputs = numOutputs + numArgOfRankedMemrefCallingConvention(rank);
}
outputs = std::vector<void *>(numOutputs);
}
// The raw argument contains pointers to inputs and pointers to store the
// results
rawArg = std::vector<void *>(inputs.size() + outputs.size(), nullptr);
@@ -139,9 +151,8 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
}
JITLambda::Argument::~Argument() {
int err;
for (auto ct : allocatedCiphertexts) {
free_lwe_ciphertext_u64(&err, ct);
free(ct);
}
for (auto buffer : ciphertextBuffers) {
free(buffer);
@@ -185,16 +196,31 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
return llvm::Error::success();
}
// Else if is encryted, allocate ciphertext and encrypt.
LweCiphertext_u64 *ctArg;
if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) {
uint64_t *ctArg;
uint64_t ctSize;
if (auto err = this->keySet.allocate_lwe(pos, &ctArg, ctSize)) {
return std::move(err);
}
allocatedCiphertexts.push_back(ctArg);
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
return std::move(err);
}
inputs[offset] = ctArg;
// memref calling convention
// allocated
inputs[offset] = nullptr;
// aligned
inputs[offset + 1] = ctArg;
// offset
inputs[offset + 2] = (void *)0;
// size
inputs[offset + 3] = (void *)ctSize;
// stride
inputs[offset + 4] = (void *)1;
rawArg[offset] = &inputs[offset];
rawArg[offset + 1] = &inputs[offset + 1];
rawArg[offset + 2] = &inputs[offset + 2];
rawArg[offset + 3] = &inputs[offset + 3];
rawArg[offset + 4] = &inputs[offset + 4];
return llvm::Error::success();
}
@@ -279,17 +305,18 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
llvm::inconvertibleErrorCode());
}
// Allocate a buffer for ciphertexts.
auto ctBuffer = (LweCiphertext_u64 **)malloc(info.shape.size *
sizeof(LweCiphertext_u64 *));
// Allocate a buffer for ciphertexts, the size of the buffer is the number
// of elements of the tensor * the size of the lwe ciphertext
auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1;
uint64_t *ctBuffer =
(uint64_t *)malloc(info.shape.size * lweSize * sizeof(uint64_t));
ciphertextBuffers.push_back(ctBuffer);
// Allocate ciphertexts and encrypt
for (size_t i = 0; i < info.shape.size; i++) {
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
return std::move(err);
}
allocatedCiphertexts.push_back(ctBuffer[i]);
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
// Encrypt ciphertexts
for (size_t i = 0, offset = 0; i < info.shape.size;
i++, offset += lweSize) {
if (auto err =
this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i])) {
return std::move(err);
}
}
@@ -316,17 +343,27 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
rawArg[offset] = &inputs[offset];
offset++;
}
// If encrypted +1 for the lwe size rank
if (keySet.isInputEncrypted(pos)) {
inputs[offset] = (void *)(keySet.getInputLweSecretKeyParam(pos).size + 1);
rawArg[offset] = &inputs[offset];
offset++;
}
// Set the stride for each dimension, equal to the product of the
// following dimensions.
int64_t stride = 1;
// If encrypted +1 set the stride for the lwe size rank
if (keySet.isInputEncrypted(pos)) {
inputs[offset + shape.size()] = (void *)stride;
rawArg[offset + shape.size()] = &inputs[offset];
stride *= keySet.getInputLweSecretKeyParam(pos).size + 1;
}
for (ssize_t i = shape.size() - 1; i >= 0; i--) {
inputs[offset + i] = (void *)stride;
rawArg[offset + i] = &inputs[offset + i];
stride *= shape[i];
}
offset += shape.size();
return llvm::Error::success();
@@ -349,7 +386,7 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
return llvm::Error::success();
}
// Else if is encryted, decrypt
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]);
uint64_t *ct = (uint64_t *)(outputs[offset + 1]);
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
return std::move(err);
}
@@ -463,8 +500,10 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, void *res,
}
} else {
// decrypt and fill the result buffer
for (size_t i = 0; i < numElements; i++) {
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)alignedBytes)[i];
auto lweSize = keySet.getOutputLweSecretKeyParam(pos).size + 1;
for (size_t i = 0, o = 0; i < numElements; i++, o += lweSize) {
uint64_t *ct = ((uint64_t *)alignedBytes) + o;
if (auto err = this->keySet.decrypt_lwe(pos, ct, ((uint64_t *)res)[i])) {
return std::move(err);
}

View File

@@ -58,7 +58,6 @@ JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName,
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
llvm::Expected<JitCompilerEngine::Lambda> res =
this->buildLambda(std::move(mb), funcName, cache, runtimeLibPath);
return std::move(res);
}