Files
concrete/compiler/include/zamalang/Support/JitCompilerEngine.h
youben11 99cce18c6a feat: get RT lib path from py and use as sharedlib
Try to find the runtime library automatically (should only work on
proper installation of the package), and fail silently by not passing
any RT lib. The RT lib can also be specified manually. The RT lib will
be used as a shared library by the JIT compiler.
2021-11-12 15:06:49 +01:00

387 lines
13 KiB
C++

#ifndef ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
#define ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <zamalang/Support/CompilerEngine.h>
#include <zamalang/Support/Error.h>
#include <zamalang/Support/Jit.h>
#include <zamalang/Support/LambdaArgument.h>
namespace mlir {
namespace zamalang {
namespace {
// Generic function template as well as specializations of
// `typedResult` must be declared at namespace scope due to return
// type template specialization
// Helper function for `JitCompilerEngine::Lambda::operator()`
// implementing type-dependent preparation of the result.
template <typename ResT>
llvm::Expected<ResT> typedResult(JITLambda::Argument &arguments);
// Specialization of `typedResult()` for scalar results, forwarding
// scalar value to caller
template <>
inline llvm::Expected<uint64_t> typedResult(JITLambda::Argument &arguments) {
uint64_t res = 0;
if (auto err = arguments.getResult(0, res))
return StreamStringError() << "Cannot retrieve result:" << err;
return res;
}
template <typename T>
inline llvm::Expected<std::vector<T>>
typedVectorResult(JITLambda::Argument &arguments) {
llvm::Expected<size_t> n = arguments.getResultVectorSize(0);
if (auto err = n.takeError())
return std::move(err);
std::vector<T> res(*n);
if (auto err = arguments.getResult(0, res.data(), res.size()))
return StreamStringError() << "Cannot retrieve result:" << err;
return std::move(res);
}
// Specializations of `typedResult()` for vector results, initializing
// an `std::vector` of the right size with the results and forwarding
// it to the caller with move semantics.
//
// Cannot factor out into a template template <typename T> inline
// llvm::Expected<std::vector<uint8_t>>
// typedResult(JITLambda::Argument &arguments); due to ambiguity with
// scalar template
template <>
inline llvm::Expected<std::vector<uint8_t>>
typedResult(JITLambda::Argument &arguments) {
return std::move(typedVectorResult<uint8_t>(arguments));
}
template <>
inline llvm::Expected<std::vector<uint16_t>>
typedResult(JITLambda::Argument &arguments) {
return std::move(typedVectorResult<uint16_t>(arguments));
}
template <>
inline llvm::Expected<std::vector<uint32_t>>
typedResult(JITLambda::Argument &arguments) {
return std::move(typedVectorResult<uint32_t>(arguments));
}
template <>
inline llvm::Expected<std::vector<uint64_t>>
typedResult(JITLambda::Argument &arguments) {
return std::move(typedVectorResult<uint64_t>(arguments));
}
template <typename T>
llvm::Expected<std::unique_ptr<LambdaArgument>>
buildTensorLambdaResult(JITLambda::Argument &arguments) {
llvm::Expected<std::vector<T>> tensorOrError =
typedResult<std::vector<T>>(arguments);
if (!tensorOrError)
return std::move(tensorOrError.takeError());
llvm::Expected<std::vector<int64_t>> tensorDimOrError =
arguments.getResultDimensions(0);
if (!tensorDimOrError)
return std::move(tensorDimOrError.takeError());
return std::move(std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
*tensorOrError, *tensorDimOrError));
}
// Specialization of `typedResult()` for a single result wrapped into
// a `LambdaArgument`.
template <>
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
typedResult(JITLambda::Argument &arguments) {
llvm::Expected<enum JITLambda::Argument::ResultType> resTy =
arguments.getResultType(0);
if (!resTy)
return std::move(resTy.takeError());
switch (*resTy) {
case JITLambda::Argument::ResultType::SCALAR: {
uint64_t res;
if (llvm::Error err = arguments.getResult(0, res))
return std::move(err);
return std::move(std::make_unique<IntLambdaArgument<uint64_t>>(res));
}
case JITLambda::Argument::ResultType::TENSOR: {
llvm::Expected<size_t> width = arguments.getResultWidth(0);
if (!width)
return std::move(width.takeError());
if (*width > 64)
return StreamStringError("Cannot handle scalars with more than 64 bits");
if (*width > 32)
return buildTensorLambdaResult<uint64_t>(arguments);
else if (*width > 16)
return buildTensorLambdaResult<uint32_t>(arguments);
else if (*width > 8)
return buildTensorLambdaResult<uint16_t>(arguments);
else if (*width <= 8)
return buildTensorLambdaResult<uint8_t>(arguments);
}
}
return StreamStringError("Unknown result type");
} // namespace
// Adaptor class that adds arguments specified as instances of
// `LambdaArgument` to `JitLambda::Argument`.
class JITLambdaArgumentAdaptor {
public:
// Checks if the argument `arg` is an plaintext / encrypted integer
// argument or a plaintext / encrypted tensor argument with a
// backing integer type `IntT` and adds the argument to `jla` at
// position `pos`.
//
// Returns `true` if `arg` has one of the types above and its value
// was successfully added to `jla`, `false` if none of the types
// matches or an error if a type matched, but adding the argument to
// `jla` failed.
template <typename IntT>
static inline llvm::Expected<bool>
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
if (llvm::Error err = jla.setArg(pos, ila->getValue()))
return std::move(err);
else
return true;
} else if (auto tla = arg.dyn_cast<
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
if (llvm::Error err =
jla.setArg(pos, tla->getValue(), tla->getDimensions()))
return std::move(err);
else
return true;
}
return false;
}
// Recursive case for `tryAddArg<IntT>(...)`
template <typename IntT, typename NextIntT, typename... IntTs>
static inline llvm::Expected<bool>
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
llvm::Expected<bool> successOrError = tryAddArg<IntT>(jla, pos, arg);
if (!successOrError)
return std::move(successOrError.takeError());
if (successOrError.get() == false)
return tryAddArg<NextIntT, IntTs...>(jla, pos, arg);
else
return true;
}
// Attempts to add a single argument `arg` to `jla` at position
// `pos`. Returns an error if either the argument type is
// unsupported or if the argument types is supported, but adding it
// to `jla` failed.
static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos,
const LambdaArgument &arg) {
llvm::Expected<bool> successOrError =
JITLambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t,
uint8_t>(jla, pos, arg);
if (!successOrError)
return std::move(successOrError.takeError());
if (successOrError.get() == false)
return StreamStringError("Unknown argument type");
else
return llvm::Error::success();
}
};
} // namespace
// A compiler engine that JIT-compiles a source and produces a lambda
// object directly invocable through its call operator.
class JitCompilerEngine : public CompilerEngine {
public:
// Wrapper class around `JITLambda` and `JITLambda::Argument` that
// allows for direct invocation of a compiled function through
// `operator ()`.
class Lambda {
public:
Lambda(Lambda &&other)
: innerLambda(std::move(other.innerLambda)),
keySet(std::move(other.keySet)),
compilationContext(other.compilationContext) {}
Lambda(std::shared_ptr<CompilationContext> compilationContext,
std::unique_ptr<JITLambda> lambda, std::unique_ptr<KeySet> keySet)
: innerLambda(std::move(lambda)), keySet(std::move(keySet)),
compilationContext(compilationContext) {}
// Returns the number of arguments required for an invocation of
// the lambda
size_t getNumArguments() { return this->keySet->numInputs(); }
// Returns the number of results an invocation of the lambda
// produces
size_t getNumResults() { return this->keySet->numOutputs(); }
// Invocation with an dynamic list of arguments of different
// types, specified as `LambdaArgument`s
template <typename ResT = uint64_t>
llvm::Expected<ResT>
operator()(llvm::ArrayRef<LambdaArgument *> lambdaArgs) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
for (size_t i = 0; i < lambdaArgs.size(); i++) {
if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument(
*arguments, i, *lambdaArgs[i])) {
return std::move(err);
}
}
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
return std::move(typedResult<ResT>(*arguments));
}
// Invocation with an array of arguments of the same type
template <typename T, typename ResT = uint64_t>
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
for (size_t i = 0; i < args.size(); i++) {
if (auto err = arguments->setArg(i, args[i])) {
return StreamStringError()
<< "Cannot push argument " << i << ": " << err;
}
}
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
return std::move(typedResult<ResT>(*arguments));
}
// Invocation with arguments of different types
template <typename ResT = uint64_t, typename... Ts>
llvm::Expected<ResT> operator()(const Ts... ts) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...))
return std::move(err);
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
return std::move(typedResult<ResT>(*arguments));
}
protected:
template <int pos>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs) {
// base case -- nothing to do
return llvm::Error::success();
}
// Recursive case for scalars: extract first scalar argument from
// parameter pack and forward rest
template <int pos, typename ArgT, typename... Ts>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg,
Ts... remainder) {
if (auto err = jitArgs->setArg(pos, arg)) {
return StreamStringError()
<< "Cannot push scalar argument " << pos << ": " << err;
}
return this->addArgs<pos + 1>(jitArgs, remainder...);
}
// Recursive case for tensors: extract pointer and size from
// parameter pack and forward rest
template <int pos, typename ArgT, typename... Ts>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg,
size_t size, Ts... remainder) {
if (auto err = jitArgs->setArg(pos, arg, size)) {
return StreamStringError()
<< "Cannot push tensor argument " << pos << ": " << err;
}
return this->addArgs<pos + 1>(jitArgs, remainder...);
}
std::unique_ptr<JITLambda> innerLambda;
std::unique_ptr<KeySet> keySet;
std::shared_ptr<CompilationContext> compilationContext;
};
JitCompilerEngine(std::shared_ptr<CompilationContext> compilationContext =
CompilationContext::createShared(),
unsigned int optimizationLevel = 3);
/// Build a Lambda from a source MLIR, with `funcName` as entrypoint.
/// Use runtimeLibPath as a shared library if specified.
llvm::Expected<Lambda>
buildLambda(llvm::StringRef src, llvm::StringRef funcName = "main",
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
llvm::Expected<Lambda>
buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::StringRef funcName = "main",
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
llvm::Expected<Lambda>
buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName = "main",
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
protected:
llvm::Expected<mlir::LLVM::LLVMFuncOp> findLLVMFuncOp(mlir::ModuleOp module,
llvm::StringRef name);
unsigned int optimizationLevel;
};
} // namespace zamalang
} // namespace mlir
#endif