refactor(compiler): Introduce compilation pipeline with multiple entries / exits

This refactoring commit restructures the compilation pipeline of
`zamacompiler`, such that it is possible to enter and exit the
pipeline at different points, effectively defining the level of
abstraction at the input and the required level of abstraction for the
output.

The entry point is specified using the `--entry-dialect`
argument. Valid choices are:

  `--entry-dialect=hlfhe`:   Source contains HLFHE operations
  `--entry-dialect=midlfhe`: Source contains MidLFHE operations
  `--entry-dialect=lowlfhe`: Source contains LowLFHE operations
  `--entry-dialect=std`:     Source does not contain any FHE Operations
  `--entry-dialect=llvm`:    Source is in LLVM dialect

The exit point is defined by an action, specified using --action.

  `--action=roundtrip`:
     Parse the source file to in-memory representation and immediately
     dump as text without any processing

  `--action=dump-midlfhe`:
     Lower source to MidLFHE and dump result as text

  `--action=dump-lowlfhe`:
     Lower source to LowLFHE and dump result as text

  `--action=dump-std`:
     Lower source to only standard MLIR dialects (i.e., all FHE
     operations have already been lowered)

  `--action=dump-llvm-dialect`:
     Lower source to MLIR's LLVM dialect (i.e., the LLVM dialect, not
     LLVM IR)

  `--action=dump-llvm-ir`:
     Lower source to plain LLVM IR (i.e., not the LLVM dialect, but
     actual LLVM IR)

  `--action=dump-optimized-llvm-ir`:
     Lower source to plain LLVM IR (i.e., not the LLVM dialect, but
     actual LLVM IR), pass the result through the LLVM optimizer and
     print the result.

  `--action=dump-jit-invoke`:
     Execute the full lowering pipeline to optimized LLVM IR, JIT
     compile the result, invoke the function specified in
     `--jit-funcname` with the parameters from `--jit-args` and print
     the functions return value.
This commit is contained in:
Andi Drebes
2021-09-17 10:45:53 +02:00
committed by Quentin Bourgerie
parent ddebedd1d6
commit 30374ebb2c
58 changed files with 1014 additions and 862 deletions

View File

@@ -19,7 +19,7 @@ struct V0Parameter {
size_t ksLevel;
size_t ksLogBase;
V0Parameter() {}
V0Parameter() = delete;
V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel,
size_t brLogBase, size_t ksLevel, size_t ksLogBase)
@@ -31,11 +31,14 @@ struct V0Parameter {
};
struct V0FHEContext {
V0FHEContext() = delete;
V0FHEContext(const V0FHEConstraint &constraint, const V0Parameter &parameter)
: constraint(constraint), parameter(parameter) {}
V0FHEConstraint constraint;
V0Parameter parameter;
};
} // namespace zamalang
} // namespace mlir
#endif
#endif

View File

@@ -1,17 +1,7 @@
#ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H
#define ZAMALANG_SUPPORT_COMPILER_ENGINE_H
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerTools.h"
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <string>
#include "Jit.h"
namespace mlir {
namespace zamalang {
@@ -55,4 +45,4 @@ private:
} // namespace zamalang
} // namespace mlir
#endif
#endif

View File

