mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-14 23:38:10 -05:00
this was already implemented for JIT using mlir::ExecutionEngine, but was using a different, and more complex way for library compilation and execution, which was causing a bad calling convention at the assembly level in MacOS M1 machine. This commits unify the invocation of JIT and Library compiled circuit, solving the previously mentioned issue, but also gives the ability to extend compiled libraries to support more than one returned value
110 lines
3.9 KiB
C++
110 lines
3.9 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include "llvm/Support/Error.h"
|
|
#include <llvm/ADT/ArrayRef.h>
|
|
#include <llvm/ADT/SmallVector.h>
|
|
#include <llvm/ADT/StringRef.h>
|
|
#include <llvm/Support/TargetSelect.h>
|
|
|
|
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
|
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
|
|
|
#include "concretelang/Common/BitsSize.h"
|
|
#include "concretelang/Runtime/DFRuntime.hpp"
|
|
#include "concretelang/Support/Error.h"
|
|
#include "concretelang/Support/Jit.h"
|
|
#include "concretelang/Support/logging.h"
|
|
#include <concretelang/Support/Utils.h>
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
|
|
llvm::Expected<std::unique_ptr<JITLambda>>
|
|
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
|
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
|
|
std::optional<std::string> runtimeLibPath) {
|
|
|
|
// Looking for the function
|
|
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
|
|
auto funcOp = llvm::find_if(rangeOps, [&](mlir::LLVM::LLVMFuncOp op) {
|
|
return op.getName() == name;
|
|
});
|
|
if (funcOp == rangeOps.end()) {
|
|
return llvm::make_error<llvm::StringError>(
|
|
"cannot find the function to JIT", llvm::inconvertibleErrorCode());
|
|
}
|
|
|
|
llvm::InitializeNativeTarget();
|
|
llvm::InitializeNativeTargetAsmPrinter();
|
|
|
|
mlir::registerLLVMDialectTranslation(*module->getContext());
|
|
|
|
// Create an MLIR execution engine. The execution engine eagerly
|
|
// JIT-compiles the module. If runtimeLibPath is specified, it's passed as a
|
|
// shared library to the JIT compiler.
|
|
std::vector<llvm::StringRef> sharedLibPaths;
|
|
if (runtimeLibPath.has_value())
|
|
sharedLibPaths.push_back(runtimeLibPath.value());
|
|
|
|
mlir::ExecutionEngineOptions execOptions;
|
|
execOptions.transformer = optPipeline;
|
|
execOptions.sharedLibPaths = sharedLibPaths;
|
|
execOptions.jitCodeGenOptLevel = std::nullopt;
|
|
execOptions.llvmModuleBuilder = nullptr;
|
|
|
|
auto maybeEngine = mlir::ExecutionEngine::create(module, execOptions);
|
|
if (!maybeEngine) {
|
|
return StreamStringError("failed to construct the MLIR ExecutionEngine");
|
|
}
|
|
auto &engine = maybeEngine.get();
|
|
auto lambda = std::make_unique<JITLambda>((*funcOp).getFunctionType(), name);
|
|
lambda->engine = std::move(engine);
|
|
|
|
return std::move(lambda);
|
|
}
|
|
|
|
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
|
auto found = std::find(args.begin(), args.end(), nullptr);
|
|
if (found == args.end()) {
|
|
return this->engine->invokePacked(this->name, args);
|
|
}
|
|
int pos = found - args.begin();
|
|
return StreamStringError("invoke: argument at pos ")
|
|
<< pos << " is null or missing";
|
|
}
|
|
|
|
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
|
JITLambda::call(clientlib::PublicArguments &args,
|
|
clientlib::EvaluationKeys &evaluationKeys) {
|
|
#ifndef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED
|
|
if (this->useDataflow) {
|
|
return StreamStringError(
|
|
"call: current runtime doesn't support dataflow execution, while "
|
|
"compilation used dataflow parallelization");
|
|
}
|
|
#else
|
|
dfr::_dfr_set_jit(true);
|
|
// When using JIT on distributed systems, the compiler only
|
|
// generates work-functions and their registration calls. No results
|
|
// are returned and no inputs are needed.
|
|
if (!dfr::_dfr_is_root_node()) {
|
|
std::vector<void *> rawArgs;
|
|
if (auto err = invokeRaw(rawArgs)) {
|
|
return std::move(err);
|
|
}
|
|
std::vector<clientlib::ScalarOrTensorData> buffers;
|
|
return clientlib::PublicResult::fromBuffers(args.clientParameters,
|
|
std::move(buffers));
|
|
}
|
|
#endif
|
|
|
|
return ::concretelang::invokeRawOnLambda(this, args.clientParameters,
|
|
args.preparedArgs, evaluationKeys);
|
|
}
|
|
|
|
} // namespace concretelang
|
|
} // namespace mlir
|