feat(compiler): use engine concrete C API

remove ConcreteToConcreteCAPI and ConcreteUnparametrize passes
This commit is contained in:
Mayeul@Zama
2022-02-24 11:16:34 +01:00
committed by mayeul-zama
parent cee07d2440
commit ca8d4fb110
37 changed files with 418 additions and 1367 deletions

View File

@@ -171,7 +171,7 @@ test-end-to-end-jit-dfr: build-end-to-end-jit-dfr
test-end-to-end-jit-auto-parallelization: build-end-to-end-jit-auto-parallelization
$(BUILD_DIR)/bin/end_to_end_jit_auto_parallelization
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-fhelinalg
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-fhelinalg test-end-to-end-jit-fhe
show-stress-tests-summary:
@echo '------ Stress tests summary ------'

View File

@@ -27,6 +27,7 @@ using RuntimeContext = mlir::concretelang::RuntimeContext;
class KeySet {
public:
KeySet();
~KeySet();
// allocate a KeySet according the ClientParameters.
@@ -81,8 +82,7 @@ public:
void setRuntimeContext(RuntimeContext &context) {
context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
context.bsk[RuntimeContext::BASE_CONTEXT_BSK] =
std::get<1>(this->bootstrapKeys.at("bsk_v0"));
context.bsk = std::get<1>(this->bootstrapKeys.at("bsk_v0"));
}
RuntimeContext runtimeContext() {
@@ -105,14 +105,13 @@ public:
protected:
outcome::checked<void, StringError>
generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
SecretRandomGenerator *generator);
generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param);
outcome::checked<void, StringError>
generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
EncryptionRandomGenerator *generator);
generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param);
outcome::checked<void, StringError>
generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
EncryptionRandomGenerator *generator);
generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param);
outcome::checked<void, StringError>
generateKeysFromParams(ClientParameters &params, uint64_t seed_msb,
@@ -125,7 +124,7 @@ protected:
friend class KeySetCache;
private:
EncryptionRandomGenerator *encryptionRandomGenerator;
Engine *engine;
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
secretKeys;
std::map<LweSecretKeyID, std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>

View File

@@ -1,22 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_CONCRETETOCONCRETECAPI_PASS_H_
#define CONCRETELANG_CONVERSION_CONCRETETOCONCRETECAPI_PASS_H_
#include "mlir/Pass/Pass.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
namespace mlir {
namespace concretelang {
/// Create a pass to convert `Concrete` operators to function call to the
/// `ConcreteCAPI`
std::unique_ptr<OperationPass<ModuleOp>>
createConvertConcreteToConcreteCAPIPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -1,18 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_CONCRETEUNPARAMETRIZE_PASS_H_
#define CONCRETELANG_CONVERSION_CONCRETEUNPARAMETRIZE_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertConcreteUnparametrizePass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -13,8 +13,6 @@
#include "concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h"
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
#include "concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h"
#include "concretelang/Conversion/ConcreteUnparametrize/Pass.h"
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"

View File

@@ -40,24 +40,12 @@ def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"];
}
def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp"> {
let summary = "Lower operations from the Concrete dialect to std with function call to the Concrete C API";
let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()";
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
}
def BConcreteToBConcreteCAPI : Pass<"bconcrete-to-bconcrete-c-api", "mlir::ModuleOp"> {
let summary = "Lower operations from the Bufferized Concrete dialect to std with function call to the Bufferized Concrete C API";
let constructor = "mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass()";
let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
}
def ConcreteUnparametrize : Pass<"concrete-unparametrize", "mlir::ModuleOp"> {
let summary = "Unparametrize Concrete types and remove unrealized_conversion_cast";
let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()";
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
}
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()";

View File

@@ -210,7 +210,7 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
// % [[TABLE]]){glweDimension = 1 : i32, p = 4 : i32, polynomialSize =
// 2048 : i32}
// : (tensor<16xi4>)
// ->!Concrete.glwe_ciphertext
// ->!Concrete.glwe_ciphertext<2048, 1, 4>
// % keyswitched = "Concrete.keyswitch_lwe"(% arg0){
// baseLog = 2 : i32,
// level = 3 : i32
@@ -221,7 +221,7 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
// glweDimension = 1 : i32,
// level = 5 : i32,
// polynomialSize = 2048 : i32
// } : (!Concrete.lwe_ciphertext<600, 4>, !Concrete.glwe_ciphertext)
// } : (!Concrete.lwe_ciphertext<600, 4>, !Concrete.glwe_ciphertext<2048, 1, 4>)
// ->!Concrete.lwe_ciphertext<2048, 4>
// ```
mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
@@ -240,8 +240,11 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::Value accumulator =
rewriter
.create<mlir::concretelang::Concrete::GlweFromTable>(
loc, Concrete::GlweCiphertextType::get(rewriter.getContext()),
table, polynomialSize, glweDimension, precision)
loc,
Concrete::GlweCiphertextType::get(
rewriter.getContext(), polynomialSize.getInt(),
glweDimension.getInt(), lwe_type.getP()),
table)
.result();
// keyswitch

View File

@@ -36,6 +36,17 @@ def NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> {
let results = (outs);
}
def FillGlweFromTable : BConcrete_Op<"fill_glwe_from_table"> {
let arguments = (ins
1DTensorOf<[I64]>:$glwe,
I32Attr:$polynomialSize,
I32Attr:$glweDimension,
I32Attr:$outPrecision,
1DTensorOf<[I64]>:$table
);
let results = (outs);
}
def KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
let arguments = (ins
1DTensorOf<[I64]>:$result,
@@ -52,7 +63,7 @@ def BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
1DTensorOf<[I64]>:$result,
// LweBootstrapKeyType:$bootstrap_key,
1DTensorOf<[I64]>:$input_ciphertext,
GlweCiphertextType:$accumulator,
1DTensorOf<[I64]>:$accumulator,
I32Attr:$glweDimension,
I32Attr:$polynomialSize,
I32Attr:$level,

View File

@@ -55,7 +55,7 @@ def NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> {
def GlweFromTable : Concrete_Op<"glwe_from_table"> {
let summary = "Creates a GLWE ciphertext which is the trivial encrytion of a the input table interpreted as a polynomial (to use later in a bootstrap)";
let arguments = (ins TensorOf<[AnyInteger]>:$table, I32Attr:$polynomialSize, I32Attr:$glweDimension, I32Attr:$p);
let arguments = (ins 1DTensorOf<[I64]>:$table);
let results = (outs GlweCiphertextType:$result);
}

View File

@@ -16,12 +16,48 @@ def GlweCiphertextType : Concrete_Type<"GlweCiphertext"> {
GLWE ciphertext.
}];
let parameters = (ins
"signed":$polynomialSize,
"signed":$glweDimension,
// Precision of the lwe ciphertext
"signed":$p
);
let printer = [{
$_printer << "glwe_ciphertext";
$_printer << "glwe_ciphertext<";
if (getImpl()->polynomialSize == -1) $_printer << "_";
else $_printer << getImpl()->polynomialSize;
$_printer << ",";
if (getImpl()->glweDimension == -1) $_printer << "_";
else $_printer << getImpl()->glweDimension;
$_printer << ",";
if (getImpl()->p == -1) $_printer << "_";
else $_printer << getImpl()->p;
$_printer << ">";
}];
let parser = [{
return get($_ctxt);
if ($_parser.parseLess())
return Type();
int polynomialSize = -1;
if ($_parser.parseOptionalKeyword("_") && $_parser.parseInteger(polynomialSize))
return Type();
if ($_parser.parseComma())
return Type();
int glweDimension = -1;
if ($_parser.parseOptionalKeyword("_") && $_parser.parseInteger(glweDimension))
return Type();
if ($_parser.parseComma())
return Type();
int p = -1;
if ($_parser.parseOptionalKeyword("_") && $_parser.parseInteger(p))
return Type();
if ($_parser.parseGreater())
return Type();
Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
return getChecked(loc, loc.getContext(), polynomialSize, glweDimension, p);
}];
}

View File