@@ -1,138 +0,0 @@
#ifndef ZAMALANG_SUPPORT_COMPILERTOOLS_H_
#define ZAMALANG_SUPPORT_COMPILERTOOLS_H_
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/Pass/PassManager.h>
#include "zamalang/Support/ClientParameters.h"
#include "zamalang/Support/KeySet.h"
#include "zamalang/Support/V0Parameters.h"
namespace mlir {
namespace zamalang {
class CompilerTools {
public:
struct LowerOptions {
llvm::function_ref<bool(std::string)> enablePass;
bool verbose;
LowerOptions()
: verbose(false), enablePass([](std::string pass) { return true; }){};
};
/// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir
/// lowerable to llvm dialect.
/// The given module MLIR operation would be modified and the constraints set.
static mlir::LogicalResult
lowerHLFHEToMlirStdsDialect(mlir::MLIRContext &context,
mlir::Operation *module, V0FHEContext &fheContext,
LowerOptions options = LowerOptions());
/// lowerMlirStdsDialectToMlirLLVMDialect run all passes to lower MLIR
/// dialects to MLIR LLVM dialect. The given module MLIR operation would be
/// modified.
static mlir::LogicalResult
lowerMlirStdsDialectToMlirLLVMDialect(mlir::MLIRContext &context,
mlir::Operation *module,
LowerOptions options = LowerOptions());
static llvm::Expected<std::unique_ptr<llvm::Module>>
toLLVMModule(llvm::LLVMContext &llvmContext, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
};
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
/// of the module.
class JITLambda {
public:
class Argument {
public:
Argument(KeySet &keySet);
~Argument();
// Create lambda Argument that use the given KeySet to perform encryption
// and decryption operations.
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
// Set a scalar argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint64_t arg);
// Set a argument at the given pos as a tensor of int64.
llvm::Error setArg(size_t pos, uint64_t *data, size_t size) {
return setArg(pos, 64, (void *)data, size);
}
// Set a argument at the given pos as a tensor of int32.
llvm::Error setArg(size_t pos, uint32_t *data, size_t size) {
return setArg(pos, 32, (void *)data, size);
}
// Set a argument at the given pos as a tensor of int32.
llvm::Error setArg(size_t pos, uint16_t *data, size_t size) {
return setArg(pos, 16, (void *)data, size);
}
// Set a tensor argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint8_t *data, size_t size) {
return setArg(pos, 8, (void *)data, size);
}
// Get the result at the given pos as an uint64_t.
llvm::Error getResult(size_t pos, uint64_t &res);
// Fill the result.
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
private:
llvm::Error setArg(size_t pos, size_t width, void *data, size_t size);
friend JITLambda;
// Store the pointer on inputs values and outputs values
std::vector<void *> rawArg;
// Store the values of inputs
std::vector<void *> inputs;
// Store the values of outputs
std::vector<void *> outputs;
// Store the input gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> inputGates;
// Store the outputs gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> outputGates;
// Store allocated lwe ciphertexts (for free)
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
// Store buffers of ciphertexts
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
KeySet &keySet;
};
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
: type(type), name(name){};
/// create a JITLambda that point to the function name of the given module.
static llvm::Expected<std::unique_ptr<JITLambda>>
create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
/// used to store the result of the computation.
/// Example:
/// uin64_t arg0 = 1;
/// uin64_t res;
/// llvm::SmallVector<void *> args{&arg1, &res};
/// lambda.invokeRaw(args);
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
/// invoke the jit lambda with the Argument.
llvm::Error invoke(Argument &args);
private:
mlir::LLVM::LLVMFunctionType type;
llvm::StringRef name;
std::unique_ptr<mlir::ExecutionEngine> engine;
};
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -1,11 +1,12 @@
#ifndef COMPILER_JIT_H
#define COMPILER_JIT_H
#include "zamalang/Support/CompilerTools.h"
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/Support/LogicalResult.h>
#include <zamalang/Support/KeySet.h>
namespace mlir {
namespace zamalang {
mlir::LogicalResult
@@ -13,6 +14,96 @@ runJit(mlir::ModuleOp module, llvm::StringRef func,
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
std::function<llvm::Error(llvm::Module *)> optPipeline,
llvm::raw_ostream &os);
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
/// of the module.
class JITLambda {
public:
class Argument {
public:
Argument(KeySet &keySet);
~Argument();
// Create lambda Argument that use the given KeySet to perform encryption
// and decryption operations.
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
// Set a scalar argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint64_t arg);
// Set a argument at the given pos as a tensor of int64.
llvm::Error setArg(size_t pos, uint64_t *data, size_t size) {
return setArg(pos, 64, (void *)data, size);
}
// Set a argument at the given pos as a tensor of int32.
llvm::Error setArg(size_t pos, uint32_t *data, size_t size) {
return setArg(pos, 32, (void *)data, size);
}
// Set a argument at the given pos as a tensor of int32.
llvm::Error setArg(size_t pos, uint16_t *data, size_t size) {
return setArg(pos, 16, (void *)data, size);
}
// Set a tensor argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint8_t *data, size_t size) {
return setArg(pos, 8, (void *)data, size);
}
// Get the result at the given pos as an uint64_t.
llvm::Error getResult(size_t pos, uint64_t &res);
// Fill the result.
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
private:
llvm::Error setArg(size_t pos, size_t width, void *data, size_t size);
friend JITLambda;
// Store the pointer on inputs values and outputs values
std::vector<void *> rawArg;
// Store the values of inputs
std::vector<void *> inputs;
// Store the values of outputs
std::vector<void *> outputs;
// Store the input gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> inputGates;
// Store the outputs gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> outputGates;
// Store allocated lwe ciphertexts (for free)
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
// Store buffers of ciphertexts
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
KeySet &keySet;
};
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
: type(type), name(name){};
/// create a JITLambda that point to the function name of the given module.
static llvm::Expected<std::unique_ptr<JITLambda>>
create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
/// used to store the result of the computation.
/// Example:
/// uin64_t arg0 = 1;
/// uin64_t res;
/// llvm::SmallVector<void *> args{&arg1, &res};
/// lambda.invokeRaw(args);
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
/// invoke the jit lambda with the Argument.
llvm::Error invoke(Argument &args);
private:
mlir::LLVM::LLVMFunctionType type;
llvm::StringRef name;
std::unique_ptr<mlir::ExecutionEngine> engine;
};
} // namespace zamalang
} // namespace mlir

View File

@@ -0,0 +1,42 @@
#ifndef ZAMALANG_SUPPORT_PIPELINE_H_
#define ZAMALANG_SUPPORT_PIPELINE_H_
#include <llvm/IR/Module.h>
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/Support/LogicalResult.h>
#include <zamalang/Support/V0Parameters.h>
namespace mlir {
namespace zamalang {
namespace pipeline {
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool verbose);
mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
V0FHEContext &fheContext,
bool parametrize);
mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module);
mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool verbose);
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
llvm::Module &module);
mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module,
V0FHEContext &fheContext, bool verbose);
std::unique_ptr<llvm::Module>
lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context,
llvm::LLVMContext &llvmContext,
mlir::ModuleOp &module);
} // namespace pipeline
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -1,5 +1,6 @@
add_mlir_library(ZamalangSupport
CompilerTools.cpp
Pipeline.cpp
Jit.cpp
CompilerEngine.cpp
V0Parameters.cpp
V0Curves.cpp

View File

@@ -1,8 +1,16 @@
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Conversion/Passes.h"
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h>
#include <zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h>
#include <zamalang/Support/CompilerEngine.h>
#include <zamalang/Support/Pipeline.h>
namespace mlir {
namespace zamalang {
@@ -29,10 +37,20 @@ llvm::Error CompilerEngine::compile(std::string mlirStr) {
return llvm::make_error<llvm::StringError>("mlir parsing failed",
llvm::inconvertibleErrorCode());
}
mlir::zamalang::V0FHEContext fheContext;
mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10,
.p = 7};
const mlir::zamalang::V0Parameter *parameter =
getV0Parameter(defaultGlobalFHECircuitConstraint);
mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint,
*parameter};
mlir::ModuleOp module = module_ref.get();
// Lower to MLIR Std
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
*context, module_ref.get(), fheContext)
if (mlir::zamalang::pipeline::lowerHLFHEToStd(*context, module, fheContext,
false)
.failed()) {
return llvm::make_error<llvm::StringError>("failed to lower to MLIR Std",
llvm::inconvertibleErrorCode());
@@ -53,8 +71,7 @@ llvm::Error CompilerEngine::compile(std::string mlirStr) {
keySet = std::move(maybeKeySet.get());
// Lower to MLIR LLVM Dialect
if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
*context, module_ref.get())
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(*context, module, false)
.failed()) {
return llvm::make_error<llvm::StringError>(
"failed to lower to LLVM dialect", llvm::inconvertibleErrorCode());
@@ -114,4 +131,4 @@ llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
return res;
}
} // namespace zamalang
} // namespace mlir
} // namespace mlir

View File

