diff --git a/compiler/Makefile b/compiler/Makefile index 627060b64..353b8ebb8 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -171,7 +171,7 @@ test-end-to-end-jit-dfr: build-end-to-end-jit-dfr test-end-to-end-jit-auto-parallelization: build-end-to-end-jit-auto-parallelization $(BUILD_DIR)/bin/end_to_end_jit_auto_parallelization -test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-fhelinalg +test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-fhelinalg test-end-to-end-jit-fhe show-stress-tests-summary: @echo '------ Stress tests summary ------' diff --git a/compiler/include/concretelang/ClientLib/KeySet.h b/compiler/include/concretelang/ClientLib/KeySet.h index 0c432f289..3039dc1ca 100644 --- a/compiler/include/concretelang/ClientLib/KeySet.h +++ b/compiler/include/concretelang/ClientLib/KeySet.h @@ -27,6 +27,7 @@ using RuntimeContext = mlir::concretelang::RuntimeContext; class KeySet { public: + KeySet(); ~KeySet(); // allocate a KeySet according the ClientParameters. @@ -81,8 +82,7 @@ public: void setRuntimeContext(RuntimeContext &context) { context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]); - context.bsk[RuntimeContext::BASE_CONTEXT_BSK] = - std::get<1>(this->bootstrapKeys.at("bsk_v0")); + context.bsk = std::get<1>(this->bootstrapKeys.at("bsk_v0")); } RuntimeContext runtimeContext() { @@ -105,14 +105,13 @@ public: protected: outcome::checked - generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param, - SecretRandomGenerator *generator); + generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param); + outcome::checked - generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param, - EncryptionRandomGenerator *generator); + generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param); + outcome::checked - generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param, - EncryptionRandomGenerator *generator); + generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param); outcome::checked generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb, @@ -125,7 +124,7 @@ protected: friend class KeySetCache; private: - EncryptionRandomGenerator *encryptionRandomGenerator; + Engine *engine; std::map> secretKeys; std::map> diff --git a/compiler/include/concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h b/compiler/include/concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h deleted file mode 100644 index 68eaa1c5d..000000000 --- a/compiler/include/concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h +++ /dev/null @@ -1,22 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt -// for license information. - -#ifndef CONCRETELANG_CONVERSION_CONCRETETOCONCRETECAPI_PASS_H_ -#define CONCRETELANG_CONVERSION_CONCRETETOCONCRETECAPI_PASS_H_ - -#include "mlir/Pass/Pass.h" - -#include "concretelang/Conversion/Utils/GlobalFHEContext.h" - -namespace mlir { -namespace concretelang { -/// Create a pass to convert `Concrete` operators to function call to the -/// `ConcreteCAPI` -std::unique_ptr> -createConvertConcreteToConcreteCAPIPass(); -} // namespace concretelang -} // namespace mlir - -#endif \ No newline at end of file diff --git a/compiler/include/concretelang/Conversion/ConcreteUnparametrize/Pass.h b/compiler/include/concretelang/Conversion/ConcreteUnparametrize/Pass.h deleted file mode 100644 index 254492634..000000000 --- a/compiler/include/concretelang/Conversion/ConcreteUnparametrize/Pass.h +++ /dev/null @@ -1,18 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt -// for license information. - -#ifndef CONCRETELANG_CONVERSION_CONCRETEUNPARAMETRIZE_PASS_H_ -#define CONCRETELANG_CONVERSION_CONCRETEUNPARAMETRIZE_PASS_H_ - -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace concretelang { -std::unique_ptr> -createConvertConcreteUnparametrizePass(); -} // namespace concretelang -} // namespace mlir - -#endif \ No newline at end of file diff --git a/compiler/include/concretelang/Conversion/Passes.h b/compiler/include/concretelang/Conversion/Passes.h index 82ec18c3f..5a97ea49f 100644 --- a/compiler/include/concretelang/Conversion/Passes.h +++ b/compiler/include/concretelang/Conversion/Passes.h @@ -13,8 +13,6 @@ #include "concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h" #include "concretelang/Conversion/ConcreteToBConcrete/Pass.h" -#include "concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h" -#include "concretelang/Conversion/ConcreteUnparametrize/Pass.h" #include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h" #include "concretelang/Conversion/FHEToTFHE/Pass.h" #include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h" diff --git a/compiler/include/concretelang/Conversion/Passes.td b/compiler/include/concretelang/Conversion/Passes.td index d8e5b0587..015258683 100644 --- a/compiler/include/concretelang/Conversion/Passes.td +++ b/compiler/include/concretelang/Conversion/Passes.td @@ -40,24 +40,12 @@ def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> { let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"]; } -def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp"> { - let summary = "Lower operations from the Concrete dialect to std with function call to the Concrete C API"; - let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()"; - let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"]; -} - def BConcreteToBConcreteCAPI : Pass<"bconcrete-to-bconcrete-c-api", "mlir::ModuleOp"> { let summary = "Lower operations from the Bufferized Concrete dialect to std with function call to the Bufferized Concrete C API"; let constructor = "mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass()"; let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"]; } -def ConcreteUnparametrize : Pass<"concrete-unparametrize", "mlir::ModuleOp"> { - let summary = "Unparametrize Concrete types and remove unrealized_conversion_cast"; - let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()"; - let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"]; -} - def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> { let summary = "Lowers operations from MLIR lowerable dialects to LLVM"; let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()"; diff --git a/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h b/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h index dcc1d460b..aa3eb4bbd 100644 --- a/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h +++ b/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h @@ -210,7 +210,7 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter, // % [[TABLE]]){glweDimension = 1 : i32, p = 4 : i32, polynomialSize = // 2048 : i32} // : (tensor<16xi4>) -// ->!Concrete.glwe_ciphertext +// ->!Concrete.glwe_ciphertext<2048, 1, 4> // % keyswitched = "Concrete.keyswitch_lwe"(% arg0){ // baseLog = 2 : i32, // level = 3 : i32 @@ -221,7 +221,7 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter, // glweDimension = 1 : i32, // level = 5 : i32, // polynomialSize = 2048 : i32 -// } : (!Concrete.lwe_ciphertext<600, 4>, !Concrete.glwe_ciphertext) +// } : (!Concrete.lwe_ciphertext<600, 4>, !Concrete.glwe_ciphertext<2048, 1, 4>) // ->!Concrete.lwe_ciphertext<2048, 4> // ``` mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc, @@ -240,8 +240,11 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Value accumulator = rewriter .create( - loc, Concrete::GlweCiphertextType::get(rewriter.getContext()), - table, polynomialSize, glweDimension, precision) + loc, + Concrete::GlweCiphertextType::get( + rewriter.getContext(), polynomialSize.getInt(), + glweDimension.getInt(), lwe_type.getP()), + table) .result(); // keyswitch diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td index 44ff5ca92..6c1e157f6 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td +++ b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td @@ -36,6 +36,17 @@ def NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> { let results = (outs); } +def FillGlweFromTable : BConcrete_Op<"fill_glwe_from_table"> { + let arguments = (ins + 1DTensorOf<[I64]>:$glwe, + I32Attr:$polynomialSize, + I32Attr:$glweDimension, + I32Attr:$outPrecision, + 1DTensorOf<[I64]>:$table + ); + let results = (outs); +} + def KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> { let arguments = (ins 1DTensorOf<[I64]>:$result, @@ -52,7 +63,7 @@ def BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> { 1DTensorOf<[I64]>:$result, // LweBootstrapKeyType:$bootstrap_key, 1DTensorOf<[I64]>:$input_ciphertext, - GlweCiphertextType:$accumulator, + 1DTensorOf<[I64]>:$accumulator, I32Attr:$glweDimension, I32Attr:$polynomialSize, I32Attr:$level, diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td index 6a36396a3..75665c7fd 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td @@ -55,7 +55,7 @@ def NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> { def GlweFromTable : Concrete_Op<"glwe_from_table"> { let summary = "Creates a GLWE ciphertext which is the trivial encrytion of a the input table interpreted as a polynomial (to use later in a bootstrap)"; - let arguments = (ins TensorOf<[AnyInteger]>:$table, I32Attr:$polynomialSize, I32Attr:$glweDimension, I32Attr:$p); + let arguments = (ins 1DTensorOf<[I64]>:$table); let results = (outs GlweCiphertextType:$result); } diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td index 6b0db0083..b8ef6f303 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td @@ -16,12 +16,48 @@ def GlweCiphertextType : Concrete_Type<"GlweCiphertext"> { GLWE ciphertext. }]; + let parameters = (ins + "signed":$polynomialSize, + "signed":$glweDimension, + // Precision of the lwe ciphertext + "signed":$p + ); + let printer = [{ - $_printer << "glwe_ciphertext"; + $_printer << "glwe_ciphertext<"; + if (getImpl()->polynomialSize == -1) $_printer << "_"; + else $_printer << getImpl()->polynomialSize; + $_printer << ","; + if (getImpl()->glweDimension == -1) $_printer << "_"; + else $_printer << getImpl()->glweDimension; + $_printer << ","; + if (getImpl()->p == -1) $_printer << "_"; + else $_printer << getImpl()->p; + $_printer << ">"; }]; + let parser = [{ - return get($_ctxt); + if ($_parser.parseLess()) + return Type(); + int polynomialSize = -1; + if ($_parser.parseOptionalKeyword("_") && $_parser.parseInteger(polynomialSize)) + return Type(); + if ($_parser.parseComma()) + return Type(); + int glweDimension = -1; + if ($_parser.parseOptionalKeyword("_") && $_parser.parseInteger(glweDimension)) + return Type(); + if ($_parser.parseComma()) + return Type(); + + int p = -1; + if ($_parser.parseOptionalKeyword("_") && $_parser.parseInteger(p)) + return Type(); + if ($_parser.parseGreater()) + return Type(); + Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc()); + return getChecked(loc, loc.getContext(), polynomialSize, glweDimension, p); }]; } diff --git a/compiler/include/concretelang/Runtime/context.h b/compiler/include/concretelang/Runtime/context.h index 074a5d4df..d484ccbc0 100644 --- a/compiler/include/concretelang/Runtime/context.h +++ b/compiler/include/concretelang/Runtime/context.h @@ -7,6 +7,7 @@ #define CONCRETELANG_RUNTIME_CONTEXT_H #include +#include #include extern "C" { @@ -18,14 +19,29 @@ namespace concretelang { typedef struct RuntimeContext { LweKeyswitchKey_u64 *ksk; - std::map bsk; + LweBootstrapKey_u64 *bsk; +#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED + std::map engines; + std::mutex engines_map_guard; +#else + Engine *engine; +#endif - static std::string BASE_CONTEXT_BSK; + RuntimeContext() +#ifndef CONCRETELANG_PARALLEL_EXECUTION_ENABLED + : engine(nullptr) +#endif + { + } ~RuntimeContext() { - for (const auto &key : bsk) { - if (key.first != BASE_CONTEXT_BSK) - free_lwe_bootstrap_key_u64(key.second); +#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED + for (const auto &key : engines) { + free_engine(key.second); } +#else + if (engine != nullptr) + free_engine(engine); +#endif } } RuntimeContext; @@ -34,9 +50,11 @@ typedef struct RuntimeContext { extern "C" { LweKeyswitchKey_u64 * -get_keyswitch_key(mlir::concretelang::RuntimeContext *context); +get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context); LweBootstrapKey_u64 * -get_bootstrap_key(mlir::concretelang::RuntimeContext *context); +get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context); + +Engine *get_engine(mlir::concretelang::RuntimeContext *context); } #endif diff --git a/compiler/include/concretelang/Runtime/wrappers.h b/compiler/include/concretelang/Runtime/wrappers.h index 28907f13f..f4f85a0ec 100644 --- a/compiler/include/concretelang/Runtime/wrappers.h +++ b/compiler/include/concretelang/Runtime/wrappers.h @@ -6,11 +6,17 @@ #ifndef CONCRETELANG_RUNTIME_WRAPPERS_H #define CONCRETELANG_RUNTIME_WRAPPERS_H +#include "concretelang/Runtime/context.h" + +extern "C" { #include "concrete-ffi.h" -struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64( - uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size, - uint64_t stride, uint32_t precision); +void memref_expand_lut_in_trivial_glwe_ct_u64( + uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, + uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision, + uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset, + uint64_t lut_size, uint64_t lut_stride); void memref_add_lwe_ciphertexts_u64( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, @@ -37,19 +43,20 @@ void memref_negate_lwe_ciphertext_u64( uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, uint64_t ct0_stride); -void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key, - uint64_t *out_allocated, uint64_t *out_aligned, - uint64_t out_offset, uint64_t out_size, - uint64_t out_stride, uint64_t *ct0_allocated, - uint64_t *ct0_aligned, uint64_t ct0_offset, - uint64_t ct0_size, uint64_t ct0_stride); - -void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key, - uint64_t *out_allocated, uint64_t *out_aligned, +void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, uint64_t ct0_stride, - struct GlweCiphertext_u64 *accumulator); + mlir::concretelang::RuntimeContext *context); + +void memref_bootstrap_lwe_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, + uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + mlir::concretelang::RuntimeContext *context); +} #endif diff --git a/compiler/lib/ClientLib/KeySet.cpp b/compiler/lib/ClientLib/KeySet.cpp index 23d4c0674..c1e53fb6e 100644 --- a/compiler/lib/ClientLib/KeySet.cpp +++ b/compiler/lib/ClientLib/KeySet.cpp @@ -18,6 +18,8 @@ namespace concretelang { namespace clientlib { +KeySet::KeySet() : engine(new_engine()) {} + KeySet::~KeySet() { for (auto it : secretKeys) { free_lwe_secret_key_u64(it.second.second); @@ -28,7 +30,7 @@ KeySet::~KeySet() { for (auto it : keyswitchKeys) { free_lwe_keyswitch_key_u64(it.second.second); } - free_encryption_generator(encryptionRandomGenerator); + free_engine(engine); } outcome::checked, StringError> @@ -81,9 +83,6 @@ KeySet::setupEncryptionMaterial(ClientParameters ¶ms, uint64_t seed_msb, } } - this->encryptionRandomGenerator = - allocate_encryption_generator(seed_msb, seed_lsb); - return outcome::success(); } @@ -93,29 +92,20 @@ KeySet::generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb, { // Generate LWE secret keys - SecretRandomGenerator *generator; - - generator = allocate_secret_generator(seed_msb, seed_lsb); for (auto secretKeyParam : params.secretKeys) { - OUTCOME_TRYV(this->generateSecretKey(secretKeyParam.first, - secretKeyParam.second, generator)); + OUTCOME_TRYV( + this->generateSecretKey(secretKeyParam.first, secretKeyParam.second)); } - free_secret_generator(generator); } - // Allocate the encryption random generator - this->encryptionRandomGenerator = - allocate_encryption_generator(seed_msb, seed_lsb); // Generate bootstrap and keyswitch keys { for (auto bootstrapKeyParam : params.bootstrapKeys) { OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam.first, - bootstrapKeyParam.second, - this->encryptionRandomGenerator)); + bootstrapKeyParam.second)); } for (auto keyswitchParam : params.keyswitchKeys) { OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam.first, - keyswitchParam.second, - this->encryptionRandomGenerator)); + keyswitchParam.second)); } } return outcome::success(); @@ -136,12 +126,9 @@ void KeySet::setKeys( } outcome::checked -KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param, - SecretRandomGenerator *generator) { +KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param) { LweSecretKey_u64 *sk; - sk = allocate_lwe_secret_key_u64({param.dimension}); - - fill_lwe_secret_key_u64(sk, generator); + sk = generate_lwe_secret_key_u64(engine, param.dimension); secretKeys[id] = {param, sk}; @@ -149,8 +136,7 @@ KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param, } outcome::checked -KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param, - EncryptionRandomGenerator *generator) { +KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) { // Finding input and output secretKeys auto inputSk = secretKeys.find(param.inputSecretKeyID); if (inputSk == secretKeys.end()) { @@ -169,32 +155,18 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param, uint64_t polynomialSize = total_dimension / param.glweDimension; - bsk = allocate_lwe_bootstrap_key_u64( - {param.level}, {param.baseLog}, {param.glweDimension}, - {inputSk->second.first.dimension}, {polynomialSize}); + bsk = generate_lwe_bootstrap_key_u64( + engine, inputSk->second.second, outputSk->second.second, param.baseLog, + param.level, param.variance, param.glweDimension, polynomialSize); // Store the bootstrap key bootstrapKeys[id] = {param, bsk}; - // Convert the output lwe key to glwe key - GlweSecretKey_u64 *glwe_sk; - - glwe_sk = - allocate_glwe_secret_key_u64({param.glweDimension}, {polynomialSize}); - - fill_glwe_secret_key_with_lwe_secret_key_u64(glwe_sk, - outputSk->second.second); - - // Initialize the bootstrap key - fill_lwe_bootstrap_key_u64(bsk, inputSk->second.second, glwe_sk, generator, - {param.variance}); - free_glwe_secret_key_u64(glwe_sk); return outcome::success(); } outcome::checked -KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param, - EncryptionRandomGenerator *generator) { +KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param) { // Finding input and output secretKeys auto inputSk = secretKeys.find(param.inputSecretKeyID); if (inputSk == secretKeys.end()) { @@ -207,17 +179,13 @@ KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param, // Allocate the keyswitch key LweKeyswitchKey_u64 *ksk; - ksk = allocate_lwe_keyswitch_key_u64({param.level}, {param.baseLog}, - {inputSk->second.first.dimension}, - {outputSk->second.first.dimension}); + ksk = generate_lwe_keyswitch_key_u64(engine, inputSk->second.second, + outputSk->second.second, param.level, + param.baseLog, param.variance); // Store the keyswitch key keyswitchKeys[id] = {param, ksk}; - // Initialize the keyswitch key - fill_lwe_keyswitch_key_u64(ksk, inputSk->second.second, - outputSk->second.second, generator, - {param.variance}); return outcome::success(); } @@ -255,9 +223,8 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) { // Encode - TODO we could check if the input value is in the right range uint64_t plaintext = input << (64 - (std::get<0>(inputSk).encryption->encoding.precision + 1)); - encrypt_lwe_u64(std::get<2>(inputSk), ciphertext, plaintext, - encryptionRandomGenerator, - {std::get<0>(inputSk).encryption->variance}); + ::encrypt_lwe_u64(engine, std::get<2>(inputSk), ciphertext, plaintext, + std::get<0>(inputSk).encryption->variance); return outcome::success(); } @@ -271,7 +238,8 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) { if (!std::get<0>(outputSk).encryption.hasValue()) { return StringError("decrypt_lwe: the positional argument is not encrypted"); } - uint64_t plaintext = decrypt_lwe_u64(std::get<2>(outputSk), ciphertext); + uint64_t plaintext = + ::decrypt_lwe_u64(engine, std::get<2>(outputSk), ciphertext); // Decode size_t precision = std::get<0>(outputSk).encryption->encoding.precision; output = plaintext >> (64 - precision - 2); diff --git a/compiler/lib/ClientLib/PublicArguments.cpp b/compiler/lib/ClientLib/PublicArguments.cpp index 1472c8561..fde5fdf89 100644 --- a/compiler/lib/ClientLib/PublicArguments.cpp +++ b/compiler/lib/ClientLib/PublicArguments.cpp @@ -45,10 +45,9 @@ PublicArguments::~PublicArguments() { if (!clearRuntimeContext) { return; } - for (auto bsk_entry : runtimeContext.bsk) { - free_lwe_bootstrap_key_u64(bsk_entry.second); + if (runtimeContext.bsk != nullptr) { + free_lwe_bootstrap_key_u64(runtimeContext.bsk); } - runtimeContext.bsk.clear(); if (runtimeContext.ksk != nullptr) { free_lwe_keyswitch_key_u64(runtimeContext.ksk); runtimeContext.ksk = nullptr; diff --git a/compiler/lib/ClientLib/Serializers.cpp b/compiler/lib/ClientLib/Serializers.cpp index 15c1709ab..fde990f7a 100644 --- a/compiler/lib/ClientLib/Serializers.cpp +++ b/compiler/lib/ClientLib/Serializers.cpp @@ -99,7 +99,7 @@ std::istream &operator>>(std::istream &istream, ClientParameters ¶ms) { std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext) { istream >> runtimeContext.ksk; - istream >> runtimeContext.bsk[RuntimeContext::BASE_CONTEXT_BSK]; + istream >> runtimeContext.bsk; assert(istream.good()); return istream; } @@ -107,7 +107,7 @@ std::istream &operator>>(std::istream &istream, std::ostream &operator<<(std::ostream &ostream, const RuntimeContext &runtimeContext) { ostream << runtimeContext.ksk; - ostream << runtimeContext.bsk.at(RuntimeContext::BASE_CONTEXT_BSK); + ostream << runtimeContext.bsk; assert(ostream.good()); return ostream; } diff --git a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp index cf8a812a4..34852f948 100644 --- a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp +++ b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp @@ -17,6 +17,7 @@ #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" #include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" +#include "concretelang/Support/Constants.h" namespace { class BConcreteToBConcreteCAPITypeConverter : public mlir::TypeConverter { @@ -72,9 +73,8 @@ inline mlir::Type getGenericLweBufferType(mlir::MLIRContext *context) { return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64)); } -inline mlir::concretelang::Concrete::GlweCiphertextType -getGenericGlweCiphertextType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::GlweCiphertextType::get(context); +inline mlir::Type getGenericGlweBufferType(mlir::MLIRContext *context) { + return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64)); } inline mlir::Type getGenericPlaintextType(mlir::MLIRContext *context) { @@ -114,10 +114,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, auto lweBufferType = getGenericLweBufferType(rewriter.getContext()); auto plaintextType = getGenericPlaintextType(rewriter.getContext()); auto cleartextType = getGenericCleartextType(rewriter.getContext()); - auto glweCiphertextType = getGenericGlweCiphertextType(rewriter.getContext()); - auto plaintextListType = getGenericPlaintextListType(rewriter.getContext()); - auto foreignPlaintextList = - getGenericForeignPlaintextListType(rewriter.getContext()); auto keySwitchKeyType = getGenericLweKeySwitchKeyType(rewriter.getContext()); auto bootstrapKeyType = getGenericLweBootstrapKeyType(rewriter.getContext()); auto contextType = @@ -134,7 +130,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, return mlir::failure(); } } - // Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function + // Insert forward declaration of the add_plaintext_lwe_ciphertext function { auto funcType = mlir::FunctionType::get( rewriter.getContext(), {lweBufferType, lweBufferType, plaintextType}, @@ -145,7 +141,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, return mlir::failure(); } } - // Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function + // Insert forward declaration of the mul_cleartext_lwe_ciphertext function { auto funcType = mlir::FunctionType::get( rewriter.getContext(), {lweBufferType, lweBufferType, cleartextType}, @@ -156,7 +152,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, return mlir::failure(); } } - // Insert forward declaration of the negate_lwe_ciphertext_u64 function + // Insert forward declaration of the negate_lwe_ciphertext function { auto funcType = mlir::FunctionType::get(rewriter.getContext(), {lweBufferType, lweBufferType}, {}); @@ -169,8 +165,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, // Insert forward declaration of the memref_keyswitch_lwe_u64 function { auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {keySwitchKeyType, lweBufferType, lweBufferType}, - {}); + rewriter.getContext(), {lweBufferType, lweBufferType, contextType}, {}); if (insertForwardDeclaration(op, rewriter, "memref_keyswitch_lwe_u64", funcType) .failed()) { @@ -181,40 +176,40 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, { auto funcType = mlir::FunctionType::get( rewriter.getContext(), - {bootstrapKeyType, lweBufferType, lweBufferType, glweCiphertextType}, - {}); + {lweBufferType, lweBufferType, lweBufferType, contextType}, {}); if (insertForwardDeclaration(op, rewriter, "memref_bootstrap_lwe_u64", funcType) .failed()) { return mlir::failure(); } } - // Insert forward declaration of the fill_plaintext_list function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {plaintextListType, foreignPlaintextList}, {}); - if (insertForwardDeclaration( - op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the add_plaintext_list_glwe function + + // Insert forward declaration of the expand_lut_in_trivial_glwe_ct function { auto funcType = mlir::FunctionType::get( rewriter.getContext(), - {glweCiphertextType, glweCiphertextType, plaintextListType}, {}); + { + getGenericGlweBufferType(rewriter.getContext()), + rewriter.getI32Type(), + rewriter.getI32Type(), + rewriter.getI32Type(), + mlir::RankedTensorType::get( + {-1}, mlir::IntegerType::get(rewriter.getContext(), 64)), + }, + {}); if (insertForwardDeclaration( - op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType) + op, rewriter, "memref_expand_lut_in_trivial_glwe_ct_u64", funcType) .failed()) { return mlir::failure(); } } + // Insert forward declaration of the getGlobalKeyswitchKey function { auto funcType = mlir::FunctionType::get(rewriter.getContext(), {contextType}, {keySwitchKeyType}); - if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType) + if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key_u64", + funcType) .failed()) { return mlir::failure(); } @@ -223,7 +218,8 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, { auto funcType = mlir::FunctionType::get(rewriter.getContext(), {contextType}, {bootstrapKeyType}); - if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType) + if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key_u64", + funcType) .failed()) { return mlir::failure(); } @@ -233,15 +229,15 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, // For all operands `tensor` replace with // `%casted = tensor.cast %op : tensor to tensor` -template mlir::SmallVector -getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) { +getCastedTensor(mlir::Location loc, mlir::Operation::operand_range operands, + mlir::PatternRewriter &rewriter) { mlir::SmallVector newOperands{}; - for (mlir::Value operand : op->getOperands()) { + for (mlir::Value operand : operands) { mlir::Type operandType = operand.getType(); if (operandType.isa()) { mlir::Value castedOp = rewriter.create( - op.getLoc(), getGenericLweBufferType(rewriter.getContext()), operand); + loc, getGenericLweBufferType(rewriter.getContext()), operand); newOperands.push_back(castedOp); } else { newOperands.push_back(operand); @@ -250,6 +246,14 @@ getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) { return std::move(newOperands); } +// For all operands `tensor` replace with +// `%casted = tensor.cast %op : tensor to tensor` +template +mlir::SmallVector +getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) { + return getCastedTensor(op->getLoc(), op->getOperands(), rewriter); +} + /// BConcreteOpToConcreteCAPICallPattern match the `BConcreteOp` /// Operation and replace with a call to `funcName`, the funcName should be an /// external function that was linked later. It insert the forward declaration @@ -379,15 +383,12 @@ struct BConcreteKeySwitchLweOpPattern matchAndRewrite(mlir::concretelang::BConcrete::KeySwitchLweBufferOp op, mlir::PatternRewriter &rewriter) const override { - mlir::CallOp kskOp = rewriter.create( - op.getLoc(), "get_keyswitch_key", - getGenericLweKeySwitchKeyType(rewriter.getContext()), - mlir::SmallVector{getContextArgument(op)}); - mlir::SmallVector operands{kskOp.getResult(0)}; - + mlir::SmallVector operands{}; operands.append( getCastedTensorOperands< mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(op, rewriter)); + operands.push_back(getContextArgument(op)); + rewriter.replaceOpWithNewOp(op, "memref_keyswitch_lwe_u64", mlir::TypeRange({}), operands); return mlir::success(); @@ -422,22 +423,83 @@ struct BConcreteBootstrapLweOpPattern mlir::LogicalResult matchAndRewrite(mlir::concretelang::BConcrete::BootstrapLweBufferOp op, mlir::PatternRewriter &rewriter) const override { - - mlir::SmallVector getkskOperands{}; - mlir::CallOp bskOp = rewriter.create( - op.getLoc(), "get_bootstrap_key", - getGenericLweBootstrapKeyType(rewriter.getContext()), - mlir::SmallVector{getContextArgument(op)}); - mlir::SmallVector operands{bskOp.getResult(0)}; + mlir::SmallVector operands{}; operands.append( getCastedTensorOperands< mlir::concretelang::BConcrete::BootstrapLweBufferOp>(op, rewriter)); + operands.push_back(getContextArgument(op)); rewriter.replaceOpWithNewOp(op, "memref_bootstrap_lwe_u64", mlir::TypeRange({}), operands); return mlir::success(); }; }; +// Rewrite pattern that rewrite every +// ``` +// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1, +// polynomialSize=2048, outPrecision=3} : +// (tensor<4096xi64>, tensor<32xi64>) -> () +// ``` +// +// to +// +// ``` +// %glweDim = arith.constant 1 : i32 +// %polySize = arith.constant 2048 : i32 +// %outPrecision = arith.constant 3 : i32 +// %glwe_ = tensor.cast %glwe : tensor<4096xi64> to tensor +// %lut_ = tensor.cast %lut : tensor<32xi64> to tensor +// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim, +// %outPrecision, %lut_) : +// (tensor, i32, i32, tensor) -> () +// ``` +struct BConcreteGlweFromTableOpPattern + : public mlir::OpRewritePattern< + mlir::concretelang::BConcrete::FillGlweFromTable> { + BConcreteGlweFromTableOpPattern( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) + : mlir::OpRewritePattern< + mlir::concretelang::BConcrete::FillGlweFromTable>(context, + benefit) {} + + mlir::LogicalResult + matchAndRewrite(mlir::concretelang::BConcrete::FillGlweFromTable op, + mlir::PatternRewriter &rewriter) const override { + BConcreteToBConcreteCAPITypeConverter typeConverter; + // %glweDim = arith.constant 1 : i32 + // %polySize = arith.constant 2048 : i32 + // %outPrecision = arith.constant 3 : i32 + + auto castedOp = getCastedTensorOperands(op, rewriter); + + auto polySizeOp = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(op.polynomialSize())); + auto glweDimensionOp = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(op.glweDimension())); + auto outPrecisionOp = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(op.outPrecision())); + + mlir::SmallVector newOperands{ + castedOp[0], polySizeOp, glweDimensionOp, outPrecisionOp, castedOp[1]}; + + // getCastedTensor(op.getLoc(), newOperands, rewriter); + // perform operands conversion + // %glwe_ = tensor.cast %glwe : tensor<4096xi64> to tensor + // %lut_ = tensor.cast %lut : tensor<32xi64> to tensor + + // call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim, + // %lut_) : + // (tensor, i32, i32, tensor) -> () + + rewriter.replaceOpWithNewOp( + op, "memref_expand_lut_in_trivial_glwe_ct_u64", + mlir::SmallVector{}, newOperands); + return mlir::success(); + }; +}; + /// Populate the RewritePatternSet with all patterns that rewrite Concrete /// operators to the corresponding function call to the `Concrete C API`. void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) { @@ -455,9 +517,9 @@ void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) { patterns.getContext(), "memref_negate_lwe_ciphertext_u64"); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); - // patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } struct AddRuntimeContextToFuncOpPattern diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index 4ad1f761d..7933368b4 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -3,7 +3,5 @@ add_subdirectory(TFHEGlobalParametrization) add_subdirectory(TFHEToConcrete) add_subdirectory(FHETensorOpsToLinalg) add_subdirectory(ConcreteToBConcrete) -add_subdirectory(ConcreteToConcreteCAPI) add_subdirectory(BConcreteToBConcreteCAPI) add_subdirectory(MLIRLowerableDialectsToLLVM) -add_subdirectory(ConcreteUnparametrize) diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp index 87db50620..071c09bcb 100644 --- a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -37,10 +37,19 @@ public: ConcreteToBConcreteTypeConverter() { addConversion([](mlir::Type type) { return type; }); addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) { + assert(type.getDimension() != -1); return mlir::RankedTensorType::get( {type.getDimension() + 1}, mlir::IntegerType::get(type.getContext(), 64)); }); + addConversion([&](mlir::concretelang::Concrete::GlweCiphertextType type) { + assert(type.getGlweDimension() != -1); + assert(type.getPolynomialSize() != -1); + + return mlir::RankedTensorType::get( + {type.getPolynomialSize() * (type.getGlweDimension() + 1)}, + mlir::IntegerType::get(type.getContext(), 64)); + }); addConversion([&](mlir::RankedTensorType type) { auto lwe = type.getElementType() .dyn_cast_or_null< @@ -48,6 +57,7 @@ public: if (lwe == nullptr) { return (mlir::Type)(type); } + assert(lwe.getDimension() != -1); mlir::SmallVector newShape; newShape.reserve(type.getShape().size() + 1); newShape.append(type.getShape().begin(), type.getShape().end()); @@ -63,6 +73,7 @@ public: if (lwe == nullptr) { return (mlir::Type)(type); } + assert(lwe.getDimension() != -1); mlir::SmallVector newShape; newShape.reserve(type.getShape().size() + 1); newShape.append(type.getShape().begin(), type.getShape().end()); @@ -177,6 +188,65 @@ struct LowToBConcrete : public mlir::OpRewritePattern { }; }; +// This rewrite pattern transforms any instance of +// `Concrete.glwe_from_table` operators. +// +// Example: +// +// ```mlir +// %0 = "Concrete.glwe_from_table"(%tlu) +// : (tensor<$Dxi64>) -> +// !Concrete.glwe_ciphertext<$polynomialSize,$glweDimension,$p> +// ``` +// +// with $D = 2^$p +// +// becomes: +// +// ```mlir +// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)] +// : tensor +// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, %tlu) +// : tensor, i64, i64, tensor<$Dxi64> +// ``` +struct GlweFromTablePattern : public mlir::OpRewritePattern< + mlir::concretelang::Concrete::GlweFromTable> { + GlweFromTablePattern(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern( + context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + auto resultTy = + op.result() + .getType() + .cast(); + + auto newResultTy = + converter.convertType(resultTy).cast(); + // %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)] + // : tensor + mlir::Value init = rewriter.replaceOpWithNewOp( + op, newResultTy.getShape(), newResultTy.getElementType()); + + // "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, + // %tlu) + + // polynomialSize*(glweDimension+1) + auto polySize = resultTy.getPolynomialSize(); + auto glweDimension = resultTy.getGlweDimension(); + auto outPrecision = resultTy.getP(); + + rewriter.create( + op.getLoc(), init, polySize, glweDimension, outPrecision, op.table()); + + return ::mlir::success(); + }; +}; + // This rewrite pattern transforms any instance of // `tensor.extract_slice` operators that operates on tensor of lwe ciphertext. // @@ -827,7 +897,6 @@ void ConcreteToBConcretePass::runOnOperation() { // ciphertexts) target.addIllegalDialect(); target.addLegalOp(); - target.addLegalOp(); target.addLegalOp(); // Add patterns to convert the zero ops to tensor.generate @@ -860,7 +929,10 @@ void ConcreteToBConcretePass::runOnOperation() { mlir::concretelang::BConcrete::BootstrapLweBufferOp>>( &getContext()); - // Add patterns to rewrite tensor operators that works on encrypted tensors + patterns.insert(&getContext()); + + // Add patterns to rewrite tensor operators that works on encrypted + // tensors patterns.insert(&getContext()); target.addDynamicallyLegalOp< diff --git a/compiler/lib/Conversion/ConcreteToConcreteCAPI/CMakeLists.txt b/compiler/lib/Conversion/ConcreteToConcreteCAPI/CMakeLists.txt deleted file mode 100644 index 5ab3fa1da..000000000 --- a/compiler/lib/Conversion/ConcreteToConcreteCAPI/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_mlir_dialect_library(ConcreteToConcreteCAPI - ConcreteToConcreteCAPI.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE - - DEPENDS - ConcreteDialect - ConcretelangConversionPassIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRTransforms -) - -target_link_libraries(ConcreteToConcreteCAPI PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/ConcreteToConcreteCAPI/ConcreteToConcreteCAPI.cpp b/compiler/lib/Conversion/ConcreteToConcreteCAPI/ConcreteToConcreteCAPI.cpp deleted file mode 100644 index ba242bcf1..000000000 --- a/compiler/lib/Conversion/ConcreteToConcreteCAPI/ConcreteToConcreteCAPI.cpp +++ /dev/null @@ -1,859 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt -// for license information. - -#include "mlir//IR/BuiltinTypes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "concretelang/Conversion/Passes.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" -#include "concretelang/Support/Constants.h" - -class ConcreteToConcreteCAPITypeConverter : public mlir::TypeConverter { - -public: - ConcreteToConcreteCAPITypeConverter() { - addConversion([](mlir::Type type) { return type; }); - addConversion([&](mlir::concretelang::Concrete::PlaintextType type) { - return mlir::IntegerType::get(type.getContext(), 64); - }); - addConversion([&](mlir::concretelang::Concrete::CleartextType type) { - return mlir::IntegerType::get(type.getContext(), 64); - }); - } -}; - -mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op, - mlir::RewriterBase &rewriter, - llvm::StringRef funcName, - mlir::FunctionType funcType) { - // Looking for the `funcName` Operation - auto module = mlir::SymbolTable::getNearestSymbolTable(op); - auto opFunc = mlir::dyn_cast_or_null( - mlir::SymbolTable::lookupSymbolIn(module, funcName)); - if (!opFunc) { - // Insert the forward declaration of the funcName - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&module->getRegion(0).front()); - - opFunc = rewriter.create(rewriter.getUnknownLoc(), funcName, - funcType); - opFunc.setPrivate(); - } else { - // Check if the `funcName` is well a private function - if (!opFunc.isPrivate()) { - op->emitError() << "the function \"" << funcName - << "\" conflicts with the concrete C API, please rename"; - return mlir::failure(); - } - } - assert(mlir::SymbolTable::lookupSymbolIn(module, funcName) - ->template hasTrait()); - return mlir::success(); -} - -// Set of functions to generate generic types. -// Generic types are used to add forward declarations without a specific type. -// For example, we may need to add LWE ciphertext of different dimensions, or -// allocate them. All the calls to the C API should be done using this generic -// types, and casting should then be performed back to the appropriate type. - -inline mlir::concretelang::Concrete::LweCiphertextType -getGenericLweCiphertextType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::LweCiphertextType::get(context, -1, -1); -} - -inline mlir::concretelang::Concrete::GlweCiphertextType -getGenericGlweCiphertextType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::GlweCiphertextType::get(context); -} - -inline mlir::concretelang::Concrete::PlaintextType -getGenericPlaintextType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::PlaintextType::get(context, -1); -} - -inline mlir::concretelang::Concrete::PlaintextListType -getGenericPlaintextListType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::PlaintextListType::get(context); -} - -inline mlir::concretelang::Concrete::ForeignPlaintextListType -getGenericForeignPlaintextListType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::ForeignPlaintextListType::get(context); -} - -inline mlir::concretelang::Concrete::CleartextType -getGenericCleartextType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::CleartextType::get(context, -1); -} - -inline mlir::concretelang::Concrete::LweBootstrapKeyType -getGenericLweBootstrapKeyType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::LweBootstrapKeyType::get(context); -} - -inline mlir::concretelang::Concrete::LweKeySwitchKeyType -getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) { - return mlir::concretelang::Concrete::LweKeySwitchKeyType::get(context); -} - -// Get the generic version of the type. -// Useful when iterating over a set of types. -mlir::Type getGenericType(mlir::Type baseType) { - if (baseType.isa()) { - return getGenericLweCiphertextType(baseType.getContext()); - } - if (baseType.isa()) { - return getGenericPlaintextType(baseType.getContext()); - } - if (baseType.isa()) { - return getGenericCleartextType(baseType.getContext()); - } - return baseType; -} - -// Insert all forward declarations needed for the pass. -// Should generalize input and output types for all decalarations, and the -// pattern using them would be resposible for casting them to the appropriate -// type. -mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, - mlir::IRRewriter &rewriter) { - auto genericLweCiphertextType = - getGenericLweCiphertextType(rewriter.getContext()); - auto genericGlweCiphertextType = - getGenericGlweCiphertextType(rewriter.getContext()); - auto genericPlaintextType = getGenericPlaintextType(rewriter.getContext()); - auto genericPlaintextListType = - getGenericPlaintextListType(rewriter.getContext()); - auto genericForeignPlaintextList = - getGenericForeignPlaintextListType(rewriter.getContext()); - auto genericCleartextType = getGenericCleartextType(rewriter.getContext()); - auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext()); - auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext()); - auto contextType = - mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); - - // Insert forward declaration of allocate lwe ciphertext - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - { - rewriter.getIndexType(), - }, - - {genericLweCiphertextType}); - if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the add_lwe_ciphertexts function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - { - genericLweCiphertextType, - genericLweCiphertextType, - genericLweCiphertextType, - }, - {}); - if (insertForwardDeclaration(op, rewriter, "add_lwe_ciphertexts_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - { - genericLweCiphertextType, - genericLweCiphertextType, - genericPlaintextType, - }, - {}); - if (insertForwardDeclaration(op, rewriter, - "add_plaintext_lwe_ciphertext_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - { - genericLweCiphertextType, - genericLweCiphertextType, - genericCleartextType, - }, - {}); - if (insertForwardDeclaration(op, rewriter, - "mul_cleartext_lwe_ciphertext_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the negate_lwe_ciphertext_u64 function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - {genericLweCiphertextType, genericLweCiphertextType}, {}); - if (insertForwardDeclaration(op, rewriter, "negate_lwe_ciphertext_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the getBsk function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - {contextType}, {genericBSKType}); - if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the bootstrap function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - { - genericBSKType, - genericLweCiphertextType, - genericLweCiphertextType, - genericGlweCiphertextType, - }, - {}); - if (insertForwardDeclaration(op, rewriter, "bootstrap_lwe_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the getKsk function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - {contextType}, {genericKSKType}); - if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the keyswitch function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - { - // ksk - genericKSKType, - // output ct - genericLweCiphertextType, - // input ct - genericLweCiphertextType, - }, - {}); - if (insertForwardDeclaration(op, rewriter, "keyswitch_lwe_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the alloc_glwe function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - { - rewriter.getI32Type(), - rewriter.getI32Type(), - }, - {genericGlweCiphertextType}); - if (insertForwardDeclaration(op, rewriter, "allocate_glwe_ciphertext_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the alloc_plaintext_list function - { - auto funcType = - mlir::FunctionType::get(rewriter.getContext(), {rewriter.getI32Type()}, - {genericPlaintextListType}); - if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the fill_plaintext_list function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - {genericPlaintextListType, genericForeignPlaintextList}, {}); - if (insertForwardDeclaration( - op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the add_plaintext_list_glwe function - { - auto funcType = mlir::FunctionType::get(rewriter.getContext(), - {genericGlweCiphertextType, - genericGlweCiphertextType, - genericPlaintextListType}, - {}); - if (insertForwardDeclaration( - op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - return mlir::success(); -} - -/// ConcreteOpToConcreteCAPICallPattern match the `Op` Operation and -/// replace with a call to `funcName`, the funcName should be an external -/// function that was linked later. It insert the forward declaration of the -/// private `funcName` if it not already in the symbol table. -/// The C signature of the function should be `void funcName(int *err, out, -/// arg0, arg1)`, the pattern rewrite: -/// ``` -/// out = op(arg0, arg1) -/// ``` -/// to -/// ``` -/// err = arith.constant 0 : i64 -/// call_op(err, out, arg0, arg1); -/// ``` -template -struct ConcreteOpToConcreteCAPICallPattern : public mlir::OpRewritePattern { - ConcreteOpToConcreteCAPICallPattern( - mlir::MLIRContext *context, mlir::StringRef funcName, - mlir::StringRef allocName, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : mlir::OpRewritePattern(context, benefit), funcName(funcName), - allocName(allocName) {} - - mlir::LogicalResult - matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { - ConcreteToConcreteCAPITypeConverter typeConverter; - - mlir::Type resultType = op->getResultTypes().front(); - auto lweResultType = - resultType.cast(); - // Replace the operation with a call to the `funcName` - { - // Get the size from the dimension - int64_t lweDimension = lweResultType.getDimension(); - - mlir::Value lweDimensionOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(lweDimension)); - // Add the call to the allocation - mlir::SmallVector allocOperands{lweDimensionOp}; - auto allocGeneric = rewriter.create( - op.getLoc(), allocName, - getGenericLweCiphertextType(rewriter.getContext()), allocOperands); - // Construct operands for the operation. - // errOp doesn't need to be casted to something generic, allocGeneric - // already is. All the rest will be converted if needed - mlir::SmallVector newOperands{allocGeneric.getResult(0)}; - for (mlir::Value operand : op->getOperands()) { - mlir::Type operandType = operand.getType(); - mlir::Type castedType = getGenericType(operandType); - if (castedType == operandType) { - // Type didn't change, no need for cast - newOperands.push_back(operand); - } else { - // Type changed, need to cast to the generic one - auto castedOperand = rewriter - .create( - op.getLoc(), castedType, operand) - .getResult(0); - newOperands.push_back(castedOperand); - } - } - // The operations called here are known to be inplace, and no need for a - // return type. - rewriter.create(op.getLoc(), funcName, mlir::TypeRange{}, - newOperands); - // cast result value to the appropriate type - rewriter.replaceOpWithNewOp( - op, op.getType(), allocGeneric.getResult(0)); - } - return mlir::success(); - }; - -private: - std::string funcName; - std::string allocName; -}; - -struct ConcreteZeroOpPattern - : public mlir::OpRewritePattern { - ConcreteZeroOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : mlir::OpRewritePattern( - context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::ZeroLWEOp op, - mlir::PatternRewriter &rewriter) const override { - - mlir::Type resultType = op->getResultTypes().front(); - auto lweResultType = - resultType.cast(); - // Get the size from the dimension - int64_t lweDimension = lweResultType.getDimension(); - - mlir::Value lweDimensionOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(lweDimension)); - // Allocate a fresh new ciphertext - mlir::SmallVector allocOperands{lweDimensionOp}; - auto allocGeneric = rewriter.create( - op.getLoc(), "allocate_lwe_ciphertext_u64", - getGenericLweCiphertextType(rewriter.getContext()), allocOperands); - // Cast the result to the appropriate type - rewriter.replaceOpWithNewOp( - op, op.getType(), allocGeneric.getResult(0)); - - return mlir::success(); - }; -}; - -struct ConcreteEncodeIntOpPattern - : public mlir::OpRewritePattern { - ConcreteEncodeIntOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : mlir::OpRewritePattern( - context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op, - mlir::PatternRewriter &rewriter) const override { - { - mlir::Value castedInt = rewriter.create( - op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front()); - mlir::Value constantShiftOp = rewriter.create( - op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP())); - - mlir::Type resultType = rewriter.getIntegerType(64); - rewriter.replaceOpWithNewOp( - op, resultType, castedInt, constantShiftOp); - } - return mlir::success(); - }; -}; - -struct ConcreteIntToCleartextOpPattern - : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::IntToCleartextOp> { - ConcreteIntToCleartextOpPattern( - mlir::MLIRContext *context, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : mlir::OpRewritePattern( - context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op, - mlir::PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, rewriter.getIntegerType(64), op->getOperands().front()); - return mlir::success(); - }; -}; - -// Rewrite the GlweFromTable operation to a series of ops: -// - allocation of two GLWE, one for the addition, and one for storing the -// result -// - allocation of plaintext_list to build the GLWE accumulator -// - build the foreign_plaintext_list using the input table -// - fill the plaintext_list with the foreign_plaintext_list -// - construct the GLWE accumulator by adding the plaintext_list to a freshly -// allocated GLWE -struct GlweFromTableOpPattern - : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::GlweFromTable> { - GlweFromTableOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : mlir::OpRewritePattern( - context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op, - mlir::PatternRewriter &rewriter) const override { - ConcreteToConcreteCAPITypeConverter typeConverter; - - // TODO: move this to insertForwardDeclarations - // issue: can't define function with tensor<*xtype> that accept ranked - // tensors - - // Insert forward declaration of the foregin_pt_list function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - {op->getOperandTypes().front(), rewriter.getI32Type()}, - {getGenericForeignPlaintextListType(rewriter.getContext())}); - if (insertForwardDeclaration(op, rewriter, - "memref_runtime_foreign_plaintext_list_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - - // allocate two glwe to build accumulator - auto polySizeOp = rewriter.create( - op.getLoc(), op->getAttr("polynomialSize")); - auto glweDimensionOp = rewriter.create( - op.getLoc(), op->getAttr("glweDimension")); - mlir::SmallVector allocGlweOperands{glweDimensionOp, - polySizeOp}; - // first accumulator would replace the op since it's the returned value - auto accumulatorOp = rewriter.replaceOpWithNewOp( - op, "allocate_glwe_ciphertext_u64", - getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands); - // second accumulator is just needed to build the actual accumulator - auto _accumulatorOp = rewriter.create( - op.getLoc(), "allocate_glwe_ciphertext_u64", - getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands); - // allocate plaintext list - mlir::SmallVector allocPlaintextListOperands{polySizeOp}; - auto plaintextListOp = rewriter.create( - op.getLoc(), "allocate_plaintext_list_u64", - getGenericPlaintextListType(rewriter.getContext()), - allocPlaintextListOperands); - // create foreign plaintext - auto rankedTensorType = - op->getOperandTypes().front().cast(); - assert(rankedTensorType.getRank() == 1 && - "table lookup must be of a single dimension"); - auto precisionOp = - rewriter.create(op.getLoc(), op->getAttr("p")); - mlir::SmallVector ForeignPlaintextListOperands{ - op->getOperand(0), precisionOp}; - auto foreignPlaintextListOp = rewriter.create( - op.getLoc(), "memref_runtime_foreign_plaintext_list_u64", - getGenericForeignPlaintextListType(rewriter.getContext()), - ForeignPlaintextListOperands); - // fill plaintext list - mlir::SmallVector FillPlaintextListOperands{ - plaintextListOp.getResult(0), foreignPlaintextListOp.getResult(0)}; - rewriter.create( - op.getLoc(), "fill_plaintext_list_with_expansion_u64", - mlir::TypeRange({}), FillPlaintextListOperands); - // add plaintext list and glwe to build final accumulator for pbs - mlir::SmallVector AddPlaintextListGlweOperands{ - accumulatorOp.getResult(0), _accumulatorOp.getResult(0), - plaintextListOp.getResult(0)}; - rewriter.create( - op.getLoc(), "add_plaintext_list_glwe_ciphertext_u64", - mlir::TypeRange({}), AddPlaintextListGlweOperands); - return mlir::success(); - }; -}; - -mlir::Value getContextArgument(mlir::Operation *op) { - mlir::Block *block = op->getBlock(); - while (block != nullptr) { - if (llvm::isa(block->getParentOp())) { - - mlir::Value context = block->getArguments().back(); - - assert( - context.getType().isa() && - "the Concrete.context should be the last argument of the enclosing " - "function of the op"); - - return context; - } - block = block->getParentOp()->getBlock(); - } - assert("can't find a function that enclose the op"); - return nullptr; -} - -// Rewrite a BootstrapLweOp with a series of ops: -// - allocate the result LWE ciphertext -// - get the global bootstrapping key -// - use the key and the input accumulator (GLWE) to bootstrap the input -// ciphertext -struct ConcreteBootstrapLweOpPattern - : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::BootstrapLweOp> { - ConcreteBootstrapLweOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : mlir::OpRewritePattern( - context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp op, - mlir::PatternRewriter &rewriter) const override { - auto resultType = op->getResultTypes().front(); - // Get the size from the dimension - int64_t outputLweDimension = - resultType.cast() - .getDimension(); - mlir::Value lweSizeOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(outputLweDimension)); - // allocate the result lwe ciphertext, should be of a generic type, to cast - // before return - mlir::SmallVector allocLweCtOperands{lweSizeOp}; - auto allocateGenericLweCtOp = rewriter.create( - op.getLoc(), "allocate_lwe_ciphertext_u64", - getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands); - // get bsk - auto getBskOp = rewriter.create( - op.getLoc(), "get_bootstrap_key", - getGenericLweBootstrapKeyType(rewriter.getContext()), - mlir::SmallVector{getContextArgument(op)}); - // bootstrap - // cast input ciphertext to a generic type - mlir::Value lweToBootstrap = - rewriter - .create( - op.getLoc(), getGenericType(op.getOperand(0).getType()), - op.getOperand(0)) - .getResult(0); - // cast input accumulator to a generic type - mlir::Value accumulator = - rewriter - .create( - op.getLoc(), getGenericType(op.getOperand(1).getType()), - op.getOperand(1)) - .getResult(0); - mlir::SmallVector bootstrapOperands{ - getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0), - lweToBootstrap, accumulator}; - rewriter.create(op.getLoc(), "bootstrap_lwe_u64", - mlir::TypeRange({}), bootstrapOperands); - // Cast result to the appropriate type - rewriter.replaceOpWithNewOp( - op, resultType, allocateGenericLweCtOp.getResult(0)); - - return mlir::success(); - }; -}; - -// Rewrite a KeySwitchLweOp with a series of ops: -// - allocate the result LWE ciphertext -// - get the global keyswitch key -// - use the key to keyswitch the input ciphertext -struct ConcreteKeySwitchLweOpPattern - : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::KeySwitchLweOp> { - ConcreteKeySwitchLweOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : mlir::OpRewritePattern( - context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp op, - mlir::PatternRewriter &rewriter) const override { - // Get the size from the dimension - int64_t lweDimension = - op.getResult() - .getType() - .cast() - .getDimension(); - mlir::Value lweDimensionOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(lweDimension)); - // allocate the result lwe ciphertext, should be of a generic type, to cast - // before return - mlir::SmallVector allocLweCtOperands{lweDimensionOp}; - auto allocateGenericLweCtOp = rewriter.create( - op.getLoc(), "allocate_lwe_ciphertext_u64", - getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands); - // get ksk - auto getKskOp = rewriter.create( - op.getLoc(), "get_keyswitch_key", - getGenericLweKeySwitchKeyType(rewriter.getContext()), - mlir::SmallVector{getContextArgument(op)}); - // keyswitch - // cast input ciphertext to a generic type - mlir::Value lweToKeyswitch = - rewriter - .create( - op.getLoc(), getGenericType(op.getOperand().getType()), - op.getOperand()) - .getResult(0); - mlir::SmallVector keyswitchOperands{ - getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0), - lweToKeyswitch}; - rewriter.create(op.getLoc(), "keyswitch_lwe_u64", - mlir::TypeRange({}), keyswitchOperands); - // Cast result to the appropriate type - auto lweOutputType = op->getResultTypes().front(); - rewriter.replaceOpWithNewOp( - op, lweOutputType, allocateGenericLweCtOp.getResult(0)); - return mlir::success(); - }; -}; - -/// Populate the RewritePatternSet with all patterns that rewrite Concrete -/// operators to the corresponding function call to the `Concrete C API`. -void populateConcreteToConcreteCAPICall(mlir::RewritePatternSet &patterns) { - patterns.add>( - patterns.getContext(), "add_lwe_ciphertexts_u64", - "allocate_lwe_ciphertext_u64"); - patterns.add>( - patterns.getContext(), "add_plaintext_lwe_ciphertext_u64", - "allocate_lwe_ciphertext_u64"); - patterns.add>( - patterns.getContext(), "mul_cleartext_lwe_ciphertext_u64", - "allocate_lwe_ciphertext_u64"); - patterns.add>( - patterns.getContext(), "negate_lwe_ciphertext_u64", - "allocate_lwe_ciphertext_u64"); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); -} - -struct AddRuntimeContextToFuncOpPattern - : public mlir::OpRewritePattern { - AddRuntimeContextToFuncOpPattern( - mlir::MLIRContext *context, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : mlir::OpRewritePattern(context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::FuncOp oldFuncOp, - mlir::PatternRewriter &rewriter) const override { - mlir::OpBuilder::InsertionGuard guard(rewriter); - mlir::FunctionType oldFuncType = oldFuncOp.getType(); - - // Add a Concrete.context to the function signature - mlir::SmallVector newInputs(oldFuncType.getInputs().begin(), - oldFuncType.getInputs().end()); - newInputs.push_back( - rewriter.getType()); - mlir::FunctionType newFuncTy = rewriter.getType( - newInputs, oldFuncType.getResults()); - // Create the new func - mlir::FuncOp newFuncOp = rewriter.create( - oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy); - - // Create the arguments of the new func - mlir::Region &newFuncBody = newFuncOp.body(); - mlir::Block *newFuncEntryBlock = new mlir::Block(); - newFuncEntryBlock->addArguments(newFuncTy.getInputs()); - newFuncBody.push_back(newFuncEntryBlock); - - // Clone the old body to the new one - mlir::BlockAndValueMapping map; - for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) { - map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index())); - } - for (auto &op : oldFuncOp.body().front()) { - newFuncEntryBlock->push_back(op.clone(map)); - } - rewriter.eraseOp(oldFuncOp); - return mlir::success(); - } - - // Legal function are one that are private or has a Concrete.context as last - // arguments. - static bool isLegal(mlir::FuncOp funcOp) { - if (!funcOp.isPublic()) { - return true; - } - // TODO : Don't need to add a runtime context for function that doesn't - // manipulates concrete types. - // - // if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) { - // if (auto tensorTy = t.dyn_cast_or_null()) { - // t = tensorTy.getElementType(); - // } - // return llvm::isa( - // t.getDialect()); - // })) { - // return true; - // } - return funcOp.getType().getNumInputs() >= 1 && - funcOp.getType() - .getInputs() - .back() - .isa(); - } -}; - -namespace { -struct ConcreteToConcreteCAPIPass - : public ConcreteToConcreteCAPIBase { - void runOnOperation() final; -}; -} // namespace - -void ConcreteToConcreteCAPIPass::runOnOperation() { - mlir::ModuleOp op = getOperation(); - - // First of all add the Concrete.context to the block arguments of function - // that manipulates ciphertexts. - { - mlir::ConversionTarget target(getContext()); - mlir::RewritePatternSet patterns(&getContext()); - - target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { - return AddRuntimeContextToFuncOpPattern::isLegal(funcOp); - }); - - patterns.add(patterns.getContext()); - - // Apply the conversion - if (mlir::applyPartialConversion(op, target, std::move(patterns)) - .failed()) { - this->signalPassFailure(); - return; - } - } - - // Insert forward declaration - mlir::IRRewriter rewriter(&getContext()); - if (insertForwardDeclarations(op, rewriter).failed()) { - this->signalPassFailure(); - } - // Rewrite Concrete ops to CallOp to the Concrete C API - { - mlir::ConversionTarget target(getContext()); - mlir::RewritePatternSet patterns(&getContext()); - - target.addIllegalDialect(); - target.addLegalDialect(); - - populateConcreteToConcreteCAPICall(patterns); - - if (mlir::applyPartialConversion(op, target, std::move(patterns)) - .failed()) { - this->signalPassFailure(); - } - } -} - -namespace mlir { -namespace concretelang { -std::unique_ptr> -createConvertConcreteToConcreteCAPIPass() { - return std::make_unique(); -} -} // namespace concretelang -} // namespace mlir diff --git a/compiler/lib/Conversion/ConcreteUnparametrize/CMakeLists.txt b/compiler/lib/Conversion/ConcreteUnparametrize/CMakeLists.txt deleted file mode 100644 index 8da969120..000000000 --- a/compiler/lib/Conversion/ConcreteUnparametrize/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_mlir_dialect_library(ConcreteUnparametrize - ConcreteUnparametrize.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE - - DEPENDS - ConcreteDialect - ConcretelangConversionPassIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRTransforms -) - -target_link_libraries(ConcreteUnparametrize PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/ConcreteUnparametrize/ConcreteUnparametrize.cpp b/compiler/lib/Conversion/ConcreteUnparametrize/ConcreteUnparametrize.cpp deleted file mode 100644 index cacc79b53..000000000 --- a/compiler/lib/Conversion/ConcreteUnparametrize/ConcreteUnparametrize.cpp +++ /dev/null @@ -1,154 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt -// for license information. - -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "concretelang/Conversion/Passes.h" -#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" -#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" -#include "concretelang/Dialect/RT/IR/RTOps.h" -#include "concretelang/Support/Constants.h" - -/// ConcreteUnparametrizeTypeConverter is a type converter that unparametrize -/// Concrete types -class ConcreteUnparametrizeTypeConverter : public mlir::TypeConverter { - -public: - static mlir::Type unparematrizeConcreteType(mlir::Type type) { - if (type.isa()) { - return mlir::IntegerType::get(type.getContext(), 64); - } - if (type.isa()) { - return mlir::IntegerType::get(type.getContext(), 64); - } - if (type.isa()) { - return mlir::concretelang::Concrete::LweCiphertextType::get( - type.getContext(), -1, -1); - } - auto tensorType = type.dyn_cast_or_null(); - if (tensorType != nullptr) { - auto eltTy0 = tensorType.getElementType(); - auto eltTy1 = unparematrizeConcreteType(eltTy0); - if (eltTy0 == eltTy1) { - return type; - } - return mlir::RankedTensorType::get(tensorType.getShape(), eltTy1); - } - return type; - } - - ConcreteUnparametrizeTypeConverter() { - addConversion( - [](mlir::Type type) { return unparematrizeConcreteType(type); }); - } -}; - -/// Replace `%1 = unrealized_conversion_cast %0 : t0 to t1` to `%0` where t0 or -/// t1 are a Concrete type. -struct ConcreteUnrealizedCastReplacementPattern - : public mlir::OpRewritePattern { - ConcreteUnrealizedCastReplacementPattern( - mlir::MLIRContext *context, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : mlir::OpRewritePattern(context, - benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::UnrealizedConversionCastOp op, - mlir::PatternRewriter &rewriter) const override { - if (mlir::isa( - op.getOperandTypes()[0].getDialect()) || - mlir::isa( - op.getType(0).getDialect())) { - rewriter.replaceOp(op, op.getOperands()); - return mlir::success(); - } - return mlir::failure(); - }; -}; - -/// ConcreteUnparametrizePass remove all parameters of Concrete types and remove -/// the unrealized_conversion_cast operation that operates on parametrized -/// Concrete types. -struct ConcreteUnparametrizePass - : public ConcreteUnparametrizeBase { - void runOnOperation() final; -}; - -void ConcreteUnparametrizePass::runOnOperation() { - auto op = this->getOperation(); - - mlir::ConversionTarget target(getContext()); - mlir::OwningRewritePatternList patterns(&getContext()); - - ConcreteUnparametrizeTypeConverter converter; - - // Conversion of linalg.generic operation - target - .addDynamicallyLegalOp( - [&](mlir::Operation *op) { - return ( - converter.isLegal(op->getOperandTypes()) && - converter.isLegal(op->getResultTypes()) && - converter.isLegal(op->getRegion(0).front().getArgumentTypes())); - }); - patterns.add>( - &getContext(), converter); - patterns.add>( - &getContext(), converter); - patterns.add>(&getContext(), - converter); - - // Conversion of function signature and arguments - target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { - return converter.isSignatureLegal(funcOp.getType()) && - converter.isLegal(&funcOp.getBody()); - }); - mlir::populateFuncOpTypeConversionPattern(patterns, converter); - - // Replacement of unrealized_conversion_cast - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::UnrealizedConversionCastOp>(target, converter); - patterns.add(patterns.getContext()); - - // Conversion of tensor operators - mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target, - converter); - - // Conversion of CallOp - patterns.add>( - patterns.getContext(), converter); - mlir::concretelang::addDynamicallyLegalTypeOp(target, - converter); - - // Conversion of RT Dialect Ops - patterns.add>(patterns.getContext(), - converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowTaskOp>(target, converter); - - // Apply conversion - if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { - this->signalPassFailure(); - } -} - -namespace mlir { -namespace concretelang { -std::unique_ptr> -createConvertConcreteUnparametrizePass() { - return std::make_unique(); -} -} // namespace concretelang -} // namespace mlir diff --git a/compiler/lib/Runtime/CMakeLists.txt b/compiler/lib/Runtime/CMakeLists.txt index 2a86ab8aa..705336c42 100644 --- a/compiler/lib/Runtime/CMakeLists.txt +++ b/compiler/lib/Runtime/CMakeLists.txt @@ -4,7 +4,7 @@ endif() add_library(ConcretelangRuntime SHARED context.cpp - wrappers.c + wrappers.cpp ) if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED) diff --git a/compiler/lib/Runtime/context.cpp b/compiler/lib/Runtime/context.cpp index bab957771..e51df477b 100644 --- a/compiler/lib/Runtime/context.cpp +++ b/compiler/lib/Runtime/context.cpp @@ -11,39 +11,32 @@ #include #endif -namespace mlir { -namespace concretelang { - -std::string RuntimeContext::BASE_CONTEXT_BSK = "_concretelang_base_context_bsk"; - -} // namespace concretelang -} // namespace mlir - LweKeyswitchKey_u64 * -get_keyswitch_key(mlir::concretelang::RuntimeContext *context) { +get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) { return context->ksk; } LweBootstrapKey_u64 * -get_bootstrap_key(mlir::concretelang::RuntimeContext *context) { - using RuntimeContext = mlir::concretelang::RuntimeContext; +get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) { + return context->bsk; +} + +// Instantiate one engine per thread on demand +Engine *get_engine(mlir::concretelang::RuntimeContext *context) { #ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED std::string threadName = hpx::get_thread_name(); - auto bskIt = context->bsk.find(threadName); - if (bskIt == context->bsk.end()) { - assert((bskIt = context->bsk.find(RuntimeContext::BASE_CONTEXT_BSK)) != - context->bsk.end() && - bskIt->second && "No BASE_CONTEXT_BSK registered in context."); - bskIt = context->bsk - .insert(std::pair( - threadName, - clone_lwe_bootstrap_key_u64( - context->bsk[RuntimeContext::BASE_CONTEXT_BSK]))) - .first; + std::lock_guard guard(context->engines_map_guard); + auto engineIt = context->engines.find(threadName); + if (engineIt == context->engines.end()) { + engineIt = + context->engines + .insert(std::pair(threadName, new_engine())) + .first; } + assert(engineIt->second && "No engine available in context"); + return engineIt->second; #else - auto bskIt = context->bsk.find(RuntimeContext::BASE_CONTEXT_BSK); + return (context->engine == nullptr) ? context->engine = new_engine() + : context->engine; #endif - assert(bskIt->second && "No bootstrap key available in context"); - return bskIt->second; } diff --git a/compiler/lib/Runtime/wrappers.c b/compiler/lib/Runtime/wrappers.cpp similarity index 55% rename from compiler/lib/Runtime/wrappers.c rename to compiler/lib/Runtime/wrappers.cpp index 6e2cb9afb..8944f59ef 100644 --- a/compiler/lib/Runtime/wrappers.c +++ b/compiler/lib/Runtime/wrappers.cpp @@ -1,21 +1,29 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + #include "concretelang/Runtime/wrappers.h" #include #include -struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64( - uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size, - uint64_t stride, uint32_t precision) { +void memref_expand_lut_in_trivial_glwe_ct_u64( + uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, + uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision, + uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset, + uint64_t lut_size, uint64_t lut_stride) { - assert(stride == 1 && "Runtime: stride not equal to 1, check " - "runtime_foreign_plaintext_list_u64"); + assert(lut_stride == 1 && "Runtime: stride not equal to 1, check " + "memref_expand_lut_in_trivial_glwe_ct_u64"); - // Encode table values in u64 - uint64_t *encoded_table = malloc(size * sizeof(uint64_t)); - for (uint64_t i = 0; i < size; i++) { - encoded_table[i] = (aligned + offset)[i] << (64 - precision - 1); - } - return foreign_plaintext_list_u64(encoded_table, size); - // TODO: is it safe to free after creating plaintext_list? + assert(glwe_ct_stride == 1 && "Runtime: stride not equal to 1, check " + "memref_expand_lut_in_trivial_glwe_ct_u64"); + + expand_lut_in_trivial_glwe_ct_u64(glwe_ct_aligned, poly_size, glwe_dimension, + out_precision, lut_aligned, lut_size); + + return; } void memref_add_lwe_ciphertexts_u64( @@ -26,7 +34,7 @@ void memref_add_lwe_ciphertexts_u64( uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride) { assert(out_size == ct0_size && out_size == ct1_size && "size of lwe buffer are incompatible"); - LweDimension lwe_dimension = {out_size - 1}; + size_t lwe_dimension = {out_size - 1}; add_two_lwe_ciphertexts_u64(out_aligned + out_offset, ct0_aligned + ct0_offset, ct1_aligned + ct1_offset, lwe_dimension); @@ -38,7 +46,7 @@ void memref_add_plaintext_lwe_ciphertext_u64( uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, uint64_t ct0_stride, uint64_t plaintext) { assert(out_size == ct0_size && "size of lwe buffer are incompatible"); - LweDimension lwe_dimension = {out_size - 1}; + size_t lwe_dimension = {out_size - 1}; add_plaintext_to_lwe_ciphertext_u64(out_aligned + out_offset, ct0_aligned + ct0_offset, plaintext, lwe_dimension); @@ -50,7 +58,7 @@ void memref_mul_cleartext_lwe_ciphertext_u64( uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, uint64_t ct0_stride, uint64_t cleartext) { assert(out_size == ct0_size && "size of lwe buffer are incompatible"); - LweDimension lwe_dimension = {out_size - 1}; + size_t lwe_dimension = {out_size - 1}; mul_cleartext_lwe_ciphertext_u64(out_aligned + out_offset, ct0_aligned + ct0_offset, cleartext, lwe_dimension); @@ -62,28 +70,29 @@ void memref_negate_lwe_ciphertext_u64( uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, uint64_t ct0_stride) { assert(out_size == ct0_size && "size of lwe buffer are incompatible"); - LweDimension lwe_dimension = {out_size - 1}; + size_t lwe_dimension = {out_size - 1}; neg_lwe_ciphertext_u64(out_aligned + out_offset, ct0_aligned + ct0_offset, lwe_dimension); } -void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key, - uint64_t *out_allocated, uint64_t *out_aligned, - uint64_t out_offset, uint64_t out_size, - uint64_t out_stride, uint64_t *ct0_allocated, - uint64_t *ct0_aligned, uint64_t ct0_offset, - uint64_t ct0_size, uint64_t ct0_stride) { - bufferized_keyswitch_lwe_u64(keyswitch_key, out_aligned + out_offset, - ct0_aligned + ct0_offset); -} - -void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key, - uint64_t *out_allocated, uint64_t *out_aligned, +void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, uint64_t ct0_stride, - struct GlweCiphertext_u64 *accumulator) { - bufferized_bootstrap_lwe_u64(bootstrap_key, out_aligned + out_offset, - ct0_aligned + ct0_offset, accumulator); + mlir::concretelang::RuntimeContext *context) { + keyswitch_lwe_u64(get_engine(context), get_keyswitch_key_u64(context), + out_aligned + out_offset, ct0_aligned + ct0_offset); +} + +void memref_bootstrap_lwe_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, + uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + mlir::concretelang::RuntimeContext *context) { + bootstrap_lwe_u64(get_engine(context), get_bootstrap_key_u64(context), + out_aligned + out_offset, ct0_aligned + ct0_offset, + glwe_ct_aligned + glwe_ct_offset); } diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 02d7ebe1f..dc1d45444 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -22,7 +22,6 @@ add_mlir_library(ConcretelangSupport FHELinalgDialectTransforms FHETensorOpsToLinalg FHEToTFHE - ConcreteUnparametrize MLIRLowerableDialectsToLLVM FHEDialectAnalysis RTDialectAnalysis diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 8961ecf24..dea3b3ea5 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -205,9 +205,6 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, addPotentiallyNestedPass( pm, mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass(), enablePass); - addPotentiallyNestedPass( - pm, mlir::concretelang::createConvertConcreteToConcreteCAPIPass(), - enablePass); return pm.run(module.getOperation()); } @@ -218,11 +215,6 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::PassManager pm(&context); pipelinePrinting("StdToLLVM", pm, context); - // Unparametrize Concrete - addPotentiallyNestedPass( - pm, mlir::concretelang::createConvertConcreteUnparametrizePass(), - enablePass); - // Bufferize addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass(), enablePass); diff --git a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/bootstrap_lwe.mlir b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/bootstrap_lwe.mlir index a86287f15..6b3755165 100644 --- a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/bootstrap_lwe.mlir +++ b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/bootstrap_lwe.mlir @@ -1,15 +1,15 @@ // RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s -// CHECK: func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext, %arg2: !Concrete.context) -> tensor<1025xi64> { -// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64> -// CHECK-NEXT: %1 = call @get_bootstrap_key(%arg2) : (!Concrete.context) -> !Concrete.lwe_bootstrap_key -// CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor -// CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor -// CHECK-NEXT: call @memref_bootstrap_lwe_u64(%1, %2, %3, %arg1) : (!Concrete.lwe_bootstrap_key, tensor, tensor, !Concrete.glwe_ciphertext) -> () -// CHECK-NEXT: return %0 : tensor<1025xi64> +// CHECK: func @apply_lookup_table(%arg0: tensor<601xi64>, %arg1: tensor<2048xi64>, %arg2: !Concrete.context) -> tensor<1025xi64> { +// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64> +// CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor +// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<601xi64> to tensor +// CHECK-NEXT: %3 = tensor.cast %arg1 : tensor<2048xi64> to tensor +// CHECK-NEXT: call @memref_bootstrap_lwe_u64(%1, %2, %3, %arg2) : (tensor, tensor, tensor, !Concrete.context) -> () +// CHECK-NEXT: return %0 : tensor<1025xi64> // CHECK-NEXT: } -func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<1025xi64> { - %0 = linalg.init_tensor [1025] : tensor<1025xi64> - "BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.glwe_ciphertext) -> () - return %0 : tensor<1025xi64> -} \ No newline at end of file +func @apply_lookup_table(%arg0: tensor<601xi64>, %arg1: tensor<2048xi64>) -> tensor<1025xi64> { + %0 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, tensor<2048xi64>) -> () + return %0 : tensor<1025xi64> + } \ No newline at end of file diff --git a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/keyswitch_lwe.mlir index 054ee8f03..7ccc2905e 100644 --- a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/keyswitch_lwe.mlir +++ b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/keyswitch_lwe.mlir @@ -2,10 +2,9 @@ //CHECK: func @keyswitch_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> { //CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64> -//CHECK-NEXT: %1 = call @get_keyswitch_key(%arg1) : (!Concrete.context) -> !Concrete.lwe_key_switch_key -//CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor -//CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor -//CHECK-NEXT: call @memref_keyswitch_lwe_u64(%1, %2, %3) : (!Concrete.lwe_key_switch_key, tensor, tensor) -> () +//CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor +//CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor +//CHECK-NEXT: call @memref_keyswitch_lwe_u64(%1, %2, %arg1) : (tensor, tensor, !Concrete.context) -> () //CHECK-NEXT: return %0 : tensor<1025xi64> //CHECK-NEXT: } func @keyswitch_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir index 457dbf8cc..921092791 100644 --- a/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir @@ -2,14 +2,15 @@ // CHECK-LABEL: func @apply_lookup_table(%arg0: tensor<1025xi64>, %arg1: tensor<16xi64>) -> tensor<1025xi64> func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext + // CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [2048] : tensor<2048xi64> + // CHECK-NEXT:"BConcrete.fill_glwe_from_table"(%[[V1]], %arg1) {glweDimension = 1 : i32, outPrecision = 4 : i32, polynomialSize = 1024 : i32} : (tensor<2048xi64>, tensor<16xi64>) -> () // CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64> // CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"(%[[V2]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<1025xi64>) -> () // CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> - // CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> () + // CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, tensor<2048xi64>) -> () // CHECK-NEXT: return %[[V3]] : tensor<1025xi64> - %0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext + %0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<1024,1,4> %1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4> - %2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<1024,4> + %2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<1024,1,4>) -> !Concrete.lwe_ciphertext<1024,4> return %2 : !Concrete.lwe_ciphertext<1024,4> } diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir index 92f2a49d6..0abdc4d41 100644 --- a/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir +++ b/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir @@ -3,15 +3,16 @@ // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: tensor<2049xi64>) -> tensor<2049xi64> func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> { // CHECK-NEXT: %[[TABLE:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> - // CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE:.*]]) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext - // CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64> - // CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"([[V2:.*]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<2049xi64>) -> () + // CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [4096] : tensor<4096xi64> + // CHECK-NEXT: "BConcrete.fill_glwe_from_table"(%[[V1]], %cst) {glweDimension = 1 : i32, outPrecision = 4 : i32, polynomialSize = 2048 : i32} : (tensor<4096xi64>, tensor<16xi64>) -> () + // CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64> + // CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"(%[[V2]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<2049xi64>) -> () // CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [2049] : tensor<2049xi64> - // CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3:.*]], %[[V2:.*]], %[[V1:.*]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<2049xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> () + // CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<2049xi64>, tensor<601xi64>, tensor<4096xi64>) -> () // CHECK-NEXT: return %[[V3]] : tensor<2049xi64> %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> - %0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext + %0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<2048,1,4> %1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4> - %2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,4> + %2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<2048,1,4>) -> !Concrete.lwe_ciphertext<2048,4> return %2 : !Concrete.lwe_ciphertext<2048,4> } diff --git a/compiler/tests/Conversion/ConcreteUnparametrize/ConcreteUnparametrize/func.mlir b/compiler/tests/Conversion/ConcreteUnparametrize/ConcreteUnparametrize/func.mlir deleted file mode 100644 index 40c67ef41..000000000 --- a/compiler/tests/Conversion/ConcreteUnparametrize/ConcreteUnparametrize/func.mlir +++ /dev/null @@ -1,7 +0,0 @@ -// RUN: concretecompiler --passes concrete-unparametrize --action=dump-llvm-dialect %s 2>&1| FileCheck %s - -// CHECK-LABEL: func @main(%arg0: !Concrete.lwe_ciphertext<_,_>) -> !Concrete.lwe_ciphertext<_,_> -func @main(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> { - // CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<_,_> - return %arg0: !Concrete.lwe_ciphertext<1024,4> -} \ No newline at end of file diff --git a/compiler/tests/Conversion/ConcreteUnparametrize/ConcreteUnparametrize/unrealized_conversion_cast.mlir b/compiler/tests/Conversion/ConcreteUnparametrize/ConcreteUnparametrize/unrealized_conversion_cast.mlir deleted file mode 100644 index a38b249ec..000000000 --- a/compiler/tests/Conversion/ConcreteUnparametrize/ConcreteUnparametrize/unrealized_conversion_cast.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: concretecompiler --passes concrete-unparametrize --action=dump-llvm-dialect %s 2>&1| FileCheck %s - -// CHECK-LABEL: func @main(%arg0: !Concrete.lwe_ciphertext<_,_>) -> !Concrete.lwe_ciphertext<_,_> -func @main(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<_,_> { - // CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<_,_> - %0 = builtin.unrealized_conversion_cast %arg0 : !Concrete.lwe_ciphertext<1024,4> to !Concrete.lwe_ciphertext<_,_> - return %0: !Concrete.lwe_ciphertext<_,_> -} \ No newline at end of file diff --git a/compiler/tests/Conversion/TFHEToConcrete/TFHEToConcrete/apply_lookup_table.mlir b/compiler/tests/Conversion/TFHEToConcrete/TFHEToConcrete/apply_lookup_table.mlir index 1e4d597c6..86fbe18f4 100644 --- a/compiler/tests/Conversion/TFHEToConcrete/TFHEToConcrete/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/TFHEToConcrete/TFHEToConcrete/apply_lookup_table.mlir @@ -2,9 +2,9 @@ // CHECK-LABEL: func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> func @apply_lookup_table(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !TFHE.glwe<{1024,1,64}{4}> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext + // CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<1024,1,4> // CHECK-NEXT: %[[V2:.*]] = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4> - // CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<1024,1,4>) -> !Concrete.lwe_ciphertext<1024,4> // CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,4> %1 = "TFHE.apply_lookup_table"(%arg0, %arg1){glweDimension=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!TFHE.glwe<{1024,1,64}{4}>, tensor<16xi64>) -> (!TFHE.glwe<{1024,1,64}{4}>) return %1: !TFHE.glwe<{1024,1,64}{4}> diff --git a/compiler/tests/Conversion/TFHEToConcrete/TFHEToConcrete/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/TFHEToConcrete/TFHEToConcrete/apply_lookup_table_cst.mlir index de447fc64..e61cf92d8 100644 --- a/compiler/tests/Conversion/TFHEToConcrete/TFHEToConcrete/apply_lookup_table_cst.mlir +++ b/compiler/tests/Conversion/TFHEToConcrete/TFHEToConcrete/apply_lookup_table_cst.mlir @@ -3,9 +3,9 @@ // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> func @apply_lookup_table_cst(%arg0: !TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{2048,1,64}{4}> { // CHECK-NEXT: %[[TABLE:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> - // CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE]]) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext + // CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE]]) : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<2048,1,4> // CHECK-NEXT: %[[V2:.*]] = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4> - // CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,4> + // CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<2048,1,4>) -> !Concrete.lwe_ciphertext<2048,4> // CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<2048,4> %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> %1 = "TFHE.apply_lookup_table"(%arg0, %tlu){glweDimension=1:i32, polynomialSize=2048:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!TFHE.glwe<{2048,1,64}{4}>, tensor<16xi64>) -> (!TFHE.glwe<{2048,1,64}{4}>) diff --git a/compiler/tests/Dialect/BConcrete/ops.mlir b/compiler/tests/Dialect/BConcrete/ops.mlir index e522007d1..74a11a012 100644 --- a/compiler/tests/Dialect/BConcrete/ops.mlir +++ b/compiler/tests/Dialect/BConcrete/ops.mlir @@ -40,13 +40,13 @@ func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> { return %0 : tensor<2049xi64> } -// CHECK-LABEL: func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<2049xi64> -func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<2049xi64> { +// CHECK-LABEL: func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<4096xi64>) -> tensor<2049xi64> +func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<4096xi64>) -> tensor<2049xi64> { // CHECK-NEXT: %[[V0:.*]] = linalg.init_tensor [2049] : tensor<2049xi64> - // CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V0]], %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, !Concrete.glwe_ciphertext) -> () + // CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V0]], %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, tensor<4096xi64>) -> () // CHECK-NEXT: return %[[V0]] : tensor<2049xi64> %0 = linalg.init_tensor [2049] : tensor<2049xi64> - "BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, !Concrete.glwe_ciphertext) -> () + "BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, tensor<4096xi64>) -> () return %0 : tensor<2049xi64> } diff --git a/compiler/tests/Dialect/Concrete/Concrete/ops.mlir b/compiler/tests/Dialect/Concrete/Concrete/ops.mlir index 4a3e54fb9..2b3e1455e 100644 --- a/compiler/tests/Dialect/Concrete/Concrete/ops.mlir +++ b/compiler/tests/Dialect/Concrete/Concrete/ops.mlir @@ -36,12 +36,11 @@ func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concret return %1: !Concrete.lwe_ciphertext<2048,7> } -// CHECK-LABEL: func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,7> -func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,7> +// CHECK-LABEL: func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> +func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - - %1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext) -> (!Concrete.lwe_ciphertext<2048,7>) + %1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> return %1: !Concrete.lwe_ciphertext<2048,7> } @@ -49,7 +48,6 @@ func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.gl func @keyswitch_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { // CHECK-NEXT: %[[V1:.*]] = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - %1 = "Concrete.keyswitch_lwe"(%arg0){baseLog = 2 : i32, level = 3 : i32}: (!Concrete.lwe_ciphertext<2048,7>) -> (!Concrete.lwe_ciphertext<2048,7>) return %1: !Concrete.lwe_ciphertext<2048,7> }