mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
refactor(compiler): Refactor CompilerEngine and related classes
This commit contains several incremental improvements towards a clear
interface for lambdas:
- Unification of static and JIT compilation by using the static
compilation path of `CompilerEngine` within a new subclass
`JitCompilerEngine`.
- Clear ownership for compilation artefacts through
`CompilationContext`, making it impossible to destroy objects used
directly or indirectly before destruction of their users.
- Clear interface for lambdas generated by the compiler through
`JitCompilerEngine::Lambda` with a templated call operator,
encapsulating otherwise manual orchestration of `CompilerEngine`,
`JITLambda`, and `CompilerEngine::Argument`.
- Improved error handling through `llvm::Expected<T>` and proper
error checking following the conventions for `llvm::Expected<T>`
and `llvm::Error`.
Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
#include <llvm/Support/CommandLine.h>
|
||||
@@ -22,15 +23,15 @@
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/Jit.h"
|
||||
#include "zamalang/Support/Error.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
#include "zamalang/Support/KeySet.h"
|
||||
#include "zamalang/Support/Pipeline.h"
|
||||
#include "zamalang/Support/logging.h"
|
||||
|
||||
enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM };
|
||||
|
||||
enum Action {
|
||||
ROUND_TRIP,
|
||||
DUMP_HLFHE,
|
||||
DUMP_HLFHE_MANP,
|
||||
DUMP_MIDLFHE,
|
||||
DUMP_LOWLFHE,
|
||||
@@ -80,26 +81,6 @@ llvm::cl::opt<bool> parametrizeMidLFHE(
|
||||
llvm::cl::desc("Perform MidLFHE global parametrization pass"),
|
||||
llvm::cl::init<bool>(true));
|
||||
|
||||
static llvm::cl::opt<enum EntryDialect> entryDialect(
|
||||
"e", "entry-dialect", llvm::cl::desc("Entry dialect"),
|
||||
llvm::cl::init<enum EntryDialect>(EntryDialect::HLFHE),
|
||||
llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required,
|
||||
llvm::cl::values(
|
||||
clEnumValN(EntryDialect::HLFHE, "hlfhe",
|
||||
"Input module is composed of HLFHE operations")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(EntryDialect::MIDLFHE, "midlfhe",
|
||||
"Input module is composed of MidLFHE operations")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(EntryDialect::LOWLFHE, "lowlfhe",
|
||||
"Input module is composed of LowLFHE operations")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(EntryDialect::STD, "std",
|
||||
"Input module is composed of operations from std")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(EntryDialect::LLVM, "llvm",
|
||||
"Input module is composed of operations from llvm")));
|
||||
|
||||
static llvm::cl::opt<enum Action> action(
|
||||
"a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired,
|
||||
llvm::cl::NumOccurrencesFlag::Required,
|
||||
@@ -109,6 +90,8 @@ static llvm::cl::opt<enum Action> action(
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp",
|
||||
"Dump HLFHE module after running the Minimal "
|
||||
"Arithmetic Noise Padding pass")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE, "dump-hlfhe",
|
||||
"Dump HLFHE module")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
|
||||
"Lower to MidLFHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe",
|
||||
@@ -158,50 +141,7 @@ llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser> assumeMaxMANP(
|
||||
llvm::cl::desc(
|
||||
"Assume a maximum for the Minimum Arithmetic Noise Padding"));
|
||||
|
||||
}; // namespace cmdline
|
||||
|
||||
std::function<llvm::Error(llvm::Module *)> defaultOptPipeline =
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
std::unique_ptr<mlir::zamalang::KeySet>
|
||||
generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext,
|
||||
const std::string &jitFuncName) {
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet;
|
||||
|
||||
mlir::zamalang::log_verbose()
|
||||
<< "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2
|
||||
<< ", p:" << fheContext.constraint.p << "}\n";
|
||||
mlir::zamalang::log_verbose()
|
||||
<< "### FHE parameters for the atomic pattern: {k: "
|
||||
<< fheContext.parameter.k
|
||||
<< ", polynomialSize: " << fheContext.parameter.polynomialSize
|
||||
<< ", nSmall: " << fheContext.parameter.nSmall
|
||||
<< ", brLevel: " << fheContext.parameter.brLevel
|
||||
<< ", brLogBase: " << fheContext.parameter.brLogBase
|
||||
<< ", ksLevel: " << fheContext.parameter.ksLevel
|
||||
<< ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n";
|
||||
|
||||
// Create the client parameters
|
||||
auto clientParameter = mlir::zamalang::createClientParametersForV0(
|
||||
fheContext, jitFuncName, module);
|
||||
|
||||
if (auto err = clientParameter.takeError()) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "cannot generate client parameters: " << err << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
mlir::zamalang::log_verbose() << "### Generate the key set\n";
|
||||
|
||||
auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
|
||||
0); // TODO: seed
|
||||
if (auto err = maybeKeySet.takeError()) {
|
||||
llvm::errs() << err;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return std::move(maybeKeySet.get());
|
||||
}
|
||||
} // namespace cmdline
|
||||
|
||||
llvm::Expected<mlir::zamalang::V0FHEContext> buildFHEContext(
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> autoFHEConstraints,
|
||||
@@ -209,65 +149,48 @@ llvm::Expected<mlir::zamalang::V0FHEContext> buildFHEContext(
|
||||
llvm::Optional<size_t> overrideMaxMANP) {
|
||||
if (!autoFHEConstraints.hasValue() &&
|
||||
(!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
return mlir::zamalang::StreamStringError(
|
||||
"Maximum encrypted integer precision and maximum for the Minimal"
|
||||
"Arithmetic Noise Passing are required, but were neither specified"
|
||||
"explicitly nor determined automatically",
|
||||
llvm::inconvertibleErrorCode());
|
||||
"explicitly nor determined automatically");
|
||||
}
|
||||
|
||||
mlir::zamalang::V0FHEConstraint fheConstraints{
|
||||
.norm2 = overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue()
|
||||
: autoFHEConstraints.getValue().norm2,
|
||||
.p = overrideMaxEintPrecision.hasValue()
|
||||
? overrideMaxEintPrecision.getValue()
|
||||
: autoFHEConstraints.getValue().p};
|
||||
overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue()
|
||||
: autoFHEConstraints.getValue().norm2,
|
||||
overrideMaxEintPrecision.hasValue() ? overrideMaxEintPrecision.getValue()
|
||||
: autoFHEConstraints.getValue().p};
|
||||
|
||||
const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints);
|
||||
|
||||
if (!parameter) {
|
||||
std::string buffer;
|
||||
llvm::raw_string_ostream strs(buffer);
|
||||
strs << "Could not determine V0 parameters for 2-norm of "
|
||||
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
|
||||
|
||||
return llvm::make_error<llvm::StringError>(strs.str(),
|
||||
llvm::inconvertibleErrorCode());
|
||||
return mlir::zamalang::StreamStringError()
|
||||
<< "Could not determine V0 parameters for 2-norm of "
|
||||
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
|
||||
}
|
||||
|
||||
return mlir::zamalang::V0FHEContext{fheConstraints, *parameter};
|
||||
}
|
||||
|
||||
mlir::LogicalResult buildAssignFHEContext(
|
||||
llvm::Optional<mlir::zamalang::V0FHEContext> &fheContext,
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> autoFHEConstraints,
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP) {
|
||||
namespace llvm {
|
||||
// This needs to be wrapped into the llvm namespace for proper
|
||||
// operator lookup
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||
const llvm::ArrayRef<uint64_t> arr) {
|
||||
os << "(";
|
||||
for (size_t i = 0; i < arr.size(); i++) {
|
||||
os << arr[i];
|
||||
|
||||
if (fheContext.hasValue())
|
||||
return mlir::success();
|
||||
|
||||
llvm::Expected<mlir::zamalang::V0FHEContext> fheContextOrErr =
|
||||
buildFHEContext(autoFHEConstraints, overrideMaxEintPrecision,
|
||||
overrideMaxMANP);
|
||||
|
||||
if (auto err = fheContextOrErr.takeError()) {
|
||||
mlir::zamalang::log_error() << err;
|
||||
return mlir::failure();
|
||||
if (i != arr.size() - 1)
|
||||
os << ", ";
|
||||
}
|
||||
|
||||
fheContext.emplace(fheContextOrErr.get());
|
||||
|
||||
return mlir::success();
|
||||
return os;
|
||||
}
|
||||
} // namespace llvm
|
||||
|
||||
// Process a single source buffer
|
||||
//
|
||||
// The parameter `entryDialect` must specify the FHE dialect to which
|
||||
// belong all FHE operations used in the source buffer. The input
|
||||
// program must only contain FHE operations from that single FHE
|
||||
// dialect, otherwise processing might fail.
|
||||
//
|
||||
// The parameter `action` specifies how the buffer should be processed
|
||||
// and thus defines the output.
|
||||
//
|
||||
@@ -276,15 +199,14 @@ mlir::LogicalResult buildAssignFHEContext(
|
||||
// using the parameters given in `jitArgs`.
|
||||
//
|
||||
// The parameter `parametrizeMidLFHE` defines, whether the
|
||||
// parametrization pass for MidLFHE is executed. If the pair of
|
||||
// `entryDialect` and `action` does not involve any MidlFHE
|
||||
// manipulation, this parameter does not have any effect.
|
||||
// parametrization pass for MidLFHE is executed. If the `action` does
|
||||
// not involve any MidlFHE manipulation, this parameter does not have
|
||||
// any effect.
|
||||
//
|
||||
// The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if
|
||||
// set, override the values for the maximum required precision of
|
||||
// encrypted integers and the maximum value for the Minimum Arithmetic
|
||||
// Noise Padding otherwise determined automatically if the entry
|
||||
// dialect is HLFHE..
|
||||
// Noise Padding otherwise determined automatically.
|
||||
//
|
||||
// If `verifyDiagnostics` is `true`, the procedure only checks if the
|
||||
// diagnostic messages provided in the source buffer using
|
||||
@@ -292,164 +214,103 @@ mlir::LogicalResult buildAssignFHEContext(
|
||||
// the procedure checks if the parsed module is valid and if all
|
||||
// requested transformations succeeded.
|
||||
//
|
||||
// If `verbose` is true, debug messages are displayed throughout the
|
||||
// compilation process.
|
||||
//
|
||||
// Compilation output is written to the stream specified by `os`.
|
||||
mlir::LogicalResult processInputBuffer(
|
||||
mlir::MLIRContext &context, std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
enum EntryDialect entryDialect, enum Action action,
|
||||
const std::string &jitFuncName, llvm::ArrayRef<uint64_t> jitArgs,
|
||||
bool parametrizeMidlHFE, llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
|
||||
bool verbose, llvm::raw_ostream &os) {
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
mlir::LogicalResult
|
||||
processInputBuffer(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
enum Action action, const std::string &jitFuncName,
|
||||
llvm::ArrayRef<uint64_t> jitArgs, bool parametrizeMidlHFE,
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP,
|
||||
bool verifyDiagnostics, llvm::raw_ostream &os) {
|
||||
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
|
||||
mlir::zamalang::CompilationContext::createShared();
|
||||
|
||||
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
|
||||
&context);
|
||||
mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context);
|
||||
mlir::zamalang::JitCompilerEngine ce{ccx};
|
||||
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraints;
|
||||
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
|
||||
ce.setVerifyDiagnostics(verifyDiagnostics);
|
||||
ce.setParametrizeMidLFHE(parametrizeMidlHFE);
|
||||
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet = nullptr;
|
||||
if (overrideMaxEintPrecision.hasValue())
|
||||
ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue());
|
||||
|
||||
if (verbose)
|
||||
context.disableMultithreading();
|
||||
if (overrideMaxMANP.hasValue())
|
||||
ce.setMaxMANP(overrideMaxMANP.getValue());
|
||||
|
||||
if (verifyDiagnostics)
|
||||
return sourceMgrHandler.verify();
|
||||
if (action == Action::JIT_INVOKE) {
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
ce.buildLambda(std::move(buffer), jitFuncName);
|
||||
|
||||
if (!moduleRef)
|
||||
return mlir::failure();
|
||||
|
||||
mlir::ModuleOp module = moduleRef.get();
|
||||
|
||||
if (action == Action::ROUND_TRIP) {
|
||||
module->print(os);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Lowering pipeline. Each stage is represented as a label in the
|
||||
// switch statement, from the most abstract dialect to the lowest
|
||||
// level. Every labels acts as an entry point into the pipeline with
|
||||
// a fallthrough mechanism to the next stage. Actions act as exit
|
||||
// points from the pipeline.
|
||||
switch (entryDialect) {
|
||||
case EntryDialect::HLFHE:
|
||||
if (action == Action::DUMP_HLFHE_MANP) {
|
||||
if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
} else {
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
fheConstraintsOrErr =
|
||||
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(context,
|
||||
module);
|
||||
if (auto err = fheConstraintsOrErr.takeError()) {
|
||||
mlir::zamalang::log_error() << err;
|
||||
return mlir::failure();
|
||||
} else {
|
||||
fheConstraints = fheConstraintsOrErr.get();
|
||||
}
|
||||
}
|
||||
|
||||
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
|
||||
.failed())
|
||||
return mlir::failure();
|
||||
|
||||
// fallthrough
|
||||
case EntryDialect::MIDLFHE:
|
||||
if (action == Action::DUMP_MIDLFHE) {
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
if (buildAssignFHEContext(fheContext, fheConstraints,
|
||||
overrideMaxEintPrecision, overrideMaxMANP)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
|
||||
context, module, fheContext.getValue(), parametrizeMidlHFE)
|
||||
.failed())
|
||||
return mlir::failure();
|
||||
|
||||
// fallthrough
|
||||
case EntryDialect::LOWLFHE:
|
||||
if (action == Action::DUMP_LOWLFHE) {
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed())
|
||||
return mlir::failure();
|
||||
|
||||
// fallthrough
|
||||
case EntryDialect::STD:
|
||||
if (action == Action::DUMP_STD) {
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
} else if (action == Action::JIT_INVOKE) {
|
||||
if (buildAssignFHEContext(fheContext, fheConstraints,
|
||||
overrideMaxEintPrecision, overrideMaxMANP)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
keySet = generateKeySet(module, fheContext.getValue(), jitFuncName);
|
||||
}
|
||||
|
||||
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module,
|
||||
verbose)
|
||||
.failed())
|
||||
return mlir::failure();
|
||||
|
||||
// fallthrough
|
||||
case EntryDialect::LLVM: {
|
||||
if (action == Action::DUMP_LLVM_DIALECT) {
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
} else if (action == Action::JIT_INVOKE) {
|
||||
return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet,
|
||||
defaultOptPipeline, os);
|
||||
}
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
std::unique_ptr<llvm::Module> llvmModule =
|
||||
mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext,
|
||||
module);
|
||||
|
||||
if (!llvmModule) {
|
||||
if (!lambdaOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Failed to translate LLVM dialect to LLVM IR\n";
|
||||
<< "Failed to JIT-compile " << jitFuncName << ": "
|
||||
<< llvm::toString(std::move(lambdaOrErr.takeError()));
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (action == Action::DUMP_LLVM_IR) {
|
||||
llvmModule->dump();
|
||||
return mlir::success();
|
||||
}
|
||||
llvm::Expected<uint64_t> resOrErr = (*lambdaOrErr)(jitArgs);
|
||||
|
||||
if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule)
|
||||
.failed()) {
|
||||
mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n";
|
||||
if (!resOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Failed to JIT-invoke " << jitFuncName << " with arguments "
|
||||
<< jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError()));
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (action == Action::DUMP_OPTIMIZED_LLVM_IR) {
|
||||
llvmModule->dump();
|
||||
return mlir::success();
|
||||
os << *resOrErr << "\n";
|
||||
} else {
|
||||
enum mlir::zamalang::CompilerEngine::Target target;
|
||||
|
||||
switch (action) {
|
||||
case Action::ROUND_TRIP:
|
||||
target = mlir::zamalang::CompilerEngine::Target::ROUND_TRIP;
|
||||
break;
|
||||
case Action::DUMP_HLFHE:
|
||||
target = mlir::zamalang::CompilerEngine::Target::HLFHE;
|
||||
break;
|
||||
case Action::DUMP_HLFHE_MANP:
|
||||
target = mlir::zamalang::CompilerEngine::Target::HLFHE_MANP;
|
||||
break;
|
||||
case Action::DUMP_MIDLFHE:
|
||||
target = mlir::zamalang::CompilerEngine::Target::MIDLFHE;
|
||||
break;
|
||||
case Action::DUMP_LOWLFHE:
|
||||
target = mlir::zamalang::CompilerEngine::Target::LOWLFHE;
|
||||
break;
|
||||
case Action::DUMP_STD:
|
||||
target = mlir::zamalang::CompilerEngine::Target::STD;
|
||||
break;
|
||||
case Action::DUMP_LLVM_DIALECT:
|
||||
target = mlir::zamalang::CompilerEngine::Target::LLVM;
|
||||
break;
|
||||
case Action::DUMP_LLVM_IR:
|
||||
target = mlir::zamalang::CompilerEngine::Target::LLVM_IR;
|
||||
break;
|
||||
case Action::DUMP_OPTIMIZED_LLVM_IR:
|
||||
target = mlir::zamalang::CompilerEngine::Target::OPTIMIZED_LLVM_IR;
|
||||
break;
|
||||
case JIT_INVOKE:
|
||||
// Case just here to satisfy the compiler; already handled above
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
llvm::Expected<mlir::zamalang::CompilerEngine::CompilationResult> retOrErr =
|
||||
ce.compile(std::move(buffer), target);
|
||||
|
||||
if (!retOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
<< llvm::toString(std::move(retOrErr.takeError())) << "\n";
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (verifyDiagnostics) {
|
||||
return mlir::success();
|
||||
} else if (action == Action::DUMP_LLVM_IR ||
|
||||
action == Action::DUMP_OPTIMIZED_LLVM_IR) {
|
||||
retOrErr->llvmModule->print(os, nullptr);
|
||||
} else {
|
||||
retOrErr->mlirModuleRef->get().print(os);
|
||||
}
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
@@ -459,44 +320,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
// Parse command line arguments
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
|
||||
// Initialize the MLIR context
|
||||
mlir::MLIRContext context;
|
||||
|
||||
mlir::zamalang::setupLogging(cmdline::verbose);
|
||||
|
||||
// String for error messages from library functions
|
||||
std::string errorMessage;
|
||||
|
||||
if (cmdline::action == Action::DUMP_HLFHE_MANP &&
|
||||
cmdline::entryDialect != EntryDialect::HLFHE) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Can only invoke Minimal Arithmetic Noise pass on HLFHE programs";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (cmdline::action == Action::JIT_INVOKE &&
|
||||
cmdline::entryDialect != EntryDialect::HLFHE &&
|
||||
cmdline::entryDialect != EntryDialect::MIDLFHE &&
|
||||
cmdline::entryDialect != EntryDialect::LOWLFHE &&
|
||||
cmdline::entryDialect != EntryDialect::STD) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// Load our Dialect in this MLIR Context.
|
||||
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
|
||||
context.getOrLoadDialect<mlir::tensor::TensorDialect>();
|
||||
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
||||
|
||||
if (cmdline::verifyDiagnostics)
|
||||
context.printOpOnDiagnostic(false);
|
||||
|
||||
auto output = mlir::openOutputFile(cmdline::output, &errorMessage);
|
||||
|
||||
if (!output) {
|
||||
@@ -523,20 +351,20 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
|
||||
llvm::raw_ostream &os) {
|
||||
return processInputBuffer(
|
||||
context, std::move(inputBuffer), cmdline::entryDialect,
|
||||
cmdline::action, cmdline::jitFuncName, cmdline::jitArgs,
|
||||
std::move(inputBuffer), cmdline::action,
|
||||
cmdline::jitFuncName, cmdline::jitArgs,
|
||||
cmdline::parametrizeMidLFHE,
|
||||
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
|
||||
cmdline::verifyDiagnostics, cmdline::verbose, os);
|
||||
cmdline::verifyDiagnostics, os);
|
||||
},
|
||||
output->os())))
|
||||
return mlir::failure();
|
||||
} else {
|
||||
return processInputBuffer(
|
||||
context, std::move(file), cmdline::entryDialect, cmdline::action,
|
||||
cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE,
|
||||
std::move(file), cmdline::action, cmdline::jitFuncName,
|
||||
cmdline::jitArgs, cmdline::parametrizeMidLFHE,
|
||||
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
|
||||
cmdline::verifyDiagnostics, cmdline::verbose, output->os());
|
||||
cmdline::verifyDiagnostics, output->os());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user