@@ -1,467 +0,0 @@
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include <llvm/Support/TargetSelect.h>
#include <mlir/Dialect/Linalg/Passes.h>
#include <mlir/Dialect/StandardOps/Transforms/Passes.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <mlir/Target/LLVMIR/Export.h>
#include <mlir/Transforms/Passes.h>
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Support/CompilerTools.h"
namespace mlir {
namespace zamalang {
// This is temporary while we doesn't yet have the high-level verification pass
V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, .p = 7};
void initLLVMNativeTarget() {
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
}
void addFilteredPassToPassManager(
mlir::PassManager &pm, std::unique_ptr<mlir::Pass> pass,
llvm::function_ref<bool(std::string)> enablePass) {
if (!enablePass(pass->getArgument().str())) {
return;
}
if (*pass->getOpName() == "module") {
pm.addPass(std::move(pass));
} else {
pm.nest(*pass->getOpName()).addPass(std::move(pass));
}
};
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
mlir::MLIRContext &context, mlir::Operation *module,
V0FHEContext &fheContext, LowerOptions options) {
mlir::PassManager pm(&context);
if (options.verbose) {
llvm::errs() << "##################################################\n";
llvm::errs() << "### HLFHEToMlirStdsDialect pipeline\n";
context.disableMultithreading();
pm.enableIRPrinting();
pm.enableStatistics();
pm.enableTiming();
pm.enableVerifier();
}
fheContext.constraint = defaultGlobalFHECircuitConstraint;
fheContext.parameter = *getV0Parameter(fheContext.constraint);
// Add all passes to lower from HLFHE to LLVM Dialect
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(),
options.enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(),
options.enablePass);
addFilteredPassToPassManager(
pm,
mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(fheContext),
options.enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(),
options.enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(),
options.enablePass);
// Run the passes
if (pm.run(module).failed()) {
return mlir::failure();
}
return mlir::success();
}
mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
mlir::MLIRContext &context, mlir::Operation *module, LowerOptions options) {
mlir::PassManager pm(&context);
if (options.verbose) {
llvm::errs() << "##################################################\n";
llvm::errs() << "### MlirStdsDialectToMlirLLVMDialect pipeline\n";
context.disableMultithreading();
pm.enableIRPrinting();
pm.enableStatistics();
pm.enableTiming();
pm.enableVerifier();
}
// Unparametrize LowLFHE
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass(),
options.enablePass);
// Bufferize
addFilteredPassToPassManager(pm, mlir::createTensorConstantBufferizePass(),
options.enablePass);
addFilteredPassToPassManager(pm, mlir::createStdBufferizePass(),
options.enablePass);
addFilteredPassToPassManager(pm, mlir::createTensorBufferizePass(),
options.enablePass);
addFilteredPassToPassManager(pm, mlir::createLinalgBufferizePass(),
options.enablePass);
addFilteredPassToPassManager(pm, mlir::createConvertLinalgToLoopsPass(),
options.enablePass);
addFilteredPassToPassManager(pm, mlir::createFuncBufferizePass(),
options.enablePass);
addFilteredPassToPassManager(pm, mlir::createFinalizingBufferizePass(),
options.enablePass);
// Convert to MLIR LLVM Dialect
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
options.enablePass);
if (pm.run(module).failed()) {
return mlir::failure();
}
return mlir::success();
}
llvm::Expected<std::unique_ptr<llvm::Module>> CompilerTools::toLLVMModule(
llvm::LLVMContext &llvmContext, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
initLLVMNativeTarget();
mlir::registerLLVMDialectTranslation(*module->getContext());
auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
if (!llvmModule) {
return llvm::make_error<llvm::StringError>(
"failed to translate MLIR to LLVM IR", llvm::inconvertibleErrorCode());
}
if (auto err = optPipeline(llvmModule.get())) {
return llvm::make_error<llvm::StringError>("failed to optimize LLVM IR",
llvm::inconvertibleErrorCode());
}
return std::move(llvmModule);
}
llvm::Expected<std::unique_ptr<JITLambda>>
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
// Looking for the function
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::LLVM::LLVMFuncOp op) {
return op.getName() == name;
});
if (funcOp == rangeOps.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find the function to JIT", llvm::inconvertibleErrorCode());
}
initLLVMNativeTarget();
mlir::registerLLVMDialectTranslation(*module->getContext());
// Create an MLIR execution engine. The execution engine eagerly
// JIT-compiles the module.
auto maybeEngine = mlir::ExecutionEngine::create(
module, /*llvmModuleBuilder=*/nullptr, optPipeline);
if (!maybeEngine) {
return llvm::make_error<llvm::StringError>(
"failed to construct the MLIR ExecutionEngine",
llvm::inconvertibleErrorCode());
}
auto &engine = maybeEngine.get();
auto lambda = std::make_unique<JITLambda>((*funcOp).getType(), name);
lambda->engine = std::move(engine);
return std::move(lambda);
}
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
size_t nbReturn = 0;
// TODO - This check break with memref as we have 5 returns args.
// if (!this->type.getReturnType().isa<mlir::LLVM::LLVMVoidType>()) {
// nbReturn = 1;
// }
// if (this->type.getNumParams() != args.size() - nbReturn) {
// return llvm::make_error<llvm::StringError>(
// "invokeRaw: wrong number of argument",
// llvm::inconvertibleErrorCode());
// }
if (llvm::find(args, nullptr) != args.end()) {
return llvm::make_error<llvm::StringError>(
"invoke: some arguments are null", llvm::inconvertibleErrorCode());
}
return this->engine->invokePacked(this->name, args);
}
llvm::Error JITLambda::invoke(Argument &args) {
return std::move(invokeRaw(args.rawArg));
}
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
// Setting the inputs
{
auto numInputs = 0;
for (size_t i = 0; i < keySet.numInputs(); i++) {
auto offset = numInputs;
auto gate = keySet.inputGate(i);
inputGates.push_back({gate, offset});
if (keySet.inputGate(i).shape.size == 0) {
// scalar gate
numInputs = numInputs + 1;
continue;
}
// memref gate, as we follow the standard calling convention
numInputs = numInputs + 5;
}
inputs = std::vector<void *>(numInputs);
}
// Setting the outputs
{
auto numOutputs = 0;
for (size_t i = 0; i < keySet.numOutputs(); i++) {
auto offset = numOutputs;
auto gate = keySet.outputGate(i);
outputGates.push_back({gate, offset});
if (gate.shape.size == 0) {
// scalar gate
numOutputs = numOutputs + 1;
continue;
}
// memref gate, as we follow the standard calling convention
numOutputs = numOutputs + 5;
}
outputs = std::vector<void *>(numOutputs);
}
// The raw argument contains pointers to inputs and pointers to store the
// results
rawArg = std::vector<void *>(inputs.size() + outputs.size(), nullptr);
// Set the pointer on outputs on rawArg
for (auto i = inputs.size(); i < rawArg.size(); i++) {
rawArg[i] = &outputs[i - inputs.size()];
}
// Setup runtime context with appropriate keys
keySet.initGlobalRuntimeContext();
}
JITLambda::Argument::~Argument() {
int err;
for (auto ct : allocatedCiphertexts) {
free_lwe_ciphertext_u64(&err, ct);
}
for (auto buffer : ciphertextBuffers) {
free(buffer);
}
}
llvm::Expected<std::unique_ptr<JITLambda::Argument>>
JITLambda::Argument::create(KeySet &keySet) {
auto args = std::make_unique<JITLambda::Argument>(keySet);
return std::move(args);
}
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
if (pos >= inputGates.size()) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument index out of bound: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
auto gate = inputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check is the argument is a scalar
if (info.shape.size != 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// If argument is not encrypted, just save.
if (!info.encryption.hasValue()) {
inputs[offset] = (void *)arg;
rawArg[offset] = &inputs[offset];
return llvm::Error::success();
}
// Else if is encryted, allocate ciphertext and encrypt.
LweCiphertext_u64 *ctArg;
if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) {
return std::move(err);
}
allocatedCiphertexts.push_back(ctArg);
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
return std::move(err);
}
inputs[offset] = ctArg;
rawArg[offset] = &inputs[offset];
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
size_t size) {
auto gate = inputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check if the width is compatible
// TODO - I found this rules empirically, they are a spec somewhere?
if (info.shape.width <= 8 && width != 8) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 8: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 16: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 32: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 64: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 64) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width not supported: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// Check the size
if (info.shape.size == 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.size != size) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("vector argument has not the expected size")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// If argument is not encrypted, just save with the right calling convention.
if (info.encryption.hasValue()) {
// Else if is encrypted
// For moment we support only 8 bits inputs
uint8_t *data8 = (uint8_t *)data;
if (width != 8) {
return llvm::make_error<llvm::StringError>(
llvm::Twine(
"argument width > 8 for encrypted gates are not supported: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// Allocate a buffer for ciphertexts.
auto ctBuffer =
(LweCiphertext_u64 **)malloc(size * sizeof(LweCiphertext_u64 *));
ciphertextBuffers.push_back(ctBuffer);
// Allocate ciphertexts and encrypt
for (auto i = 0; i < size; i++) {
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
return std::move(err);
}
allocatedCiphertexts.push_back(ctBuffer[i]);
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
return std::move(err);
}
}
// Replace the data by the buffer to ciphertext
data = (void *)ctBuffer;
}
// Set the buffer as the memref calling convention expect.
// allocated
inputs[offset] = (void *)0; // TODO - Better understand how it is used.
rawArg[offset] = &inputs[offset];
// aligned
inputs[offset + 1] = data;
rawArg[offset + 1] = &inputs[offset + 1];
// offset
inputs[offset + 2] = (void *)0;
rawArg[offset + 2] = &inputs[offset + 2];
// size
inputs[offset + 3] = (void *)size;
rawArg[offset + 3] = &inputs[offset + 3];
// stride
inputs[offset + 4] = (void *)0;
rawArg[offset + 4] = &inputs[offset + 4];
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
auto gate = outputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check is the argument is a scalar
if (info.shape.size != 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// If result is not encrypted, just set the result
if (!info.encryption.hasValue()) {
res = (uint64_t)(outputs[offset]);
return llvm::Error::success();
}
// Else if is encryted, decrypt
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]);
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
return std::move(err);
}
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
size_t size) {
auto gate = outputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check is the argument is a scalar
if (info.shape.size == 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (!info.encryption.hasValue()) {
return llvm::make_error<llvm::StringError>(
"unencrypted result as tensor output NYI",
llvm::inconvertibleErrorCode());
}
// Get the values as the memref calling convention expect.
void *allocated = outputs[offset]; // TODO - Better understand how it is used.
// aligned
void *aligned = outputs[offset + 1];
// offset
size_t offset_r = (size_t)outputs[offset + 2];
// size
size_t size_r = (size_t)outputs[offset + 3];
// stride
size_t stride = (size_t)outputs[offset + 4];
// Check the sizes
if (info.shape.size != size || size_r != size) {
return llvm::make_error<llvm::StringError>("output bad result buffer size",
llvm::inconvertibleErrorCode());
}
// decrypt and fill the result buffer
for (auto i = 0; i < size_r; i++) {
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i];
if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) {
return std::move(err);
}
}
return llvm::Error::success();
}
} // namespace zamalang
} // namespace mlir

