mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(compiler): Add passes to lower mlir to mlir llvm ir and run jit and emit llvm code (#63)
This commit is contained in:
committed by
Ayoub Benaissa
parent
c58abe6565
commit
b4e57984b1
@@ -5,18 +5,19 @@
|
||||
#include <llvm/Support/ToolOutputFile.h>
|
||||
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Parser.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/FileUtilities.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Support/ToolUtilities.h>
|
||||
|
||||
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
|
||||
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/CompilerTools.h"
|
||||
|
||||
namespace cmdline {
|
||||
|
||||
@@ -49,24 +50,58 @@ llvm::cl::opt<bool> splitInputFile(
|
||||
llvm::cl::desc("Split the input file into pieces and process each "
|
||||
"chunk independently"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
|
||||
llvm::cl::init<bool>(false));
|
||||
llvm::cl::list<int>
|
||||
jitArgs("jit-args",
|
||||
llvm::cl::desc("Value of arguments to pass to the main func"),
|
||||
llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore);
|
||||
|
||||
llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
|
||||
llvm::cl::init<bool>(false));
|
||||
}; // namespace cmdline
|
||||
|
||||
void addPassCmdLineFiltered(mlir::PassManager &pm,
|
||||
std::unique_ptr<mlir::Pass> pass) {
|
||||
if (cmdline::roundTrip)
|
||||
return;
|
||||
auto passName = pass->getName();
|
||||
if (cmdline::passes.size() == 0 ||
|
||||
std::any_of(
|
||||
cmdline::passes.begin(), cmdline::passes.end(),
|
||||
[&](const std::string &p) { return pass->getArgument() == p; })) {
|
||||
if (*pass->getOpName() == "module") {
|
||||
pm.addPass(std::move(pass));
|
||||
} else {
|
||||
pm.nest(*pass->getOpName()).addPass(std::move(pass));
|
||||
}
|
||||
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
llvm::LLVMContext context;
|
||||
auto llvmModule = mlir::zamalang::CompilerTools::toLLVMModule(
|
||||
context, module, defaultOptPipeline);
|
||||
if (!llvmModule) {
|
||||
return mlir::failure();
|
||||
}
|
||||
return;
|
||||
os << **llvmModule;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult runJit(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
// Create the JIT lambda
|
||||
auto maybeLambda =
|
||||
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
|
||||
if (!maybeLambda) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto lambda = maybeLambda.get().get();
|
||||
|
||||
// Create buffer to copy argument
|
||||
std::vector<int64_t> dummy(cmdline::jitArgs.size());
|
||||
llvm::SmallVector<void *> llvmArgs;
|
||||
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
|
||||
dummy[i] = cmdline::jitArgs[i];
|
||||
llvmArgs.push_back(&dummy[i]);
|
||||
}
|
||||
// Add the result pointer
|
||||
uint64_t res = 0;
|
||||
llvmArgs.push_back(&res);
|
||||
|
||||
// Invoke the lambda
|
||||
if (lambda->invokeRaw(llvmArgs)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
std::cerr << res << "\n";
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Process a single source buffer
|
||||
@@ -82,14 +117,11 @@ mlir::LogicalResult
|
||||
processInputBuffer(mlir::MLIRContext &context,
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::raw_ostream &os, bool verifyDiagnostics) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
|
||||
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
|
||||
&context);
|
||||
|
||||
auto module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
|
||||
if (verifyDiagnostics)
|
||||
@@ -98,17 +130,30 @@ processInputBuffer(mlir::MLIRContext &context,
|
||||
if (!module)
|
||||
return mlir::failure();
|
||||
|
||||
addPassCmdLineFiltered(pm,
|
||||
mlir::zamalang::createConvertHLFHETensorOpsToLinalg());
|
||||
addPassCmdLineFiltered(pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass());
|
||||
if (cmdline::roundTrip) {
|
||||
module->print(os);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
if (pm.run(*module).failed()) {
|
||||
llvm::errs() << "Could not run passes!\n";
|
||||
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirLLVMDialect(
|
||||
context, *module,
|
||||
[](std::string passName) {
|
||||
return cmdline::passes.size() == 0 ||
|
||||
std::any_of(
|
||||
cmdline::passes.begin(), cmdline::passes.end(),
|
||||
[&](const std::string &p) { return passName == p; });
|
||||
})
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (cmdline::runJit) {
|
||||
return runJit(module.get(), os);
|
||||
}
|
||||
if (cmdline::toLLVM) {
|
||||
return dumpLLVMIR(module.get(), os);
|
||||
}
|
||||
module->print(os);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user