@@ -7,6 +7,7 @@
#define CONCRETELANG_RUNTIME_CONTEXT_H
#include <map>
#include <mutex>
#include <string>
extern "C" {
@@ -18,14 +19,29 @@ namespace concretelang {
typedef struct RuntimeContext {
LweKeyswitchKey_u64 *ksk;
std::map<std::string, LweBootstrapKey_u64 *> bsk;
LweBootstrapKey_u64 *bsk;
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
std::map<std::string, Engine *> engines;
std::mutex engines_map_guard;
#else
Engine *engine;
#endif
static std::string BASE_CONTEXT_BSK;
RuntimeContext()
#ifndef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
: engine(nullptr)
#endif
{
}
~RuntimeContext() {
for (const auto &key : bsk) {
if (key.first != BASE_CONTEXT_BSK)
free_lwe_bootstrap_key_u64(key.second);
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
for (const auto &key : engines) {
free_engine(key.second);
}
#else
if (engine != nullptr)
free_engine(engine);
#endif
}
} RuntimeContext;
@@ -34,9 +50,11 @@ typedef struct RuntimeContext {
extern "C" {
LweKeyswitchKey_u64 *
get_keyswitch_key(mlir::concretelang::RuntimeContext *context);
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context);
LweBootstrapKey_u64 *
get_bootstrap_key(mlir::concretelang::RuntimeContext *context);
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
Engine *get_engine(mlir::concretelang::RuntimeContext *context);
}
#endif

View File

@@ -6,11 +6,17 @@
#ifndef CONCRETELANG_RUNTIME_WRAPPERS_H
#define CONCRETELANG_RUNTIME_WRAPPERS_H
#include "concretelang/Runtime/context.h"
extern "C" {
#include "concrete-ffi.h"
struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64(
uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size,
uint64_t stride, uint32_t precision);
void memref_expand_lut_in_trivial_glwe_ct_u64(
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision,
uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset,
uint64_t lut_size, uint64_t lut_stride);
void memref_add_lwe_ciphertexts_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
@@ -37,19 +43,20 @@ void memref_negate_lwe_ciphertext_u64(
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride);
void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key,
uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride);
void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key,
uint64_t *out_allocated, uint64_t *out_aligned,
void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride,
struct GlweCiphertext_u64 *accumulator);
mlir::concretelang::RuntimeContext *context);
void memref_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
mlir::concretelang::RuntimeContext *context);
}
#endif

View File