View File

@@ -1,6 +1,10 @@
#include <llvm/ADT/ArrayRef.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Support/TargetSelect.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <zamalang/Support/Jit.h>
#include <zamalang/Support/logging.h>
@@ -54,5 +58,329 @@ runJit(mlir::ModuleOp module, llvm::StringRef func,
return mlir::success();
}
llvm::Expected<std::unique_ptr<JITLambda>>
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
// Looking for the function
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::LLVM::LLVMFuncOp op) {
return op.getName() == name;
});
if (funcOp == rangeOps.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find the function to JIT", llvm::inconvertibleErrorCode());
}
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
mlir::registerLLVMDialectTranslation(*module->getContext());
// Create an MLIR execution engine. The execution engine eagerly
// JIT-compiles the module.
auto maybeEngine = mlir::ExecutionEngine::create(
module, /*llvmModuleBuilder=*/nullptr, optPipeline);
if (!maybeEngine) {
return llvm::make_error<llvm::StringError>(
"failed to construct the MLIR ExecutionEngine",
llvm::inconvertibleErrorCode());
}
auto &engine = maybeEngine.get();
auto lambda = std::make_unique<JITLambda>((*funcOp).getType(), name);
lambda->engine = std::move(engine);
return std::move(lambda);
}
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
size_t nbReturn = 0;
// TODO - This check break with memref as we have 5 returns args.
// if (!this->type.getReturnType().isa<mlir::LLVM::LLVMVoidType>()) {
// nbReturn = 1;
// }
// if (this->type.getNumParams() != args.size() - nbReturn) {
// return llvm::make_error<llvm::StringError>(
// "invokeRaw: wrong number of argument",
// llvm::inconvertibleErrorCode());
// }
if (llvm::find(args, nullptr) != args.end()) {
return llvm::make_error<llvm::StringError>(
"invoke: some arguments are null", llvm::inconvertibleErrorCode());
}
return this->engine->invokePacked(this->name, args);
}
llvm::Error JITLambda::invoke(Argument &args) {
return std::move(invokeRaw(args.rawArg));
}
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
// Setting the inputs
{
auto numInputs = 0;
for (size_t i = 0; i < keySet.numInputs(); i++) {
auto offset = numInputs;
auto gate = keySet.inputGate(i);
inputGates.push_back({gate, offset});
if (keySet.inputGate(i).shape.size == 0) {
// scalar gate
numInputs = numInputs + 1;
continue;
}
// memref gate, as we follow the standard calling convention
numInputs = numInputs + 5;
}
inputs = std::vector<void *>(numInputs);
}
// Setting the outputs
{
auto numOutputs = 0;
for (size_t i = 0; i < keySet.numOutputs(); i++) {
auto offset = numOutputs;
auto gate = keySet.outputGate(i);
outputGates.push_back({gate, offset});
if (gate.shape.size == 0) {
// scalar gate
numOutputs = numOutputs + 1;
continue;
}
// memref gate, as we follow the standard calling convention
numOutputs = numOutputs + 5;
}
outputs = std::vector<void *>(numOutputs);
}
// The raw argument contains pointers to inputs and pointers to store the
// results
rawArg = std::vector<void *>(inputs.size() + outputs.size(), nullptr);
// Set the pointer on outputs on rawArg
for (auto i = inputs.size(); i < rawArg.size(); i++) {
rawArg[i] = &outputs[i - inputs.size()];
}
// Setup runtime context with appropriate keys
keySet.initGlobalRuntimeContext();
}
JITLambda::Argument::~Argument() {
int err;
for (auto ct : allocatedCiphertexts) {
free_lwe_ciphertext_u64(&err, ct);
}
for (auto buffer : ciphertextBuffers) {
free(buffer);
}
}
llvm::Expected<std::unique_ptr<JITLambda::Argument>>
JITLambda::Argument::create(KeySet &keySet) {
auto args = std::make_unique<JITLambda::Argument>(keySet);
return std::move(args);
}
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
if (pos >= inputGates.size()) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument index out of bound: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
auto gate = inputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check is the argument is a scalar
if (info.shape.size != 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// If argument is not encrypted, just save.
if (!info.encryption.hasValue()) {
inputs[offset] = (void *)arg;
rawArg[offset] = &inputs[offset];
return llvm::Error::success();
}
// Else if is encryted, allocate ciphertext and encrypt.
LweCiphertext_u64 *ctArg;
if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) {
return std::move(err);
}
allocatedCiphertexts.push_back(ctArg);
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
return std::move(err);
}
inputs[offset] = ctArg;
rawArg[offset] = &inputs[offset];
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
size_t size) {
auto gate = inputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check if the width is compatible
// TODO - I found this rules empirically, they are a spec somewhere?
if (info.shape.width <= 8 && width != 8) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 8: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 16: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 32: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 64: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 64) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width not supported: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// Check the size
if (info.shape.size == 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.size != size) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("vector argument has not the expected size")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// If argument is not encrypted, just save with the right calling convention.
if (info.encryption.hasValue()) {
// Else if is encrypted
// For moment we support only 8 bits inputs
uint8_t *data8 = (uint8_t *)data;
if (width != 8) {
return llvm::make_error<llvm::StringError>(
llvm::Twine(
"argument width > 8 for encrypted gates are not supported: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// Allocate a buffer for ciphertexts.
auto ctBuffer =
(LweCiphertext_u64 **)malloc(size * sizeof(LweCiphertext_u64 *));
ciphertextBuffers.push_back(ctBuffer);
// Allocate ciphertexts and encrypt
for (auto i = 0; i < size; i++) {
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
return std::move(err);
}
allocatedCiphertexts.push_back(ctBuffer[i]);
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
return std::move(err);
}
}
// Replace the data by the buffer to ciphertext
data = (void *)ctBuffer;
}
// Set the buffer as the memref calling convention expect.
// allocated
inputs[offset] = (void *)0; // TODO - Better understand how it is used.
rawArg[offset] = &inputs[offset];
// aligned
inputs[offset + 1] = data;
rawArg[offset + 1] = &inputs[offset + 1];
// offset
inputs[offset + 2] = (void *)0;
rawArg[offset + 2] = &inputs[offset + 2];
// size
inputs[offset + 3] = (void *)size;
rawArg[offset + 3] = &inputs[offset + 3];
// stride
inputs[offset + 4] = (void *)0;
rawArg[offset + 4] = &inputs[offset + 4];
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
auto gate = outputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check is the argument is a scalar
if (info.shape.size != 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// If result is not encrypted, just set the result
if (!info.encryption.hasValue()) {
res = (uint64_t)(outputs[offset]);
return llvm::Error::success();
}
// Else if is encryted, decrypt
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]);
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
return std::move(err);
}
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
size_t size) {
auto gate = outputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check is the argument is a scalar
if (info.shape.size == 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (!info.encryption.hasValue()) {
return llvm::make_error<llvm::StringError>(
"unencrypted result as tensor output NYI",
llvm::inconvertibleErrorCode());
}
// Get the values as the memref calling convention expect.
void *allocated = outputs[offset]; // TODO - Better understand how it is used.
// aligned
void *aligned = outputs[offset + 1];
// offset
size_t offset_r = (size_t)outputs[offset + 2];
// size
size_t size_r = (size_t)outputs[offset + 3];
// stride
size_t stride = (size_t)outputs[offset + 4];
// Check the sizes
if (info.shape.size != size || size_r != size) {
return llvm::make_error<llvm::StringError>("output bad result buffer size",
llvm::inconvertibleErrorCode());
}
// decrypt and fill the result buffer
for (auto i = 0; i < size_r; i++) {
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i];
if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) {
return std::move(err);
}
}
return llvm::Error::success();
}
} // namespace zamalang
} // namespace mlir

