diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h new file mode 100644 index 000000000..c72d67f82 --- /dev/null +++ b/compiler/include/zamalang/Support/Jit.h @@ -0,0 +1,19 @@ +#ifndef COMPILER_JIT_H +#define COMPILER_JIT_H + +#include "zamalang/Support/CompilerTools.h" + +#include +#include + +namespace mlir { +namespace zamalang { +mlir::LogicalResult +runJit(mlir::ModuleOp module, llvm::StringRef func, + llvm::ArrayRef funcArgs, mlir::zamalang::KeySet &keySet, + std::function optPipeline, + llvm::raw_ostream &os); +} // namespace zamalang +} // namespace mlir + +#endif // COMPILER_JIT_H diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 55fc97e2c..61280fc1f 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_library(ZamalangSupport ClientParameters.cpp KeySet.cpp logging.cpp + Jit.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/zamalang/Support diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp new file mode 100644 index 000000000..8c95d6a1d --- /dev/null +++ b/compiler/lib/Support/Jit.cpp @@ -0,0 +1,58 @@ +#include +#include +#include + +#include +#include + +namespace mlir { +namespace zamalang { + +// JIT-compiles `module` invokes `func` with the arguments passed in +// `jitArguments` and `keySet` +mlir::LogicalResult +runJit(mlir::ModuleOp module, llvm::StringRef func, + llvm::ArrayRef funcArgs, mlir::zamalang::KeySet &keySet, + std::function optPipeline, + llvm::raw_ostream &os) { + // Create the JIT lambda + auto maybeLambda = + mlir::zamalang::JITLambda::create(func, module, optPipeline); + if (!maybeLambda) { + return mlir::failure(); + } + auto lambda = std::move(maybeLambda.get()); + + // Create the arguments of the JIT lambda + auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet); + if (auto err = maybeArguments.takeError()) { + ::mlir::zamalang::log_error() + << "Cannot create lambda arguments: " << err << "\n"; + return mlir::failure(); + } + + // Set the arguments + auto arguments = std::move(maybeArguments.get()); + for (size_t i = 0; i < funcArgs.size(); i++) { + if (auto err = arguments->setArg(i, funcArgs[i])) { + ::mlir::zamalang::log_error() + << "Cannot push argument " << i << ": " << err << "\n"; + return mlir::failure(); + } + } + // Invoke the lambda + if (auto err = lambda->invoke(*arguments)) { + ::mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n"; + return mlir::failure(); + } + uint64_t res = 0; + if (auto err = arguments->getResult(0, res)) { + ::mlir::zamalang::log_error() << "Cannot get result : " << err << "\n"; + return mlir::failure(); + } + llvm::errs() << res << "\n"; + return mlir::success(); +} + +} // namespace zamalang +} // namespace mlir diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index a6441a789..ec15e7b80 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -21,6 +21,7 @@ #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" #include "zamalang/Support/CompilerTools.h" #include "zamalang/Support/logging.h" +#include "zamalang/Support/Jit.h" namespace cmdline { @@ -79,7 +80,8 @@ llvm::cl::opt toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "), llvm::cl::init(false)); }; // namespace cmdline -auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); +std::function defaultOptPipeline = + mlir::makeOptimizingTransformer(3, 0, nullptr); mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) { llvm::LLVMContext context; @@ -92,48 +94,6 @@ mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) { return mlir::success(); } -mlir::LogicalResult runJit(mlir::ModuleOp module, - mlir::zamalang::KeySet &keySet, - llvm::raw_ostream &os) { - // Create the JIT lambda - auto maybeLambda = mlir::zamalang::JITLambda::create( - cmdline::jitFuncname, module, defaultOptPipeline); - if (!maybeLambda) { - return mlir::failure(); - } - auto lambda = std::move(maybeLambda.get()); - - // Create the arguments of the JIT lambda - auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet); - if (auto err = maybeArguments.takeError()) { - - mlir::zamalang::log_error() - << "Cannot create lambda arguments: " << err << "\n"; - return mlir::failure(); - } - // Set the arguments - auto arguments = std::move(maybeArguments.get()); - for (auto i = 0; i < cmdline::jitArgs.size(); i++) { - if (auto err = arguments->setArg(i, cmdline::jitArgs[i])) { - mlir::zamalang::log_error() - << "Cannot push argument " << i << ": " << err << "\n"; - return mlir::failure(); - } - } - // Invoke the lambda - if (auto err = lambda->invoke(*arguments)) { - mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n"; - return mlir::failure(); - } - uint64_t res = 0; - if (auto err = arguments->getResult(0, res)) { - mlir::zamalang::log_error() << "Cannot get result : " << err << "\n"; - return mlir::failure(); - } - llvm::errs() << res << "\n"; - return mlir::success(); -} - // Process a single source buffer // // If `verifyDiagnostics` is `true`, the procedure only checks if the @@ -226,7 +186,9 @@ processInputBuffer(mlir::MLIRContext &context, if (cmdline::runJit) { mlir::zamalang::log_verbose() << "### JIT compile & running\n"; - return runJit(module.get(), *keySet, os); + return mlir::zamalang::runJit(module.get(), cmdline::jitFuncname, + cmdline::jitArgs, *keySet, + defaultOptPipeline, os); } if (cmdline::toLLVM) { return dumpLLVMIR(module.get(), os);