@@ -18,6 +18,8 @@
namespace concretelang {
namespace clientlib {
KeySet::KeySet() : engine(new_engine()) {}
KeySet::~KeySet() {
for (auto it : secretKeys) {
free_lwe_secret_key_u64(it.second.second);
@@ -28,7 +30,7 @@ KeySet::~KeySet() {
for (auto it : keyswitchKeys) {
free_lwe_keyswitch_key_u64(it.second.second);
}
free_encryption_generator(encryptionRandomGenerator);
free_engine(engine);
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
@@ -81,9 +83,6 @@ KeySet::setupEncryptionMaterial(ClientParameters &params, uint64_t seed_msb,
}
}
this->encryptionRandomGenerator =
allocate_encryption_generator(seed_msb, seed_lsb);
return outcome::success();
}
@@ -93,29 +92,20 @@ KeySet::generateKeysFromParams(ClientParameters &params, uint64_t seed_msb,
{
// Generate LWE secret keys
SecretRandomGenerator *generator;
generator = allocate_secret_generator(seed_msb, seed_lsb);
for (auto secretKeyParam : params.secretKeys) {
OUTCOME_TRYV(this->generateSecretKey(secretKeyParam.first,
secretKeyParam.second, generator));
OUTCOME_TRYV(
this->generateSecretKey(secretKeyParam.first, secretKeyParam.second));
}
free_secret_generator(generator);
}
// Allocate the encryption random generator
this->encryptionRandomGenerator =
allocate_encryption_generator(seed_msb, seed_lsb);
// Generate bootstrap and keyswitch keys
{
for (auto bootstrapKeyParam : params.bootstrapKeys) {
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam.first,
bootstrapKeyParam.second,
this->encryptionRandomGenerator));
bootstrapKeyParam.second));
}
for (auto keyswitchParam : params.keyswitchKeys) {
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam.first,
keyswitchParam.second,
this->encryptionRandomGenerator));
keyswitchParam.second));
}
}
return outcome::success();
@@ -136,12 +126,9 @@ void KeySet::setKeys(
}
outcome::checked<void, StringError>
KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
SecretRandomGenerator *generator) {
KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param) {
LweSecretKey_u64 *sk;
sk = allocate_lwe_secret_key_u64({param.dimension});
fill_lwe_secret_key_u64(sk, generator);
sk = generate_lwe_secret_key_u64(engine, param.dimension);
secretKeys[id] = {param, sk};
@@ -149,8 +136,7 @@ KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
}
outcome::checked<void, StringError>
KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
EncryptionRandomGenerator *generator) {
KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) {
// Finding input and output secretKeys
auto inputSk = secretKeys.find(param.inputSecretKeyID);
if (inputSk == secretKeys.end()) {
@@ -169,32 +155,18 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
uint64_t polynomialSize = total_dimension / param.glweDimension;
bsk = allocate_lwe_bootstrap_key_u64(
{param.level}, {param.baseLog}, {param.glweDimension},
{inputSk->second.first.dimension}, {polynomialSize});
bsk = generate_lwe_bootstrap_key_u64(
engine, inputSk->second.second, outputSk->second.second, param.baseLog,
param.level, param.variance, param.glweDimension, polynomialSize);
// Store the bootstrap key
bootstrapKeys[id] = {param, bsk};
// Convert the output lwe key to glwe key
GlweSecretKey_u64 *glwe_sk;
glwe_sk =
allocate_glwe_secret_key_u64({param.glweDimension}, {polynomialSize});
fill_glwe_secret_key_with_lwe_secret_key_u64(glwe_sk,
outputSk->second.second);
// Initialize the bootstrap key
fill_lwe_bootstrap_key_u64(bsk, inputSk->second.second, glwe_sk, generator,
{param.variance});
free_glwe_secret_key_u64(glwe_sk);
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
EncryptionRandomGenerator *generator) {
KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param) {
// Finding input and output secretKeys
auto inputSk = secretKeys.find(param.inputSecretKeyID);
if (inputSk == secretKeys.end()) {
@@ -207,17 +179,13 @@ KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
// Allocate the keyswitch key
LweKeyswitchKey_u64 *ksk;
ksk = allocate_lwe_keyswitch_key_u64({param.level}, {param.baseLog},
{inputSk->second.first.dimension},
{outputSk->second.first.dimension});
ksk = generate_lwe_keyswitch_key_u64(engine, inputSk->second.second,
outputSk->second.second, param.level,
param.baseLog, param.variance);
// Store the keyswitch key
keyswitchKeys[id] = {param, ksk};
// Initialize the keyswitch key
fill_lwe_keyswitch_key_u64(ksk, inputSk->second.second,
outputSk->second.second, generator,
{param.variance});
return outcome::success();
}
@@ -255,9 +223,8 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
// Encode - TODO we could check if the input value is in the right range
uint64_t plaintext =
input << (64 - (std::get<0>(inputSk).encryption->encoding.precision + 1));
encrypt_lwe_u64(std::get<2>(inputSk), ciphertext, plaintext,
encryptionRandomGenerator,
{std::get<0>(inputSk).encryption->variance});
::encrypt_lwe_u64(engine, std::get<2>(inputSk), ciphertext, plaintext,
std::get<0>(inputSk).encryption->variance);
return outcome::success();
}
@@ -271,7 +238,8 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
if (!std::get<0>(outputSk).encryption.hasValue()) {
return StringError("decrypt_lwe: the positional argument is not encrypted");
}
uint64_t plaintext = decrypt_lwe_u64(std::get<2>(outputSk), ciphertext);
uint64_t plaintext =
::decrypt_lwe_u64(engine, std::get<2>(outputSk), ciphertext);
// Decode
size_t precision = std::get<0>(outputSk).encryption->encoding.precision;
output = plaintext >> (64 - precision - 2);

View File

@@ -45,10 +45,9 @@ PublicArguments::~PublicArguments() {
if (!clearRuntimeContext) {
return;
}
for (auto bsk_entry : runtimeContext.bsk) {
free_lwe_bootstrap_key_u64(bsk_entry.second);
if (runtimeContext.bsk != nullptr) {
free_lwe_bootstrap_key_u64(runtimeContext.bsk);
}
runtimeContext.bsk.clear();
if (runtimeContext.ksk != nullptr) {
free_lwe_keyswitch_key_u64(runtimeContext.ksk);
runtimeContext.ksk = nullptr;

View File

@@ -99,7 +99,7 @@ std::istream &operator>>(std::istream &istream, ClientParameters &params) {
std::istream &operator>>(std::istream &istream,
RuntimeContext &runtimeContext) {
istream >> runtimeContext.ksk;
istream >> runtimeContext.bsk[RuntimeContext::BASE_CONTEXT_BSK];
istream >> runtimeContext.bsk;
assert(istream.good());
return istream;
}
@@ -107,7 +107,7 @@ std::istream &operator>>(std::istream &istream,
std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext) {
ostream << runtimeContext.ksk;
ostream << runtimeContext.bsk.at(RuntimeContext::BASE_CONTEXT_BSK);
ostream << runtimeContext.bsk;
assert(ostream.good());
return ostream;
}

View File

@@ -17,6 +17,7 @@
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Support/Constants.h"
namespace {
class BConcreteToBConcreteCAPITypeConverter : public mlir::TypeConverter {
@@ -72,9 +73,8 @@ inline mlir::Type getGenericLweBufferType(mlir::MLIRContext *context) {
return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64));
}
inline mlir::concretelang::Concrete::GlweCiphertextType
getGenericGlweCiphertextType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::GlweCiphertextType::get(context);
inline mlir::Type getGenericGlweBufferType(mlir::MLIRContext *context) {
return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64));
}
inline mlir::Type getGenericPlaintextType(mlir::MLIRContext *context) {
@@ -114,10 +114,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
auto lweBufferType = getGenericLweBufferType(rewriter.getContext());
auto plaintextType = getGenericPlaintextType(rewriter.getContext());
auto cleartextType = getGenericCleartextType(rewriter.getContext());
auto glweCiphertextType = getGenericGlweCiphertextType(rewriter.getContext());
auto plaintextListType = getGenericPlaintextListType(rewriter.getContext());
auto foreignPlaintextList =
getGenericForeignPlaintextListType(rewriter.getContext());
auto keySwitchKeyType = getGenericLweKeySwitchKeyType(rewriter.getContext());
auto bootstrapKeyType = getGenericLweBootstrapKeyType(rewriter.getContext());
auto contextType =
@@ -134,7 +130,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
return mlir::failure();
}
}
// Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function
// Insert forward declaration of the add_plaintext_lwe_ciphertext function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {lweBufferType, lweBufferType, plaintextType},
@@ -145,7 +141,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
return mlir::failure();
}
}
// Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function
// Insert forward declaration of the mul_cleartext_lwe_ciphertext function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {lweBufferType, lweBufferType, cleartextType},
@@ -156,7 +152,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
return mlir::failure();
}
}
// Insert forward declaration of the negate_lwe_ciphertext_u64 function
// Insert forward declaration of the negate_lwe_ciphertext function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{lweBufferType, lweBufferType}, {});
@@ -169,8 +165,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
// Insert forward declaration of the memref_keyswitch_lwe_u64 function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {keySwitchKeyType, lweBufferType, lweBufferType},
{});
rewriter.getContext(), {lweBufferType, lweBufferType, contextType}, {});
if (insertForwardDeclaration(op, rewriter, "memref_keyswitch_lwe_u64",
funcType)
.failed()) {
@@ -181,40 +176,40 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{bootstrapKeyType, lweBufferType, lweBufferType, glweCiphertextType},
{});
{lweBufferType, lweBufferType, lweBufferType, contextType}, {});
if (insertForwardDeclaration(op, rewriter, "memref_bootstrap_lwe_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the fill_plaintext_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {plaintextListType, foreignPlaintextList}, {});
if (insertForwardDeclaration(
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_plaintext_list_glwe function
// Insert forward declaration of the expand_lut_in_trivial_glwe_ct function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{glweCiphertextType, glweCiphertextType, plaintextListType}, {});
{
getGenericGlweBufferType(rewriter.getContext()),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
mlir::RankedTensorType::get(
{-1}, mlir::IntegerType::get(rewriter.getContext(), 64)),
},
{});
if (insertForwardDeclaration(
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
op, rewriter, "memref_expand_lut_in_trivial_glwe_ct_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the getGlobalKeyswitchKey function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{contextType}, {keySwitchKeyType});
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType)
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key_u64",
funcType)
.failed()) {
return mlir::failure();
}
@@ -223,7 +218,8 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{contextType}, {bootstrapKeyType});
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType)
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key_u64",
funcType)
.failed()) {
return mlir::failure();
}
@@ -233,15 +229,15 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
// For all operands `tensor<Axi64>` replace with
// `%casted = tensor.cast %op : tensor<Axi64> to tensor<?xui64>`
template <typename Op>
mlir::SmallVector<mlir::Value>
getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) {
getCastedTensor(mlir::Location loc, mlir::Operation::operand_range operands,
mlir::PatternRewriter &rewriter) {
mlir::SmallVector<mlir::Value, 4> newOperands{};
for (mlir::Value operand : op->getOperands()) {
for (mlir::Value operand : operands) {
mlir::Type operandType = operand.getType();
if (operandType.isa<mlir::RankedTensorType>()) {
mlir::Value castedOp = rewriter.create<mlir::tensor::CastOp>(
op.getLoc(), getGenericLweBufferType(rewriter.getContext()), operand);
loc, getGenericLweBufferType(rewriter.getContext()), operand);
newOperands.push_back(castedOp);
} else {
newOperands.push_back(operand);
@@ -250,6 +246,14 @@ getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) {
return std::move(newOperands);
}
// For all operands `tensor<Axi64>` replace with
// `%casted = tensor.cast %op : tensor<Axi64> to tensor<?xui64>`
template <typename Op>
mlir::SmallVector<mlir::Value>
getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) {
return getCastedTensor(op->getLoc(), op->getOperands(), rewriter);
}
/// BConcreteOpToConcreteCAPICallPattern<Op> match the `BConcreteOp`
/// Operation and replace with a call to `funcName`, the funcName should be an
/// external function that was linked later. It insert the forward declaration
@@ -379,15 +383,12 @@ struct BConcreteKeySwitchLweOpPattern
matchAndRewrite(mlir::concretelang::BConcrete::KeySwitchLweBufferOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::CallOp kskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "get_keyswitch_key",
getGenericLweKeySwitchKeyType(rewriter.getContext()),
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
mlir::SmallVector<mlir::Value, 3> operands{kskOp.getResult(0)};
mlir::SmallVector<mlir::Value, 3> operands{};
operands.append(
getCastedTensorOperands<
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(op, rewriter));
operands.push_back(getContextArgument(op));
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, "memref_keyswitch_lwe_u64",
mlir::TypeRange({}), operands);
return mlir::success();
@@ -422,22 +423,83 @@ struct BConcreteBootstrapLweOpPattern
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::BConcrete::BootstrapLweBufferOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::SmallVector<mlir::Value> getkskOperands{};
mlir::CallOp bskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "get_bootstrap_key",
getGenericLweBootstrapKeyType(rewriter.getContext()),
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
mlir::SmallVector<mlir::Value, 4> operands{bskOp.getResult(0)};
mlir::SmallVector<mlir::Value, 4> operands{};
operands.append(
getCastedTensorOperands<
mlir::concretelang::BConcrete::BootstrapLweBufferOp>(op, rewriter));
operands.push_back(getContextArgument(op));
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, "memref_bootstrap_lwe_u64",
mlir::TypeRange({}), operands);
return mlir::success();
};
};
// Rewrite pattern that rewrite every
// ```
// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1,
// polynomialSize=2048, outPrecision=3} :
// (tensor<4096xi64>, tensor<32xi64>) -> ()
// ```
//
// to
//
// ```
// %glweDim = arith.constant 1 : i32
// %polySize = arith.constant 2048 : i32
// %outPrecision = arith.constant 3 : i32
// %glwe_ = tensor.cast %glwe : tensor<4096xi64> to tensor<?xi64>
// %lut_ = tensor.cast %lut : tensor<32xi64> to tensor<?xi64>
// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim,
// %outPrecision, %lut_) :
// (tensor<?xi64>, i32, i32, tensor<?xi64>) -> ()
// ```
struct BConcreteGlweFromTableOpPattern
: public mlir::OpRewritePattern<
mlir::concretelang::BConcrete::FillGlweFromTable> {
BConcreteGlweFromTableOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<
mlir::concretelang::BConcrete::FillGlweFromTable>(context,
benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::BConcrete::FillGlweFromTable op,
mlir::PatternRewriter &rewriter) const override {
BConcreteToBConcreteCAPITypeConverter typeConverter;
// %glweDim = arith.constant 1 : i32
// %polySize = arith.constant 2048 : i32
// %outPrecision = arith.constant 3 : i32
auto castedOp = getCastedTensorOperands(op, rewriter);
auto polySizeOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(op.polynomialSize()));
auto glweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(op.glweDimension()));
auto outPrecisionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(op.outPrecision()));
mlir::SmallVector<mlir::Value> newOperands{
castedOp[0], polySizeOp, glweDimensionOp, outPrecisionOp, castedOp[1]};
// getCastedTensor(op.getLoc(), newOperands, rewriter);
// perform operands conversion
// %glwe_ = tensor.cast %glwe : tensor<4096xi64> to tensor<?xi64>
// %lut_ = tensor.cast %lut : tensor<32xi64> to tensor<?xi64>
// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim,
// %lut_) :
// (tensor<?xi64>, i32, i32, tensor<?xi64>) -> ()
rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "memref_expand_lut_in_trivial_glwe_ct_u64",
mlir::SmallVector<mlir::Type>{}, newOperands);
return mlir::success();
};
};
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
/// operators to the corresponding function call to the `Concrete C API`.
void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) {
@@ -455,9 +517,9 @@ void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) {
patterns.getContext(), "memref_negate_lwe_ciphertext_u64");
patterns.add<ConcreteEncodeIntOpPattern>(patterns.getContext());
patterns.add<ConcreteIntToCleartextOpPattern>(patterns.getContext());
// patterns.add<ConcreteZeroOpPattern>(patterns.getContext());
patterns.add<BConcreteKeySwitchLweOpPattern>(patterns.getContext());
patterns.add<BConcreteBootstrapLweOpPattern>(patterns.getContext());
patterns.add<BConcreteGlweFromTableOpPattern>(patterns.getContext());
}
struct AddRuntimeContextToFuncOpPattern

