mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
enhance(compiler/lowlfhe): Give the runtime context as function argument instead of a global variable (close #195)
This commit is contained in:
@@ -210,7 +210,23 @@ def LweBootstrapKeyType : LowLFHE_Type<"LweBootstrapKey"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def Context : LowLFHE_Type<"Context"> {
|
||||
let mnemonic = "context";
|
||||
|
||||
let summary = "Runtime context";
|
||||
|
||||
let description = [{
|
||||
An abstract runtime context to pass contextual value, like public keys, ...
|
||||
}];
|
||||
|
||||
let printer = [{
|
||||
$_printer << "context";
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
return get($_ctxt);
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -8,23 +8,8 @@ typedef struct RuntimeContext {
|
||||
struct LweBootstrapKey_u64 *bsk;
|
||||
} RuntimeContext;
|
||||
|
||||
extern RuntimeContext *globalRuntimeContext;
|
||||
LweKeyswitchKey_u64 *get_keyswitch_key(RuntimeContext *context);
|
||||
|
||||
RuntimeContext *createRuntimeContext(LweKeyswitchKey_u64 *ksk,
|
||||
LweBootstrapKey_u64 *bsk);
|
||||
|
||||
void setGlobalRuntimeContext(RuntimeContext *context);
|
||||
|
||||
RuntimeContext *getGlobalRuntimeContext();
|
||||
|
||||
LweKeyswitchKey_u64 *getGlobalKeyswitchKey();
|
||||
|
||||
LweBootstrapKey_u64 *getGlobalBootstrapKey();
|
||||
|
||||
LweKeyswitchKey_u64 *getKeyswitckKeyFromContext(RuntimeContext *context);
|
||||
|
||||
LweBootstrapKey_u64 *getBootstrapKeyFromContext(RuntimeContext *context);
|
||||
|
||||
bool checkError(int *err);
|
||||
LweBootstrapKey_u64 *get_bootstrap_key(RuntimeContext *context);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -96,6 +96,7 @@ public:
|
||||
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
RuntimeContext context;
|
||||
};
|
||||
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
|
||||
: type(type), name(name){};
|
||||
|
||||
@@ -41,10 +41,9 @@ public:
|
||||
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
|
||||
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
|
||||
|
||||
void initGlobalRuntimeContext() {
|
||||
auto ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
|
||||
auto bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]);
|
||||
setGlobalRuntimeContext(createRuntimeContext(ksk, bsk));
|
||||
void setRuntimeContext(RuntimeContext &context) {
|
||||
context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
|
||||
context.bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
@@ -133,6 +133,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
auto genericCleartextType = getGenericCleartextType(rewriter.getContext());
|
||||
auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext());
|
||||
auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext());
|
||||
auto contextType =
|
||||
mlir::zamalang::LowLFHE::ContextType::get(rewriter.getContext());
|
||||
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
|
||||
// Insert forward declaration of allocate lwe ciphertext
|
||||
@@ -211,10 +214,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
}
|
||||
// Insert forward declaration of the getBsk function
|
||||
{
|
||||
auto funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(), {}, {genericBSKType});
|
||||
if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey",
|
||||
funcType)
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{contextType}, {genericBSKType});
|
||||
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -237,10 +239,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
}
|
||||
// Insert forward declaration of the getKsk function
|
||||
{
|
||||
auto funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(), {}, {genericKSKType});
|
||||
if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey",
|
||||
funcType)
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{contextType}, {genericKSKType});
|
||||
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -563,6 +564,25 @@ struct GlweFromTableOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
mlir::Block *block = op->getBlock();
|
||||
while (block != nullptr) {
|
||||
if (llvm::isa<mlir::FuncOp>(block->getParentOp())) {
|
||||
|
||||
mlir::Value context = block->getArguments().back();
|
||||
|
||||
assert(context.getType().isa<mlir::zamalang::LowLFHE::ContextType>() &&
|
||||
"the LowLFHE.context should be the last argument of the enclosing "
|
||||
"function of the op");
|
||||
|
||||
return context;
|
||||
}
|
||||
block = block->getParentOp()->getBlock();
|
||||
}
|
||||
assert("can't find a function that enclose the op");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Rewrite a BootstrapLweOp with a series of ops:
|
||||
// - allocate the result LWE ciphertext
|
||||
// - get the global bootstrapping key
|
||||
@@ -592,10 +612,10 @@ struct LowLFHEBootstrapLweOpPattern
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
|
||||
// get bsk
|
||||
mlir::SmallVector<mlir::Value> getBskOperands{};
|
||||
auto getBskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "getGlobalBootstrapKey",
|
||||
getGenericLweBootstrapKeyType(rewriter.getContext()), getBskOperands);
|
||||
op.getLoc(), "get_bootstrap_key",
|
||||
getGenericLweBootstrapKeyType(rewriter.getContext()),
|
||||
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
|
||||
// bootstrap
|
||||
// cast input ciphertext to a generic type
|
||||
mlir::Value lweToBootstrap =
|
||||
@@ -651,10 +671,10 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
|
||||
// get ksk
|
||||
mlir::SmallVector<mlir::Value> getkskOperands{};
|
||||
auto getKskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "getGlobalKeyswitchKey",
|
||||
getGenericLweKeySwitchKeyType(rewriter.getContext()), getkskOperands);
|
||||
op.getLoc(), "get_keyswitch_key",
|
||||
getGenericLweKeySwitchKeyType(rewriter.getContext()),
|
||||
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
|
||||
// keyswitch
|
||||
// cast input ciphertext to a generic type
|
||||
mlir::Value lweToKeyswitch =
|
||||
@@ -703,6 +723,73 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
patterns.add<LowLFHEBootstrapLweOpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
struct AddRuntimeContextToFuncOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::FuncOp> {
|
||||
AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::FuncOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::FuncOp oldFuncOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
mlir::FunctionType oldFuncType = oldFuncOp.getType();
|
||||
|
||||
// Add a LowLFHE.context to the function signature
|
||||
mlir::SmallVector<mlir::Type> newInputs(oldFuncType.getInputs().begin(),
|
||||
oldFuncType.getInputs().end());
|
||||
newInputs.push_back(
|
||||
rewriter.getType<mlir::zamalang::LowLFHE::ContextType>());
|
||||
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
|
||||
newInputs, oldFuncType.getResults());
|
||||
// Create the new func
|
||||
mlir::FuncOp newFuncOp = rewriter.create<mlir::FuncOp>(
|
||||
oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy);
|
||||
|
||||
// Create the arguments of the new func
|
||||
mlir::Region &newFuncBody = newFuncOp.body();
|
||||
mlir::Block *newFuncEntryBlock = new mlir::Block();
|
||||
newFuncEntryBlock->addArguments(newFuncTy.getInputs());
|
||||
newFuncBody.push_back(newFuncEntryBlock);
|
||||
|
||||
// Clone the old body to the new one
|
||||
mlir::BlockAndValueMapping map;
|
||||
for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) {
|
||||
map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index()));
|
||||
}
|
||||
for (auto &op : oldFuncOp.body().front()) {
|
||||
newFuncEntryBlock->push_back(op.clone(map));
|
||||
}
|
||||
rewriter.eraseOp(oldFuncOp);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Legal function are one that are private or has a LowLFHE.context as last
|
||||
// arguments.
|
||||
static bool isLegal(mlir::FuncOp funcOp) {
|
||||
if (!funcOp.isPublic()) {
|
||||
return true;
|
||||
}
|
||||
// TODO : Don't need to add a runtime context for function that doesn't
|
||||
// manipulates lowlfhe types.
|
||||
//
|
||||
// if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) {
|
||||
// if (auto tensorTy = t.dyn_cast_or_null<mlir::TensorType>()) {
|
||||
// t = tensorTy.getElementType();
|
||||
// }
|
||||
// return llvm::isa<mlir::zamalang::LowLFHE::LowLFHEDialect>(
|
||||
// t.getDialect());
|
||||
// })) {
|
||||
// return true;
|
||||
// }
|
||||
return funcOp.getType().getNumInputs() >= 1 &&
|
||||
funcOp.getType()
|
||||
.getInputs()
|
||||
.back()
|
||||
.isa<mlir::zamalang::LowLFHE::ContextType>();
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct LowLFHEToConcreteCAPIPass
|
||||
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
|
||||
@@ -711,27 +798,49 @@ struct LowLFHEToConcreteCAPIPass
|
||||
} // namespace
|
||||
|
||||
void LowLFHEToConcreteCAPIPass::runOnOperation() {
|
||||
// Setup the conversion target.
|
||||
mlir::ConversionTarget target(getContext());
|
||||
target.addIllegalDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
|
||||
mlir::memref::MemRefDialect,
|
||||
mlir::arith::ArithmeticDialect>();
|
||||
|
||||
// Setup rewrite patterns
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
populateLowLFHEToConcreteCAPICall(patterns);
|
||||
|
||||
// Insert forward declarations
|
||||
mlir::ModuleOp op = getOperation();
|
||||
|
||||
// First of all add the LowLFHE.context to the block arguments of function
|
||||
// that manipulates ciphertexts.
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return AddRuntimeContextToFuncOpPattern::isLegal(funcOp);
|
||||
});
|
||||
|
||||
patterns.add<AddRuntimeContextToFuncOpPattern>(patterns.getContext());
|
||||
|
||||
// Apply the conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert forward declaration
|
||||
mlir::IRRewriter rewriter(&getContext());
|
||||
if (insertForwardDeclarations(op, rewriter).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
// Rewrite LowLFHE ops to CallOp to the Concrete C API
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// Apply the conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
target.addIllegalDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
|
||||
mlir::memref::MemRefDialect,
|
||||
mlir::arith::ArithmeticDialect>();
|
||||
|
||||
populateLowLFHEToConcreteCAPICall(patterns);
|
||||
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
|
||||
type.isa<mlir::zamalang::LowLFHE::GlweCiphertextType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::LweKeySwitchKeyType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::LweBootstrapKeyType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::ContextType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::ForeignPlaintextListType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::PlaintextListType>()) {
|
||||
return mlir::LLVM::LLVMPointerType::get(
|
||||
|
||||
@@ -25,24 +25,9 @@ void LowLFHEDialect::initialize() {
|
||||
mlir::Type type;
|
||||
|
||||
std::string types_str[] = {
|
||||
"enc_rand_gen",
|
||||
"secret_rand_gen",
|
||||
"plaintext",
|
||||
"plaintext_list",
|
||||
"foreign_plaintext_list",
|
||||
"lwe_ciphertext",
|
||||
"lwe_key_switch_key",
|
||||
"lwe_bootstrap_key",
|
||||
"lwe_secret_key",
|
||||
"lwe_size",
|
||||
"glwe_ciphertext",
|
||||
"glwe_secret_key",
|
||||
"glwe_size",
|
||||
"polynomial_size",
|
||||
"decomp_level_count",
|
||||
"decomp_base_log",
|
||||
"variance",
|
||||
"cleartext",
|
||||
"plaintext", "plaintext_list", "foreign_plaintext_list",
|
||||
"lwe_ciphertext", "lwe_key_switch_key", "lwe_bootstrap_key",
|
||||
"glwe_ciphertext", "cleartext", "context",
|
||||
};
|
||||
|
||||
for (const std::string &type_str : types_str) {
|
||||
@@ -53,8 +38,7 @@ void LowLFHEDialect::initialize() {
|
||||
}
|
||||
|
||||
parser.emitError(parser.getCurrentLocation(), "Unknown LowLFHE type");
|
||||
// call default parser
|
||||
parser.parseType(type);
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,51 +1,10 @@
|
||||
#include "zamalang/Runtime/context.h"
|
||||
#include <stdio.h>
|
||||
|
||||
RuntimeContext *globalRuntimeContext;
|
||||
|
||||
RuntimeContext *createRuntimeContext(LweKeyswitchKey_u64 *ksk,
|
||||
LweBootstrapKey_u64 *bsk) {
|
||||
RuntimeContext *context = (RuntimeContext *)malloc(sizeof(RuntimeContext));
|
||||
context->ksk = ksk;
|
||||
context->bsk = bsk;
|
||||
return context;
|
||||
}
|
||||
|
||||
void setGlobalRuntimeContext(RuntimeContext *context) {
|
||||
globalRuntimeContext = context;
|
||||
}
|
||||
|
||||
RuntimeContext *getGlobalRuntimeContext() { return globalRuntimeContext; }
|
||||
|
||||
LweKeyswitchKey_u64 *getGlobalKeyswitchKey() {
|
||||
return globalRuntimeContext->ksk;
|
||||
}
|
||||
|
||||
LweBootstrapKey_u64 *getGlobalBootstrapKey() {
|
||||
return globalRuntimeContext->bsk;
|
||||
}
|
||||
|
||||
LweKeyswitchKey_u64 *getKeyswitckKeyFromContext(RuntimeContext *context) {
|
||||
LweKeyswitchKey_u64 *get_keyswitch_key(RuntimeContext *context) {
|
||||
return context->ksk;
|
||||
}
|
||||
|
||||
LweBootstrapKey_u64 *getBootstrapKeyFromContext(RuntimeContext *context) {
|
||||
LweBootstrapKey_u64 *get_bootstrap_key(RuntimeContext *context) {
|
||||
return context->bsk;
|
||||
}
|
||||
|
||||
bool checkError(int *err) {
|
||||
switch (*err) {
|
||||
case ERR_INDEX_OUT_OF_BOUND:
|
||||
fprintf(stderr, "Runtime: index out of bound");
|
||||
break;
|
||||
case ERR_NULL_POINTER:
|
||||
fprintf(stderr, "Runtime: null pointer");
|
||||
break;
|
||||
case ERR_SIZE_MISMATCH:
|
||||
fprintf(stderr, "Runtime: size mismatch");
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@@ -125,8 +125,11 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
|
||||
// Create input and output circuit gate parameters
|
||||
auto funcType = (*funcOp).getType();
|
||||
for (auto inType : funcType.getInputs()) {
|
||||
auto gate = gateFromMLIRType("big", precision, encryptionVariance, inType);
|
||||
bool hasContext =
|
||||
funcType.getInputs().back().isa<mlir::zamalang::LowLFHE::ContextType>();
|
||||
for (auto inType = funcType.getInputs().begin();
|
||||
inType < funcType.getInputs().end() - hasContext; inType++) {
|
||||
auto gate = gateFromMLIRType("big", precision, encryptionVariance, *inType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
@@ -78,8 +78,8 @@ llvm::Error JITLambda::invoke(Argument &args) {
|
||||
|
||||
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
// Setting the inputs
|
||||
auto numInputs = 0;
|
||||
{
|
||||
auto numInputs = 0;
|
||||
for (size_t i = 0; i < keySet.numInputs(); i++) {
|
||||
auto offset = numInputs;
|
||||
auto gate = keySet.inputGate(i);
|
||||
@@ -95,6 +95,8 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
// dimension of the tensor.
|
||||
numInputs = numInputs + 2 * keySet.inputGate(i).shape.dimensions.size();
|
||||
}
|
||||
// Reserve for the context argument
|
||||
numInputs = numInputs + 1;
|
||||
inputs = std::vector<const void *>(numInputs);
|
||||
}
|
||||
|
||||
@@ -128,8 +130,10 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
rawArg[i] = &outputs[i - inputs.size()];
|
||||
}
|
||||
|
||||
// Setup runtime context with appropriate keys
|
||||
keySet.initGlobalRuntimeContext();
|
||||
// Set the context argument
|
||||
keySet.setRuntimeContext(context);
|
||||
inputs[numInputs - 1] = &context;
|
||||
rawArg[numInputs - 1] = &inputs[numInputs - 1];
|
||||
}
|
||||
|
||||
JITLambda::Argument::~Argument() {
|
||||
|
||||
@@ -140,7 +140,8 @@ lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("LowLFHEToStd", pm, context);
|
||||
pm.addPass(mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass());
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
|
||||
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
|
||||
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
|
||||
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
|
||||
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
|
||||
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-LABEL: func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK: func private @get_keyswitch_key(!LowLFHE.context) -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
|
||||
// CHECK: func private @get_bootstrap_key(!LowLFHE.context) -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-LABEL: func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext, %arg2: !LowLFHE.context) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[ERR:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 1024 : index
|
||||
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, index) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-NEXT: %[[V2:.*]] = call @getGlobalBootstrapKey() : () -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK-NEXT: %[[V2:.*]] = call @get_bootstrap_key(%arg2) : (!LowLFHE.context) -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK-NEXT: %[[V3:.*]] = builtin.unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-NEXT: %[[V4:.*]] = builtin.unrealized_conversion_cast %arg1 : !LowLFHE.glwe_ciphertext to !LowLFHE.glwe_ciphertext
|
||||
// CHECK-NEXT: call @bootstrap_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]], %[[V4]]) : (index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) -> ()
|
||||
|
||||
@@ -1,21 +1,12 @@
|
||||
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
|
||||
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
|
||||
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
|
||||
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
|
||||
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
|
||||
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
|
||||
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
|
||||
// CHECK: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
// CHECK: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
|
||||
// CHECK: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
|
||||
// CHECK: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi64>, %arg1: !LowLFHE.context) -> !LowLFHE.glwe_ciphertext
|
||||
func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext {
|
||||
// CHECK-NEXT: %[[V0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 1 : i32
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
|
||||
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
|
||||
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
|
||||
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
|
||||
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
|
||||
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK: func private @get_keyswitch_key(!LowLFHE.context) -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
|
||||
// CHECK: func private @get_bootstrap_key(!LowLFHE.context) -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.context) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[ERR:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, index) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-NEXT: %[[V2:.*]] = call @getGlobalKeyswitchKey() : () -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK-NEXT: %[[V2:.*]] = call @get_keyswitch_key(%arg1) : (!LowLFHE.context) -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK-NEXT: %[[V3:.*]] = builtin.unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-NEXT: call @keyswitch_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]]) : (index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -> ()
|
||||
// CHECK-NEXT: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !LowLFHE.lwe_ciphertext<_,_> to !LowLFHE.lwe_ciphertext<1024,4>
|
||||
|
||||
Reference in New Issue
Block a user