mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): use engine concrete C API
remove ConcreteToConcreteCAPI and ConcreteUnparametrize passes
This commit is contained in:
@@ -171,7 +171,7 @@ test-end-to-end-jit-dfr: build-end-to-end-jit-dfr
|
||||
test-end-to-end-jit-auto-parallelization: build-end-to-end-jit-auto-parallelization
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_auto_parallelization
|
||||
|
||||
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-fhelinalg
|
||||
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-fhelinalg test-end-to-end-jit-fhe
|
||||
|
||||
show-stress-tests-summary:
|
||||
@echo '------ Stress tests summary ------'
|
||||
|
||||
@@ -27,6 +27,7 @@ using RuntimeContext = mlir::concretelang::RuntimeContext;
|
||||
|
||||
class KeySet {
|
||||
public:
|
||||
KeySet();
|
||||
~KeySet();
|
||||
|
||||
// allocate a KeySet according the ClientParameters.
|
||||
@@ -81,8 +82,7 @@ public:
|
||||
|
||||
void setRuntimeContext(RuntimeContext &context) {
|
||||
context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
|
||||
context.bsk[RuntimeContext::BASE_CONTEXT_BSK] =
|
||||
std::get<1>(this->bootstrapKeys.at("bsk_v0"));
|
||||
context.bsk = std::get<1>(this->bootstrapKeys.at("bsk_v0"));
|
||||
}
|
||||
|
||||
RuntimeContext runtimeContext() {
|
||||
@@ -105,14 +105,13 @@ public:
|
||||
|
||||
protected:
|
||||
outcome::checked<void, StringError>
|
||||
generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator);
|
||||
generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
|
||||
EncryptionRandomGenerator *generator);
|
||||
generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
|
||||
EncryptionRandomGenerator *generator);
|
||||
generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
@@ -125,7 +124,7 @@ protected:
|
||||
friend class KeySetCache;
|
||||
|
||||
private:
|
||||
EncryptionRandomGenerator *encryptionRandomGenerator;
|
||||
Engine *engine;
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
secretKeys;
|
||||
std::map<LweSecretKeyID, std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_CONCRETETOCONCRETECAPI_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_CONCRETETOCONCRETECAPI_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` operators to function call to the
|
||||
/// `ConcreteCAPI`
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToConcreteCAPIPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,18 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_CONCRETEUNPARAMETRIZE_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_CONCRETEUNPARAMETRIZE_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteUnparametrizePass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -13,8 +13,6 @@
|
||||
|
||||
#include "concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteUnparametrize/Pass.h"
|
||||
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
|
||||
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
|
||||
@@ -40,24 +40,12 @@ def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"];
|
||||
}
|
||||
|
||||
def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp"> {
|
||||
let summary = "Lower operations from the Concrete dialect to std with function call to the Concrete C API";
|
||||
let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def BConcreteToBConcreteCAPI : Pass<"bconcrete-to-bconcrete-c-api", "mlir::ModuleOp"> {
|
||||
let summary = "Lower operations from the Bufferized Concrete dialect to std with function call to the Bufferized Concrete C API";
|
||||
let constructor = "mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def ConcreteUnparametrize : Pass<"concrete-unparametrize", "mlir::ModuleOp"> {
|
||||
let summary = "Unparametrize Concrete types and remove unrealized_conversion_cast";
|
||||
let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
|
||||
let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()";
|
||||
|
||||
@@ -210,7 +210,7 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
// % [[TABLE]]){glweDimension = 1 : i32, p = 4 : i32, polynomialSize =
|
||||
// 2048 : i32}
|
||||
// : (tensor<16xi4>)
|
||||
// ->!Concrete.glwe_ciphertext
|
||||
// ->!Concrete.glwe_ciphertext<2048, 1, 4>
|
||||
// % keyswitched = "Concrete.keyswitch_lwe"(% arg0){
|
||||
// baseLog = 2 : i32,
|
||||
// level = 3 : i32
|
||||
@@ -221,7 +221,7 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
// glweDimension = 1 : i32,
|
||||
// level = 5 : i32,
|
||||
// polynomialSize = 2048 : i32
|
||||
// } : (!Concrete.lwe_ciphertext<600, 4>, !Concrete.glwe_ciphertext)
|
||||
// } : (!Concrete.lwe_ciphertext<600, 4>, !Concrete.glwe_ciphertext<2048, 1, 4>)
|
||||
// ->!Concrete.lwe_ciphertext<2048, 4>
|
||||
// ```
|
||||
mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
|
||||
@@ -240,8 +240,11 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
|
||||
mlir::Value accumulator =
|
||||
rewriter
|
||||
.create<mlir::concretelang::Concrete::GlweFromTable>(
|
||||
loc, Concrete::GlweCiphertextType::get(rewriter.getContext()),
|
||||
table, polynomialSize, glweDimension, precision)
|
||||
loc,
|
||||
Concrete::GlweCiphertextType::get(
|
||||
rewriter.getContext(), polynomialSize.getInt(),
|
||||
glweDimension.getInt(), lwe_type.getP()),
|
||||
table)
|
||||
.result();
|
||||
|
||||
// keyswitch
|
||||
|
||||
@@ -36,6 +36,17 @@ def NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> {
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def FillGlweFromTable : BConcrete_Op<"fill_glwe_from_table"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$glwe,
|
||||
I32Attr:$polynomialSize,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$outPrecision,
|
||||
1DTensorOf<[I64]>:$table
|
||||
);
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$result,
|
||||
@@ -52,7 +63,7 @@ def BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
|
||||
1DTensorOf<[I64]>:$result,
|
||||
// LweBootstrapKeyType:$bootstrap_key,
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
GlweCiphertextType:$accumulator,
|
||||
1DTensorOf<[I64]>:$accumulator,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$polynomialSize,
|
||||
I32Attr:$level,
|
||||
|
||||
@@ -55,7 +55,7 @@ def NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> {
|
||||
def GlweFromTable : Concrete_Op<"glwe_from_table"> {
|
||||
let summary = "Creates a GLWE ciphertext which is the trivial encrytion of a the input table interpreted as a polynomial (to use later in a bootstrap)";
|
||||
|
||||
let arguments = (ins TensorOf<[AnyInteger]>:$table, I32Attr:$polynomialSize, I32Attr:$glweDimension, I32Attr:$p);
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$table);
|
||||
let results = (outs GlweCiphertextType:$result);
|
||||
}
|
||||
|
||||
|
||||
@@ -16,12 +16,48 @@ def GlweCiphertextType : Concrete_Type<"GlweCiphertext"> {
|
||||
GLWE ciphertext.
|
||||
}];
|
||||
|
||||
let parameters = (ins
|
||||
"signed":$polynomialSize,
|
||||
"signed":$glweDimension,
|
||||
// Precision of the lwe ciphertext
|
||||
"signed":$p
|
||||
);
|
||||
|
||||
let printer = [{
|
||||
$_printer << "glwe_ciphertext";
|
||||
$_printer << "glwe_ciphertext<";
|
||||
if (getImpl()->polynomialSize == -1) $_printer << "_";
|
||||
else $_printer << getImpl()->polynomialSize;
|
||||
$_printer << ",";
|
||||
if (getImpl()->glweDimension == -1) $_printer << "_";
|
||||
else $_printer << getImpl()->glweDimension;
|
||||
$_printer << ",";
|
||||
if (getImpl()->p == -1) $_printer << "_";
|
||||
else $_printer << getImpl()->p;
|
||||
$_printer << ">";
|
||||
}];
|
||||
|
||||
|
||||
let parser = [{
|
||||
return get($_ctxt);
|
||||
if ($_parser.parseLess())
|
||||
return Type();
|
||||
int polynomialSize = -1;
|
||||
if ($_parser.parseOptionalKeyword("_") && $_parser.parseInteger(polynomialSize))
|
||||
return Type();
|
||||
if ($_parser.parseComma())
|
||||
return Type();
|
||||
int glweDimension = -1;
|
||||
if ($_parser.parseOptionalKeyword("_") && $_parser.parseInteger(glweDimension))
|
||||
return Type();
|
||||
if ($_parser.parseComma())
|
||||
return Type();
|
||||
|
||||
int p = -1;
|
||||
if ($_parser.parseOptionalKeyword("_") && $_parser.parseInteger(p))
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
|
||||
return getChecked(loc, loc.getContext(), polynomialSize, glweDimension, p);
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define CONCRETELANG_RUNTIME_CONTEXT_H
|
||||
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
|
||||
extern "C" {
|
||||
@@ -18,14 +19,29 @@ namespace concretelang {
|
||||
|
||||
typedef struct RuntimeContext {
|
||||
LweKeyswitchKey_u64 *ksk;
|
||||
std::map<std::string, LweBootstrapKey_u64 *> bsk;
|
||||
LweBootstrapKey_u64 *bsk;
|
||||
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
std::map<std::string, Engine *> engines;
|
||||
std::mutex engines_map_guard;
|
||||
#else
|
||||
Engine *engine;
|
||||
#endif
|
||||
|
||||
static std::string BASE_CONTEXT_BSK;
|
||||
RuntimeContext()
|
||||
#ifndef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
: engine(nullptr)
|
||||
#endif
|
||||
{
|
||||
}
|
||||
~RuntimeContext() {
|
||||
for (const auto &key : bsk) {
|
||||
if (key.first != BASE_CONTEXT_BSK)
|
||||
free_lwe_bootstrap_key_u64(key.second);
|
||||
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
for (const auto &key : engines) {
|
||||
free_engine(key.second);
|
||||
}
|
||||
#else
|
||||
if (engine != nullptr)
|
||||
free_engine(engine);
|
||||
#endif
|
||||
}
|
||||
} RuntimeContext;
|
||||
|
||||
@@ -34,9 +50,11 @@ typedef struct RuntimeContext {
|
||||
|
||||
extern "C" {
|
||||
LweKeyswitchKey_u64 *
|
||||
get_keyswitch_key(mlir::concretelang::RuntimeContext *context);
|
||||
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
LweBootstrapKey_u64 *
|
||||
get_bootstrap_key(mlir::concretelang::RuntimeContext *context);
|
||||
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
Engine *get_engine(mlir::concretelang::RuntimeContext *context);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -6,11 +6,17 @@
|
||||
#ifndef CONCRETELANG_RUNTIME_WRAPPERS_H
|
||||
#define CONCRETELANG_RUNTIME_WRAPPERS_H
|
||||
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
|
||||
struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64(
|
||||
uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size,
|
||||
uint64_t stride, uint32_t precision);
|
||||
void memref_expand_lut_in_trivial_glwe_ct_u64(
|
||||
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision,
|
||||
uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset,
|
||||
uint64_t lut_size, uint64_t lut_stride);
|
||||
|
||||
void memref_add_lwe_ciphertexts_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
@@ -37,19 +43,20 @@ void memref_negate_lwe_ciphertext_u64(
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride);
|
||||
|
||||
void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key,
|
||||
uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride);
|
||||
|
||||
void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key,
|
||||
uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
struct GlweCiphertext_u64 *accumulator);
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
KeySet::KeySet() : engine(new_engine()) {}
|
||||
|
||||
KeySet::~KeySet() {
|
||||
for (auto it : secretKeys) {
|
||||
free_lwe_secret_key_u64(it.second.second);
|
||||
@@ -28,7 +30,7 @@ KeySet::~KeySet() {
|
||||
for (auto it : keyswitchKeys) {
|
||||
free_lwe_keyswitch_key_u64(it.second.second);
|
||||
}
|
||||
free_encryption_generator(encryptionRandomGenerator);
|
||||
free_engine(engine);
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
@@ -81,9 +83,6 @@ KeySet::setupEncryptionMaterial(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
}
|
||||
}
|
||||
|
||||
this->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(seed_msb, seed_lsb);
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -93,29 +92,20 @@ KeySet::generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
|
||||
{
|
||||
// Generate LWE secret keys
|
||||
SecretRandomGenerator *generator;
|
||||
|
||||
generator = allocate_secret_generator(seed_msb, seed_lsb);
|
||||
for (auto secretKeyParam : params.secretKeys) {
|
||||
OUTCOME_TRYV(this->generateSecretKey(secretKeyParam.first,
|
||||
secretKeyParam.second, generator));
|
||||
OUTCOME_TRYV(
|
||||
this->generateSecretKey(secretKeyParam.first, secretKeyParam.second));
|
||||
}
|
||||
free_secret_generator(generator);
|
||||
}
|
||||
// Allocate the encryption random generator
|
||||
this->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(seed_msb, seed_lsb);
|
||||
// Generate bootstrap and keyswitch keys
|
||||
{
|
||||
for (auto bootstrapKeyParam : params.bootstrapKeys) {
|
||||
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam.first,
|
||||
bootstrapKeyParam.second,
|
||||
this->encryptionRandomGenerator));
|
||||
bootstrapKeyParam.second));
|
||||
}
|
||||
for (auto keyswitchParam : params.keyswitchKeys) {
|
||||
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam.first,
|
||||
keyswitchParam.second,
|
||||
this->encryptionRandomGenerator));
|
||||
keyswitchParam.second));
|
||||
}
|
||||
}
|
||||
return outcome::success();
|
||||
@@ -136,12 +126,9 @@ void KeySet::setKeys(
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator) {
|
||||
KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param) {
|
||||
LweSecretKey_u64 *sk;
|
||||
sk = allocate_lwe_secret_key_u64({param.dimension});
|
||||
|
||||
fill_lwe_secret_key_u64(sk, generator);
|
||||
sk = generate_lwe_secret_key_u64(engine, param.dimension);
|
||||
|
||||
secretKeys[id] = {param, sk};
|
||||
|
||||
@@ -149,8 +136,7 @@ KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
|
||||
EncryptionRandomGenerator *generator) {
|
||||
KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) {
|
||||
// Finding input and output secretKeys
|
||||
auto inputSk = secretKeys.find(param.inputSecretKeyID);
|
||||
if (inputSk == secretKeys.end()) {
|
||||
@@ -169,32 +155,18 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
|
||||
|
||||
uint64_t polynomialSize = total_dimension / param.glweDimension;
|
||||
|
||||
bsk = allocate_lwe_bootstrap_key_u64(
|
||||
{param.level}, {param.baseLog}, {param.glweDimension},
|
||||
{inputSk->second.first.dimension}, {polynomialSize});
|
||||
bsk = generate_lwe_bootstrap_key_u64(
|
||||
engine, inputSk->second.second, outputSk->second.second, param.baseLog,
|
||||
param.level, param.variance, param.glweDimension, polynomialSize);
|
||||
|
||||
// Store the bootstrap key
|
||||
bootstrapKeys[id] = {param, bsk};
|
||||
|
||||
// Convert the output lwe key to glwe key
|
||||
GlweSecretKey_u64 *glwe_sk;
|
||||
|
||||
glwe_sk =
|
||||
allocate_glwe_secret_key_u64({param.glweDimension}, {polynomialSize});
|
||||
|
||||
fill_glwe_secret_key_with_lwe_secret_key_u64(glwe_sk,
|
||||
outputSk->second.second);
|
||||
|
||||
// Initialize the bootstrap key
|
||||
fill_lwe_bootstrap_key_u64(bsk, inputSk->second.second, glwe_sk, generator,
|
||||
{param.variance});
|
||||
free_glwe_secret_key_u64(glwe_sk);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
|
||||
EncryptionRandomGenerator *generator) {
|
||||
KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param) {
|
||||
// Finding input and output secretKeys
|
||||
auto inputSk = secretKeys.find(param.inputSecretKeyID);
|
||||
if (inputSk == secretKeys.end()) {
|
||||
@@ -207,17 +179,13 @@ KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
|
||||
// Allocate the keyswitch key
|
||||
LweKeyswitchKey_u64 *ksk;
|
||||
|
||||
ksk = allocate_lwe_keyswitch_key_u64({param.level}, {param.baseLog},
|
||||
{inputSk->second.first.dimension},
|
||||
{outputSk->second.first.dimension});
|
||||
ksk = generate_lwe_keyswitch_key_u64(engine, inputSk->second.second,
|
||||
outputSk->second.second, param.level,
|
||||
param.baseLog, param.variance);
|
||||
|
||||
// Store the keyswitch key
|
||||
keyswitchKeys[id] = {param, ksk};
|
||||
// Initialize the keyswitch key
|
||||
|
||||
fill_lwe_keyswitch_key_u64(ksk, inputSk->second.second,
|
||||
outputSk->second.second, generator,
|
||||
{param.variance});
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -255,9 +223,8 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
|
||||
// Encode - TODO we could check if the input value is in the right range
|
||||
uint64_t plaintext =
|
||||
input << (64 - (std::get<0>(inputSk).encryption->encoding.precision + 1));
|
||||
encrypt_lwe_u64(std::get<2>(inputSk), ciphertext, plaintext,
|
||||
encryptionRandomGenerator,
|
||||
{std::get<0>(inputSk).encryption->variance});
|
||||
::encrypt_lwe_u64(engine, std::get<2>(inputSk), ciphertext, plaintext,
|
||||
std::get<0>(inputSk).encryption->variance);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -271,7 +238,8 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
if (!std::get<0>(outputSk).encryption.hasValue()) {
|
||||
return StringError("decrypt_lwe: the positional argument is not encrypted");
|
||||
}
|
||||
uint64_t plaintext = decrypt_lwe_u64(std::get<2>(outputSk), ciphertext);
|
||||
uint64_t plaintext =
|
||||
::decrypt_lwe_u64(engine, std::get<2>(outputSk), ciphertext);
|
||||
// Decode
|
||||
size_t precision = std::get<0>(outputSk).encryption->encoding.precision;
|
||||
output = plaintext >> (64 - precision - 2);
|
||||
|
||||
@@ -45,10 +45,9 @@ PublicArguments::~PublicArguments() {
|
||||
if (!clearRuntimeContext) {
|
||||
return;
|
||||
}
|
||||
for (auto bsk_entry : runtimeContext.bsk) {
|
||||
free_lwe_bootstrap_key_u64(bsk_entry.second);
|
||||
if (runtimeContext.bsk != nullptr) {
|
||||
free_lwe_bootstrap_key_u64(runtimeContext.bsk);
|
||||
}
|
||||
runtimeContext.bsk.clear();
|
||||
if (runtimeContext.ksk != nullptr) {
|
||||
free_lwe_keyswitch_key_u64(runtimeContext.ksk);
|
||||
runtimeContext.ksk = nullptr;
|
||||
|
||||
@@ -99,7 +99,7 @@ std::istream &operator>>(std::istream &istream, ClientParameters ¶ms) {
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
RuntimeContext &runtimeContext) {
|
||||
istream >> runtimeContext.ksk;
|
||||
istream >> runtimeContext.bsk[RuntimeContext::BASE_CONTEXT_BSK];
|
||||
istream >> runtimeContext.bsk;
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
@@ -107,7 +107,7 @@ std::istream &operator>>(std::istream &istream,
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const RuntimeContext &runtimeContext) {
|
||||
ostream << runtimeContext.ksk;
|
||||
ostream << runtimeContext.bsk.at(RuntimeContext::BASE_CONTEXT_BSK);
|
||||
ostream << runtimeContext.bsk;
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
|
||||
namespace {
|
||||
class BConcreteToBConcreteCAPITypeConverter : public mlir::TypeConverter {
|
||||
@@ -72,9 +73,8 @@ inline mlir::Type getGenericLweBufferType(mlir::MLIRContext *context) {
|
||||
return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64));
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::GlweCiphertextType
|
||||
getGenericGlweCiphertextType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::GlweCiphertextType::get(context);
|
||||
inline mlir::Type getGenericGlweBufferType(mlir::MLIRContext *context) {
|
||||
return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64));
|
||||
}
|
||||
|
||||
inline mlir::Type getGenericPlaintextType(mlir::MLIRContext *context) {
|
||||
@@ -114,10 +114,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
auto lweBufferType = getGenericLweBufferType(rewriter.getContext());
|
||||
auto plaintextType = getGenericPlaintextType(rewriter.getContext());
|
||||
auto cleartextType = getGenericCleartextType(rewriter.getContext());
|
||||
auto glweCiphertextType = getGenericGlweCiphertextType(rewriter.getContext());
|
||||
auto plaintextListType = getGenericPlaintextListType(rewriter.getContext());
|
||||
auto foreignPlaintextList =
|
||||
getGenericForeignPlaintextListType(rewriter.getContext());
|
||||
auto keySwitchKeyType = getGenericLweKeySwitchKeyType(rewriter.getContext());
|
||||
auto bootstrapKeyType = getGenericLweBootstrapKeyType(rewriter.getContext());
|
||||
auto contextType =
|
||||
@@ -134,7 +130,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function
|
||||
// Insert forward declaration of the add_plaintext_lwe_ciphertext function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {lweBufferType, lweBufferType, plaintextType},
|
||||
@@ -145,7 +141,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function
|
||||
// Insert forward declaration of the mul_cleartext_lwe_ciphertext function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {lweBufferType, lweBufferType, cleartextType},
|
||||
@@ -156,7 +152,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the negate_lwe_ciphertext_u64 function
|
||||
// Insert forward declaration of the negate_lwe_ciphertext function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{lweBufferType, lweBufferType}, {});
|
||||
@@ -169,8 +165,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
// Insert forward declaration of the memref_keyswitch_lwe_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {keySwitchKeyType, lweBufferType, lweBufferType},
|
||||
{});
|
||||
rewriter.getContext(), {lweBufferType, lweBufferType, contextType}, {});
|
||||
if (insertForwardDeclaration(op, rewriter, "memref_keyswitch_lwe_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
@@ -181,40 +176,40 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{bootstrapKeyType, lweBufferType, lweBufferType, glweCiphertextType},
|
||||
{});
|
||||
{lweBufferType, lweBufferType, lweBufferType, contextType}, {});
|
||||
if (insertForwardDeclaration(op, rewriter, "memref_bootstrap_lwe_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the fill_plaintext_list function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {plaintextListType, foreignPlaintextList}, {});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_plaintext_list_glwe function
|
||||
|
||||
// Insert forward declaration of the expand_lut_in_trivial_glwe_ct function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{glweCiphertextType, glweCiphertextType, plaintextListType}, {});
|
||||
{
|
||||
getGenericGlweBufferType(rewriter.getContext()),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
mlir::RankedTensorType::get(
|
||||
{-1}, mlir::IntegerType::get(rewriter.getContext(), 64)),
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
|
||||
op, rewriter, "memref_expand_lut_in_trivial_glwe_ct_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Insert forward declaration of the getGlobalKeyswitchKey function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{contextType}, {keySwitchKeyType});
|
||||
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType)
|
||||
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -223,7 +218,8 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{contextType}, {bootstrapKeyType});
|
||||
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType)
|
||||
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -233,15 +229,15 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
|
||||
// For all operands `tensor<Axi64>` replace with
|
||||
// `%casted = tensor.cast %op : tensor<Axi64> to tensor<?xui64>`
|
||||
template <typename Op>
|
||||
mlir::SmallVector<mlir::Value>
|
||||
getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) {
|
||||
getCastedTensor(mlir::Location loc, mlir::Operation::operand_range operands,
|
||||
mlir::PatternRewriter &rewriter) {
|
||||
mlir::SmallVector<mlir::Value, 4> newOperands{};
|
||||
for (mlir::Value operand : op->getOperands()) {
|
||||
for (mlir::Value operand : operands) {
|
||||
mlir::Type operandType = operand.getType();
|
||||
if (operandType.isa<mlir::RankedTensorType>()) {
|
||||
mlir::Value castedOp = rewriter.create<mlir::tensor::CastOp>(
|
||||
op.getLoc(), getGenericLweBufferType(rewriter.getContext()), operand);
|
||||
loc, getGenericLweBufferType(rewriter.getContext()), operand);
|
||||
newOperands.push_back(castedOp);
|
||||
} else {
|
||||
newOperands.push_back(operand);
|
||||
@@ -250,6 +246,14 @@ getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) {
|
||||
return std::move(newOperands);
|
||||
}
|
||||
|
||||
// For all operands `tensor<Axi64>` replace with
|
||||
// `%casted = tensor.cast %op : tensor<Axi64> to tensor<?xui64>`
|
||||
template <typename Op>
|
||||
mlir::SmallVector<mlir::Value>
|
||||
getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) {
|
||||
return getCastedTensor(op->getLoc(), op->getOperands(), rewriter);
|
||||
}
|
||||
|
||||
/// BConcreteOpToConcreteCAPICallPattern<Op> match the `BConcreteOp`
|
||||
/// Operation and replace with a call to `funcName`, the funcName should be an
|
||||
/// external function that was linked later. It insert the forward declaration
|
||||
@@ -379,15 +383,12 @@ struct BConcreteKeySwitchLweOpPattern
|
||||
matchAndRewrite(mlir::concretelang::BConcrete::KeySwitchLweBufferOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::CallOp kskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "get_keyswitch_key",
|
||||
getGenericLweKeySwitchKeyType(rewriter.getContext()),
|
||||
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
|
||||
mlir::SmallVector<mlir::Value, 3> operands{kskOp.getResult(0)};
|
||||
|
||||
mlir::SmallVector<mlir::Value, 3> operands{};
|
||||
operands.append(
|
||||
getCastedTensorOperands<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(op, rewriter));
|
||||
operands.push_back(getContextArgument(op));
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, "memref_keyswitch_lwe_u64",
|
||||
mlir::TypeRange({}), operands);
|
||||
return mlir::success();
|
||||
@@ -422,22 +423,83 @@ struct BConcreteBootstrapLweOpPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::BConcrete::BootstrapLweBufferOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::SmallVector<mlir::Value> getkskOperands{};
|
||||
mlir::CallOp bskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "get_bootstrap_key",
|
||||
getGenericLweBootstrapKeyType(rewriter.getContext()),
|
||||
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
|
||||
mlir::SmallVector<mlir::Value, 4> operands{bskOp.getResult(0)};
|
||||
mlir::SmallVector<mlir::Value, 4> operands{};
|
||||
operands.append(
|
||||
getCastedTensorOperands<
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>(op, rewriter));
|
||||
operands.push_back(getContextArgument(op));
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, "memref_bootstrap_lwe_u64",
|
||||
mlir::TypeRange({}), operands);
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// Rewrite pattern that rewrite every
|
||||
// ```
|
||||
// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1,
|
||||
// polynomialSize=2048, outPrecision=3} :
|
||||
// (tensor<4096xi64>, tensor<32xi64>) -> ()
|
||||
// ```
|
||||
//
|
||||
// to
|
||||
//
|
||||
// ```
|
||||
// %glweDim = arith.constant 1 : i32
|
||||
// %polySize = arith.constant 2048 : i32
|
||||
// %outPrecision = arith.constant 3 : i32
|
||||
// %glwe_ = tensor.cast %glwe : tensor<4096xi64> to tensor<?xi64>
|
||||
// %lut_ = tensor.cast %lut : tensor<32xi64> to tensor<?xi64>
|
||||
// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim,
|
||||
// %outPrecision, %lut_) :
|
||||
// (tensor<?xi64>, i32, i32, tensor<?xi64>) -> ()
|
||||
// ```
|
||||
struct BConcreteGlweFromTableOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::FillGlweFromTable> {
|
||||
BConcreteGlweFromTableOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::FillGlweFromTable>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::BConcrete::FillGlweFromTable op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
BConcreteToBConcreteCAPITypeConverter typeConverter;
|
||||
// %glweDim = arith.constant 1 : i32
|
||||
// %polySize = arith.constant 2048 : i32
|
||||
// %outPrecision = arith.constant 3 : i32
|
||||
|
||||
auto castedOp = getCastedTensorOperands(op, rewriter);
|
||||
|
||||
auto polySizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(op.polynomialSize()));
|
||||
auto glweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(op.glweDimension()));
|
||||
auto outPrecisionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(op.outPrecision()));
|
||||
|
||||
mlir::SmallVector<mlir::Value> newOperands{
|
||||
castedOp[0], polySizeOp, glweDimensionOp, outPrecisionOp, castedOp[1]};
|
||||
|
||||
// getCastedTensor(op.getLoc(), newOperands, rewriter);
|
||||
// perform operands conversion
|
||||
// %glwe_ = tensor.cast %glwe : tensor<4096xi64> to tensor<?xi64>
|
||||
// %lut_ = tensor.cast %lut : tensor<32xi64> to tensor<?xi64>
|
||||
|
||||
// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim,
|
||||
// %lut_) :
|
||||
// (tensor<?xi64>, i32, i32, tensor<?xi64>) -> ()
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, "memref_expand_lut_in_trivial_glwe_ct_u64",
|
||||
mlir::SmallVector<mlir::Type>{}, newOperands);
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
|
||||
/// operators to the corresponding function call to the `Concrete C API`.
|
||||
void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
@@ -455,9 +517,9 @@ void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
patterns.getContext(), "memref_negate_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteEncodeIntOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteIntToCleartextOpPattern>(patterns.getContext());
|
||||
// patterns.add<ConcreteZeroOpPattern>(patterns.getContext());
|
||||
patterns.add<BConcreteKeySwitchLweOpPattern>(patterns.getContext());
|
||||
patterns.add<BConcreteBootstrapLweOpPattern>(patterns.getContext());
|
||||
patterns.add<BConcreteGlweFromTableOpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
struct AddRuntimeContextToFuncOpPattern
|
||||
|
||||
@@ -3,7 +3,5 @@ add_subdirectory(TFHEGlobalParametrization)
|
||||
add_subdirectory(TFHEToConcrete)
|
||||
add_subdirectory(FHETensorOpsToLinalg)
|
||||
add_subdirectory(ConcreteToBConcrete)
|
||||
add_subdirectory(ConcreteToConcreteCAPI)
|
||||
add_subdirectory(BConcreteToBConcreteCAPI)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
add_subdirectory(ConcreteUnparametrize)
|
||||
|
||||
@@ -37,10 +37,19 @@ public:
|
||||
ConcreteToBConcreteTypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) {
|
||||
assert(type.getDimension() != -1);
|
||||
return mlir::RankedTensorType::get(
|
||||
{type.getDimension() + 1},
|
||||
mlir::IntegerType::get(type.getContext(), 64));
|
||||
});
|
||||
addConversion([&](mlir::concretelang::Concrete::GlweCiphertextType type) {
|
||||
assert(type.getGlweDimension() != -1);
|
||||
assert(type.getPolynomialSize() != -1);
|
||||
|
||||
return mlir::RankedTensorType::get(
|
||||
{type.getPolynomialSize() * (type.getGlweDimension() + 1)},
|
||||
mlir::IntegerType::get(type.getContext(), 64));
|
||||
});
|
||||
addConversion([&](mlir::RankedTensorType type) {
|
||||
auto lwe = type.getElementType()
|
||||
.dyn_cast_or_null<
|
||||
@@ -48,6 +57,7 @@ public:
|
||||
if (lwe == nullptr) {
|
||||
return (mlir::Type)(type);
|
||||
}
|
||||
assert(lwe.getDimension() != -1);
|
||||
mlir::SmallVector<int64_t> newShape;
|
||||
newShape.reserve(type.getShape().size() + 1);
|
||||
newShape.append(type.getShape().begin(), type.getShape().end());
|
||||
@@ -63,6 +73,7 @@ public:
|
||||
if (lwe == nullptr) {
|
||||
return (mlir::Type)(type);
|
||||
}
|
||||
assert(lwe.getDimension() != -1);
|
||||
mlir::SmallVector<int64_t> newShape;
|
||||
newShape.reserve(type.getShape().size() + 1);
|
||||
newShape.append(type.getShape().begin(), type.getShape().end());
|
||||
@@ -177,6 +188,65 @@ struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `Concrete.glwe_from_table` operators.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "Concrete.glwe_from_table"(%tlu)
|
||||
// : (tensor<$Dxi64>) ->
|
||||
// !Concrete.glwe_ciphertext<$polynomialSize,$glweDimension,$p>
|
||||
// ```
|
||||
//
|
||||
// with $D = 2^$p
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)]
|
||||
// : tensor<polynomialSize*(glweDimension+1), i64>
|
||||
// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, %tlu)
|
||||
// : tensor<polynomialSize*(glweDimension+1), i64>, i64, i64, tensor<$Dxi64>
|
||||
// ```
|
||||
struct GlweFromTablePattern : public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::GlweFromTable> {
|
||||
GlweFromTablePattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::Concrete::GlweFromTable>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy =
|
||||
op.result()
|
||||
.getType()
|
||||
.cast<mlir::concretelang::Concrete::GlweCiphertextType>();
|
||||
|
||||
auto newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)]
|
||||
// : tensor<polynomialSize*(glweDimension+1), i64>
|
||||
mlir::Value init = rewriter.replaceOpWithNewOp<mlir::linalg::InitTensorOp>(
|
||||
op, newResultTy.getShape(), newResultTy.getElementType());
|
||||
|
||||
// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension,
|
||||
// %tlu)
|
||||
|
||||
// polynomialSize*(glweDimension+1)
|
||||
auto polySize = resultTy.getPolynomialSize();
|
||||
auto glweDimension = resultTy.getGlweDimension();
|
||||
auto outPrecision = resultTy.getP();
|
||||
|
||||
rewriter.create<mlir::concretelang::BConcrete::FillGlweFromTable>(
|
||||
op.getLoc(), init, polySize, glweDimension, outPrecision, op.table());
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `tensor.extract_slice` operators that operates on tensor of lwe ciphertext.
|
||||
//
|
||||
@@ -827,7 +897,6 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
// ciphertexts)
|
||||
target.addIllegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
|
||||
target.addLegalOp<mlir::concretelang::Concrete::EncodeIntOp>();
|
||||
target.addLegalOp<mlir::concretelang::Concrete::GlweFromTable>();
|
||||
target.addLegalOp<mlir::concretelang::Concrete::IntToCleartextOp>();
|
||||
|
||||
// Add patterns to convert the zero ops to tensor.generate
|
||||
@@ -860,7 +929,10 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>>(
|
||||
&getContext());
|
||||
|
||||
// Add patterns to rewrite tensor operators that works on encrypted tensors
|
||||
patterns.insert<GlweFromTablePattern>(&getContext());
|
||||
|
||||
// Add patterns to rewrite tensor operators that works on encrypted
|
||||
// tensors
|
||||
patterns.insert<ExtractSliceOpPattern, ExtractOpPattern,
|
||||
InsertSliceOpPattern, FromElementsOpPattern>(&getContext());
|
||||
target.addDynamicallyLegalOp<
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
add_mlir_dialect_library(ConcreteToConcreteCAPI
|
||||
ConcreteToConcreteCAPI.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
|
||||
DEPENDS
|
||||
ConcreteDialect
|
||||
ConcretelangConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(ConcreteToConcreteCAPI PUBLIC MLIRIR)
|
||||
@@ -1,859 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir//IR/BuiltinTypes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
|
||||
class ConcreteToConcreteCAPITypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
ConcreteToConcreteCAPITypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](mlir::concretelang::Concrete::PlaintextType type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
});
|
||||
addConversion([&](mlir::concretelang::Concrete::CleartextType type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
|
||||
mlir::RewriterBase &rewriter,
|
||||
llvm::StringRef funcName,
|
||||
mlir::FunctionType funcType) {
|
||||
// Looking for the `funcName` Operation
|
||||
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
|
||||
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
|
||||
mlir::SymbolTable::lookupSymbolIn(module, funcName));
|
||||
if (!opFunc) {
|
||||
// Insert the forward declaration of the funcName
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
|
||||
|
||||
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
|
||||
funcType);
|
||||
opFunc.setPrivate();
|
||||
} else {
|
||||
// Check if the `funcName` is well a private function
|
||||
if (!opFunc.isPrivate()) {
|
||||
op->emitError() << "the function \"" << funcName
|
||||
<< "\" conflicts with the concrete C API, please rename";
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
|
||||
->template hasTrait<mlir::OpTrait::FunctionLike>());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Set of functions to generate generic types.
|
||||
// Generic types are used to add forward declarations without a specific type.
|
||||
// For example, we may need to add LWE ciphertext of different dimensions, or
|
||||
// allocate them. All the calls to the C API should be done using this generic
|
||||
// types, and casting should then be performed back to the appropriate type.
|
||||
|
||||
inline mlir::concretelang::Concrete::LweCiphertextType
|
||||
getGenericLweCiphertextType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::LweCiphertextType::get(context, -1, -1);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::GlweCiphertextType
|
||||
getGenericGlweCiphertextType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::GlweCiphertextType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::PlaintextType
|
||||
getGenericPlaintextType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::PlaintextType::get(context, -1);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::PlaintextListType
|
||||
getGenericPlaintextListType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::PlaintextListType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::ForeignPlaintextListType
|
||||
getGenericForeignPlaintextListType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::ForeignPlaintextListType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::CleartextType
|
||||
getGenericCleartextType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::CleartextType::get(context, -1);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::LweBootstrapKeyType
|
||||
getGenericLweBootstrapKeyType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::LweBootstrapKeyType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::LweKeySwitchKeyType
|
||||
getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::LweKeySwitchKeyType::get(context);
|
||||
}
|
||||
|
||||
// Get the generic version of the type.
|
||||
// Useful when iterating over a set of types.
|
||||
mlir::Type getGenericType(mlir::Type baseType) {
|
||||
if (baseType.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
|
||||
return getGenericLweCiphertextType(baseType.getContext());
|
||||
}
|
||||
if (baseType.isa<mlir::concretelang::Concrete::PlaintextType>()) {
|
||||
return getGenericPlaintextType(baseType.getContext());
|
||||
}
|
||||
if (baseType.isa<mlir::concretelang::Concrete::CleartextType>()) {
|
||||
return getGenericCleartextType(baseType.getContext());
|
||||
}
|
||||
return baseType;
|
||||
}
|
||||
|
||||
// Insert all forward declarations needed for the pass.
|
||||
// Should generalize input and output types for all decalarations, and the
|
||||
// pattern using them would be resposible for casting them to the appropriate
|
||||
// type.
|
||||
mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
mlir::IRRewriter &rewriter) {
|
||||
auto genericLweCiphertextType =
|
||||
getGenericLweCiphertextType(rewriter.getContext());
|
||||
auto genericGlweCiphertextType =
|
||||
getGenericGlweCiphertextType(rewriter.getContext());
|
||||
auto genericPlaintextType = getGenericPlaintextType(rewriter.getContext());
|
||||
auto genericPlaintextListType =
|
||||
getGenericPlaintextListType(rewriter.getContext());
|
||||
auto genericForeignPlaintextList =
|
||||
getGenericForeignPlaintextListType(rewriter.getContext());
|
||||
auto genericCleartextType = getGenericCleartextType(rewriter.getContext());
|
||||
auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext());
|
||||
auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext());
|
||||
auto contextType =
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
|
||||
// Insert forward declaration of allocate lwe ciphertext
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
rewriter.getIndexType(),
|
||||
},
|
||||
|
||||
{genericLweCiphertextType});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_lwe_ciphertexts function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "add_lwe_ciphertexts_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericPlaintextType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter,
|
||||
"add_plaintext_lwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericCleartextType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter,
|
||||
"mul_cleartext_lwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the negate_lwe_ciphertext_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{genericLweCiphertextType, genericLweCiphertextType}, {});
|
||||
if (insertForwardDeclaration(op, rewriter, "negate_lwe_ciphertext_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the getBsk function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{contextType}, {genericBSKType});
|
||||
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the bootstrap function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
genericBSKType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericGlweCiphertextType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "bootstrap_lwe_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the getKsk function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{contextType}, {genericKSKType});
|
||||
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the keyswitch function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
// ksk
|
||||
genericKSKType,
|
||||
// output ct
|
||||
genericLweCiphertextType,
|
||||
// input ct
|
||||
genericLweCiphertextType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "keyswitch_lwe_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the alloc_glwe function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
},
|
||||
{genericGlweCiphertextType});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_glwe_ciphertext_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the alloc_plaintext_list function
|
||||
{
|
||||
auto funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(), {rewriter.getI32Type()},
|
||||
{genericPlaintextListType});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the fill_plaintext_list function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{genericPlaintextListType, genericForeignPlaintextList}, {});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_plaintext_list_glwe function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{genericGlweCiphertextType,
|
||||
genericGlweCiphertextType,
|
||||
genericPlaintextListType},
|
||||
{});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// ConcreteOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
|
||||
/// replace with a call to `funcName`, the funcName should be an external
|
||||
/// function that was linked later. It insert the forward declaration of the
|
||||
/// private `funcName` if it not already in the symbol table.
|
||||
/// The C signature of the function should be `void funcName(int *err, out,
|
||||
/// arg0, arg1)`, the pattern rewrite:
|
||||
/// ```
|
||||
/// out = op(arg0, arg1)
|
||||
/// ```
|
||||
/// to
|
||||
/// ```
|
||||
/// err = arith.constant 0 : i64
|
||||
/// call_op(err, out, arg0, arg1);
|
||||
/// ```
|
||||
template <typename Op>
|
||||
struct ConcreteOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
ConcreteOpToConcreteCAPICallPattern(
|
||||
mlir::MLIRContext *context, mlir::StringRef funcName,
|
||||
mlir::StringRef allocName,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<Op>(context, benefit), funcName(funcName),
|
||||
allocName(allocName) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToConcreteCAPITypeConverter typeConverter;
|
||||
|
||||
mlir::Type resultType = op->getResultTypes().front();
|
||||
auto lweResultType =
|
||||
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
// Replace the operation with a call to the `funcName`
|
||||
{
|
||||
// Get the size from the dimension
|
||||
int64_t lweDimension = lweResultType.getDimension();
|
||||
|
||||
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweDimension));
|
||||
// Add the call to the allocation
|
||||
mlir::SmallVector<mlir::Value> allocOperands{lweDimensionOp};
|
||||
auto allocGeneric = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), allocName,
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
|
||||
// Construct operands for the operation.
|
||||
// errOp doesn't need to be casted to something generic, allocGeneric
|
||||
// already is. All the rest will be converted if needed
|
||||
mlir::SmallVector<mlir::Value, 4> newOperands{allocGeneric.getResult(0)};
|
||||
for (mlir::Value operand : op->getOperands()) {
|
||||
mlir::Type operandType = operand.getType();
|
||||
mlir::Type castedType = getGenericType(operandType);
|
||||
if (castedType == operandType) {
|
||||
// Type didn't change, no need for cast
|
||||
newOperands.push_back(operand);
|
||||
} else {
|
||||
// Type changed, need to cast to the generic one
|
||||
auto castedOperand = rewriter
|
||||
.create<mlir::UnrealizedConversionCastOp>(
|
||||
op.getLoc(), castedType, operand)
|
||||
.getResult(0);
|
||||
newOperands.push_back(castedOperand);
|
||||
}
|
||||
}
|
||||
// The operations called here are known to be inplace, and no need for a
|
||||
// return type.
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), funcName, mlir::TypeRange{},
|
||||
newOperands);
|
||||
// cast result value to the appropriate type
|
||||
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
|
||||
op, op.getType(), allocGeneric.getResult(0));
|
||||
}
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
std::string funcName;
|
||||
std::string allocName;
|
||||
};
|
||||
|
||||
struct ConcreteZeroOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::ZeroLWEOp> {
|
||||
ConcreteZeroOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::ZeroLWEOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::ZeroLWEOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Type resultType = op->getResultTypes().front();
|
||||
auto lweResultType =
|
||||
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
// Get the size from the dimension
|
||||
int64_t lweDimension = lweResultType.getDimension();
|
||||
|
||||
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweDimension));
|
||||
// Allocate a fresh new ciphertext
|
||||
mlir::SmallVector<mlir::Value> allocOperands{lweDimensionOp};
|
||||
auto allocGeneric = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
|
||||
// Cast the result to the appropriate type
|
||||
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
|
||||
op, op.getType(), allocGeneric.getResult(0));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct ConcreteEncodeIntOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp> {
|
||||
ConcreteEncodeIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
{
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front());
|
||||
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP()));
|
||||
|
||||
mlir::Type resultType = rewriter.getIntegerType(64);
|
||||
rewriter.replaceOpWithNewOp<mlir::arith::ShLIOp>(
|
||||
op, resultType, castedInt, constantShiftOp);
|
||||
}
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct ConcreteIntToCleartextOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::IntToCleartextOp> {
|
||||
ConcreteIntToCleartextOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(
|
||||
op, rewriter.getIntegerType(64), op->getOperands().front());
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// Rewrite the GlweFromTable operation to a series of ops:
|
||||
// - allocation of two GLWE, one for the addition, and one for storing the
|
||||
// result
|
||||
// - allocation of plaintext_list to build the GLWE accumulator
|
||||
// - build the foreign_plaintext_list using the input table
|
||||
// - fill the plaintext_list with the foreign_plaintext_list
|
||||
// - construct the GLWE accumulator by adding the plaintext_list to a freshly
|
||||
// allocated GLWE
|
||||
struct GlweFromTableOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::GlweFromTable> {
|
||||
GlweFromTableOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::GlweFromTable>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToConcreteCAPITypeConverter typeConverter;
|
||||
|
||||
// TODO: move this to insertForwardDeclarations
|
||||
// issue: can't define function with tensor<*xtype> that accept ranked
|
||||
// tensors
|
||||
|
||||
// Insert forward declaration of the foregin_pt_list function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{op->getOperandTypes().front(), rewriter.getI32Type()},
|
||||
{getGenericForeignPlaintextListType(rewriter.getContext())});
|
||||
if (insertForwardDeclaration(op, rewriter,
|
||||
"memref_runtime_foreign_plaintext_list_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
// allocate two glwe to build accumulator
|
||||
auto polySizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op->getAttr("polynomialSize"));
|
||||
auto glweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op->getAttr("glweDimension"));
|
||||
mlir::SmallVector<mlir::Value> allocGlweOperands{glweDimensionOp,
|
||||
polySizeOp};
|
||||
// first accumulator would replace the op since it's the returned value
|
||||
auto accumulatorOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, "allocate_glwe_ciphertext_u64",
|
||||
getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands);
|
||||
// second accumulator is just needed to build the actual accumulator
|
||||
auto _accumulatorOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_glwe_ciphertext_u64",
|
||||
getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands);
|
||||
// allocate plaintext list
|
||||
mlir::SmallVector<mlir::Value> allocPlaintextListOperands{polySizeOp};
|
||||
auto plaintextListOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_plaintext_list_u64",
|
||||
getGenericPlaintextListType(rewriter.getContext()),
|
||||
allocPlaintextListOperands);
|
||||
// create foreign plaintext
|
||||
auto rankedTensorType =
|
||||
op->getOperandTypes().front().cast<mlir::RankedTensorType>();
|
||||
assert(rankedTensorType.getRank() == 1 &&
|
||||
"table lookup must be of a single dimension");
|
||||
auto precisionOp =
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op->getAttr("p"));
|
||||
mlir::SmallVector<mlir::Value> ForeignPlaintextListOperands{
|
||||
op->getOperand(0), precisionOp};
|
||||
auto foreignPlaintextListOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "memref_runtime_foreign_plaintext_list_u64",
|
||||
getGenericForeignPlaintextListType(rewriter.getContext()),
|
||||
ForeignPlaintextListOperands);
|
||||
// fill plaintext list
|
||||
mlir::SmallVector<mlir::Value> FillPlaintextListOperands{
|
||||
plaintextListOp.getResult(0), foreignPlaintextListOp.getResult(0)};
|
||||
rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "fill_plaintext_list_with_expansion_u64",
|
||||
mlir::TypeRange({}), FillPlaintextListOperands);
|
||||
// add plaintext list and glwe to build final accumulator for pbs
|
||||
mlir::SmallVector<mlir::Value> AddPlaintextListGlweOperands{
|
||||
accumulatorOp.getResult(0), _accumulatorOp.getResult(0),
|
||||
plaintextListOp.getResult(0)};
|
||||
rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "add_plaintext_list_glwe_ciphertext_u64",
|
||||
mlir::TypeRange({}), AddPlaintextListGlweOperands);
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
mlir::Block *block = op->getBlock();
|
||||
while (block != nullptr) {
|
||||
if (llvm::isa<mlir::FuncOp>(block->getParentOp())) {
|
||||
|
||||
mlir::Value context = block->getArguments().back();
|
||||
|
||||
assert(
|
||||
context.getType().isa<mlir::concretelang::Concrete::ContextType>() &&
|
||||
"the Concrete.context should be the last argument of the enclosing "
|
||||
"function of the op");
|
||||
|
||||
return context;
|
||||
}
|
||||
block = block->getParentOp()->getBlock();
|
||||
}
|
||||
assert("can't find a function that enclose the op");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Rewrite a BootstrapLweOp with a series of ops:
|
||||
// - allocate the result LWE ciphertext
|
||||
// - get the global bootstrapping key
|
||||
// - use the key and the input accumulator (GLWE) to bootstrap the input
|
||||
// ciphertext
|
||||
struct ConcreteBootstrapLweOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::BootstrapLweOp> {
|
||||
ConcreteBootstrapLweOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::BootstrapLweOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultType = op->getResultTypes().front();
|
||||
// Get the size from the dimension
|
||||
int64_t outputLweDimension =
|
||||
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>()
|
||||
.getDimension();
|
||||
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(outputLweDimension));
|
||||
// allocate the result lwe ciphertext, should be of a generic type, to cast
|
||||
// before return
|
||||
mlir::SmallVector<mlir::Value> allocLweCtOperands{lweSizeOp};
|
||||
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
|
||||
// get bsk
|
||||
auto getBskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "get_bootstrap_key",
|
||||
getGenericLweBootstrapKeyType(rewriter.getContext()),
|
||||
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
|
||||
// bootstrap
|
||||
// cast input ciphertext to a generic type
|
||||
mlir::Value lweToBootstrap =
|
||||
rewriter
|
||||
.create<mlir::UnrealizedConversionCastOp>(
|
||||
op.getLoc(), getGenericType(op.getOperand(0).getType()),
|
||||
op.getOperand(0))
|
||||
.getResult(0);
|
||||
// cast input accumulator to a generic type
|
||||
mlir::Value accumulator =
|
||||
rewriter
|
||||
.create<mlir::UnrealizedConversionCastOp>(
|
||||
op.getLoc(), getGenericType(op.getOperand(1).getType()),
|
||||
op.getOperand(1))
|
||||
.getResult(0);
|
||||
mlir::SmallVector<mlir::Value> bootstrapOperands{
|
||||
getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
|
||||
lweToBootstrap, accumulator};
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), "bootstrap_lwe_u64",
|
||||
mlir::TypeRange({}), bootstrapOperands);
|
||||
// Cast result to the appropriate type
|
||||
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
|
||||
op, resultType, allocateGenericLweCtOp.getResult(0));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// Rewrite a KeySwitchLweOp with a series of ops:
|
||||
// - allocate the result LWE ciphertext
|
||||
// - get the global keyswitch key
|
||||
// - use the key to keyswitch the input ciphertext
|
||||
struct ConcreteKeySwitchLweOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::KeySwitchLweOp> {
|
||||
ConcreteKeySwitchLweOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::KeySwitchLweOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Get the size from the dimension
|
||||
int64_t lweDimension =
|
||||
op.getResult()
|
||||
.getType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>()
|
||||
.getDimension();
|
||||
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweDimension));
|
||||
// allocate the result lwe ciphertext, should be of a generic type, to cast
|
||||
// before return
|
||||
mlir::SmallVector<mlir::Value> allocLweCtOperands{lweDimensionOp};
|
||||
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
|
||||
// get ksk
|
||||
auto getKskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "get_keyswitch_key",
|
||||
getGenericLweKeySwitchKeyType(rewriter.getContext()),
|
||||
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
|
||||
// keyswitch
|
||||
// cast input ciphertext to a generic type
|
||||
mlir::Value lweToKeyswitch =
|
||||
rewriter
|
||||
.create<mlir::UnrealizedConversionCastOp>(
|
||||
op.getLoc(), getGenericType(op.getOperand().getType()),
|
||||
op.getOperand())
|
||||
.getResult(0);
|
||||
mlir::SmallVector<mlir::Value> keyswitchOperands{
|
||||
getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
|
||||
lweToKeyswitch};
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
|
||||
mlir::TypeRange({}), keyswitchOperands);
|
||||
// Cast result to the appropriate type
|
||||
auto lweOutputType = op->getResultTypes().front();
|
||||
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
|
||||
op, lweOutputType, allocateGenericLweCtOp.getResult(0));
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
|
||||
/// operators to the corresponding function call to the `Concrete C API`.
|
||||
void populateConcreteToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::Concrete::AddLweCiphertextsOp>>(
|
||||
patterns.getContext(), "add_lwe_ciphertexts_u64",
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp>>(
|
||||
patterns.getContext(), "add_plaintext_lwe_ciphertext_u64",
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>>(
|
||||
patterns.getContext(), "mul_cleartext_lwe_ciphertext_u64",
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::Concrete::NegateLweCiphertextOp>>(
|
||||
patterns.getContext(), "negate_lwe_ciphertext_u64",
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteEncodeIntOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteIntToCleartextOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteZeroOpPattern>(patterns.getContext());
|
||||
patterns.add<GlweFromTableOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteKeySwitchLweOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteBootstrapLweOpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
struct AddRuntimeContextToFuncOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::FuncOp> {
|
||||
AddRuntimeContextToFuncOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::FuncOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::FuncOp oldFuncOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
mlir::FunctionType oldFuncType = oldFuncOp.getType();
|
||||
|
||||
// Add a Concrete.context to the function signature
|
||||
mlir::SmallVector<mlir::Type> newInputs(oldFuncType.getInputs().begin(),
|
||||
oldFuncType.getInputs().end());
|
||||
newInputs.push_back(
|
||||
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
|
||||
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
|
||||
newInputs, oldFuncType.getResults());
|
||||
// Create the new func
|
||||
mlir::FuncOp newFuncOp = rewriter.create<mlir::FuncOp>(
|
||||
oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy);
|
||||
|
||||
// Create the arguments of the new func
|
||||
mlir::Region &newFuncBody = newFuncOp.body();
|
||||
mlir::Block *newFuncEntryBlock = new mlir::Block();
|
||||
newFuncEntryBlock->addArguments(newFuncTy.getInputs());
|
||||
newFuncBody.push_back(newFuncEntryBlock);
|
||||
|
||||
// Clone the old body to the new one
|
||||
mlir::BlockAndValueMapping map;
|
||||
for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) {
|
||||
map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index()));
|
||||
}
|
||||
for (auto &op : oldFuncOp.body().front()) {
|
||||
newFuncEntryBlock->push_back(op.clone(map));
|
||||
}
|
||||
rewriter.eraseOp(oldFuncOp);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Legal function are one that are private or has a Concrete.context as last
|
||||
// arguments.
|
||||
static bool isLegal(mlir::FuncOp funcOp) {
|
||||
if (!funcOp.isPublic()) {
|
||||
return true;
|
||||
}
|
||||
// TODO : Don't need to add a runtime context for function that doesn't
|
||||
// manipulates concrete types.
|
||||
//
|
||||
// if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) {
|
||||
// if (auto tensorTy = t.dyn_cast_or_null<mlir::TensorType>()) {
|
||||
// t = tensorTy.getElementType();
|
||||
// }
|
||||
// return llvm::isa<mlir::concretelang::Concrete::ConcreteDialect>(
|
||||
// t.getDialect());
|
||||
// })) {
|
||||
// return true;
|
||||
// }
|
||||
return funcOp.getType().getNumInputs() >= 1 &&
|
||||
funcOp.getType()
|
||||
.getInputs()
|
||||
.back()
|
||||
.isa<mlir::concretelang::Concrete::ContextType>();
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct ConcreteToConcreteCAPIPass
|
||||
: public ConcreteToConcreteCAPIBase<ConcreteToConcreteCAPIPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ConcreteToConcreteCAPIPass::runOnOperation() {
|
||||
mlir::ModuleOp op = getOperation();
|
||||
|
||||
// First of all add the Concrete.context to the block arguments of function
|
||||
// that manipulates ciphertexts.
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return AddRuntimeContextToFuncOpPattern::isLegal(funcOp);
|
||||
});
|
||||
|
||||
patterns.add<AddRuntimeContextToFuncOpPattern>(patterns.getContext());
|
||||
|
||||
// Apply the conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert forward declaration
|
||||
mlir::IRRewriter rewriter(&getContext());
|
||||
if (insertForwardDeclarations(op, rewriter).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
// Rewrite Concrete ops to CallOp to the Concrete C API
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
target.addIllegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
|
||||
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
|
||||
mlir::memref::MemRefDialect,
|
||||
mlir::arith::ArithmeticDialect>();
|
||||
|
||||
populateConcreteToConcreteCAPICall(patterns);
|
||||
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToConcreteCAPIPass() {
|
||||
return std::make_unique<ConcreteToConcreteCAPIPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -1,16 +0,0 @@
|
||||
add_mlir_dialect_library(ConcreteUnparametrize
|
||||
ConcreteUnparametrize.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
|
||||
DEPENDS
|
||||
ConcreteDialect
|
||||
ConcretelangConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(ConcreteUnparametrize PUBLIC MLIRIR)
|
||||
@@ -1,154 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
|
||||
/// ConcreteUnparametrizeTypeConverter is a type converter that unparametrize
|
||||
/// Concrete types
|
||||
class ConcreteUnparametrizeTypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
static mlir::Type unparematrizeConcreteType(mlir::Type type) {
|
||||
if (type.isa<mlir::concretelang::Concrete::PlaintextType>()) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
}
|
||||
if (type.isa<mlir::concretelang::Concrete::CleartextType>()) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
}
|
||||
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
|
||||
return mlir::concretelang::Concrete::LweCiphertextType::get(
|
||||
type.getContext(), -1, -1);
|
||||
}
|
||||
auto tensorType = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensorType != nullptr) {
|
||||
auto eltTy0 = tensorType.getElementType();
|
||||
auto eltTy1 = unparematrizeConcreteType(eltTy0);
|
||||
if (eltTy0 == eltTy1) {
|
||||
return type;
|
||||
}
|
||||
return mlir::RankedTensorType::get(tensorType.getShape(), eltTy1);
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
ConcreteUnparametrizeTypeConverter() {
|
||||
addConversion(
|
||||
[](mlir::Type type) { return unparematrizeConcreteType(type); });
|
||||
}
|
||||
};
|
||||
|
||||
/// Replace `%1 = unrealized_conversion_cast %0 : t0 to t1` to `%0` where t0 or
|
||||
/// t1 are a Concrete type.
|
||||
struct ConcreteUnrealizedCastReplacementPattern
|
||||
: public mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp> {
|
||||
ConcreteUnrealizedCastReplacementPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::UnrealizedConversionCastOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (mlir::isa<mlir::concretelang::Concrete::ConcreteDialect>(
|
||||
op.getOperandTypes()[0].getDialect()) ||
|
||||
mlir::isa<mlir::concretelang::Concrete::ConcreteDialect>(
|
||||
op.getType(0).getDialect())) {
|
||||
rewriter.replaceOp(op, op.getOperands());
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
};
|
||||
};
|
||||
|
||||
/// ConcreteUnparametrizePass remove all parameters of Concrete types and remove
|
||||
/// the unrealized_conversion_cast operation that operates on parametrized
|
||||
/// Concrete types.
|
||||
struct ConcreteUnparametrizePass
|
||||
: public ConcreteUnparametrizeBase<ConcreteUnparametrizePass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
|
||||
void ConcreteUnparametrizePass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
|
||||
ConcreteUnparametrizeTypeConverter converter;
|
||||
|
||||
// Conversion of linalg.generic operation
|
||||
target
|
||||
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::linalg::GenericOp, ConcreteUnparametrizeTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::tensor::GenerateOp, ConcreteUnparametrizeTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::scf::ForOp, ConcreteUnparametrizeTypeConverter>>(&getContext(),
|
||||
converter);
|
||||
|
||||
// Conversion of function signature and arguments
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Replacement of unrealized_conversion_cast
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::UnrealizedConversionCastOp>(target, converter);
|
||||
patterns.add<ConcreteUnrealizedCastReplacementPattern>(patterns.getContext());
|
||||
|
||||
// Conversion of tensor operators
|
||||
mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
|
||||
// Conversion of CallOp
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<mlir::CallOp>>(
|
||||
patterns.getContext(), converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::CallOp>(target,
|
||||
converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteUnparametrizePass() {
|
||||
return std::make_unique<ConcreteUnparametrizePass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -4,7 +4,7 @@ endif()
|
||||
|
||||
add_library(ConcretelangRuntime SHARED
|
||||
context.cpp
|
||||
wrappers.c
|
||||
wrappers.cpp
|
||||
)
|
||||
|
||||
if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED)
|
||||
|
||||
@@ -11,39 +11,32 @@
|
||||
#include <hpx/include/runtime.hpp>
|
||||
#endif
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::string RuntimeContext::BASE_CONTEXT_BSK = "_concretelang_base_context_bsk";
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
LweKeyswitchKey_u64 *
|
||||
get_keyswitch_key(mlir::concretelang::RuntimeContext *context) {
|
||||
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->ksk;
|
||||
}
|
||||
|
||||
LweBootstrapKey_u64 *
|
||||
get_bootstrap_key(mlir::concretelang::RuntimeContext *context) {
|
||||
using RuntimeContext = mlir::concretelang::RuntimeContext;
|
||||
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->bsk;
|
||||
}
|
||||
|
||||
// Instantiate one engine per thread on demand
|
||||
Engine *get_engine(mlir::concretelang::RuntimeContext *context) {
|
||||
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
std::string threadName = hpx::get_thread_name();
|
||||
auto bskIt = context->bsk.find(threadName);
|
||||
if (bskIt == context->bsk.end()) {
|
||||
assert((bskIt = context->bsk.find(RuntimeContext::BASE_CONTEXT_BSK)) !=
|
||||
context->bsk.end() &&
|
||||
bskIt->second && "No BASE_CONTEXT_BSK registered in context.");
|
||||
bskIt = context->bsk
|
||||
.insert(std::pair<std::string, LweBootstrapKey_u64 *>(
|
||||
threadName,
|
||||
clone_lwe_bootstrap_key_u64(
|
||||
context->bsk[RuntimeContext::BASE_CONTEXT_BSK])))
|
||||
.first;
|
||||
std::lock_guard<std::mutex> guard(context->engines_map_guard);
|
||||
auto engineIt = context->engines.find(threadName);
|
||||
if (engineIt == context->engines.end()) {
|
||||
engineIt =
|
||||
context->engines
|
||||
.insert(std::pair<std::string, Engine *>(threadName, new_engine()))
|
||||
.first;
|
||||
}
|
||||
assert(engineIt->second && "No engine available in context");
|
||||
return engineIt->second;
|
||||
#else
|
||||
auto bskIt = context->bsk.find(RuntimeContext::BASE_CONTEXT_BSK);
|
||||
return (context->engine == nullptr) ? context->engine = new_engine()
|
||||
: context->engine;
|
||||
#endif
|
||||
assert(bskIt->second && "No bootstrap key available in context");
|
||||
return bskIt->second;
|
||||
}
|
||||
|
||||
@@ -1,21 +1,29 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Runtime/wrappers.h"
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
|
||||
struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64(
|
||||
uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size,
|
||||
uint64_t stride, uint32_t precision) {
|
||||
void memref_expand_lut_in_trivial_glwe_ct_u64(
|
||||
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision,
|
||||
uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset,
|
||||
uint64_t lut_size, uint64_t lut_stride) {
|
||||
|
||||
assert(stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"runtime_foreign_plaintext_list_u64");
|
||||
assert(lut_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64");
|
||||
|
||||
// Encode table values in u64
|
||||
uint64_t *encoded_table = malloc(size * sizeof(uint64_t));
|
||||
for (uint64_t i = 0; i < size; i++) {
|
||||
encoded_table[i] = (aligned + offset)[i] << (64 - precision - 1);
|
||||
}
|
||||
return foreign_plaintext_list_u64(encoded_table, size);
|
||||
// TODO: is it safe to free after creating plaintext_list?
|
||||
assert(glwe_ct_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64");
|
||||
|
||||
expand_lut_in_trivial_glwe_ct_u64(glwe_ct_aligned, poly_size, glwe_dimension,
|
||||
out_precision, lut_aligned, lut_size);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void memref_add_lwe_ciphertexts_u64(
|
||||
@@ -26,7 +34,7 @@ void memref_add_lwe_ciphertexts_u64(
|
||||
uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride) {
|
||||
assert(out_size == ct0_size && out_size == ct1_size &&
|
||||
"size of lwe buffer are incompatible");
|
||||
LweDimension lwe_dimension = {out_size - 1};
|
||||
size_t lwe_dimension = {out_size - 1};
|
||||
add_two_lwe_ciphertexts_u64(out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset,
|
||||
ct1_aligned + ct1_offset, lwe_dimension);
|
||||
@@ -38,7 +46,7 @@ void memref_add_plaintext_lwe_ciphertext_u64(
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t plaintext) {
|
||||
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
|
||||
LweDimension lwe_dimension = {out_size - 1};
|
||||
size_t lwe_dimension = {out_size - 1};
|
||||
add_plaintext_to_lwe_ciphertext_u64(out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, plaintext,
|
||||
lwe_dimension);
|
||||
@@ -50,7 +58,7 @@ void memref_mul_cleartext_lwe_ciphertext_u64(
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t cleartext) {
|
||||
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
|
||||
LweDimension lwe_dimension = {out_size - 1};
|
||||
size_t lwe_dimension = {out_size - 1};
|
||||
mul_cleartext_lwe_ciphertext_u64(out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, cleartext,
|
||||
lwe_dimension);
|
||||
@@ -62,28 +70,29 @@ void memref_negate_lwe_ciphertext_u64(
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride) {
|
||||
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
|
||||
LweDimension lwe_dimension = {out_size - 1};
|
||||
size_t lwe_dimension = {out_size - 1};
|
||||
neg_lwe_ciphertext_u64(out_aligned + out_offset, ct0_aligned + ct0_offset,
|
||||
lwe_dimension);
|
||||
}
|
||||
|
||||
void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key,
|
||||
uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride) {
|
||||
bufferized_keyswitch_lwe_u64(keyswitch_key, out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset);
|
||||
}
|
||||
|
||||
void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key,
|
||||
uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
struct GlweCiphertext_u64 *accumulator) {
|
||||
bufferized_bootstrap_lwe_u64(bootstrap_key, out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, accumulator);
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
keyswitch_lwe_u64(get_engine(context), get_keyswitch_key_u64(context),
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset);
|
||||
}
|
||||
|
||||
void memref_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
bootstrap_lwe_u64(get_engine(context), get_bootstrap_key_u64(context),
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset,
|
||||
glwe_ct_aligned + glwe_ct_offset);
|
||||
}
|
||||
@@ -22,7 +22,6 @@ add_mlir_library(ConcretelangSupport
|
||||
FHELinalgDialectTransforms
|
||||
FHETensorOpsToLinalg
|
||||
FHEToTFHE
|
||||
ConcreteUnparametrize
|
||||
MLIRLowerableDialectsToLLVM
|
||||
FHEDialectAnalysis
|
||||
RTDialectAnalysis
|
||||
|
||||
@@ -205,9 +205,6 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertConcreteToConcreteCAPIPass(),
|
||||
enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
@@ -218,11 +215,6 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("StdToLLVM", pm, context);
|
||||
|
||||
// Unparametrize Concrete
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertConcreteUnparametrizePass(),
|
||||
enablePass);
|
||||
|
||||
// Bufferize
|
||||
addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass(),
|
||||
enablePass);
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext, %arg2: !Concrete.context) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %1 = call @get_bootstrap_key(%arg2) : (!Concrete.context) -> !Concrete.lwe_bootstrap_key
|
||||
// CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_bootstrap_lwe_u64(%1, %2, %3, %arg1) : (!Concrete.lwe_bootstrap_key, tensor<?xi64>, tensor<?xi64>, !Concrete.glwe_ciphertext) -> ()
|
||||
// CHECK-NEXT: return %0 : tensor<1025xi64>
|
||||
// CHECK: func @apply_lookup_table(%arg0: tensor<601xi64>, %arg1: tensor<2048xi64>, %arg2: !Concrete.context) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<601xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %3 = tensor.cast %arg1 : tensor<2048xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_bootstrap_lwe_u64(%1, %2, %3, %arg2) : (tensor<?xi64>, tensor<?xi64>, tensor<?xi64>, !Concrete.context) -> ()
|
||||
// CHECK-NEXT: return %0 : tensor<1025xi64>
|
||||
// CHECK-NEXT: }
|
||||
func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<1025xi64> {
|
||||
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.glwe_ciphertext) -> ()
|
||||
return %0 : tensor<1025xi64>
|
||||
}
|
||||
func @apply_lookup_table(%arg0: tensor<601xi64>, %arg1: tensor<2048xi64>) -> tensor<1025xi64> {
|
||||
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, tensor<2048xi64>) -> ()
|
||||
return %0 : tensor<1025xi64>
|
||||
}
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
//CHECK: func @keyswitch_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> {
|
||||
//CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
//CHECK-NEXT: %1 = call @get_keyswitch_key(%arg1) : (!Concrete.context) -> !Concrete.lwe_key_switch_key
|
||||
//CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
//CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
//CHECK-NEXT: call @memref_keyswitch_lwe_u64(%1, %2, %3) : (!Concrete.lwe_key_switch_key, tensor<?xi64>, tensor<?xi64>) -> ()
|
||||
//CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
//CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
//CHECK-NEXT: call @memref_keyswitch_lwe_u64(%1, %2, %arg1) : (tensor<?xi64>, tensor<?xi64>, !Concrete.context) -> ()
|
||||
//CHECK-NEXT: return %0 : tensor<1025xi64>
|
||||
//CHECK-NEXT: }
|
||||
func @keyswitch_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: tensor<1025xi64>, %arg1: tensor<16xi64>) -> tensor<1025xi64>
|
||||
func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [2048] : tensor<2048xi64>
|
||||
// CHECK-NEXT:"BConcrete.fill_glwe_from_table"(%[[V1]], %arg1) {glweDimension = 1 : i32, outPrecision = 4 : i32, polynomialSize = 1024 : i32} : (tensor<2048xi64>, tensor<16xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64>
|
||||
// CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"(%[[V2]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<1025xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> ()
|
||||
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, tensor<2048xi64>) -> ()
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<1025xi64>
|
||||
%0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
|
||||
%0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<1024,1,4>
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<1024,1,4>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
|
||||
@@ -3,15 +3,16 @@
|
||||
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: tensor<2049xi64>) -> tensor<2049xi64>
|
||||
func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> {
|
||||
// CHECK-NEXT: %[[TABLE:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE:.*]]) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
|
||||
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64>
|
||||
// CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"([[V2:.*]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<2049xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [4096] : tensor<4096xi64>
|
||||
// CHECK-NEXT: "BConcrete.fill_glwe_from_table"(%[[V1]], %cst) {glweDimension = 1 : i32, outPrecision = 4 : i32, polynomialSize = 2048 : i32} : (tensor<4096xi64>, tensor<16xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64>
|
||||
// CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"(%[[V2]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<2049xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [2049] : tensor<2049xi64>
|
||||
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3:.*]], %[[V2:.*]], %[[V1:.*]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<2049xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> ()
|
||||
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<2049xi64>, tensor<601xi64>, tensor<4096xi64>) -> ()
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<2049xi64>
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
%0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
|
||||
%0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<2048,1,4>
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<2048,1,4>) -> !Concrete.lwe_ciphertext<2048,4>
|
||||
return %2 : !Concrete.lwe_ciphertext<2048,4>
|
||||
}
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
// RUN: concretecompiler --passes concrete-unparametrize --action=dump-llvm-dialect %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: !Concrete.lwe_ciphertext<_,_>) -> !Concrete.lwe_ciphertext<_,_>
|
||||
func @main(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<_,_>
|
||||
return %arg0: !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
// RUN: concretecompiler --passes concrete-unparametrize --action=dump-llvm-dialect %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: !Concrete.lwe_ciphertext<_,_>) -> !Concrete.lwe_ciphertext<_,_>
|
||||
func @main(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<_,_> {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<_,_>
|
||||
%0 = builtin.unrealized_conversion_cast %arg0 : !Concrete.lwe_ciphertext<1024,4> to !Concrete.lwe_ciphertext<_,_>
|
||||
return %0: !Concrete.lwe_ciphertext<_,_>
|
||||
}
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
func @apply_lookup_table(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<1024,1,4>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<1024,1,4>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
%1 = "TFHE.apply_lookup_table"(%arg0, %arg1){glweDimension=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!TFHE.glwe<{1024,1,64}{4}>, tensor<16xi64>) -> (!TFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{4}>
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4>
|
||||
func @apply_lookup_table_cst(%arg0: !TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{2048,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[TABLE:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE]]) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE]]) : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<2048,1,4>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<2048,1,4>) -> !Concrete.lwe_ciphertext<2048,4>
|
||||
// CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<2048,4>
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
%1 = "TFHE.apply_lookup_table"(%arg0, %tlu){glweDimension=1:i32, polynomialSize=2048:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!TFHE.glwe<{2048,1,64}{4}>, tensor<16xi64>) -> (!TFHE.glwe<{2048,1,64}{4}>)
|
||||
|
||||
@@ -40,13 +40,13 @@ func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<2049xi64>
|
||||
func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<2049xi64> {
|
||||
// CHECK-LABEL: func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<4096xi64>) -> tensor<2049xi64>
|
||||
func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<4096xi64>) -> tensor<2049xi64> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = linalg.init_tensor [2049] : tensor<2049xi64>
|
||||
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V0]], %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, !Concrete.glwe_ciphertext) -> ()
|
||||
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V0]], %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, tensor<4096xi64>) -> ()
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<2049xi64>
|
||||
%0 = linalg.init_tensor [2049] : tensor<2049xi64>
|
||||
"BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, !Concrete.glwe_ciphertext) -> ()
|
||||
"BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, tensor<4096xi64>) -> ()
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
|
||||
@@ -36,12 +36,11 @@ func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concret
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-LABEL: func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
@@ -49,7 +48,6 @@ func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.gl
|
||||
func @keyswitch_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0){baseLog = 2 : i32, level = 3 : i32}: (!Concrete.lwe_ciphertext<2048,7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user