View File

@@ -3,7 +3,5 @@ add_subdirectory(TFHEGlobalParametrization)
add_subdirectory(TFHEToConcrete)
add_subdirectory(FHETensorOpsToLinalg)
add_subdirectory(ConcreteToBConcrete)
add_subdirectory(ConcreteToConcreteCAPI)
add_subdirectory(BConcreteToBConcreteCAPI)
add_subdirectory(MLIRLowerableDialectsToLLVM)
add_subdirectory(ConcreteUnparametrize)

View File

@@ -37,10 +37,19 @@ public:
ConcreteToBConcreteTypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) {
assert(type.getDimension() != -1);
return mlir::RankedTensorType::get(
{type.getDimension() + 1},
mlir::IntegerType::get(type.getContext(), 64));
});
addConversion([&](mlir::concretelang::Concrete::GlweCiphertextType type) {
assert(type.getGlweDimension() != -1);
assert(type.getPolynomialSize() != -1);
return mlir::RankedTensorType::get(
{type.getPolynomialSize() * (type.getGlweDimension() + 1)},
mlir::IntegerType::get(type.getContext(), 64));
});
addConversion([&](mlir::RankedTensorType type) {
auto lwe = type.getElementType()
.dyn_cast_or_null<
@@ -48,6 +57,7 @@ public:
if (lwe == nullptr) {
return (mlir::Type)(type);
}
assert(lwe.getDimension() != -1);
mlir::SmallVector<int64_t> newShape;
newShape.reserve(type.getShape().size() + 1);
newShape.append(type.getShape().begin(), type.getShape().end());
@@ -63,6 +73,7 @@ public:
if (lwe == nullptr) {
return (mlir::Type)(type);
}
assert(lwe.getDimension() != -1);
mlir::SmallVector<int64_t> newShape;
newShape.reserve(type.getShape().size() + 1);
newShape.append(type.getShape().begin(), type.getShape().end());
@@ -177,6 +188,65 @@ struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
};
};
// This rewrite pattern transforms any instance of
// `Concrete.glwe_from_table` operators.
//
// Example:
//
// ```mlir
// %0 = "Concrete.glwe_from_table"(%tlu)
// : (tensor<$Dxi64>) ->
// !Concrete.glwe_ciphertext<$polynomialSize,$glweDimension,$p>
// ```
//
// with $D = 2^$p
//
// becomes:
//
// ```mlir
// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)]
// : tensor<polynomialSize*(glweDimension+1), i64>
// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, %tlu)
// : tensor<polynomialSize*(glweDimension+1), i64>, i64, i64, tensor<$Dxi64>
// ```
struct GlweFromTablePattern : public mlir::OpRewritePattern<
mlir::concretelang::Concrete::GlweFromTable> {
GlweFromTablePattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::concretelang::Concrete::GlweFromTable>(
context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy =
op.result()
.getType()
.cast<mlir::concretelang::Concrete::GlweCiphertextType>();
auto newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)]
// : tensor<polynomialSize*(glweDimension+1), i64>
mlir::Value init = rewriter.replaceOpWithNewOp<mlir::linalg::InitTensorOp>(
op, newResultTy.getShape(), newResultTy.getElementType());
// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension,
// %tlu)
// polynomialSize*(glweDimension+1)
auto polySize = resultTy.getPolynomialSize();
auto glweDimension = resultTy.getGlweDimension();
auto outPrecision = resultTy.getP();
rewriter.create<mlir::concretelang::BConcrete::FillGlweFromTable>(
op.getLoc(), init, polySize, glweDimension, outPrecision, op.table());
return ::mlir::success();
};
};
// This rewrite pattern transforms any instance of
// `tensor.extract_slice` operators that operates on tensor of lwe ciphertext.
//
@@ -827,7 +897,6 @@ void ConcreteToBConcretePass::runOnOperation() {
// ciphertexts)
target.addIllegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
target.addLegalOp<mlir::concretelang::Concrete::EncodeIntOp>();
target.addLegalOp<mlir::concretelang::Concrete::GlweFromTable>();
target.addLegalOp<mlir::concretelang::Concrete::IntToCleartextOp>();
// Add patterns to convert the zero ops to tensor.generate
@@ -860,7 +929,10 @@ void ConcreteToBConcretePass::runOnOperation() {
mlir::concretelang::BConcrete::BootstrapLweBufferOp>>(
&getContext());
// Add patterns to rewrite tensor operators that works on encrypted tensors
patterns.insert<GlweFromTablePattern>(&getContext());
// Add patterns to rewrite tensor operators that works on encrypted
// tensors
patterns.insert<ExtractSliceOpPattern, ExtractOpPattern,
InsertSliceOpPattern, FromElementsOpPattern>(&getContext());
target.addDynamicallyLegalOp<

View File

@@ -1,16 +0,0 @@
add_mlir_dialect_library(ConcreteToConcreteCAPI
ConcreteToConcreteCAPI.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
ConcreteDialect
ConcretelangConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
)
target_link_libraries(ConcreteToConcreteCAPI PUBLIC MLIRIR)

View File

@@ -1,859 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "mlir//IR/BuiltinTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Support/Constants.h"
class ConcreteToConcreteCAPITypeConverter : public mlir::TypeConverter {
public:
ConcreteToConcreteCAPITypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](mlir::concretelang::Concrete::PlaintextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
addConversion([&](mlir::concretelang::Concrete::CleartextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
}
};
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
mlir::RewriterBase &rewriter,
llvm::StringRef funcName,
mlir::FunctionType funcType) {
// Looking for the `funcName` Operation
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
mlir::SymbolTable::lookupSymbolIn(module, funcName));
if (!opFunc) {
// Insert the forward declaration of the funcName
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
funcType);
opFunc.setPrivate();
} else {
// Check if the `funcName` is well a private function
if (!opFunc.isPrivate()) {
op->emitError() << "the function \"" << funcName
<< "\" conflicts with the concrete C API, please rename";
return mlir::failure();
}
}
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
->template hasTrait<mlir::OpTrait::FunctionLike>());
return mlir::success();
}
// Set of functions to generate generic types.
// Generic types are used to add forward declarations without a specific type.
// For example, we may need to add LWE ciphertext of different dimensions, or
// allocate them. All the calls to the C API should be done using this generic
// types, and casting should then be performed back to the appropriate type.
inline mlir::concretelang::Concrete::LweCiphertextType
getGenericLweCiphertextType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::LweCiphertextType::get(context, -1, -1);
}
inline mlir::concretelang::Concrete::GlweCiphertextType
getGenericGlweCiphertextType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::GlweCiphertextType::get(context);
}
inline mlir::concretelang::Concrete::PlaintextType
getGenericPlaintextType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::PlaintextType::get(context, -1);
}
inline mlir::concretelang::Concrete::PlaintextListType
getGenericPlaintextListType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::PlaintextListType::get(context);
}
inline mlir::concretelang::Concrete::ForeignPlaintextListType
getGenericForeignPlaintextListType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::ForeignPlaintextListType::get(context);
}
inline mlir::concretelang::Concrete::CleartextType
getGenericCleartextType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::CleartextType::get(context, -1);
}
inline mlir::concretelang::Concrete::LweBootstrapKeyType
getGenericLweBootstrapKeyType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::LweBootstrapKeyType::get(context);
}
inline mlir::concretelang::Concrete::LweKeySwitchKeyType
getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::LweKeySwitchKeyType::get(context);
}
// Get the generic version of the type.
// Useful when iterating over a set of types.
mlir::Type getGenericType(mlir::Type baseType) {
if (baseType.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
return getGenericLweCiphertextType(baseType.getContext());
}
if (baseType.isa<mlir::concretelang::Concrete::PlaintextType>()) {
return getGenericPlaintextType(baseType.getContext());
}
if (baseType.isa<mlir::concretelang::Concrete::CleartextType>()) {
return getGenericCleartextType(baseType.getContext());
}
return baseType;
}
// Insert all forward declarations needed for the pass.
// Should generalize input and output types for all decalarations, and the
// pattern using them would be resposible for casting them to the appropriate
// type.
mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
mlir::IRRewriter &rewriter) {
auto genericLweCiphertextType =
getGenericLweCiphertextType(rewriter.getContext());
auto genericGlweCiphertextType =
getGenericGlweCiphertextType(rewriter.getContext());
auto genericPlaintextType = getGenericPlaintextType(rewriter.getContext());
auto genericPlaintextListType =
getGenericPlaintextListType(rewriter.getContext());
auto genericForeignPlaintextList =
getGenericForeignPlaintextListType(rewriter.getContext());
auto genericCleartextType = getGenericCleartextType(rewriter.getContext());
auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext());
auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext());
auto contextType =
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
// Insert forward declaration of allocate lwe ciphertext
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
rewriter.getIndexType(),
},
{genericLweCiphertextType});
if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_lwe_ciphertexts function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
genericLweCiphertextType,
genericLweCiphertextType,
genericLweCiphertextType,
},
{});
if (insertForwardDeclaration(op, rewriter, "add_lwe_ciphertexts_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
genericLweCiphertextType,
genericLweCiphertextType,
genericPlaintextType,
},
{});
if (insertForwardDeclaration(op, rewriter,
"add_plaintext_lwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
genericLweCiphertextType,
genericLweCiphertextType,
genericCleartextType,
},
{});
if (insertForwardDeclaration(op, rewriter,
"mul_cleartext_lwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the negate_lwe_ciphertext_u64 function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{genericLweCiphertextType, genericLweCiphertextType}, {});
if (insertForwardDeclaration(op, rewriter, "negate_lwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the getBsk function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{contextType}, {genericBSKType});
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the bootstrap function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
genericBSKType,
genericLweCiphertextType,
genericLweCiphertextType,
genericGlweCiphertextType,
},
{});
if (insertForwardDeclaration(op, rewriter, "bootstrap_lwe_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the getKsk function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{contextType}, {genericKSKType});
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the keyswitch function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
// ksk
genericKSKType,
// output ct
genericLweCiphertextType,
// input ct
genericLweCiphertextType,
},
{});
if (insertForwardDeclaration(op, rewriter, "keyswitch_lwe_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the alloc_glwe function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
rewriter.getI32Type(),
rewriter.getI32Type(),
},
{genericGlweCiphertextType});
if (insertForwardDeclaration(op, rewriter, "allocate_glwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the alloc_plaintext_list function
{
auto funcType =
mlir::FunctionType::get(rewriter.getContext(), {rewriter.getI32Type()},
{genericPlaintextListType});
if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the fill_plaintext_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{genericPlaintextListType, genericForeignPlaintextList}, {});
if (insertForwardDeclaration(
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_plaintext_list_glwe function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{genericGlweCiphertextType,
genericGlweCiphertextType,
genericPlaintextListType},
{});
if (insertForwardDeclaration(
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
return mlir::success();
}
/// ConcreteOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
/// replace with a call to `funcName`, the funcName should be an external
/// function that was linked later. It insert the forward declaration of the
/// private `funcName` if it not already in the symbol table.
/// The C signature of the function should be `void funcName(int *err, out,
/// arg0, arg1)`, the pattern rewrite:
/// ```
/// out = op(arg0, arg1)
/// ```
/// to
/// ```
/// err = arith.constant 0 : i64
/// call_op(err, out, arg0, arg1);
/// ```
template <typename Op>
struct ConcreteOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
ConcreteOpToConcreteCAPICallPattern(
mlir::MLIRContext *context, mlir::StringRef funcName,
mlir::StringRef allocName,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<Op>(context, benefit), funcName(funcName),
allocName(allocName) {}
mlir::LogicalResult
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
ConcreteToConcreteCAPITypeConverter typeConverter;
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
// Replace the operation with a call to the `funcName`
{
// Get the size from the dimension
int64_t lweDimension = lweResultType.getDimension();
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweDimension));
// Add the call to the allocation
mlir::SmallVector<mlir::Value> allocOperands{lweDimensionOp};
auto allocGeneric = rewriter.create<mlir::CallOp>(
op.getLoc(), allocName,
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
// Construct operands for the operation.
// errOp doesn't need to be casted to something generic, allocGeneric
// already is. All the rest will be converted if needed
mlir::SmallVector<mlir::Value, 4> newOperands{allocGeneric.getResult(0)};
for (mlir::Value operand : op->getOperands()) {
mlir::Type operandType = operand.getType();
mlir::Type castedType = getGenericType(operandType);
if (castedType == operandType) {
// Type didn't change, no need for cast
newOperands.push_back(operand);
} else {
// Type changed, need to cast to the generic one
auto castedOperand = rewriter
.create<mlir::UnrealizedConversionCastOp>(
op.getLoc(), castedType, operand)
.getResult(0);
newOperands.push_back(castedOperand);
}
}
// The operations called here are known to be inplace, and no need for a
// return type.
rewriter.create<mlir::CallOp>(op.getLoc(), funcName, mlir::TypeRange{},
newOperands);
// cast result value to the appropriate type
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
op, op.getType(), allocGeneric.getResult(0));
}
return mlir::success();
};
private:
std::string funcName;
std::string allocName;
};
struct ConcreteZeroOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::ZeroLWEOp> {
ConcreteZeroOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::Concrete::ZeroLWEOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::ZeroLWEOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
// Get the size from the dimension
int64_t lweDimension = lweResultType.getDimension();
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweDimension));
// Allocate a fresh new ciphertext
mlir::SmallVector<mlir::Value> allocOperands{lweDimensionOp};
auto allocGeneric = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
// Cast the result to the appropriate type
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
op, op.getType(), allocGeneric.getResult(0));
return mlir::success();
};
};
struct ConcreteEncodeIntOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp> {
ConcreteEncodeIntOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op,
mlir::PatternRewriter &rewriter) const override {
{
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front());
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP()));
mlir::Type resultType = rewriter.getIntegerType(64);
rewriter.replaceOpWithNewOp<mlir::arith::ShLIOp>(
op, resultType, castedInt, constantShiftOp);
}
return mlir::success();
};
};
struct ConcreteIntToCleartextOpPattern
: public mlir::OpRewritePattern<
mlir::concretelang::Concrete::IntToCleartextOp> {
ConcreteIntToCleartextOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(
op, rewriter.getIntegerType(64), op->getOperands().front());
return mlir::success();
};
};
// Rewrite the GlweFromTable operation to a series of ops:
// - allocation of two GLWE, one for the addition, and one for storing the
// result
// - allocation of plaintext_list to build the GLWE accumulator
// - build the foreign_plaintext_list using the input table
// - fill the plaintext_list with the foreign_plaintext_list
// - construct the GLWE accumulator by adding the plaintext_list to a freshly
// allocated GLWE
struct GlweFromTableOpPattern
: public mlir::OpRewritePattern<
mlir::concretelang::Concrete::GlweFromTable> {
GlweFromTableOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::Concrete::GlweFromTable>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op,
mlir::PatternRewriter &rewriter) const override {
ConcreteToConcreteCAPITypeConverter typeConverter;
// TODO: move this to insertForwardDeclarations
// issue: can't define function with tensor<*xtype> that accept ranked
// tensors
// Insert forward declaration of the foregin_pt_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{op->getOperandTypes().front(), rewriter.getI32Type()},
{getGenericForeignPlaintextListType(rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter,
"memref_runtime_foreign_plaintext_list_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// allocate two glwe to build accumulator
auto polySizeOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op->getAttr("polynomialSize"));
auto glweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op->getAttr("glweDimension"));
mlir::SmallVector<mlir::Value> allocGlweOperands{glweDimensionOp,
polySizeOp};
// first accumulator would replace the op since it's the returned value
auto accumulatorOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "allocate_glwe_ciphertext_u64",
getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands);
// second accumulator is just needed to build the actual accumulator
auto _accumulatorOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_glwe_ciphertext_u64",
getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands);
// allocate plaintext list
mlir::SmallVector<mlir::Value> allocPlaintextListOperands{polySizeOp};
auto plaintextListOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_plaintext_list_u64",
getGenericPlaintextListType(rewriter.getContext()),
allocPlaintextListOperands);
// create foreign plaintext
auto rankedTensorType =
op->getOperandTypes().front().cast<mlir::RankedTensorType>();
assert(rankedTensorType.getRank() == 1 &&
"table lookup must be of a single dimension");
auto precisionOp =
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op->getAttr("p"));
mlir::SmallVector<mlir::Value> ForeignPlaintextListOperands{
op->getOperand(0), precisionOp};
auto foreignPlaintextListOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "memref_runtime_foreign_plaintext_list_u64",
getGenericForeignPlaintextListType(rewriter.getContext()),
ForeignPlaintextListOperands);
// fill plaintext list
mlir::SmallVector<mlir::Value> FillPlaintextListOperands{
plaintextListOp.getResult(0), foreignPlaintextListOp.getResult(0)};
rewriter.create<mlir::CallOp>(
op.getLoc(), "fill_plaintext_list_with_expansion_u64",
mlir::TypeRange({}), FillPlaintextListOperands);
// add plaintext list and glwe to build final accumulator for pbs
mlir::SmallVector<mlir::Value> AddPlaintextListGlweOperands{
accumulatorOp.getResult(0), _accumulatorOp.getResult(0),
plaintextListOp.getResult(0)};
rewriter.create<mlir::CallOp>(
op.getLoc(), "add_plaintext_list_glwe_ciphertext_u64",
mlir::TypeRange({}), AddPlaintextListGlweOperands);
return mlir::success();
};
};
mlir::Value getContextArgument(mlir::Operation *op) {
mlir::Block *block = op->getBlock();
while (block != nullptr) {
if (llvm::isa<mlir::FuncOp>(block->getParentOp())) {
mlir::Value context = block->getArguments().back();
assert(
context.getType().isa<mlir::concretelang::Concrete::ContextType>() &&
"the Concrete.context should be the last argument of the enclosing "
"function of the op");
return context;
}
block = block->getParentOp()->getBlock();
}
assert("can't find a function that enclose the op");
return nullptr;
}
// Rewrite a BootstrapLweOp with a series of ops:
// - allocate the result LWE ciphertext
// - get the global bootstrapping key
// - use the key and the input accumulator (GLWE) to bootstrap the input
// ciphertext
struct ConcreteBootstrapLweOpPattern
: public mlir::OpRewritePattern<
mlir::concretelang::Concrete::BootstrapLweOp> {
ConcreteBootstrapLweOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::Concrete::BootstrapLweOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultType = op->getResultTypes().front();
// Get the size from the dimension
int64_t outputLweDimension =
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>()
.getDimension();
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(outputLweDimension));
// allocate the result lwe ciphertext, should be of a generic type, to cast
// before return
mlir::SmallVector<mlir::Value> allocLweCtOperands{lweSizeOp};
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
// get bsk
auto getBskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "get_bootstrap_key",
getGenericLweBootstrapKeyType(rewriter.getContext()),
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
// bootstrap
// cast input ciphertext to a generic type
mlir::Value lweToBootstrap =
rewriter
.create<mlir::UnrealizedConversionCastOp>(
op.getLoc(), getGenericType(op.getOperand(0).getType()),
op.getOperand(0))
.getResult(0);
// cast input accumulator to a generic type
mlir::Value accumulator =
rewriter
.create<mlir::UnrealizedConversionCastOp>(
op.getLoc(), getGenericType(op.getOperand(1).getType()),
op.getOperand(1))
.getResult(0);
mlir::SmallVector<mlir::Value> bootstrapOperands{
getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
lweToBootstrap, accumulator};
rewriter.create<mlir::CallOp>(op.getLoc(), "bootstrap_lwe_u64",
mlir::TypeRange({}), bootstrapOperands);
// Cast result to the appropriate type
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
op, resultType, allocateGenericLweCtOp.getResult(0));
return mlir::success();
};
};
// Rewrite a KeySwitchLweOp with a series of ops:
// - allocate the result LWE ciphertext
// - get the global keyswitch key
// - use the key to keyswitch the input ciphertext
struct ConcreteKeySwitchLweOpPattern
: public mlir::OpRewritePattern<
mlir::concretelang::Concrete::KeySwitchLweOp> {
ConcreteKeySwitchLweOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::Concrete::KeySwitchLweOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp op,
mlir::PatternRewriter &rewriter) const override {
// Get the size from the dimension
int64_t lweDimension =
op.getResult()
.getType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>()
.getDimension();
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweDimension));
// allocate the result lwe ciphertext, should be of a generic type, to cast
// before return
mlir::SmallVector<mlir::Value> allocLweCtOperands{lweDimensionOp};
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
// get ksk
auto getKskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "get_keyswitch_key",
getGenericLweKeySwitchKeyType(rewriter.getContext()),
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
// keyswitch
// cast input ciphertext to a generic type
mlir::Value lweToKeyswitch =
rewriter
.create<mlir::UnrealizedConversionCastOp>(
op.getLoc(), getGenericType(op.getOperand().getType()),
op.getOperand())
.getResult(0);
mlir::SmallVector<mlir::Value> keyswitchOperands{
getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
lweToKeyswitch};
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
mlir::TypeRange({}), keyswitchOperands);
// Cast result to the appropriate type
auto lweOutputType = op->getResultTypes().front();
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
op, lweOutputType, allocateGenericLweCtOp.getResult(0));
return mlir::success();
};
};
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
/// operators to the corresponding function call to the `Concrete C API`.
void populateConcreteToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::Concrete::AddLweCiphertextsOp>>(
patterns.getContext(), "add_lwe_ciphertexts_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp>>(
patterns.getContext(), "add_plaintext_lwe_ciphertext_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>>(
patterns.getContext(), "mul_cleartext_lwe_ciphertext_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::Concrete::NegateLweCiphertextOp>>(
patterns.getContext(), "negate_lwe_ciphertext_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<ConcreteEncodeIntOpPattern>(patterns.getContext());
patterns.add<ConcreteIntToCleartextOpPattern>(patterns.getContext());
patterns.add<ConcreteZeroOpPattern>(patterns.getContext());
patterns.add<GlweFromTableOpPattern>(patterns.getContext());
patterns.add<ConcreteKeySwitchLweOpPattern>(patterns.getContext());
patterns.add<ConcreteBootstrapLweOpPattern>(patterns.getContext());
}
struct AddRuntimeContextToFuncOpPattern
: public mlir::OpRewritePattern<mlir::FuncOp> {
AddRuntimeContextToFuncOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::FuncOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::FuncOp oldFuncOp,
mlir::PatternRewriter &rewriter) const override {
mlir::OpBuilder::InsertionGuard guard(rewriter);
mlir::FunctionType oldFuncType = oldFuncOp.getType();
// Add a Concrete.context to the function signature
mlir::SmallVector<mlir::Type> newInputs(oldFuncType.getInputs().begin(),
oldFuncType.getInputs().end());
newInputs.push_back(
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
newInputs, oldFuncType.getResults());
// Create the new func
mlir::FuncOp newFuncOp = rewriter.create<mlir::FuncOp>(
oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy);
// Create the arguments of the new func
mlir::Region &newFuncBody = newFuncOp.body();
mlir::Block *newFuncEntryBlock = new mlir::Block();
newFuncEntryBlock->addArguments(newFuncTy.getInputs());
newFuncBody.push_back(newFuncEntryBlock);
// Clone the old body to the new one
mlir::BlockAndValueMapping map;
for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) {
map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index()));
}
for (auto &op : oldFuncOp.body().front()) {
newFuncEntryBlock->push_back(op.clone(map));
}
rewriter.eraseOp(oldFuncOp);
return mlir::success();
}
// Legal function are one that are private or has a Concrete.context as last
// arguments.
static bool isLegal(mlir::FuncOp funcOp) {
if (!funcOp.isPublic()) {
return true;
}
// TODO : Don't need to add a runtime context for function that doesn't
// manipulates concrete types.
//
// if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) {
// if (auto tensorTy = t.dyn_cast_or_null<mlir::TensorType>()) {
// t = tensorTy.getElementType();
// }
// return llvm::isa<mlir::concretelang::Concrete::ConcreteDialect>(
// t.getDialect());
// })) {
// return true;
// }
return funcOp.getType().getNumInputs() >= 1 &&
funcOp.getType()
.getInputs()
.back()
.isa<mlir::concretelang::Concrete::ContextType>();
}
};
namespace {
struct ConcreteToConcreteCAPIPass
: public ConcreteToConcreteCAPIBase<ConcreteToConcreteCAPIPass> {
void runOnOperation() final;
};
} // namespace
void ConcreteToConcreteCAPIPass::runOnOperation() {
mlir::ModuleOp op = getOperation();
// First of all add the Concrete.context to the block arguments of function
// that manipulates ciphertexts.
{
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return AddRuntimeContextToFuncOpPattern::isLegal(funcOp);
});
patterns.add<AddRuntimeContextToFuncOpPattern>(patterns.getContext());
// Apply the conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
return;
}
}
// Insert forward declaration
mlir::IRRewriter rewriter(&getContext());
if (insertForwardDeclarations(op, rewriter).failed()) {
this->signalPassFailure();
}
// Rewrite Concrete ops to CallOp to the Concrete C API
{
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
target.addIllegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
mlir::memref::MemRefDialect,
mlir::arith::ArithmeticDialect>();
populateConcreteToConcreteCAPICall(patterns);
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
}
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertConcreteToConcreteCAPIPass() {
return std::make_unique<ConcreteToConcreteCAPIPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -1,16 +0,0 @@
add_mlir_dialect_library(ConcreteUnparametrize
ConcreteUnparametrize.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
ConcreteDialect
ConcretelangConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
)
target_link_libraries(ConcreteUnparametrize PUBLIC MLIRIR)

View File

@@ -1,154 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Support/Constants.h"
/// ConcreteUnparametrizeTypeConverter is a type converter that unparametrize
/// Concrete types
class ConcreteUnparametrizeTypeConverter : public mlir::TypeConverter {
public:
static mlir::Type unparematrizeConcreteType(mlir::Type type) {
if (type.isa<mlir::concretelang::Concrete::PlaintextType>()) {
return mlir::IntegerType::get(type.getContext(), 64);
}
if (type.isa<mlir::concretelang::Concrete::CleartextType>()) {
return mlir::IntegerType::get(type.getContext(), 64);
}
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
return mlir::concretelang::Concrete::LweCiphertextType::get(
type.getContext(), -1, -1);
}
auto tensorType = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensorType != nullptr) {
auto eltTy0 = tensorType.getElementType();
auto eltTy1 = unparematrizeConcreteType(eltTy0);
if (eltTy0 == eltTy1) {
return type;
}
return mlir::RankedTensorType::get(tensorType.getShape(), eltTy1);
}
return type;
}
ConcreteUnparametrizeTypeConverter() {
addConversion(
[](mlir::Type type) { return unparematrizeConcreteType(type); });
}
};
/// Replace `%1 = unrealized_conversion_cast %0 : t0 to t1` to `%0` where t0 or
/// t1 are a Concrete type.
struct ConcreteUnrealizedCastReplacementPattern
: public mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp> {
ConcreteUnrealizedCastReplacementPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp>(context,
benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::UnrealizedConversionCastOp op,
mlir::PatternRewriter &rewriter) const override {
if (mlir::isa<mlir::concretelang::Concrete::ConcreteDialect>(
op.getOperandTypes()[0].getDialect()) ||
mlir::isa<mlir::concretelang::Concrete::ConcreteDialect>(
op.getType(0).getDialect())) {
rewriter.replaceOp(op, op.getOperands());
return mlir::success();
}
return mlir::failure();
};
};
/// ConcreteUnparametrizePass remove all parameters of Concrete types and remove
/// the unrealized_conversion_cast operation that operates on parametrized
/// Concrete types.
struct ConcreteUnparametrizePass
: public ConcreteUnparametrizeBase<ConcreteUnparametrizePass> {
void runOnOperation() final;
};
void ConcreteUnparametrizePass::runOnOperation() {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
mlir::OwningRewritePatternList patterns(&getContext());
ConcreteUnparametrizeTypeConverter converter;
// Conversion of linalg.generic operation
target
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
patterns.add<RegionOpTypeConverterPattern<
mlir::linalg::GenericOp, ConcreteUnparametrizeTypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<
mlir::tensor::GenerateOp, ConcreteUnparametrizeTypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<
mlir::scf::ForOp, ConcreteUnparametrizeTypeConverter>>(&getContext(),
converter);
// Conversion of function signature and arguments
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getType()) &&
converter.isLegal(&funcOp.getBody());
});
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// Replacement of unrealized_conversion_cast
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::UnrealizedConversionCastOp>(target, converter);
patterns.add<ConcreteUnrealizedCastReplacementPattern>(patterns.getContext());
// Conversion of tensor operators
mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
// Conversion of CallOp
patterns.add<mlir::concretelang::GenericTypeConverterPattern<mlir::CallOp>>(
patterns.getContext(), converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::CallOp>(target,
converter);
// Conversion of RT Dialect Ops
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
}
}
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertConcreteUnparametrizePass() {
return std::make_unique<ConcreteUnparametrizePass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -4,7 +4,7 @@ endif()
add_library(ConcretelangRuntime SHARED
context.cpp
wrappers.c
wrappers.cpp
)
if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED)

