enhance(compiler/lowlfhe): Give the runtime context as function argument instead of a global variable (close #195)

This commit is contained in:
Quentin Bourgerie
2021-11-26 18:02:12 +01:00
parent 99fe188e66
commit fb58dcc59d
14 changed files with 201 additions and 164 deletions

View File

@@ -210,7 +210,23 @@ def LweBootstrapKeyType : LowLFHE_Type<"LweBootstrapKey"> {
}];
}
def Context : LowLFHE_Type<"Context"> {
let mnemonic = "context";
let summary = "Runtime context";
let description = [{
An abstract runtime context to pass contextual value, like public keys, ...
}];
let printer = [{
$_printer << "context";
}];
let parser = [{
return get($_ctxt);
}];
}

View File

@@ -8,23 +8,8 @@ typedef struct RuntimeContext {
struct LweBootstrapKey_u64 *bsk;
} RuntimeContext;
extern RuntimeContext *globalRuntimeContext;
LweKeyswitchKey_u64 *get_keyswitch_key(RuntimeContext *context);
RuntimeContext *createRuntimeContext(LweKeyswitchKey_u64 *ksk,
LweBootstrapKey_u64 *bsk);
void setGlobalRuntimeContext(RuntimeContext *context);
RuntimeContext *getGlobalRuntimeContext();
LweKeyswitchKey_u64 *getGlobalKeyswitchKey();
LweBootstrapKey_u64 *getGlobalBootstrapKey();
LweKeyswitchKey_u64 *getKeyswitckKeyFromContext(RuntimeContext *context);
LweBootstrapKey_u64 *getBootstrapKeyFromContext(RuntimeContext *context);
bool checkError(int *err);
LweBootstrapKey_u64 *get_bootstrap_key(RuntimeContext *context);
#endif

View File

@@ -96,6 +96,7 @@ public:
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
KeySet &keySet;
RuntimeContext context;
};
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
: type(type), name(name){};

View File

@@ -41,10 +41,9 @@ public:
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
void initGlobalRuntimeContext() {
auto ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
auto bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]);
setGlobalRuntimeContext(createRuntimeContext(ksk, bsk));
void setRuntimeContext(RuntimeContext &context) {
context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
context.bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]);
}
protected:

View File