View File

@@ -0,0 +1,148 @@
#include <llvm/Support/TargetSelect.h>
#include <mlir/Dialect/Linalg/Passes.h>
#include <mlir/Dialect/StandardOps/Transforms/Passes.h>
#include <mlir/Dialect/Tensor/Transforms/Passes.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <mlir/Target/LLVMIR/Export.h>
#include <mlir/Transforms/Passes.h>
#include <zamalang/Conversion/Passes.h>
#include <zamalang/Support/Pipeline.h>
#include <zamalang/Support/logging.h>
namespace mlir {
namespace zamalang {
namespace pipeline {
static void addPotentiallyNestedPass(mlir::PassManager &pm,
std::unique_ptr<Pass> pass) {
if (!pass->getOpName() || *pass->getOpName() == "module") {
pm.addPass(std::move(pass));
} else {
pm.nest(*pass->getOpName()).addPass(std::move(pass));
}
}
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool verbose) {
mlir::PassManager pm(&context);
if (verbose) {
mlir::zamalang::log_verbose()
<< "##################################################\n"
<< "### HLFHE to MidLFHE pipeline\n";
pm.enableIRPrinting();
pm.enableStatistics();
pm.enableTiming();
pm.enableVerifier();
}
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg());
addPotentiallyNestedPass(pm,
mlir::zamalang::createConvertHLFHEToMidLFHEPass());
return pm.run(module.getOperation());
}
mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
V0FHEContext &fheContext,
bool parametrize) {
mlir::PassManager pm(&context);
if (parametrize) {
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(
fheContext));
}
addPotentiallyNestedPass(pm,
mlir::zamalang::createConvertMidLFHEToLowLFHEPass());
return pm.run(module.getOperation());
}
mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module) {
mlir::PassManager pm(&context);
pm.addPass(mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass());
return pm.run(module.getOperation());
}
mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context,
mlir::ModuleOp &module,
bool verbose) {
mlir::PassManager pm(&context);
if (verbose) {
mlir::zamalang::log_verbose()
<< "##################################################\n"
<< "### MlirStdsDialectToMlirLLVMDialect pipeline\n";
context.disableMultithreading();
pm.enableIRPrinting();
pm.enableStatistics();
pm.enableTiming();
pm.enableVerifier();
}
// Unparametrize LowLFHE
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass());
// Bufferize
addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass());
addPotentiallyNestedPass(pm, mlir::createStdBufferizePass());
addPotentiallyNestedPass(pm, mlir::createTensorBufferizePass());
addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass());
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass());
addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass());
addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass());
// Convert to MLIR LLVM Dialect
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass());
return pm.run(module);
}
std::unique_ptr<llvm::Module>
lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context,
llvm::LLVMContext &llvmContext,
mlir::ModuleOp &module) {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
mlir::registerLLVMDialectTranslation(*module->getContext());
return mlir::translateModuleToLLVMIR(module, llvmContext);
}
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
llvm::Module &module) {
std::function<llvm::Error(llvm::Module *)> optPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
if (optPipeline(&module))
return mlir::failure();
else
return mlir::success();
}
mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module,
V0FHEContext &fheContext, bool verbose) {
if (lowerHLFHEToMidLFHE(context, module, verbose).failed() ||
lowerMidLFHEToLowLFHE(context, module, fheContext, true).failed() ||
lowerLowLFHEToStd(context, module).failed()) {
return mlir::failure();
} else {
return mlir::success();
}
}
} // namespace pipeline
} // namespace zamalang
} // namespace mlir

View File

@@ -7,7 +7,6 @@
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/CompilerTools.h"
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/ExecutionEngine/OptUtils.h>

View File