View File

@@ -11,39 +11,32 @@
#include <hpx/include/runtime.hpp>
#endif
namespace mlir {
namespace concretelang {
std::string RuntimeContext::BASE_CONTEXT_BSK = "_concretelang_base_context_bsk";
} // namespace concretelang
} // namespace mlir
LweKeyswitchKey_u64 *
get_keyswitch_key(mlir::concretelang::RuntimeContext *context) {
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) {
return context->ksk;
}
LweBootstrapKey_u64 *
get_bootstrap_key(mlir::concretelang::RuntimeContext *context) {
using RuntimeContext = mlir::concretelang::RuntimeContext;
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
return context->bsk;
}
// Instantiate one engine per thread on demand
Engine *get_engine(mlir::concretelang::RuntimeContext *context) {
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
std::string threadName = hpx::get_thread_name();
auto bskIt = context->bsk.find(threadName);
if (bskIt == context->bsk.end()) {
assert((bskIt = context->bsk.find(RuntimeContext::BASE_CONTEXT_BSK)) !=
context->bsk.end() &&
bskIt->second && "No BASE_CONTEXT_BSK registered in context.");
bskIt = context->bsk
.insert(std::pair<std::string, LweBootstrapKey_u64 *>(
threadName,
clone_lwe_bootstrap_key_u64(
context->bsk[RuntimeContext::BASE_CONTEXT_BSK])))
.first;
std::lock_guard<std::mutex> guard(context->engines_map_guard);
auto engineIt = context->engines.find(threadName);
if (engineIt == context->engines.end()) {
engineIt =
context->engines
.insert(std::pair<std::string, Engine *>(threadName, new_engine()))
.first;
}
assert(engineIt->second && "No engine available in context");
return engineIt->second;
#else
auto bskIt = context->bsk.find(RuntimeContext::BASE_CONTEXT_BSK);
return (context->engine == nullptr) ? context->engine = new_engine()
: context->engine;
#endif
assert(bskIt->second && "No bootstrap key available in context");
return bskIt->second;
}

View File

@@ -1,21 +1,29 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "concretelang/Runtime/wrappers.h"
#include <assert.h>
#include <stdio.h>
struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64(
uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size,
uint64_t stride, uint32_t precision) {
void memref_expand_lut_in_trivial_glwe_ct_u64(
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision,
uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset,
uint64_t lut_size, uint64_t lut_stride) {
assert(stride == 1 && "Runtime: stride not equal to 1, check "
"runtime_foreign_plaintext_list_u64");
assert(lut_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_expand_lut_in_trivial_glwe_ct_u64");
// Encode table values in u64
uint64_t *encoded_table = malloc(size * sizeof(uint64_t));
for (uint64_t i = 0; i < size; i++) {
encoded_table[i] = (aligned + offset)[i] << (64 - precision - 1);
}
return foreign_plaintext_list_u64(encoded_table, size);
// TODO: is it safe to free after creating plaintext_list?
assert(glwe_ct_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_expand_lut_in_trivial_glwe_ct_u64");
expand_lut_in_trivial_glwe_ct_u64(glwe_ct_aligned, poly_size, glwe_dimension,
out_precision, lut_aligned, lut_size);
return;
}
void memref_add_lwe_ciphertexts_u64(
@@ -26,7 +34,7 @@ void memref_add_lwe_ciphertexts_u64(
uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride) {
assert(out_size == ct0_size && out_size == ct1_size &&
"size of lwe buffer are incompatible");
LweDimension lwe_dimension = {out_size - 1};
size_t lwe_dimension = {out_size - 1};
add_two_lwe_ciphertexts_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset,
ct1_aligned + ct1_offset, lwe_dimension);
@@ -38,7 +46,7 @@ void memref_add_plaintext_lwe_ciphertext_u64(
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t plaintext) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
LweDimension lwe_dimension = {out_size - 1};
size_t lwe_dimension = {out_size - 1};
add_plaintext_to_lwe_ciphertext_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset, plaintext,
lwe_dimension);
@@ -50,7 +58,7 @@ void memref_mul_cleartext_lwe_ciphertext_u64(
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t cleartext) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
LweDimension lwe_dimension = {out_size - 1};
size_t lwe_dimension = {out_size - 1};
mul_cleartext_lwe_ciphertext_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset, cleartext,
lwe_dimension);
@@ -62,28 +70,29 @@ void memref_negate_lwe_ciphertext_u64(
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
LweDimension lwe_dimension = {out_size - 1};
size_t lwe_dimension = {out_size - 1};
neg_lwe_ciphertext_u64(out_aligned + out_offset, ct0_aligned + ct0_offset,
lwe_dimension);
}
void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key,
uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride) {
bufferized_keyswitch_lwe_u64(keyswitch_key, out_aligned + out_offset,
ct0_aligned + ct0_offset);
}
void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key,
uint64_t *out_allocated, uint64_t *out_aligned,
void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride,
struct GlweCiphertext_u64 *accumulator) {
bufferized_bootstrap_lwe_u64(bootstrap_key, out_aligned + out_offset,
ct0_aligned + ct0_offset, accumulator);
mlir::concretelang::RuntimeContext *context) {
keyswitch_lwe_u64(get_engine(context), get_keyswitch_key_u64(context),
out_aligned + out_offset, ct0_aligned + ct0_offset);
}
void memref_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
mlir::concretelang::RuntimeContext *context) {
bootstrap_lwe_u64(get_engine(context), get_bootstrap_key_u64(context),
out_aligned + out_offset, ct0_aligned + ct0_offset,
glwe_ct_aligned + glwe_ct_offset);
}

