Files
concrete/compiler/lib/Support/Jit.cpp
Andi Drebes a7051c2c9c enhance(client/server): Add support for scalar results
This patch adds support for scalar results to the client/server
protocol and tests. In addition to `TensorData`, a new type
`ScalarData` is added. Previous representations of scalar values using
one-dimensional `TensorData` instances have been replaced with proper
instantiations of `ScalarData`.

The generic use of `TensorData` for scalar and tensor values has been
replaced with uses of a new variant `ScalarOrTensorData`, which can
either hold an instance of `TensorData` or `ScalarData`.
2022-10-04 14:40:40 +02:00

193 lines
6.9 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "llvm/Support/Error.h"
#include <llvm/ADT/ArrayRef.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Support/TargetSelect.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include "concretelang/Common/BitsSize.h"
#include <concretelang/Runtime/DFRuntime.hpp>
#include <concretelang/Support/Error.h>
#include <concretelang/Support/Jit.h>
#include <concretelang/Support/logging.h>
namespace mlir {
namespace concretelang {
llvm::Expected<std::unique_ptr<JITLambda>>
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
llvm::Optional<std::string> runtimeLibPath) {
// Looking for the function
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::LLVM::LLVMFuncOp op) {
return op.getName() == name;
});
if (funcOp == rangeOps.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find the function to JIT", llvm::inconvertibleErrorCode());
}
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
mlir::registerLLVMDialectTranslation(*module->getContext());
// Create an MLIR execution engine. The execution engine eagerly
// JIT-compiles the module. If runtimeLibPath is specified, it's passed as a
// shared library to the JIT compiler.
std::vector<llvm::StringRef> sharedLibPaths;
if (runtimeLibPath.hasValue())
sharedLibPaths.push_back(runtimeLibPath.getValue());
mlir::ExecutionEngineOptions execOptions;
execOptions.transformer = optPipeline;
execOptions.sharedLibPaths = sharedLibPaths;
execOptions.jitCodeGenOptLevel = llvm::None;
execOptions.llvmModuleBuilder = nullptr;
auto maybeEngine = mlir::ExecutionEngine::create(module, execOptions);
if (!maybeEngine) {
return StreamStringError("failed to construct the MLIR ExecutionEngine");
}
auto &engine = maybeEngine.get();
auto lambda = std::make_unique<JITLambda>((*funcOp).getFunctionType(), name);
lambda->engine = std::move(engine);
return std::move(lambda);
}
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
auto found = std::find(args.begin(), args.end(), nullptr);
if (found == args.end()) {
return this->engine->invokePacked(this->name, args);
}
int pos = found - args.begin();
return StreamStringError("invoke: argument at pos ")
<< pos << " is null or missing";
}
// memref is a struct which is flattened aligned, allocated pointers, offset,
// and two array of rank size for sizes and strides.
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
return 3 + 2 * rank;
}
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
JITLambda::call(clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys) {
#ifndef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED
if (this->useDataflow) {
return StreamStringError(
"call: current runtime doesn't support dataflow execution, while "
"compilation used dataflow parallelization");
}
#else
dfr::_dfr_set_jit(true);
// When using JIT on distributed systems, the compiler only
// generates work-functions and their registration calls. No results
// are returned and no inputs are needed.
if (!dfr::_dfr_is_root_node()) {
std::vector<void *> rawArgs;
if (auto err = invokeRaw(rawArgs)) {
return std::move(err);
}
std::vector<clientlib::ScalarOrTensorData> buffers;
return clientlib::PublicResult::fromBuffers(args.clientParameters,
std::move(buffers));
}
#endif
// invokeRaw needs to have pointers on arguments and a pointers on the result
// as last argument.
// Prepare the outputs vector to store the output value of the lambda.
auto numOutputs = 0;
for (auto &output : args.clientParameters.outputs) {
auto shape = args.clientParameters.bufferShape(output);
if (shape.size() == 0) {
// scalar gate
numOutputs += 1;
} else {
// buffer gate
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
}
}
std::vector<uint64_t> outputs(numOutputs);
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
// inputs and outputs.
std::vector<void *> rawArgs(
args.preparedArgs.size() + 1 /*runtime context*/ + 1 /* outputs */
);
size_t i = 0;
// Pointers on inputs
for (auto &arg : args.preparedArgs) {
rawArgs[i++] = &arg;
}
RuntimeContext runtimeContext;
runtimeContext.evaluationKeys = evaluationKeys;
// Pointer on runtime context, the rawArgs take pointer on actual value that
// is passed to the compiled function.
auto rtCtxPtr = &runtimeContext;
rawArgs[i++] = &rtCtxPtr;
// Outputs
rawArgs[i++] = reinterpret_cast<void *>(outputs.data());
// Invoke
if (auto err = invokeRaw(rawArgs)) {
return std::move(err);
}
// Store the result to the PublicResult
std::vector<clientlib::ScalarOrTensorData> buffers;
{
size_t outputOffset = 0;
for (auto &output : args.clientParameters.outputs) {
auto shape = args.clientParameters.bufferShape(output);
if (shape.size() == 0) {
// scalar scalar
buffers.push_back(concretelang::clientlib::ScalarOrTensorData(
concretelang::clientlib::ScalarData(outputs[outputOffset++],
output.shape.sign,
output.shape.width)));
} else {
// buffer gate
auto rank = shape.size();
auto allocated = (uint64_t *)outputs[outputOffset++];
auto aligned = (uint64_t *)outputs[outputOffset++];
auto offset = (size_t)outputs[outputOffset++];
size_t *sizes = (size_t *)&outputs[outputOffset];
outputOffset += rank;
size_t *strides = (size_t *)&outputs[outputOffset];
outputOffset += rank;
size_t elementWidth = (output.isEncrypted())
? clientlib::EncryptedScalarElementWidth
: output.shape.width;
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
concretelang::clientlib::TensorData td =
clientlib::tensorDataFromMemRef(rank, elementWidth, sign, allocated,
aligned, offset, sizes, strides);
buffers.push_back(
concretelang::clientlib::ScalarOrTensorData(std::move(td)));
}
}
}
return clientlib::PublicResult::fromBuffers(args.clientParameters,
std::move(buffers));
}
} // namespace concretelang
} // namespace mlir