@@ -12,16 +12,32 @@
#include <mlir/Support/LogicalResult.h>
#include <mlir/Support/ToolUtilities.h>
#include "mlir/IR/BuiltinOps.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerTools.h"
#include "zamalang/Support/logging.h"
#include "zamalang/Support/Jit.h"
#include "zamalang/Support/KeySet.h"
#include "zamalang/Support/Pipeline.h"
#include "zamalang/Support/logging.h"
enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM };
enum Action {
ROUND_TRIP,
DUMP_MIDLFHE,
DUMP_LOWLFHE,
DUMP_STD,
DUMP_LLVM_DIALECT,
DUMP_LLVM_IR,
DUMP_OPTIMIZED_LLVM_IR,
JIT_INVOKE
};
namespace cmdline {
@@ -37,14 +53,53 @@ llvm::cl::opt<std::string> output("o",
llvm::cl::opt<bool> verbose("verbose", llvm::cl::desc("verbose logs"),
llvm::cl::init<bool>(false));
llvm::cl::list<std::string> passes(
"passes",
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore);
llvm::cl::opt<bool> parametrizeMidLFHE(
"parametrize-midlfhe",
llvm::cl::desc("Perform MidLFHE global parametrization pass"),
llvm::cl::init<bool>(true));
llvm::cl::opt<bool> roundTrip("round-trip",
llvm::cl::desc("Just parse and dump"),
llvm::cl::init(false));
static llvm::cl::opt<enum EntryDialect> entryDialect(
"e", "entry-dialect", llvm::cl::desc("Entry dialect"),
llvm::cl::init<enum EntryDialect>(EntryDialect::HLFHE),
llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required,
llvm::cl::values(
clEnumValN(EntryDialect::HLFHE, "hlfhe",
"Input module is composed of HLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::MIDLFHE, "midlfhe",
"Input module is composed of MidLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::LOWLFHE, "lowlfhe",
"Input module is composed of LowLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::STD, "std",
"Input module is composed of operations from std")),
llvm::cl::values(
clEnumValN(EntryDialect::LLVM, "llvm",
"Input module is composed of operations from llvm")));
static llvm::cl::opt<enum Action> action(
"a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired,
llvm::cl::NumOccurrencesFlag::Required,
llvm::cl::values(
clEnumValN(Action::ROUND_TRIP, "roundtrip",
"Parse input module and regenerate textual representation")),
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
"Lower to MidLFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe",
"Lower to LowLFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_STD, "dump-std",
"Lower to std and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LLVM_DIALECT, "dump-llvm-dialect",
"Lower to LLVM dialect and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LLVM_IR, "dump-llvm-ir",
"Lower to LLVM-IR and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_OPTIMIZED_LLVM_IR,
"dump-optimized-llvm-ir",
"Lower to LLVM-IR, optimize and dump result")),
llvm::cl::values(clEnumValN(Action::JIT_INVOKE, "jit-invoke",
"Lower and JIT-compile input module and invoke "
"function specified with --jit-funcname")));
llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
@@ -58,15 +113,7 @@ llvm::cl::opt<bool> splitInputFile(
"chunk independently"),
llvm::cl::init(false));
llvm::cl::opt<bool> generateKeySet(
"generate-keyset",
llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"),
llvm::cl::init<bool>(false));
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
llvm::cl::init<bool>(false));
llvm::cl::opt<std::string> jitFuncname(
llvm::cl::opt<std::string> jitFuncName(
"jit-funcname",
llvm::cl::desc("Name of the function to execute, default 'main'"),
llvm::cl::init<std::string>("main"));
@@ -75,73 +122,16 @@ llvm::cl::list<uint64_t>
jitArgs("jit-args",
llvm::cl::desc("Value of arguments to pass to the main func"),
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore);
llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
llvm::cl::init<bool>(false));
}; // namespace cmdline
std::function<llvm::Error(llvm::Module *)> defaultOptPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
llvm::LLVMContext context;
auto llvmModule = mlir::zamalang::CompilerTools::toLLVMModule(
context, module, defaultOptPipeline);
if (!llvmModule) {
return mlir::failure();
}
os << **llvmModule;
return mlir::success();
}
std::unique_ptr<mlir::zamalang::KeySet>
generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext,
const std::string &jitFuncName) {
std::unique_ptr<mlir::zamalang::KeySet> keySet;
// Process a single source buffer
//
// If `verifyDiagnostics` is `true`, the procedure only checks if the
// diagnostic messages provided in the source buffer using
// `expected-error` are produced.
//
// If `verifyDiagnostics` is `false`, the procedure checks if the
// parsed module is valid and if all requested transformations
// succeeded.
mlir::LogicalResult
processInputBuffer(mlir::MLIRContext &context,
std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::raw_ostream &os, bool verifyDiagnostics) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
&context);
auto module = mlir::parseSourceFile(sourceMgr, &context);
if (verifyDiagnostics)
return sourceMgrHandler.verify();
if (!module)
return mlir::failure();
if (cmdline::roundTrip) {
module->print(os);
return mlir::success();
}
auto enablePass = [](std::string passName) {
return cmdline::passes.size() == 0 ||
std::any_of(cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return passName == p; });
};
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
mlir::zamalang::CompilerTools::LowerOptions lowerOptions;
lowerOptions.enablePass = enablePass;
lowerOptions.verbose = cmdline::verbose;
mlir::zamalang::V0FHEContext fheContext;
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
context, *module, fheContext, lowerOptions)
.failed()) {
return mlir::failure();
}
mlir::zamalang::log_verbose()
<< "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2
<< ", p:" << fheContext.constraint.p << "}\n";
@@ -155,45 +145,196 @@ processInputBuffer(mlir::MLIRContext &context,
<< ", ksLevel: " << fheContext.parameter.ksLevel
<< ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n";
// Generate the keySet
std::unique_ptr<mlir::zamalang::KeySet> keySet;
if (cmdline::generateKeySet || cmdline::runJit) {
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheContext, cmdline::jitFuncname, *module);
if (auto err = clientParameter.takeError()) {
mlir::zamalang::log_error()
<< "cannot generate client parameters: " << err << "\n";
return mlir::failure();
}
mlir::zamalang::log_verbose() << "### Generate the key set\n";
auto maybeKeySet =
mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
0); // TODO: seed
if (auto err = maybeKeySet.takeError()) {
llvm::errs() << err;
return mlir::failure();
}
keySet = std::move(maybeKeySet.get());
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheContext, jitFuncName, module);
if (auto err = clientParameter.takeError()) {
mlir::zamalang::log_error()
<< "cannot generate client parameters: " << err << "\n";
return nullptr;
}
// Lower to MLIR LLVM Dialect
if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
context, *module, lowerOptions)
.failed()) {
mlir::zamalang::log_verbose() << "### Generate the key set\n";
auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
0); // TODO: seed
if (auto err = maybeKeySet.takeError()) {
llvm::errs() << err;
return nullptr;
}
return std::move(maybeKeySet.get());
}
// Process a single source buffer
//
// The parameter `entryDialect` must specify the FHE dialect to which
// belong all FHE operations used in the source buffer. The input
// program must only contain FHE operations from that single FHE
// dialect, otherwise processing might fail.
//
// The parameter `action` specifies how the buffer should be processed
// and thus defines the output.
//
// If the specified action involves JIT compilation, `jitFuncName`
// designates the function to JIT compile. This function is invoked
// using the parameters given in `jitArgs`.
//
// The parameter `parametrizeMidLFHE` defines, whether the
// parametrization pass for MidLFHE is executed. If the pair of
// `entryDialect` and `action` does not involve any MidlFHE
// manipulation, this parameter does not have any effect.
//
// If `verifyDiagnostics` is `true`, the procedure only checks if the
// diagnostic messages provided in the source buffer using
// `expected-error` are produced. If `verifyDiagnostics` is `false`,
// the procedure checks if the parsed module is valid and if all
// requested transformations succeeded.
//
// If `verbose` is true, debug messages are displayed throughout the
// compilation process.
//
// Compilation output is written to the stream specified by `os`.
mlir::LogicalResult processInputBuffer(
mlir::MLIRContext &context, std::unique_ptr<llvm::MemoryBuffer> buffer,
enum EntryDialect entryDialect, enum Action action,
const std::string &jitFuncName, llvm::ArrayRef<uint64_t> jitArgs,
bool parametrizeMidlHFE, bool verifyDiagnostics, bool verbose,
llvm::raw_ostream &os) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
&context);
mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context);
// This is temporary until we have the high-level verification pass
// determining these parameters automatically
mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10,
.p = 7};
std::unique_ptr<mlir::zamalang::KeySet> keySet = nullptr;
const mlir::zamalang::V0Parameter *parameter =
getV0Parameter(defaultGlobalFHECircuitConstraint);
if (!parameter) {
mlir::zamalang::log_error()
<< "Could not determine V0 parameters for 2-norm of "
<< defaultGlobalFHECircuitConstraint.norm2 << " and p of "
<< defaultGlobalFHECircuitConstraint.p << "\n";
return mlir::failure();
}
if (cmdline::runJit) {
mlir::zamalang::log_verbose() << "### JIT compile & running\n";
return mlir::zamalang::runJit(module.get(), cmdline::jitFuncname,
cmdline::jitArgs, *keySet,
defaultOptPipeline, os);
mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint,
*parameter};
if (verbose)
context.disableMultithreading();
if (verifyDiagnostics)
return sourceMgrHandler.verify();
if (!moduleRef)
return mlir::failure();
mlir::ModuleOp module = moduleRef.get();
if (action == Action::ROUND_TRIP) {
module->print(os);
return mlir::success();
}
if (cmdline::toLLVM) {
return dumpLLVMIR(module.get(), os);
// Lowering pipeline. Each stage is represented as a label in the
// switch statement, from the most abstract dialect to the lowest
// level. Every labels acts as an entry point into the pipeline with
// a fallthrough mechanism to the next stage. Actions act as exit
// points from the pipeline.
switch (entryDialect) {
case EntryDialect::HLFHE:
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::MIDLFHE:
if (action == Action::DUMP_MIDLFHE) {
module.print(os);
return mlir::success();
}
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
context, module, fheContext, parametrizeMidlHFE)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::LOWLFHE:
if (action == Action::DUMP_LOWLFHE) {
module.print(os);
return mlir::success();
}
if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed())
return mlir::failure();
// fallthrough
case EntryDialect::STD:
if (action == Action::DUMP_STD) {
module.print(os);
return mlir::success();
} else if (action == Action::JIT_INVOKE) {
keySet = generateKeySet(module, fheContext, jitFuncName);
}
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module,
verbose)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::LLVM: {
if (action == Action::DUMP_LLVM_DIALECT) {
module.print(os);
return mlir::success();
} else if (action == Action::JIT_INVOKE) {
return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet,
defaultOptPipeline, os);
}
llvm::LLVMContext llvmContext;
std::unique_ptr<llvm::Module> llvmModule =
mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext,
module);
if (!llvmModule) {
mlir::zamalang::log_error()
<< "Failed to translate LLVM dialect to LLVM IR\n";
return mlir::failure();
}
if (action == Action::DUMP_LLVM_IR) {
llvmModule->dump();
return mlir::success();
}
if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule)
.failed()) {
mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n";
return mlir::failure();
}
if (action == Action::DUMP_OPTIMIZED_LLVM_IR) {
llvmModule->dump();
return mlir::success();
}
break;
}
module->print(os);
}
return mlir::success();
}
@@ -209,6 +350,16 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
// String for error messages from library functions
std::string errorMessage;
if (cmdline::action == Action::JIT_INVOKE &&
cmdline::entryDialect != EntryDialect::HLFHE &&
cmdline::entryDialect != EntryDialect::MIDLFHE &&
cmdline::entryDialect != EntryDialect::LOWLFHE &&
cmdline::entryDialect != EntryDialect::STD) {
mlir::zamalang::log_error()
<< "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs";
return mlir::failure();
}
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
@@ -229,7 +380,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
return mlir::failure();
}
// Iterate over all inpiut files specified on the command line
// Iterate over all input files specified on the command line
for (const auto &fileName : cmdline::inputs) {
auto file = mlir::openInputFile(fileName, &errorMessage);
@@ -247,14 +398,19 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
std::move(file),
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(context, std::move(inputBuffer), os,
cmdline::verifyDiagnostics);
return processInputBuffer(
context, std::move(inputBuffer), cmdline::entryDialect,
cmdline::action, cmdline::jitFuncName, cmdline::jitArgs,
cmdline::parametrizeMidLFHE, cmdline::verifyDiagnostics,
cmdline::verbose, os);
},
output->os())))
return mlir::failure();
} else {
return processInputBuffer(context, std::move(file), output->os(),
cmdline::verifyDiagnostics);
return processInputBuffer(
context, std::move(file), cmdline::entryDialect, cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE,
cmdline::verifyDiagnostics, cmdline::verbose, output->os());
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
@@ -7,4 +7,4 @@ func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{_,_,_}{2}>
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.eint<2> {
@@ -7,4 +7,4 @@ func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.e
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<4xi2>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
@@ -8,4 +8,4 @@ func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
%1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
}

View File

@@ -1,19 +1,18 @@
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
//CHECK: #map0 = affine_map<(d0) -> (d0)>
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
//CHECK-NEXT: module {
//CHECK-NEXT: func @linalg_generic(%arg0: tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>, %arg2: tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: %0 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%arg2 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
//CHECK-NEXT: %1 = "MidLFHE.mul_glwe_int"(%arg3, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %2 = "MidLFHE.add_glwe"(%1, %arg5) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %2 : !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: return
//CHECK-NEXT: }
//CHECK-NEXT: }
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK: #map0 = affine_map<(d0) -> (d0)>
// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
// CHECK-NEXT: module {
// CHECK-NEXT: func @linalg_generic(%arg0: tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>, %arg2: tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %0 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%arg2 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
// CHECK-NEXT: %1 = "MidLFHE.mul_glwe_int"(%arg3, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %2 = "MidLFHE.add_glwe"(%1, %arg5) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: linalg.yield %2 : !MidLFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: } -> tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK-NEXT: }
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (0)>

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
@@ -9,4 +9,4 @@ func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%0 = constant 1 : i8
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
@@ -27,4 +27,4 @@ func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe
// CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, k = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
return %1: !LowLFHE.lwe_ciphertext<1024,4>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
@@ -31,4 +31,4 @@ func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext {
// CHECK-NEXT: return %[[V1]] : !LowLFHE.glwe_ciphertext
%1 = "LowLFHE.glwe_from_table"(%arg0) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !LowLFHE.glwe_ciphertext
return %1: !LowLFHE.glwe_ciphertext
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
@@ -26,4 +26,4 @@ func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciph
// CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
return %1: !LowLFHE.lwe_ciphertext<1024,4>
}
}