View File

@@ -22,7 +22,6 @@ add_mlir_library(ConcretelangSupport
FHELinalgDialectTransforms
FHETensorOpsToLinalg
FHEToTFHE
ConcreteUnparametrize
MLIRLowerableDialectsToLLVM
FHEDialectAnalysis
RTDialectAnalysis

View File

@@ -205,9 +205,6 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass(),
enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertConcreteToConcreteCAPIPass(),
enablePass);
return pm.run(module.getOperation());
}
@@ -218,11 +215,6 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::PassManager pm(&context);
pipelinePrinting("StdToLLVM", pm, context);
// Unparametrize Concrete
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertConcreteUnparametrizePass(),
enablePass);
// Bufferize
addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass(),
enablePass);

View File

@@ -1,15 +1,15 @@
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK: func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext, %arg2: !Concrete.context) -> tensor<1025xi64> {
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %1 = call @get_bootstrap_key(%arg2) : (!Concrete.context) -> !Concrete.lwe_bootstrap_key
// CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_bootstrap_lwe_u64(%1, %2, %3, %arg1) : (!Concrete.lwe_bootstrap_key, tensor<?xi64>, tensor<?xi64>, !Concrete.glwe_ciphertext) -> ()
// CHECK-NEXT: return %0 : tensor<1025xi64>
// CHECK: func @apply_lookup_table(%arg0: tensor<601xi64>, %arg1: tensor<2048xi64>, %arg2: !Concrete.context) -> tensor<1025xi64> {
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<601xi64> to tensor<?xi64>
// CHECK-NEXT: %3 = tensor.cast %arg1 : tensor<2048xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_bootstrap_lwe_u64(%1, %2, %3, %arg2) : (tensor<?xi64>, tensor<?xi64>, tensor<?xi64>, !Concrete.context) -> ()
// CHECK-NEXT: return %0 : tensor<1025xi64>
// CHECK-NEXT: }
func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<1025xi64> {
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.glwe_ciphertext) -> ()
return %0 : tensor<1025xi64>
}
func @apply_lookup_table(%arg0: tensor<601xi64>, %arg1: tensor<2048xi64>) -> tensor<1025xi64> {
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, tensor<2048xi64>) -> ()
return %0 : tensor<1025xi64>
}

