diff --git a/compiler/README.md b/compiler/README.md index 68c3d4d9e..e24aecffb 100644 --- a/compiler/README.md +++ b/compiler/README.md @@ -12,9 +12,9 @@ pip install pybind11 Build concrete library: ```sh -git clone https://github.com/zama-ai/concrete -cd concrete -git checkout feature/core_c_api +git clone https://github.com/zama-ai/concrete_internal +cd concrete_internal +git checkout compiler_c_api cd concrete-ffi RUSTFLAGS="-C target-cpu=native" cargo build --release ``` @@ -23,7 +23,7 @@ Generate the compiler build system, in the `build` directory ```sh export LLVM_PROJECT="PATH_TO_LLVM_PROJECT" -export CONCRETE_PROJECT="PATH_TO_CONCRETE_PROJECT" +export CONCRETE_PROJECT="PATH_TO_CONCRETE_INTERNAL_PROJECT" make build-initialized ``` diff --git a/compiler/include/concretelang/ClientLib/KeySet.h b/compiler/include/concretelang/ClientLib/KeySet.h index d45aca4e3..354a7c749 100644 --- a/compiler/include/concretelang/ClientLib/KeySet.h +++ b/compiler/include/concretelang/ClientLib/KeySet.h @@ -37,16 +37,36 @@ public: // isInputEncrypted return true if the input at the given pos is encrypted. bool isInputEncrypted(size_t pos); - // allocate a lwe ciphertext for the argument at argPos. - llvm::Error allocate_lwe(size_t argPos, LweCiphertext_u64 **ciphertext); + + // getInputLweSecretKeyParam returns the parameters of the lwe secret key for + // the input at the given `pos`. + // The input must be encrupted + LweSecretKeyParam getInputLweSecretKeyParam(size_t pos) { + auto gate = inputGate(pos); + auto inputSk = this->secretKeys.find(gate.encryption->secretKeyID); + return inputSk->second.first; + } + + // getOutputLweSecretKeyParam returns the parameters of the lwe secret key for + // the given output. + LweSecretKeyParam getOutputLweSecretKeyParam(size_t pos) { + auto gate = outputGate(pos); + auto outputSk = this->secretKeys.find(gate.encryption->secretKeyID); + return outputSk->second.first; + } + + // allocate a lwe ciphertext buffer for the argument at argPos, set the size + // of the allocated buffer. + llvm::Error allocate_lwe(size_t argPos, uint64_t **ciphertext, + uint64_t &size); + // encrypt the input to the ciphertext for the argument at argPos. - llvm::Error encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, - uint64_t input); + llvm::Error encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input); // isOuputEncrypted return true if the output at the given pos is encrypted. bool isOutputEncrypted(size_t pos); // decrypt the ciphertext to the output for the argument at argPos. - llvm::Error decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, + llvm::Error decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output); size_t numInputs() { return inputs.size(); } diff --git a/compiler/include/concretelang/Runtime/context.h b/compiler/include/concretelang/Runtime/context.h index 24a6a1391..0cc9410fc 100644 --- a/compiler/include/concretelang/Runtime/context.h +++ b/compiler/include/concretelang/Runtime/context.h @@ -20,10 +20,9 @@ typedef struct RuntimeContext { std::map bsk; ~RuntimeContext() { - int err; for (const auto &key : bsk) { if (key.first != "_concretelang_base_context_bsk") - free_lwe_bootstrap_key_u64(&err, key.second); + free_lwe_bootstrap_key_u64(key.second); } } } RuntimeContext; diff --git a/compiler/include/concretelang/Runtime/wrappers.h b/compiler/include/concretelang/Runtime/wrappers.h index d85be42b0..28907f13f 100644 --- a/compiler/include/concretelang/Runtime/wrappers.h +++ b/compiler/include/concretelang/Runtime/wrappers.h @@ -8,10 +8,48 @@ #include "concrete-ffi.h" -ForeignPlaintextList_u64 * -runtime_foreign_plaintext_list_u64(int *err, uint64_t *allocated, - uint64_t *aligned, uint64_t offset, - uint64_t size_dim0, uint64_t stride_dim0, - uint64_t size, uint32_t precision); +struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64( + uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size, + uint64_t stride, uint32_t precision); + +void memref_add_lwe_ciphertexts_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t *ct1_allocated, uint64_t *ct1_aligned, + uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride); + +void memref_add_plaintext_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t plaintext); + +void memref_mul_cleartext_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t cleartext); + +void memref_negate_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride); + +void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key, + uint64_t *out_allocated, uint64_t *out_aligned, + uint64_t out_offset, uint64_t out_size, + uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, + uint64_t ct0_size, uint64_t ct0_stride); + +void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key, + uint64_t *out_allocated, uint64_t *out_aligned, + uint64_t out_offset, uint64_t out_size, + uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, + uint64_t ct0_size, uint64_t ct0_stride, + struct GlweCiphertext_u64 *accumulator); #endif diff --git a/compiler/include/concretelang/Support/Jit.h b/compiler/include/concretelang/Support/Jit.h index d6d889078..5bff085b3 100644 --- a/compiler/include/concretelang/Support/Jit.h +++ b/compiler/include/concretelang/Support/Jit.h @@ -96,13 +96,13 @@ public: // Store the values of outputs std::vector outputs; // Store the input gates description and the offset of the argument. - std::vector> inputGates; + std::vector> inputGates; // Store the outputs gates description and the offset of the argument. - std::vector> outputGates; + std::vector> outputGates; // Store allocated lwe ciphertexts (for free) - std::vector allocatedCiphertexts; + std::vector allocatedCiphertexts; // Store buffers of ciphertexts - std::vector ciphertextBuffers; + std::vector ciphertextBuffers; KeySet &keySet; RuntimeContext context; diff --git a/compiler/lib/ClientLib/KeySet.cpp b/compiler/lib/ClientLib/KeySet.cpp index 1b11f6ea7..fd942d4f7 100644 --- a/compiler/lib/ClientLib/KeySet.cpp +++ b/compiler/lib/ClientLib/KeySet.cpp @@ -6,31 +6,20 @@ #include "concretelang/ClientLib/KeySet.h" #include "concretelang/Support/Error.h" -#define CAPI_ERR_TO_LLVM_ERROR(s, msg) \ - { \ - int err; \ - s; \ - if (err != 0) { \ - return llvm::make_error( \ - msg, llvm::inconvertibleErrorCode()); \ - } \ - } - namespace mlir { namespace concretelang { KeySet::~KeySet() { - int err; for (auto it : secretKeys) { - free_lwe_secret_key_u64(&err, it.second.second); + free_lwe_secret_key_u64(it.second.second); } for (auto it : bootstrapKeys) { - free_lwe_bootstrap_key_u64(&err, it.second.second); + free_lwe_bootstrap_key_u64(it.second.second); } for (auto it : keyswitchKeys) { - free_lwe_keyswitch_key_u64(&err, it.second.second); + free_lwe_keyswitch_key_u64(it.second.second); } - free_encryption_generator(&err, encryptionRandomGenerator); + free_encryption_generator(encryptionRandomGenerator); } llvm::Expected> @@ -97,10 +86,8 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms, } } - CAPI_ERR_TO_LLVM_ERROR( - this->encryptionRandomGenerator = - allocate_encryption_generator(&err, seed_msb, seed_lsb), - "cannot allocate encryption generator"); + this->encryptionRandomGenerator = + allocate_encryption_generator(seed_msb, seed_lsb); return llvm::Error::success(); } @@ -108,13 +95,11 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms, llvm::Error KeySet::generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb) { - { // Generate LWE secret keys SecretRandomGenerator *generator; - CAPI_ERR_TO_LLVM_ERROR( - generator = allocate_secret_generator(&err, seed_msb, seed_lsb), - "cannot allocate random generator"); + + generator = allocate_secret_generator(seed_msb, seed_lsb); for (auto secretKeyParam : params.secretKeys) { auto e = this->generateSecretKey(secretKeyParam.first, secretKeyParam.second, generator); @@ -122,14 +107,12 @@ llvm::Error KeySet::generateKeysFromParams(ClientParameters ¶ms, return std::move(e); } } - CAPI_ERR_TO_LLVM_ERROR(free_secret_generator(&err, generator), - "cannot free random generator"); + free_secret_generator(generator); } // Allocate the encryption random generator - CAPI_ERR_TO_LLVM_ERROR( - this->encryptionRandomGenerator = - allocate_encryption_generator(&err, seed_msb, seed_lsb), - "cannot allocate encryption generator"); + + this->encryptionRandomGenerator = + allocate_encryption_generator(seed_msb, seed_lsb); // Generate bootstrap and keyswitch keys { for (auto bootstrapKeyParam : params.bootstrapKeys) { @@ -170,12 +153,9 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param, SecretRandomGenerator *generator) { LweSecretKey_u64 *sk; - CAPI_ERR_TO_LLVM_ERROR( - sk = allocate_lwe_secret_key_u64(&err, {param.size + 1}), - "cannot allocate secret key"); + sk = allocate_lwe_secret_key_u64({param.size}); - CAPI_ERR_TO_LLVM_ERROR(fill_lwe_secret_key_u64(&err, sk, generator), - "cannot fill secret key with random generator"); + fill_lwe_secret_key_u64(sk, generator); secretKeys[id] = {param, sk}; @@ -207,11 +187,9 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id, uint64_t polynomialSize = total_dimension / param.glweDimension; - CAPI_ERR_TO_LLVM_ERROR( - bsk = allocate_lwe_bootstrap_key_u64( - &err, {param.level}, {param.baseLog}, {param.glweDimension + 1}, - {inputSk->second.first.size + 1}, {polynomialSize}), - "cannot allocate bootstrap key"); + bsk = allocate_lwe_bootstrap_key_u64( + {param.level}, {param.baseLog}, {param.glweDimension}, + {inputSk->second.first.size}, {polynomialSize}); // Store the bootstrap key bootstrapKeys[id] = {param, bsk}; @@ -219,23 +197,16 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id, // Convert the output lwe key to glwe key GlweSecretKey_u64 *glwe_sk; - CAPI_ERR_TO_LLVM_ERROR( - glwe_sk = allocate_glwe_secret_key_u64(&err, {param.glweDimension + 1}, - {polynomialSize}), - "cannot allocate glwe key for initiliazation of bootstrap key"); + glwe_sk = + allocate_glwe_secret_key_u64({param.glweDimension}, {polynomialSize}); - CAPI_ERR_TO_LLVM_ERROR(fill_glwe_secret_key_with_lwe_secret_key_u64( - &err, glwe_sk, outputSk->second.second), - "cannot fill glwe key with big key"); + fill_glwe_secret_key_with_lwe_secret_key_u64(glwe_sk, + outputSk->second.second); // Initialize the bootstrap key - CAPI_ERR_TO_LLVM_ERROR( - fill_lwe_bootstrap_key_u64(&err, bsk, inputSk->second.second, glwe_sk, - generator, {param.variance}), - "cannot fill bootstrap key"); - CAPI_ERR_TO_LLVM_ERROR( - free_glwe_secret_key_u64(&err, glwe_sk), - "cannot free glwe key for initiliazation of bootstrap key") + fill_lwe_bootstrap_key_u64(bsk, inputSk->second.second, glwe_sk, generator, + {param.variance}); + free_glwe_secret_key_u64(glwe_sk); return llvm::Error::success(); } @@ -257,33 +228,32 @@ llvm::Error KeySet::generateKeyswitchKey(KeyswitchKeyID id, } // Allocate the keyswitch key LweKeyswitchKey_u64 *ksk; - CAPI_ERR_TO_LLVM_ERROR( - ksk = allocate_lwe_keyswitch_key_u64(&err, {param.level}, {param.baseLog}, - {inputSk->second.first.size + 1}, - {outputSk->second.first.size + 1}), - "cannot allocate keyswitch key"); + + ksk = allocate_lwe_keyswitch_key_u64({param.level}, {param.baseLog}, + {inputSk->second.first.size}, + {outputSk->second.first.size}); + // Store the keyswitch key keyswitchKeys[id] = {param, ksk}; // Initialize the keyswitch key - CAPI_ERR_TO_LLVM_ERROR( - fill_lwe_keyswitch_key_u64(&err, ksk, inputSk->second.second, - outputSk->second.second, generator, - {param.variance}), - "cannot fill bootsrap key"); + + fill_lwe_keyswitch_key_u64(ksk, inputSk->second.second, + outputSk->second.second, generator, + {param.variance}); return llvm::Error::success(); } -llvm::Error KeySet::allocate_lwe(size_t argPos, - LweCiphertext_u64 **ciphertext) { +llvm::Error KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, + uint64_t &size) { if (argPos >= inputs.size()) { return llvm::make_error( "allocate_lwe position of argument is too high", llvm::inconvertibleErrorCode()); } auto inputSk = inputs[argPos]; - CAPI_ERR_TO_LLVM_ERROR(*ciphertext = allocate_lwe_ciphertext_u64( - &err, {std::get<1>(inputSk).size + 1}), - "cannot allocate ciphertext"); + + size = std::get<1>(inputSk).size + 1; + *ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size); return llvm::Error::success(); } @@ -297,7 +267,7 @@ bool KeySet::isOutputEncrypted(size_t argPos) { std::get<0>(outputs[argPos]).encryption.hasValue(); } -llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, +llvm::Error KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) { if (argPos >= inputs.size()) { return llvm::make_error( @@ -311,19 +281,15 @@ llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, llvm::inconvertibleErrorCode()); } // Encode - TODO we could check if the input value is in the right range - Plaintext_u64 plaintext = { - input << (64 - - (std::get<0>(inputSk).encryption->encoding.precision + 1))}; - // Encrypt - CAPI_ERR_TO_LLVM_ERROR( - encrypt_lwe_u64(&err, std::get<2>(inputSk), ciphertext, plaintext, - encryptionRandomGenerator, - {std::get<0>(inputSk).encryption->variance}), - "cannot encrypt"); + uint64_t plaintext = + input << (64 - (std::get<0>(inputSk).encryption->encoding.precision + 1)); + encrypt_lwe_u64(std::get<2>(inputSk), ciphertext, plaintext, + encryptionRandomGenerator, + {std::get<0>(inputSk).encryption->variance}); return llvm::Error::success(); } -llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, +llvm::Error KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) { if (argPos >= outputs.size()) { @@ -337,14 +303,10 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, "decrypt_lwe: the positional argument is not encrypted", llvm::inconvertibleErrorCode()); } - // Decrypt - Plaintext_u64 plaintext = {0}; - CAPI_ERR_TO_LLVM_ERROR( - decrypt_lwe_u64(&err, std::get<2>(outputSk), ciphertext, &plaintext), - "cannot decrypt"); + uint64_t plaintext = decrypt_lwe_u64(std::get<2>(outputSk), ciphertext); // Decode size_t precision = std::get<0>(outputSk).encryption->encoding.precision; - output = plaintext._0 >> (64 - precision - 2); + output = plaintext >> (64 - precision - 2); size_t carry = output % 2; output = ((output >> 1) + carry) % (1 << (precision + 1)); return llvm::Error::success(); diff --git a/compiler/lib/Conversion/ConcreteToConcreteCAPI/ConcreteToConcreteCAPI.cpp b/compiler/lib/Conversion/ConcreteToConcreteCAPI/ConcreteToConcreteCAPI.cpp index 1e163d541..ba242bcf1 100644 --- a/compiler/lib/Conversion/ConcreteToConcreteCAPI/ConcreteToConcreteCAPI.cpp +++ b/compiler/lib/Conversion/ConcreteToConcreteCAPI/ConcreteToConcreteCAPI.cpp @@ -142,13 +142,10 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, auto contextType = mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); - auto errType = mlir::IndexType::get(rewriter.getContext()); - // Insert forward declaration of allocate lwe ciphertext { auto funcType = mlir::FunctionType::get(rewriter.getContext(), { - errType, rewriter.getIndexType(), }, @@ -163,7 +160,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, { auto funcType = mlir::FunctionType::get(rewriter.getContext(), { - errType, genericLweCiphertextType, genericLweCiphertextType, genericLweCiphertextType, @@ -179,7 +175,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, { auto funcType = mlir::FunctionType::get(rewriter.getContext(), { - errType, genericLweCiphertextType, genericLweCiphertextType, genericPlaintextType, @@ -195,7 +190,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, { auto funcType = mlir::FunctionType::get(rewriter.getContext(), { - errType, genericLweCiphertextType, genericLweCiphertextType, genericCleartextType, @@ -211,7 +205,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, { auto funcType = mlir::FunctionType::get( rewriter.getContext(), - {errType, genericLweCiphertextType, genericLweCiphertextType}, {}); + {genericLweCiphertextType, genericLweCiphertextType}, {}); if (insertForwardDeclaration(op, rewriter, "negate_lwe_ciphertext_u64", funcType) .failed()) { @@ -231,7 +225,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, { auto funcType = mlir::FunctionType::get(rewriter.getContext(), { - errType, genericBSKType, genericLweCiphertextType, genericLweCiphertextType, @@ -256,7 +249,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, { auto funcType = mlir::FunctionType::get(rewriter.getContext(), { - errType, // ksk genericKSKType, // output ct @@ -274,7 +266,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, { auto funcType = mlir::FunctionType::get(rewriter.getContext(), { - errType, rewriter.getI32Type(), rewriter.getI32Type(), }, @@ -287,9 +278,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, } // Insert forward declaration of the alloc_plaintext_list function { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - {errType, rewriter.getI32Type()}, - {genericPlaintextListType}); + auto funcType = + mlir::FunctionType::get(rewriter.getContext(), {rewriter.getI32Type()}, + {genericPlaintextListType}); if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64", funcType) .failed()) { @@ -300,7 +291,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, { auto funcType = mlir::FunctionType::get( rewriter.getContext(), - {errType, genericPlaintextListType, genericForeignPlaintextList}, {}); + {genericPlaintextListType, genericForeignPlaintextList}, {}); if (insertForwardDeclaration( op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType) .failed()) { @@ -310,7 +301,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, // Insert forward declaration of the add_plaintext_list_glwe function { auto funcType = mlir::FunctionType::get(rewriter.getContext(), - {errType, genericGlweCiphertextType, + {genericGlweCiphertextType, genericGlweCiphertextType, genericPlaintextListType}, {}); @@ -356,24 +347,20 @@ struct ConcreteOpToConcreteCAPICallPattern : public mlir::OpRewritePattern { resultType.cast(); // Replace the operation with a call to the `funcName` { - // Create the err value - auto errOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); // Get the size from the dimension int64_t lweDimension = lweResultType.getDimension(); - int64_t lweSize = lweDimension + 1; - mlir::Value lweSizeOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(lweSize)); + + mlir::Value lweDimensionOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(lweDimension)); // Add the call to the allocation - mlir::SmallVector allocOperands{errOp, lweSizeOp}; + mlir::SmallVector allocOperands{lweDimensionOp}; auto allocGeneric = rewriter.create( op.getLoc(), allocName, getGenericLweCiphertextType(rewriter.getContext()), allocOperands); // Construct operands for the operation. // errOp doesn't need to be casted to something generic, allocGeneric // already is. All the rest will be converted if needed - mlir::SmallVector newOperands{errOp, - allocGeneric.getResult(0)}; + mlir::SmallVector newOperands{allocGeneric.getResult(0)}; for (mlir::Value operand : op->getOperands()) { mlir::Type operandType = operand.getType(); mlir::Type castedType = getGenericType(operandType); @@ -420,16 +407,13 @@ struct ConcreteZeroOpPattern mlir::Type resultType = op->getResultTypes().front(); auto lweResultType = resultType.cast(); - // Create the err value - auto errOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); // Get the size from the dimension int64_t lweDimension = lweResultType.getDimension(); - int64_t lweSize = lweDimension + 1; - mlir::Value lweSizeOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(lweSize)); + + mlir::Value lweDimensionOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(lweDimension)); // Allocate a fresh new ciphertext - mlir::SmallVector allocOperands{errOp, lweSizeOp}; + mlir::SmallVector allocOperands{lweDimensionOp}; auto allocGeneric = rewriter.create( op.getLoc(), "allocate_lwe_ciphertext_u64", getGenericLweCiphertextType(rewriter.getContext()), allocOperands); @@ -506,7 +490,6 @@ struct GlweFromTableOpPattern matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op, mlir::PatternRewriter &rewriter) const override { ConcreteToConcreteCAPITypeConverter typeConverter; - auto errType = mlir::IndexType::get(rewriter.getContext()); // TODO: move this to insertForwardDeclarations // issue: can't define function with tensor<*xtype> that accept ranked @@ -516,28 +499,22 @@ struct GlweFromTableOpPattern { auto funcType = mlir::FunctionType::get( rewriter.getContext(), - {errType, op->getOperandTypes().front(), rewriter.getI64Type(), - rewriter.getI32Type()}, + {op->getOperandTypes().front(), rewriter.getI32Type()}, {getGenericForeignPlaintextListType(rewriter.getContext())}); - if (insertForwardDeclaration( - op, rewriter, "runtime_foreign_plaintext_list_u64", funcType) + if (insertForwardDeclaration(op, rewriter, + "memref_runtime_foreign_plaintext_list_u64", + funcType) .failed()) { return mlir::failure(); } } - auto errOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); - // Get the size from the dimension - int64_t glweDimension = - op->getAttr("glweDimension").cast().getInt(); - int64_t glweSize = glweDimension + 1; - mlir::Value glweSizeOp = rewriter.create( - op.getLoc(), rewriter.getI32IntegerAttr(glweSize)); // allocate two glwe to build accumulator auto polySizeOp = rewriter.create( op.getLoc(), op->getAttr("polynomialSize")); - mlir::SmallVector allocGlweOperands{errOp, glweSizeOp, + auto glweDimensionOp = rewriter.create( + op.getLoc(), op->getAttr("glweDimension")); + mlir::SmallVector allocGlweOperands{glweDimensionOp, polySizeOp}; // first accumulator would replace the op since it's the returned value auto accumulatorOp = rewriter.replaceOpWithNewOp( @@ -548,8 +525,7 @@ struct GlweFromTableOpPattern op.getLoc(), "allocate_glwe_ciphertext_u64", getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands); // allocate plaintext list - mlir::SmallVector allocPlaintextListOperands{errOp, - polySizeOp}; + mlir::SmallVector allocPlaintextListOperands{polySizeOp}; auto plaintextListOp = rewriter.create( op.getLoc(), "allocate_plaintext_list_u64", getGenericPlaintextListType(rewriter.getContext()), @@ -559,27 +535,23 @@ struct GlweFromTableOpPattern op->getOperandTypes().front().cast(); assert(rankedTensorType.getRank() == 1 && "table lookup must be of a single dimension"); - auto sizeOp = rewriter.create( - op.getLoc(), - rewriter.getI64IntegerAttr(rankedTensorType.getDimSize(0))); auto precisionOp = rewriter.create(op.getLoc(), op->getAttr("p")); mlir::SmallVector ForeignPlaintextListOperands{ - errOp, op->getOperand(0), sizeOp, precisionOp}; + op->getOperand(0), precisionOp}; auto foreignPlaintextListOp = rewriter.create( - op.getLoc(), "runtime_foreign_plaintext_list_u64", + op.getLoc(), "memref_runtime_foreign_plaintext_list_u64", getGenericForeignPlaintextListType(rewriter.getContext()), ForeignPlaintextListOperands); // fill plaintext list mlir::SmallVector FillPlaintextListOperands{ - errOp, plaintextListOp.getResult(0), - foreignPlaintextListOp.getResult(0)}; + plaintextListOp.getResult(0), foreignPlaintextListOp.getResult(0)}; rewriter.create( op.getLoc(), "fill_plaintext_list_with_expansion_u64", mlir::TypeRange({}), FillPlaintextListOperands); // add plaintext list and glwe to build final accumulator for pbs mlir::SmallVector AddPlaintextListGlweOperands{ - errOp, accumulatorOp.getResult(0), _accumulatorOp.getResult(0), + accumulatorOp.getResult(0), _accumulatorOp.getResult(0), plaintextListOp.getResult(0)}; rewriter.create( op.getLoc(), "add_plaintext_list_glwe_ciphertext_u64", @@ -626,18 +598,15 @@ struct ConcreteBootstrapLweOpPattern matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp op, mlir::PatternRewriter &rewriter) const override { auto resultType = op->getResultTypes().front(); - auto errOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); // Get the size from the dimension int64_t outputLweDimension = resultType.cast() .getDimension(); - int64_t outputLweSize = outputLweDimension + 1; mlir::Value lweSizeOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(outputLweSize)); + op.getLoc(), rewriter.getIndexAttr(outputLweDimension)); // allocate the result lwe ciphertext, should be of a generic type, to cast // before return - mlir::SmallVector allocLweCtOperands{errOp, lweSizeOp}; + mlir::SmallVector allocLweCtOperands{lweSizeOp}; auto allocateGenericLweCtOp = rewriter.create( op.getLoc(), "allocate_lwe_ciphertext_u64", getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands); @@ -662,7 +631,7 @@ struct ConcreteBootstrapLweOpPattern op.getOperand(1)) .getResult(0); mlir::SmallVector bootstrapOperands{ - errOp, getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0), + getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0), lweToBootstrap, accumulator}; rewriter.create(op.getLoc(), "bootstrap_lwe_u64", mlir::TypeRange({}), bootstrapOperands); @@ -690,20 +659,17 @@ struct ConcreteKeySwitchLweOpPattern mlir::LogicalResult matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp op, mlir::PatternRewriter &rewriter) const override { - auto errOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); // Get the size from the dimension int64_t lweDimension = op.getResult() .getType() .cast() .getDimension(); - int64_t lweSize = lweDimension + 1; - mlir::Value lweSizeOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(lweSize)); + mlir::Value lweDimensionOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(lweDimension)); // allocate the result lwe ciphertext, should be of a generic type, to cast // before return - mlir::SmallVector allocLweCtOperands{errOp, lweSizeOp}; + mlir::SmallVector allocLweCtOperands{lweDimensionOp}; auto allocateGenericLweCtOp = rewriter.create( op.getLoc(), "allocate_lwe_ciphertext_u64", getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands); @@ -721,7 +687,7 @@ struct ConcreteKeySwitchLweOpPattern op.getOperand()) .getResult(0); mlir::SmallVector keyswitchOperands{ - errOp, getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0), + getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0), lweToKeyswitch}; rewriter.create(op.getLoc(), "keyswitch_lwe_u64", mlir::TypeRange({}), keyswitchOperands); diff --git a/compiler/lib/Runtime/CMakeLists.txt b/compiler/lib/Runtime/CMakeLists.txt index 1fb9dda44..2a86ab8aa 100644 --- a/compiler/lib/Runtime/CMakeLists.txt +++ b/compiler/lib/Runtime/CMakeLists.txt @@ -14,9 +14,9 @@ if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED) install(TARGETS DFRuntime EXPORT DFRuntime) install(EXPORT DFRuntime DESTINATION "./") - target_link_libraries(ConcretelangRuntime Concrete pthread m dl HPX::hpx) + target_link_libraries(ConcretelangRuntime Concrete pthread m dl HPX::hpx $) else() - target_link_libraries(ConcretelangRuntime Concrete pthread m dl) + target_link_libraries(ConcretelangRuntime Concrete pthread m dl $) endif() install(TARGETS ConcretelangRuntime EXPORT ConcretelangRuntime) diff --git a/compiler/lib/Runtime/context.cpp b/compiler/lib/Runtime/context.cpp index 07daec465..7a0dbcc7d 100644 --- a/compiler/lib/Runtime/context.cpp +++ b/compiler/lib/Runtime/context.cpp @@ -19,7 +19,6 @@ get_keyswitch_key(mlir::concretelang::RuntimeContext *context) { LweBootstrapKey_u64 * get_bootstrap_key(mlir::concretelang::RuntimeContext *context) { #ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED - int err; std::string threadName = hpx::get_thread_name(); auto bskIt = context->bsk.find(threadName); if (bskIt == context->bsk.end()) { @@ -27,10 +26,8 @@ get_bootstrap_key(mlir::concretelang::RuntimeContext *context) { .insert(std::pair( threadName, clone_lwe_bootstrap_key_u64( - &err, context->bsk["_concretelang_base_context_bsk"]))) + context->bsk["_concretelang_base_context_bsk"]))) .first; - if (err != 0) - fprintf(stderr, "Runtime: cloning bootstrap key failed.\n"); } #else std::string baseName = "_concretelang_base_context_bsk"; diff --git a/compiler/lib/Runtime/wrappers.c b/compiler/lib/Runtime/wrappers.c index 5ba888a9c..6e2cb9afb 100644 --- a/compiler/lib/Runtime/wrappers.c +++ b/compiler/lib/Runtime/wrappers.c @@ -1,20 +1,89 @@ #include "concretelang/Runtime/wrappers.h" +#include #include -ForeignPlaintextList_u64 * -runtime_foreign_plaintext_list_u64(int *err, uint64_t *allocated, - uint64_t *aligned, uint64_t offset, - uint64_t size_dim0, uint64_t stride_dim0, - uint64_t size, uint32_t precision) { - if (stride_dim0 != 1) { - fprintf(stderr, "Runtime: stride not equal to 1, check " - "runtime_foreign_plaintext_list_u64"); - } +struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64( + uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size, + uint64_t stride, uint32_t precision) { + + assert(stride == 1 && "Runtime: stride not equal to 1, check " + "runtime_foreign_plaintext_list_u64"); + // Encode table values in u64 uint64_t *encoded_table = malloc(size * sizeof(uint64_t)); for (uint64_t i = 0; i < size; i++) { encoded_table[i] = (aligned + offset)[i] << (64 - precision - 1); } - return foreign_plaintext_list_u64(err, encoded_table, size); + return foreign_plaintext_list_u64(encoded_table, size); // TODO: is it safe to free after creating plaintext_list? } + +void memref_add_lwe_ciphertexts_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t *ct1_allocated, uint64_t *ct1_aligned, + uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride) { + assert(out_size == ct0_size && out_size == ct1_size && + "size of lwe buffer are incompatible"); + LweDimension lwe_dimension = {out_size - 1}; + add_two_lwe_ciphertexts_u64(out_aligned + out_offset, + ct0_aligned + ct0_offset, + ct1_aligned + ct1_offset, lwe_dimension); +} + +void memref_add_plaintext_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t plaintext) { + assert(out_size == ct0_size && "size of lwe buffer are incompatible"); + LweDimension lwe_dimension = {out_size - 1}; + add_plaintext_to_lwe_ciphertext_u64(out_aligned + out_offset, + ct0_aligned + ct0_offset, plaintext, + lwe_dimension); +} + +void memref_mul_cleartext_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t cleartext) { + assert(out_size == ct0_size && "size of lwe buffer are incompatible"); + LweDimension lwe_dimension = {out_size - 1}; + mul_cleartext_lwe_ciphertext_u64(out_aligned + out_offset, + ct0_aligned + ct0_offset, cleartext, + lwe_dimension); +} + +void memref_negate_lwe_ciphertext_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride) { + assert(out_size == ct0_size && "size of lwe buffer are incompatible"); + LweDimension lwe_dimension = {out_size - 1}; + neg_lwe_ciphertext_u64(out_aligned + out_offset, ct0_aligned + ct0_offset, + lwe_dimension); +} + +void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key, + uint64_t *out_allocated, uint64_t *out_aligned, + uint64_t out_offset, uint64_t out_size, + uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, + uint64_t ct0_size, uint64_t ct0_stride) { + bufferized_keyswitch_lwe_u64(keyswitch_key, out_aligned + out_offset, + ct0_aligned + ct0_offset); +} + +void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key, + uint64_t *out_allocated, uint64_t *out_aligned, + uint64_t out_offset, uint64_t out_size, + uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, + uint64_t ct0_size, uint64_t ct0_stride, + struct GlweCiphertext_u64 *accumulator) { + bufferized_bootstrap_lwe_u64(bootstrap_key, out_aligned + out_offset, + ct0_aligned + ct0_offset, accumulator); +} diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index c79a38eeb..0df828f1b 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -78,6 +78,12 @@ llvm::Error JITLambda::invoke(Argument &args) { << actualInputs << "arguments instead of " << expectedInputs; } +// memref is a struct which is flattened aligned, allocated pointers, offset, +// and two array of rank size for sizes and strides. +uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) { + return 3 + 2 * rank; +} + JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { // Setting the inputs auto numInputs = 0; @@ -86,16 +92,20 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { auto offset = numInputs; auto gate = keySet.inputGate(i); inputGates.push_back({gate, offset}); - if (keySet.inputGate(i).shape.dimensions.empty()) { + if (gate.shape.dimensions.empty()) { // scalar gate - numInputs = numInputs + 1; + if (gate.encryption.hasValue()) { + // encrypted is a memref + numInputs = numInputs + numArgOfRankedMemrefCallingConvention(1); + } else { + numInputs = numInputs + 1; + } continue; } // memref gate, as we follow the standard calling convention - numInputs = numInputs + 3; - // Offsets and strides are array of size N where N is the number of - // dimension of the tensor. - numInputs = numInputs + 2 * keySet.inputGate(i).shape.dimensions.size(); + auto rank = keySet.inputGate(i).shape.dimensions.size() + + (keySet.isInputEncrypted(i) ? 1 /* for lwe size */ : 0); + numInputs = numInputs + numArgOfRankedMemrefCallingConvention(rank); } // Reserve for the context argument numInputs = numInputs + 1; @@ -111,19 +121,21 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { outputGates.push_back({gate, offset}); if (gate.shape.dimensions.empty()) { // scalar gate - numOutputs = numOutputs + 1; + if (gate.encryption.hasValue()) { + // encrypted is a memref + numOutputs = numOutputs + numArgOfRankedMemrefCallingConvention(1); + } else { + numOutputs = numOutputs + 1; + } continue; } // memref gate, as we follow the standard calling convention - numOutputs = numOutputs + 3; - // Offsets and strides are array of size N where N is the number of - // dimension of the tensor. - numOutputs = - numOutputs + 2 * keySet.outputGate(i).shape.dimensions.size(); + auto rank = keySet.outputGate(i).shape.dimensions.size() + + (keySet.isOutputEncrypted(i) ? 1 /* for lwe size */ : 0); + numOutputs = numOutputs + numArgOfRankedMemrefCallingConvention(rank); } outputs = std::vector(numOutputs); } - // The raw argument contains pointers to inputs and pointers to store the // results rawArg = std::vector(inputs.size() + outputs.size(), nullptr); @@ -139,9 +151,8 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { } JITLambda::Argument::~Argument() { - int err; for (auto ct : allocatedCiphertexts) { - free_lwe_ciphertext_u64(&err, ct); + free(ct); } for (auto buffer : ciphertextBuffers) { free(buffer); @@ -185,16 +196,31 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { return llvm::Error::success(); } // Else if is encryted, allocate ciphertext and encrypt. - LweCiphertext_u64 *ctArg; - if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) { + uint64_t *ctArg; + uint64_t ctSize; + if (auto err = this->keySet.allocate_lwe(pos, &ctArg, ctSize)) { return std::move(err); } allocatedCiphertexts.push_back(ctArg); if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) { return std::move(err); } - inputs[offset] = ctArg; + // memref calling convention + // allocated + inputs[offset] = nullptr; + // aligned + inputs[offset + 1] = ctArg; + // offset + inputs[offset + 2] = (void *)0; + // size + inputs[offset + 3] = (void *)ctSize; + // stride + inputs[offset + 4] = (void *)1; rawArg[offset] = &inputs[offset]; + rawArg[offset + 1] = &inputs[offset + 1]; + rawArg[offset + 2] = &inputs[offset + 2]; + rawArg[offset + 3] = &inputs[offset + 3]; + rawArg[offset + 4] = &inputs[offset + 4]; return llvm::Error::success(); } @@ -279,17 +305,18 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, llvm::inconvertibleErrorCode()); } - // Allocate a buffer for ciphertexts. - auto ctBuffer = (LweCiphertext_u64 **)malloc(info.shape.size * - sizeof(LweCiphertext_u64 *)); + // Allocate a buffer for ciphertexts, the size of the buffer is the number + // of elements of the tensor * the size of the lwe ciphertext + auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1; + uint64_t *ctBuffer = + (uint64_t *)malloc(info.shape.size * lweSize * sizeof(uint64_t)); ciphertextBuffers.push_back(ctBuffer); - // Allocate ciphertexts and encrypt - for (size_t i = 0; i < info.shape.size; i++) { - if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) { - return std::move(err); - } - allocatedCiphertexts.push_back(ctBuffer[i]); - if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) { + // Encrypt ciphertexts + for (size_t i = 0, offset = 0; i < info.shape.size; + i++, offset += lweSize) { + + if (auto err = + this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i])) { return std::move(err); } } @@ -316,17 +343,27 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, rawArg[offset] = &inputs[offset]; offset++; } + // If encrypted +1 for the lwe size rank + if (keySet.isInputEncrypted(pos)) { + inputs[offset] = (void *)(keySet.getInputLweSecretKeyParam(pos).size + 1); + rawArg[offset] = &inputs[offset]; + offset++; + } // Set the stride for each dimension, equal to the product of the // following dimensions. int64_t stride = 1; - + // If encrypted +1 set the stride for the lwe size rank + if (keySet.isInputEncrypted(pos)) { + inputs[offset + shape.size()] = (void *)stride; + rawArg[offset + shape.size()] = &inputs[offset]; + stride *= keySet.getInputLweSecretKeyParam(pos).size + 1; + } for (ssize_t i = shape.size() - 1; i >= 0; i--) { inputs[offset + i] = (void *)stride; rawArg[offset + i] = &inputs[offset + i]; stride *= shape[i]; } - offset += shape.size(); return llvm::Error::success(); @@ -349,7 +386,7 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) { return llvm::Error::success(); } // Else if is encryted, decrypt - LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]); + uint64_t *ct = (uint64_t *)(outputs[offset + 1]); if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) { return std::move(err); } @@ -463,8 +500,10 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, void *res, } } else { // decrypt and fill the result buffer - for (size_t i = 0; i < numElements; i++) { - LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)alignedBytes)[i]; + auto lweSize = keySet.getOutputLweSecretKeyParam(pos).size + 1; + + for (size_t i = 0, o = 0; i < numElements; i++, o += lweSize) { + uint64_t *ct = ((uint64_t *)alignedBytes) + o; if (auto err = this->keySet.decrypt_lwe(pos, ct, ((uint64_t *)res)[i])) { return std::move(err); } diff --git a/compiler/lib/Support/JitCompilerEngine.cpp b/compiler/lib/Support/JitCompilerEngine.cpp index a73dc465b..6c8f7e743 100644 --- a/compiler/lib/Support/JitCompilerEngine.cpp +++ b/compiler/lib/Support/JitCompilerEngine.cpp @@ -58,7 +58,6 @@ JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName, std::unique_ptr mb = llvm::MemoryBuffer::getMemBuffer(s); llvm::Expected res = this->buildLambda(std::move(mb), funcName, cache, runtimeLibPath); - return std::move(res); }