View File

@@ -1,7 +0,0 @@
// RUN: zamacompiler --passes lowlfhe-unparametrize %s 2>&1| FileCheck %s
// CHECK-LABEL: func @main(%arg0: !LowLFHE.lwe_ciphertext<_,_>) -> !LowLFHE.lwe_ciphertext<_,_>
func @main(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> {
// CHECK-NEXT: return %arg0 : !LowLFHE.lwe_ciphertext<_,_>
return %arg0: !LowLFHE.lwe_ciphertext<1024,4>
}

View File

@@ -1,8 +0,0 @@
// RUN: zamacompiler --passes lowlfhe-unparametrize %s 2>&1| FileCheck %s
// CHECK-LABEL: func @main(%arg0: !LowLFHE.lwe_ciphertext<_,_>) -> !LowLFHE.lwe_ciphertext<_,_>
func @main(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<_,_> {
// CHECK-NEXT: return %arg0 : !LowLFHE.lwe_ciphertext<_,_>
%0 = unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_>
return %0: !LowLFHE.lwe_ciphertext<_,_>
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
@@ -19,4 +19,4 @@ func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE.
// CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "MidLFHE.add_glwe_int"(%arg0, %arg1): (!MidLFHE.glwe<{1024,1,64}{4}>, i5) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> {
@@ -8,4 +8,4 @@ func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16x
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4>
func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> {
@@ -10,4 +10,4 @@ func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.g
%tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi4>
%1 = "MidLFHE.apply_lookup_table"(%arg0, %tlu){k=1:i32, polynomialSize=2048:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{2048,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{2048,1,64}{4}>)
return %1: !MidLFHE.glwe<{2048,1,64}{4}>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
@@ -19,4 +19,4 @@ func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE.
// CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "MidLFHE.mul_glwe_int"(%arg0, %arg1): (!MidLFHE.glwe<{1024,1,64}{4}>, i5) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
@@ -20,4 +20,4 @@ func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE.
// CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "MidLFHE.sub_int_glwe"(%arg1, %arg0): (i5, !MidLFHE.glwe<{1024,1,64}{4}>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s
// Incompatible shapes
func @dot_incompatible_shapes(
@@ -66,4 +66,4 @@ func @dot_incompatible_int(
(tensor<4x!HLFHE.eint<2>>, tensor<4xi4>) -> !HLFHE.eint<2>
return %ret : !HLFHE.eint<2>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: eint support only precision in ]0;7]
func @test(%arg0: !HLFHE.eint<8>) {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: eint support only precision in ]0;7]
func @test(%arg0: !HLFHE.eint<0>) {

View File

@@ -1,7 +1,7 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<3>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
}

View File

@@ -1,7 +1,7 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<2>) -> (!HLFHE.eint<3>)
return %1: !HLFHE.eint<3>
}
}

View File

@@ -1,8 +1,8 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
%0 = constant 1 : i4
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i4) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
}