View File

@@ -2,10 +2,9 @@
//CHECK: func @keyswitch_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> {
//CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
//CHECK-NEXT: %1 = call @get_keyswitch_key(%arg1) : (!Concrete.context) -> !Concrete.lwe_key_switch_key
//CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
//CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
//CHECK-NEXT: call @memref_keyswitch_lwe_u64(%1, %2, %3) : (!Concrete.lwe_key_switch_key, tensor<?xi64>, tensor<?xi64>) -> ()
//CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
//CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
//CHECK-NEXT: call @memref_keyswitch_lwe_u64(%1, %2, %arg1) : (tensor<?xi64>, tensor<?xi64>, !Concrete.context) -> ()
//CHECK-NEXT: return %0 : tensor<1025xi64>
//CHECK-NEXT: }
func @keyswitch_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {

View File

@@ -2,14 +2,15 @@
// CHECK-LABEL: func @apply_lookup_table(%arg0: tensor<1025xi64>, %arg1: tensor<16xi64>) -> tensor<1025xi64>
func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
// CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [2048] : tensor<2048xi64>
// CHECK-NEXT:"BConcrete.fill_glwe_from_table"(%[[V1]], %arg1) {glweDimension = 1 : i32, outPrecision = 4 : i32, polynomialSize = 1024 : i32} : (tensor<2048xi64>, tensor<16xi64>) -> ()
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64>
// CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"(%[[V2]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<1025xi64>) -> ()
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> ()
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, tensor<2048xi64>) -> ()
// CHECK-NEXT: return %[[V3]] : tensor<1025xi64>
%0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
%0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<1024,1,4>
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4>
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<1024,4>
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<1024,1,4>) -> !Concrete.lwe_ciphertext<1024,4>
return %2 : !Concrete.lwe_ciphertext<1024,4>
}

