mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(compiler): Add passes to lower mlir to mlir llvm ir and run jit and emit llvm code (#63)
This commit is contained in:
committed by
Ayoub Benaissa
parent
c58abe6565
commit
b4e57984b1
@@ -0,0 +1,16 @@
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_MLIRLOWERABLEDIALECTSTOLLVM_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_MLIRLOWERABLEDIALECTSTOLLVM_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
template <typename T> class OperationPass;
|
||||
namespace zamalang {
|
||||
/// Create a pass to convert MLIR lowerable dialects to LLVM.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertMLIRLowerableDialectsToLLVMPass();
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -6,6 +6,9 @@
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
||||
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
|
||||
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
|
||||
#include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
|
||||
@@ -17,4 +17,11 @@ def HLFHEToMidLFHE : Pass<"hlfhe-to-midlfhe", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
|
||||
let constructor = "mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()";
|
||||
let dependentDialects = ["mlir::StandardOpsDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect"];
|
||||
let options = [];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
56
compiler/include/zamalang/Support/CompilerTools.h
Normal file
56
compiler/include/zamalang/Support/CompilerTools.h
Normal file
@@ -0,0 +1,56 @@
|
||||
#ifndef ZAMALANG_SUPPORT_COMPILERTOOLS_H_
|
||||
#define ZAMALANG_SUPPORT_COMPILERTOOLS_H_
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
class CompilerTools {
|
||||
public:
|
||||
/// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir
|
||||
/// LLVM dialect.
|
||||
static mlir::LogicalResult lowerHLFHEToMlirLLVMDialect(
|
||||
mlir::MLIRContext &context, mlir::Operation *module,
|
||||
llvm::function_ref<bool(std::string)> enablePass = [](std::string pass) {
|
||||
return true;
|
||||
});
|
||||
|
||||
static llvm::Expected<std::unique_ptr<llvm::Module>>
|
||||
toLLVMModule(llvm::LLVMContext &context, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
|
||||
};
|
||||
|
||||
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
|
||||
/// of the module.
|
||||
class JITLambda {
|
||||
public:
|
||||
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
|
||||
: type(type), name(name){};
|
||||
|
||||
/// create a JITLambda that point to the function name of the given module.
|
||||
static llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
|
||||
|
||||
/// invokeRaw execute the jit lambda with a lits of arguments, the last one is
|
||||
/// used to store the result of the computation.
|
||||
/// Example:
|
||||
/// uin64_t arg0 = 1;
|
||||
/// uin64_t res;
|
||||
/// llvm::SmallVector<void *> args{&arg1, &res};
|
||||
/// lambda.invokeRaw(args);
|
||||
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
|
||||
|
||||
private:
|
||||
mlir::LLVM::LLVMFunctionType type;
|
||||
llvm::StringRef name;
|
||||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,5 +1,6 @@
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Support)
|
||||
|
||||
# CAPI needed only for python bindings
|
||||
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
add_subdirectory(HLFHEToMidLFHE)
|
||||
add_subdirectory(HLFHETensorOpsToLinalg)
|
||||
add_subdirectory(HLFHETensorOpsToLinalg)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
|
||||
@@ -119,7 +119,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<mlir::Pass> createConvertHLFHETensorOpsToLinalg() {
|
||||
std::unique_ptr<mlir::FunctionPass> createConvertHLFHETensorOpsToLinalg() {
|
||||
return std::make_unique<HLFHETensorOpsToLinalg>();
|
||||
}
|
||||
} // namespace zamalang
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
add_mlir_dialect_library(MLIRLowerableDialectsToLLVM
|
||||
MLIRLowerableDialectsToLLVM.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
MLIRLLVMIR
|
||||
MLIRLLVMToLLVMIRTranslation
|
||||
MLIRMath)
|
||||
|
||||
target_link_libraries(MLIRLowerableDialectsToLLVM PUBLIC MLIRIR)
|
||||
@@ -0,0 +1,61 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
||||
namespace {
|
||||
struct MLIRLowerableDialectsToLLVMPass
|
||||
: public MLIRLowerableDialectsToLLVMBase<MLIRLowerableDialectsToLLVMPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
// Setup the conversion target. We reuse the LLVMConversionTarget that
|
||||
// legalize LLVM dialect.
|
||||
mlir::LLVMConversionTarget target(getContext());
|
||||
target.addLegalOp<mlir::ModuleOp>();
|
||||
|
||||
// Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and
|
||||
// add our types conversion to `llvm` compatible type.
|
||||
mlir::LLVMTypeConverter typeConverter(&getContext());
|
||||
|
||||
// Setup the set of the patterns rewriter. At this point we want to
|
||||
// convert the `scf` operations to `std` and `std` operations to `llvm`.
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
mlir::populateLoopToStdConversionPatterns(patterns);
|
||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// Apply a `FullConversion` to `llvm`.
|
||||
auto module = getOperation();
|
||||
if (mlir::applyFullConversion(module, target, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
/// Create a pass for lowering operations the remaining mlir dialects
|
||||
/// operations, to the LLVM dialect for codegen.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertMLIRLowerableDialectsToLLVMPass() {
|
||||
return std::make_unique<MLIRLowerableDialectsToLLVMPass>();
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
16
compiler/lib/Support/CMakeLists.txt
Normal file
16
compiler/lib/Support/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
add_mlir_library(ZamalangSupport
|
||||
CompilerTools.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/zamalang/Support
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
HLFHETensorOpsToLinalg
|
||||
HLFHEToMidLFHE
|
||||
MLIRLowerableDialectsToLLVM
|
||||
|
||||
MLIRExecutionEngine
|
||||
${LLVM_PTHREAD_LIB})
|
||||
115
compiler/lib/Support/CompilerTools.cpp
Normal file
115
compiler/lib/Support/CompilerTools.cpp
Normal file
@@ -0,0 +1,115 @@
|
||||
#include <llvm/Support/TargetSelect.h>
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
#include <mlir/Target/LLVMIR/Export.h>
|
||||
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Support/CompilerTools.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
void initLLVMNativeTarget() {
|
||||
// Initialize LLVM targets.
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
}
|
||||
|
||||
void addFilteredPassToPassManager(
|
||||
mlir::PassManager &pm, std::unique_ptr<mlir::Pass> pass,
|
||||
llvm::function_ref<bool(std::string)> enablePass) {
|
||||
if (!enablePass(pass->getArgument().str())) {
|
||||
return;
|
||||
}
|
||||
if (*pass->getOpName() == "module") {
|
||||
pm.addPass(std::move(pass));
|
||||
} else {
|
||||
pm.nest(*pass->getOpName()).addPass(std::move(pass));
|
||||
}
|
||||
};
|
||||
|
||||
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect(
|
||||
mlir::MLIRContext &context, mlir::Operation *module,
|
||||
llvm::function_ref<bool(std::string)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
// Add all passes to lower from HLFHE to LLVM Dialect
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
|
||||
enablePass);
|
||||
|
||||
// Run the passes
|
||||
if (pm.run(module).failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<llvm::Module>> CompilerTools::toLLVMModule(
|
||||
llvm::LLVMContext &context, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
|
||||
|
||||
initLLVMNativeTarget();
|
||||
mlir::registerLLVMDialectTranslation(*module->getContext());
|
||||
|
||||
auto llvmModule = mlir::translateModuleToLLVMIR(module, context);
|
||||
if (!llvmModule) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"failed to translate MLIR to LLVM IR", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
if (auto err = optPipeline(llvmModule.get())) {
|
||||
return llvm::make_error<llvm::StringError>("failed to optimize LLVM IR",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
return std::move(llvmModule);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
|
||||
|
||||
// Looking for the function
|
||||
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
|
||||
auto funcOp = llvm::find_if(rangeOps, [&](mlir::LLVM::LLVMFuncOp op) {
|
||||
return op.getName() == name;
|
||||
});
|
||||
if (funcOp == rangeOps.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find the function to JIT", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
initLLVMNativeTarget();
|
||||
mlir::registerLLVMDialectTranslation(*module->getContext());
|
||||
|
||||
// Create an MLIR execution engine. The execution engine eagerly
|
||||
// JIT-compiles the module.
|
||||
auto maybeEngine = mlir::ExecutionEngine::create(
|
||||
module, /*llvmModuleBuilder=*/nullptr, optPipeline);
|
||||
if (!maybeEngine) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"failed to construct the MLIR ExecutionEngine",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto &engine = maybeEngine.get();
|
||||
auto lambda = std::make_unique<JITLambda>((*funcOp).getType(), name);
|
||||
lambda->engine = std::move(engine);
|
||||
|
||||
return std::move(lambda);
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
||||
if (this->type.getNumParams() == args.size() - 1 /*For the result*/) {
|
||||
return this->engine->invokePacked(this->name, args);
|
||||
}
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"wrong number of argument when invoke the JIT lambda",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -1,14 +1,24 @@
|
||||
add_llvm_tool(zamacompiler main.cpp)
|
||||
target_compile_options(zamacompiler PRIVATE -fexceptions)
|
||||
llvm_update_compile_flags(zamacompiler)
|
||||
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
target_link_libraries(zamacompiler
|
||||
PRIVATE
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
|
||||
MLIRTransforms
|
||||
MidLFHEDialect
|
||||
HLFHEDialect
|
||||
|
||||
HLFHETensorOpsToLinalg
|
||||
HLFHEToMidLFHE
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRLLVMToLLVMIRTranslation
|
||||
|
||||
ZamalangSupport
|
||||
)
|
||||
|
||||
mlir_check_all_link_libraries(zamacompiler)
|
||||
|
||||
@@ -5,18 +5,19 @@
|
||||
#include <llvm/Support/ToolOutputFile.h>
|
||||
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Parser.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/FileUtilities.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Support/ToolUtilities.h>
|
||||
|
||||
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
|
||||
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/CompilerTools.h"
|
||||
|
||||
namespace cmdline {
|
||||
|
||||
@@ -49,24 +50,58 @@ llvm::cl::opt<bool> splitInputFile(
|
||||
llvm::cl::desc("Split the input file into pieces and process each "
|
||||
"chunk independently"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
|
||||
llvm::cl::init<bool>(false));
|
||||
llvm::cl::list<int>
|
||||
jitArgs("jit-args",
|
||||
llvm::cl::desc("Value of arguments to pass to the main func"),
|
||||
llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore);
|
||||
|
||||
llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
|
||||
llvm::cl::init<bool>(false));
|
||||
}; // namespace cmdline
|
||||
|
||||
void addPassCmdLineFiltered(mlir::PassManager &pm,
|
||||
std::unique_ptr<mlir::Pass> pass) {
|
||||
if (cmdline::roundTrip)
|
||||
return;
|
||||
auto passName = pass->getName();
|
||||
if (cmdline::passes.size() == 0 ||
|
||||
std::any_of(
|
||||
cmdline::passes.begin(), cmdline::passes.end(),
|
||||
[&](const std::string &p) { return pass->getArgument() == p; })) {
|
||||
if (*pass->getOpName() == "module") {
|
||||
pm.addPass(std::move(pass));
|
||||
} else {
|
||||
pm.nest(*pass->getOpName()).addPass(std::move(pass));
|
||||
}
|
||||
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
llvm::LLVMContext context;
|
||||
auto llvmModule = mlir::zamalang::CompilerTools::toLLVMModule(
|
||||
context, module, defaultOptPipeline);
|
||||
if (!llvmModule) {
|
||||
return mlir::failure();
|
||||
}
|
||||
return;
|
||||
os << **llvmModule;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult runJit(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
// Create the JIT lambda
|
||||
auto maybeLambda =
|
||||
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
|
||||
if (!maybeLambda) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto lambda = maybeLambda.get().get();
|
||||
|
||||
// Create buffer to copy argument
|
||||
std::vector<int64_t> dummy(cmdline::jitArgs.size());
|
||||
llvm::SmallVector<void *> llvmArgs;
|
||||
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
|
||||
dummy[i] = cmdline::jitArgs[i];
|
||||
llvmArgs.push_back(&dummy[i]);
|
||||
}
|
||||
// Add the result pointer
|
||||
uint64_t res = 0;
|
||||
llvmArgs.push_back(&res);
|
||||
|
||||
// Invoke the lambda
|
||||
if (lambda->invokeRaw(llvmArgs)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
std::cerr << res << "\n";
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Process a single source buffer
|
||||
@@ -82,14 +117,11 @@ mlir::LogicalResult
|
||||
processInputBuffer(mlir::MLIRContext &context,
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::raw_ostream &os, bool verifyDiagnostics) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
|
||||
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
|
||||
&context);
|
||||
|
||||
auto module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
|
||||
if (verifyDiagnostics)
|
||||
@@ -98,17 +130,30 @@ processInputBuffer(mlir::MLIRContext &context,
|
||||
if (!module)
|
||||
return mlir::failure();
|
||||
|
||||
addPassCmdLineFiltered(pm,
|
||||
mlir::zamalang::createConvertHLFHETensorOpsToLinalg());
|
||||
addPassCmdLineFiltered(pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass());
|
||||
if (cmdline::roundTrip) {
|
||||
module->print(os);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
if (pm.run(*module).failed()) {
|
||||
llvm::errs() << "Could not run passes!\n";
|
||||
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirLLVMDialect(
|
||||
context, *module,
|
||||
[](std::string passName) {
|
||||
return cmdline::passes.size() == 0 ||
|
||||
std::any_of(
|
||||
cmdline::passes.begin(), cmdline::passes.end(),
|
||||
[&](const std::string &p) { return passName == p; });
|
||||
})
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (cmdline::runJit) {
|
||||
return runJit(module.get(), os);
|
||||
}
|
||||
if (cmdline::toLLVM) {
|
||||
return dumpLLVMIR(module.get(), os);
|
||||
}
|
||||
module->print(os);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user