mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 12:44:57 -05:00
refactor(compiler): Move JIT functionality to separate source file
This commit is contained in:
committed by
Quentin Bourgerie
parent
b9e2690823
commit
6a76177a47
19
compiler/include/zamalang/Support/Jit.h
Normal file
19
compiler/include/zamalang/Support/Jit.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#ifndef COMPILER_JIT_H
|
||||
#define COMPILER_JIT_H
|
||||
|
||||
#include "zamalang/Support/CompilerTools.h"
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
mlir::LogicalResult
|
||||
runJit(mlir::ModuleOp module, llvm::StringRef func,
|
||||
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
|
||||
std::function<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::raw_ostream &os);
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif // COMPILER_JIT_H
|
||||
@@ -6,6 +6,7 @@ add_mlir_library(ZamalangSupport
|
||||
ClientParameters.cpp
|
||||
KeySet.cpp
|
||||
logging.cpp
|
||||
Jit.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/zamalang/Support
|
||||
|
||||
58
compiler/lib/Support/Jit.cpp
Normal file
58
compiler/lib/Support/Jit.cpp
Normal file
@@ -0,0 +1,58 @@
|
||||
#include <llvm/ADT/ArrayRef.h>
|
||||
#include <llvm/ADT/SmallVector.h>
|
||||
#include <llvm/ADT/StringRef.h>
|
||||
|
||||
#include <zamalang/Support/Jit.h>
|
||||
#include <zamalang/Support/logging.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// JIT-compiles `module` invokes `func` with the arguments passed in
|
||||
// `jitArguments` and `keySet`
|
||||
mlir::LogicalResult
|
||||
runJit(mlir::ModuleOp module, llvm::StringRef func,
|
||||
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
|
||||
std::function<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::raw_ostream &os) {
|
||||
// Create the JIT lambda
|
||||
auto maybeLambda =
|
||||
mlir::zamalang::JITLambda::create(func, module, optPipeline);
|
||||
if (!maybeLambda) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto lambda = std::move(maybeLambda.get());
|
||||
|
||||
// Create the arguments of the JIT lambda
|
||||
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
|
||||
if (auto err = maybeArguments.takeError()) {
|
||||
::mlir::zamalang::log_error()
|
||||
<< "Cannot create lambda arguments: " << err << "\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// Set the arguments
|
||||
auto arguments = std::move(maybeArguments.get());
|
||||
for (size_t i = 0; i < funcArgs.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, funcArgs[i])) {
|
||||
::mlir::zamalang::log_error()
|
||||
<< "Cannot push argument " << i << ": " << err << "\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = lambda->invoke(*arguments)) {
|
||||
::mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
uint64_t res = 0;
|
||||
if (auto err = arguments->getResult(0, res)) {
|
||||
::mlir::zamalang::log_error() << "Cannot get result : " << err << "\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
llvm::errs() << res << "\n";
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/CompilerTools.h"
|
||||
#include "zamalang/Support/logging.h"
|
||||
#include "zamalang/Support/Jit.h"
|
||||
|
||||
namespace cmdline {
|
||||
|
||||
@@ -79,7 +80,8 @@ llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
|
||||
llvm::cl::init<bool>(false));
|
||||
}; // namespace cmdline
|
||||
|
||||
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
std::function<llvm::Error(llvm::Module *)> defaultOptPipeline =
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
llvm::LLVMContext context;
|
||||
@@ -92,48 +94,6 @@ mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult runJit(mlir::ModuleOp module,
|
||||
mlir::zamalang::KeySet &keySet,
|
||||
llvm::raw_ostream &os) {
|
||||
// Create the JIT lambda
|
||||
auto maybeLambda = mlir::zamalang::JITLambda::create(
|
||||
cmdline::jitFuncname, module, defaultOptPipeline);
|
||||
if (!maybeLambda) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto lambda = std::move(maybeLambda.get());
|
||||
|
||||
// Create the arguments of the JIT lambda
|
||||
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
|
||||
if (auto err = maybeArguments.takeError()) {
|
||||
|
||||
mlir::zamalang::log_error()
|
||||
<< "Cannot create lambda arguments: " << err << "\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
// Set the arguments
|
||||
auto arguments = std::move(maybeArguments.get());
|
||||
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, cmdline::jitArgs[i])) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Cannot push argument " << i << ": " << err << "\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = lambda->invoke(*arguments)) {
|
||||
mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
uint64_t res = 0;
|
||||
if (auto err = arguments->getResult(0, res)) {
|
||||
mlir::zamalang::log_error() << "Cannot get result : " << err << "\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
llvm::errs() << res << "\n";
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Process a single source buffer
|
||||
//
|
||||
// If `verifyDiagnostics` is `true`, the procedure only checks if the
|
||||
@@ -226,7 +186,9 @@ processInputBuffer(mlir::MLIRContext &context,
|
||||
|
||||
if (cmdline::runJit) {
|
||||
mlir::zamalang::log_verbose() << "### JIT compile & running\n";
|
||||
return runJit(module.get(), *keySet, os);
|
||||
return mlir::zamalang::runJit(module.get(), cmdline::jitFuncname,
|
||||
cmdline::jitArgs, *keySet,
|
||||
defaultOptPipeline, os);
|
||||
}
|
||||
if (cmdline::toLLVM) {
|
||||
return dumpLLVMIR(module.get(), os);
|
||||
|
||||
Reference in New Issue
Block a user