View File

@@ -3,15 +3,16 @@
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: tensor<2049xi64>) -> tensor<2049xi64>
func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> {
// CHECK-NEXT: %[[TABLE:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE:.*]]) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64>
// CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"([[V2:.*]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<2049xi64>) -> ()
// CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [4096] : tensor<4096xi64>
// CHECK-NEXT: "BConcrete.fill_glwe_from_table"(%[[V1]], %cst) {glweDimension = 1 : i32, outPrecision = 4 : i32, polynomialSize = 2048 : i32} : (tensor<4096xi64>, tensor<16xi64>) -> ()
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64>
// CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"(%[[V2]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<2049xi64>) -> ()
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [2049] : tensor<2049xi64>
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3:.*]], %[[V2:.*]], %[[V1:.*]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<2049xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> ()
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<2049xi64>, tensor<601xi64>, tensor<4096xi64>) -> ()
// CHECK-NEXT: return %[[V3]] : tensor<2049xi64>
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
%0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
%0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<2048,1,4>
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4>
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,4>
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<2048,1,4>) -> !Concrete.lwe_ciphertext<2048,4>
return %2 : !Concrete.lwe_ciphertext<2048,4>
}

View File

@@ -1,7 +0,0 @@
// RUN: concretecompiler --passes concrete-unparametrize --action=dump-llvm-dialect %s 2>&1| FileCheck %s
// CHECK-LABEL: func @main(%arg0: !Concrete.lwe_ciphertext<_,_>) -> !Concrete.lwe_ciphertext<_,_>
func @main(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> {
// CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<_,_>
return %arg0: !Concrete.lwe_ciphertext<1024,4>
}

View File

@@ -1,8 +0,0 @@
// RUN: concretecompiler --passes concrete-unparametrize --action=dump-llvm-dialect %s 2>&1| FileCheck %s
// CHECK-LABEL: func @main(%arg0: !Concrete.lwe_ciphertext<_,_>) -> !Concrete.lwe_ciphertext<_,_>
func @main(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<_,_> {
// CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<_,_>
%0 = builtin.unrealized_conversion_cast %arg0 : !Concrete.lwe_ciphertext<1024,4> to !Concrete.lwe_ciphertext<_,_>
return %0: !Concrete.lwe_ciphertext<_,_>
}

View File

@@ -2,9 +2,9 @@
// CHECK-LABEL: func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4>
func @apply_lookup_table(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !TFHE.glwe<{1024,1,64}{4}> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<1024,1,4>
// CHECK-NEXT: %[[V2:.*]] = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4>
// CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<1024,1,4>) -> !Concrete.lwe_ciphertext<1024,4>
// CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,4>
%1 = "TFHE.apply_lookup_table"(%arg0, %arg1){glweDimension=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!TFHE.glwe<{1024,1,64}{4}>, tensor<16xi64>) -> (!TFHE.glwe<{1024,1,64}{4}>)
return %1: !TFHE.glwe<{1024,1,64}{4}>

View File

@@ -3,9 +3,9 @@
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4>
func @apply_lookup_table_cst(%arg0: !TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{2048,1,64}{4}> {
// CHECK-NEXT: %[[TABLE:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE]]) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE]]) : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<2048,1,4>
// CHECK-NEXT: %[[V2:.*]] = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4>
// CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,4>
// CHECK-NEXT: %[[V3:.*]] = "Concrete.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<2048,1,4>) -> !Concrete.lwe_ciphertext<2048,4>
// CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<2048,4>
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
%1 = "TFHE.apply_lookup_table"(%arg0, %tlu){glweDimension=1:i32, polynomialSize=2048:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!TFHE.glwe<{2048,1,64}{4}>, tensor<16xi64>) -> (!TFHE.glwe<{2048,1,64}{4}>)

View File

@@ -40,13 +40,13 @@ func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
return %0 : tensor<2049xi64>
}
// CHECK-LABEL: func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<2049xi64>
func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<2049xi64> {
// CHECK-LABEL: func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<4096xi64>) -> tensor<2049xi64>
func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<4096xi64>) -> tensor<2049xi64> {
// CHECK-NEXT: %[[V0:.*]] = linalg.init_tensor [2049] : tensor<2049xi64>
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V0]], %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, !Concrete.glwe_ciphertext) -> ()
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V0]], %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, tensor<4096xi64>) -> ()
// CHECK-NEXT: return %[[V0]] : tensor<2049xi64>
%0 = linalg.init_tensor [2049] : tensor<2049xi64>
"BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, !Concrete.glwe_ciphertext) -> ()
"BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<2049xi64>, tensor<4096xi64>) -> ()
return %0 : tensor<2049xi64>
}

View File

@@ -36,12 +36,11 @@ func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concret
return %1: !Concrete.lwe_ciphertext<2048,7>
}
// CHECK-LABEL: func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,7>
func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-LABEL: func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext) -> (!Concrete.lwe_ciphertext<2048,7>)
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
return %1: !Concrete.lwe_ciphertext<2048,7>
}
@@ -49,7 +48,6 @@ func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.gl
func @keyswitch_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%1 = "Concrete.keyswitch_lwe"(%arg0){baseLog = 2 : i32, level = 3 : i32}: (!Concrete.lwe_ciphertext<2048,7>) -> (!Concrete.lwe_ciphertext<2048,7>)
return %1: !Concrete.lwe_ciphertext<2048,7>
}