@@ -133,6 +133,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
auto genericCleartextType = getGenericCleartextType(rewriter.getContext());
auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext());
auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext());
auto contextType =
mlir::zamalang::LowLFHE::ContextType::get(rewriter.getContext());
auto errType = mlir::IndexType::get(rewriter.getContext());
// Insert forward declaration of allocate lwe ciphertext
@@ -211,10 +214,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
}
// Insert forward declaration of the getBsk function
{
auto funcType =
mlir::FunctionType::get(rewriter.getContext(), {}, {genericBSKType});
if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey",
funcType)
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{contextType}, {genericBSKType});
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType)
.failed()) {
return mlir::failure();
}
@@ -237,10 +239,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
}
// Insert forward declaration of the getKsk function
{
auto funcType =
mlir::FunctionType::get(rewriter.getContext(), {}, {genericKSKType});
if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey",
funcType)
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{contextType}, {genericKSKType});
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType)
.failed()) {
return mlir::failure();
}
@@ -563,6 +564,25 @@ struct GlweFromTableOpPattern
};
};
mlir::Value getContextArgument(mlir::Operation *op) {
mlir::Block *block = op->getBlock();
while (block != nullptr) {
if (llvm::isa<mlir::FuncOp>(block->getParentOp())) {
mlir::Value context = block->getArguments().back();
assert(context.getType().isa<mlir::zamalang::LowLFHE::ContextType>() &&
"the LowLFHE.context should be the last argument of the enclosing "
"function of the op");
return context;
}
block = block->getParentOp()->getBlock();
}
assert("can't find a function that enclose the op");
return nullptr;
}
// Rewrite a BootstrapLweOp with a series of ops:
// - allocate the result LWE ciphertext
// - get the global bootstrapping key
@@ -592,10 +612,10 @@ struct LowLFHEBootstrapLweOpPattern
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
// get bsk
mlir::SmallVector<mlir::Value> getBskOperands{};
auto getBskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "getGlobalBootstrapKey",
getGenericLweBootstrapKeyType(rewriter.getContext()), getBskOperands);
op.getLoc(), "get_bootstrap_key",
getGenericLweBootstrapKeyType(rewriter.getContext()),
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
// bootstrap
// cast input ciphertext to a generic type
mlir::Value lweToBootstrap =
@@ -651,10 +671,10 @@ struct LowLFHEKeySwitchLweOpPattern
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
// get ksk
mlir::SmallVector<mlir::Value> getkskOperands{};
auto getKskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "getGlobalKeyswitchKey",
getGenericLweKeySwitchKeyType(rewriter.getContext()), getkskOperands);
op.getLoc(), "get_keyswitch_key",
getGenericLweKeySwitchKeyType(rewriter.getContext()),
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
// keyswitch
// cast input ciphertext to a generic type
mlir::Value lweToKeyswitch =
@@ -703,6 +723,73 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
patterns.add<LowLFHEBootstrapLweOpPattern>(patterns.getContext());
}
struct AddRuntimeContextToFuncOpPattern
: public mlir::OpRewritePattern<mlir::FuncOp> {
AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::FuncOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::FuncOp oldFuncOp,
mlir::PatternRewriter &rewriter) const override {
mlir::OpBuilder::InsertionGuard guard(rewriter);
mlir::FunctionType oldFuncType = oldFuncOp.getType();
// Add a LowLFHE.context to the function signature
mlir::SmallVector<mlir::Type> newInputs(oldFuncType.getInputs().begin(),
oldFuncType.getInputs().end());
newInputs.push_back(
rewriter.getType<mlir::zamalang::LowLFHE::ContextType>());
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
newInputs, oldFuncType.getResults());
// Create the new func
mlir::FuncOp newFuncOp = rewriter.create<mlir::FuncOp>(
oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy);
// Create the arguments of the new func
mlir::Region &newFuncBody = newFuncOp.body();
mlir::Block *newFuncEntryBlock = new mlir::Block();
newFuncEntryBlock->addArguments(newFuncTy.getInputs());
newFuncBody.push_back(newFuncEntryBlock);
// Clone the old body to the new one
mlir::BlockAndValueMapping map;
for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) {
map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index()));
}
for (auto &op : oldFuncOp.body().front()) {
newFuncEntryBlock->push_back(op.clone(map));
}
rewriter.eraseOp(oldFuncOp);
return mlir::success();
}
// Legal function are one that are private or has a LowLFHE.context as last
// arguments.
static bool isLegal(mlir::FuncOp funcOp) {
if (!funcOp.isPublic()) {
return true;
}
// TODO : Don't need to add a runtime context for function that doesn't
// manipulates lowlfhe types.
//
// if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) {
// if (auto tensorTy = t.dyn_cast_or_null<mlir::TensorType>()) {
// t = tensorTy.getElementType();
// }
// return llvm::isa<mlir::zamalang::LowLFHE::LowLFHEDialect>(
// t.getDialect());
// })) {
// return true;
// }
return funcOp.getType().getNumInputs() >= 1 &&
funcOp.getType()
.getInputs()
.back()
.isa<mlir::zamalang::LowLFHE::ContextType>();
}
};
namespace {
struct LowLFHEToConcreteCAPIPass
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
@@ -711,27 +798,49 @@ struct LowLFHEToConcreteCAPIPass
} // namespace
void LowLFHEToConcreteCAPIPass::runOnOperation() {
// Setup the conversion target.
mlir::ConversionTarget target(getContext());
target.addIllegalDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
mlir::memref::MemRefDialect,
mlir::arith::ArithmeticDialect>();
// Setup rewrite patterns
mlir::RewritePatternSet patterns(&getContext());
populateLowLFHEToConcreteCAPICall(patterns);
// Insert forward declarations
mlir::ModuleOp op = getOperation();
// First of all add the LowLFHE.context to the block arguments of function
// that manipulates ciphertexts.
{
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return AddRuntimeContextToFuncOpPattern::isLegal(funcOp);
});
patterns.add<AddRuntimeContextToFuncOpPattern>(patterns.getContext());
// Apply the conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
return;
}
}
// Insert forward declaration
mlir::IRRewriter rewriter(&getContext());
if (insertForwardDeclarations(op, rewriter).failed()) {
this->signalPassFailure();
}
// Rewrite LowLFHE ops to CallOp to the Concrete C API
{
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
// Apply the conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
target.addIllegalDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
mlir::memref::MemRefDialect,
mlir::arith::ArithmeticDialect>();
populateLowLFHEToConcreteCAPICall(patterns);
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
}