View File

@@ -1,8 +1,8 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
%0 = constant 1 : i2
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i2) -> (!HLFHE.eint<3>)
return %1: !HLFHE.eint<3>
}
}

View File

@@ -1,7 +1,7 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument.
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> {
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<8xi3>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
}

View File

@@ -1,8 +1,8 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
%0 = constant 1 : i4
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i4) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
}

View File

@@ -1,8 +1,8 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
%0 = constant 1 : i2
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i2) -> (!HLFHE.eint<3>)
return %1: !HLFHE.eint<3>
}
}

View File

@@ -1,8 +1,8 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1
func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
%0 = constant 1 : i4
%1 = "HLFHE.sub_int_eint"(%0, %arg0): (i4, !HLFHE.eint<2>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
}

View File

@@ -1,8 +1,8 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of encrypted inputs and result equals
func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
%0 = constant 1 : i2
%1 = "HLFHE.sub_int_eint"(%0, %arg0): (i2, !HLFHE.eint<2>) -> (!HLFHE.eint<3>)
return %1: !HLFHE.eint<3>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @zero() -> !HLFHE.eint<2>
func @zero() -> !HLFHE.eint<2> {

View File

@@ -1,22 +1,22 @@
// RUN: zamacompiler %s --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1 | FileCheck %s
//CHECK: #map0 = affine_map<(d0) -> (d0)>
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
//CHECK-NEXT: module {
//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2> {
//CHECK-NEXT: %0 = "HLFHE.zero"() : () -> !HLFHE.eint<2>
//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<2>>
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) outs(%1 : tensor<1x!HLFHE.eint<2>>) {
//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i3, %arg4: !HLFHE.eint<2>): // no predecessors
//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
//CHECK-NEXT: %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
//CHECK-NEXT: linalg.yield %5 : !HLFHE.eint<2>
//CHECK-NEXT: } -> tensor<1x!HLFHE.eint<2>>
//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>) -> !MidLFHE.glwe<{_,_,_}{2}> {
//CHECK-NEXT: %0 = "MidLFHE.zero"() : () -> !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%1 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg2: !MidLFHE.glwe<{_,_,_}{2}>, %arg3: i3, %arg4: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
//CHECK-NEXT: %4 = "MidLFHE.mul_glwe_int"(%arg2, %arg3) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %5 = "MidLFHE.add_glwe"(%4, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %5 : !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %c0 = constant 0 : index
//CHECK-NEXT: %3 = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<2>>
//CHECK-NEXT: return %3 : !HLFHE.eint<2>
//CHECK-NEXT: %3 = tensor.extract %2[%c0] : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: return %3 : !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>,
%arg1: tensor<2xi3>) -> !HLFHE.eint<2>
{

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>
func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) {

View File

@@ -1,5 +1,4 @@
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> {

View File

@@ -1,5 +1,4 @@
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen
func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// GLWE p parameter result
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// GLWE p parameter
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
@@ -9,4 +9,4 @@ func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024
%0 = constant 1 : i8
%1 = "MidLFHE.add_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,12,64}{7}>, i8) -> (!MidLFHE.glwe<{1024,12,64}{7}>)
return %1: !MidLFHE.glwe<{1024,12,64}{7}>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// Bad dimension of the lookup table
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {
@@ -7,4 +7,4 @@ func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<12
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {k = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32, outputSizeKS = 600 : i32} : (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<128xi2>) -> (!MidLFHE.glwe<{512,10,64}{2}>)
return %1: !MidLFHE.glwe<{512,10,64}{2}>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// GLWE p parameter
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
@@ -9,4 +9,4 @@ func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024
%0 = constant 1 : i8
%1 = "MidLFHE.mul_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,12,64}{7}>, i8) -> (!MidLFHE.glwe<{1024,12,64}{7}>)
return %1: !MidLFHE.glwe<{1024,12,64}{7}>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// GLWE p parameter
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
@@ -9,4 +9,4 @@ func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024
%0 = constant 1 : i8
%1 = "MidLFHE.sub_int_glwe"(%0, %arg0): (i8, !MidLFHE.glwe<{1024,12,64}{7}>) -> (!MidLFHE.glwe<{1024,12,64}{7}>)
return %1: !MidLFHE.glwe<{1024,12,64}{7}>
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --round-trip 2>&1| FileCheck %s
// RUN: zamacompiler %s --entry-dialect=midlfhe --action=roundtrip 2>&1| FileCheck %s
// CHECK-LABEL: func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
@@ -10,4 +10,4 @@ func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64
func @glwe_1(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> {
// CHECK-LABEL: return %arg0 : !MidLFHE.glwe<{_,_,_}{7}>
return %arg0: !MidLFHE.glwe<{_,_,_}{7}>
}
}