View File

@@ -70,6 +70,7 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
type.isa<mlir::zamalang::LowLFHE::GlweCiphertextType>() ||
type.isa<mlir::zamalang::LowLFHE::LweKeySwitchKeyType>() ||
type.isa<mlir::zamalang::LowLFHE::LweBootstrapKeyType>() ||
type.isa<mlir::zamalang::LowLFHE::ContextType>() ||
type.isa<mlir::zamalang::LowLFHE::ForeignPlaintextListType>() ||
type.isa<mlir::zamalang::LowLFHE::PlaintextListType>()) {
return mlir::LLVM::LLVMPointerType::get(

View File

@@ -25,24 +25,9 @@ void LowLFHEDialect::initialize() {
mlir::Type type;
std::string types_str[] = {
"enc_rand_gen",
"secret_rand_gen",
"plaintext",
"plaintext_list",
"foreign_plaintext_list",
"lwe_ciphertext",
"lwe_key_switch_key",
"lwe_bootstrap_key",
"lwe_secret_key",
"lwe_size",
"glwe_ciphertext",
"glwe_secret_key",
"glwe_size",
"polynomial_size",
"decomp_level_count",
"decomp_base_log",
"variance",
"cleartext",
"plaintext", "plaintext_list", "foreign_plaintext_list",
"lwe_ciphertext", "lwe_key_switch_key", "lwe_bootstrap_key",
"glwe_ciphertext", "cleartext", "context",
};
for (const std::string &type_str : types_str) {
@@ -53,8 +38,7 @@ void LowLFHEDialect::initialize() {
}
parser.emitError(parser.getCurrentLocation(), "Unknown LowLFHE type");
// call default parser
parser.parseType(type);
return type;
}

View File

@@ -1,51 +1,10 @@
#include "zamalang/Runtime/context.h"
#include <stdio.h>
RuntimeContext *globalRuntimeContext;
RuntimeContext *createRuntimeContext(LweKeyswitchKey_u64 *ksk,
LweBootstrapKey_u64 *bsk) {
RuntimeContext *context = (RuntimeContext *)malloc(sizeof(RuntimeContext));
context->ksk = ksk;
context->bsk = bsk;
return context;
}
void setGlobalRuntimeContext(RuntimeContext *context) {
globalRuntimeContext = context;
}
RuntimeContext *getGlobalRuntimeContext() { return globalRuntimeContext; }
LweKeyswitchKey_u64 *getGlobalKeyswitchKey() {
return globalRuntimeContext->ksk;
}
LweBootstrapKey_u64 *getGlobalBootstrapKey() {
return globalRuntimeContext->bsk;
}
LweKeyswitchKey_u64 *getKeyswitckKeyFromContext(RuntimeContext *context) {
LweKeyswitchKey_u64 *get_keyswitch_key(RuntimeContext *context) {
return context->ksk;
}
LweBootstrapKey_u64 *getBootstrapKeyFromContext(RuntimeContext *context) {
LweBootstrapKey_u64 *get_bootstrap_key(RuntimeContext *context) {
return context->bsk;
}
bool checkError(int *err) {
switch (*err) {
case ERR_INDEX_OUT_OF_BOUND:
fprintf(stderr, "Runtime: index out of bound");
break;
case ERR_NULL_POINTER:
fprintf(stderr, "Runtime: null pointer");
break;
case ERR_SIZE_MISMATCH:
fprintf(stderr, "Runtime: size mismatch");
break;
default:
return false;
}
return true;
}

View File

@@ -125,8 +125,11 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
// Create input and output circuit gate parameters
auto funcType = (*funcOp).getType();
for (auto inType : funcType.getInputs()) {
auto gate = gateFromMLIRType("big", precision, encryptionVariance, inType);
bool hasContext =
funcType.getInputs().back().isa<mlir::zamalang::LowLFHE::ContextType>();
for (auto inType = funcType.getInputs().begin();
inType < funcType.getInputs().end() - hasContext; inType++) {
auto gate = gateFromMLIRType("big", precision, encryptionVariance, *inType);
if (auto err = gate.takeError()) {
return std::move(err);
}

View File

@@ -78,8 +78,8 @@ llvm::Error JITLambda::invoke(Argument &args) {
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
// Setting the inputs
auto numInputs = 0;
{
auto numInputs = 0;
for (size_t i = 0; i < keySet.numInputs(); i++) {
auto offset = numInputs;
auto gate = keySet.inputGate(i);
@@ -95,6 +95,8 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
// dimension of the tensor.
numInputs = numInputs + 2 * keySet.inputGate(i).shape.dimensions.size();
}
// Reserve for the context argument
numInputs = numInputs + 1;
inputs = std::vector<const void *>(numInputs);
}
@@ -128,8 +130,10 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
rawArg[i] = &outputs[i - inputs.size()];
}
// Setup runtime context with appropriate keys
keySet.initGlobalRuntimeContext();
// Set the context argument
keySet.setRuntimeContext(context);
inputs[numInputs - 1] = &context;
rawArg[numInputs - 1] = &inputs[numInputs - 1];
}
JITLambda::Argument::~Argument() {

View File

@@ -140,7 +140,8 @@ lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("LowLFHEToStd", pm, context);
pm.addPass(mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass());
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), enablePass);
return pm.run(module.getOperation());
}

View File

@@ -1,25 +1,17 @@
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-LABEL: func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK: func private @get_keyswitch_key(!LowLFHE.context) -> !LowLFHE.lwe_key_switch_key
// CHECK: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
// CHECK: func private @get_bootstrap_key(!LowLFHE.context) -> !LowLFHE.lwe_bootstrap_key
// CHECK: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-LABEL: func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext, %arg2: !LowLFHE.context) -> !LowLFHE.lwe_ciphertext<1024,4>
func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[ERR:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[C0:.*]] = arith.constant 1024 : index
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, index) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-NEXT: %[[V2:.*]] = call @getGlobalBootstrapKey() : () -> !LowLFHE.lwe_bootstrap_key
// CHECK-NEXT: %[[V2:.*]] = call @get_bootstrap_key(%arg2) : (!LowLFHE.context) -> !LowLFHE.lwe_bootstrap_key
// CHECK-NEXT: %[[V3:.*]] = builtin.unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_>
// CHECK-NEXT: %[[V4:.*]] = builtin.unrealized_conversion_cast %arg1 : !LowLFHE.glwe_ciphertext to !LowLFHE.glwe_ciphertext
// CHECK-NEXT: call @bootstrap_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]], %[[V4]]) : (index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) -> ()

View File

@@ -1,21 +1,12 @@
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext
// CHECK: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
// CHECK: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
// CHECK: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
// CHECK: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
// CHECK: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi64>, %arg1: !LowLFHE.context) -> !LowLFHE.glwe_ciphertext
func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext {
// CHECK-NEXT: %[[V0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[C0:.*]] = arith.constant 1 : i32

View File

@@ -1,25 +1,17 @@
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK: func private @get_keyswitch_key(!LowLFHE.context) -> !LowLFHE.lwe_key_switch_key
// CHECK: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
// CHECK: func private @get_bootstrap_key(!LowLFHE.context) -> !LowLFHE.lwe_bootstrap_key
// CHECK: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.context) -> !LowLFHE.lwe_ciphertext<1024,4>
func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[ERR:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[C0:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, index) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-NEXT: %[[V2:.*]] = call @getGlobalKeyswitchKey() : () -> !LowLFHE.lwe_key_switch_key
// CHECK-NEXT: %[[V2:.*]] = call @get_keyswitch_key(%arg1) : (!LowLFHE.context) -> !LowLFHE.lwe_key_switch_key
// CHECK-NEXT: %[[V3:.*]] = builtin.unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_>
// CHECK-NEXT: call @keyswitch_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]]) : (index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -> ()
// CHECK-NEXT: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !LowLFHE.lwe_ciphertext<_,_> to !LowLFHE.lwe_ciphertext<1024,4>