refactor(compiler): Refactor CompilerEngine and related classes

This commit contains several incremental improvements towards a clear
interface for lambdas:

  - Unification of static and JIT compilation by using the static
    compilation path of `CompilerEngine` within a new subclass
    `JitCompilerEngine`.

  - Clear ownership for compilation artefacts through
    `CompilationContext`, making it impossible to destroy objects used
    directly or indirectly before destruction of their users.

  - Clear interface for lambdas generated by the compiler through
    `JitCompilerEngine::Lambda` with a templated call operator,
    encapsulating otherwise manual orchestration of `CompilerEngine`,
    `JITLambda`, and `CompilerEngine::Argument`.

  - Improved error handling through `llvm::Expected<T>` and proper
    error checking following the conventions for `llvm::Expected<T>`
    and `llvm::Error`.

Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
This commit is contained in:
Andi Drebes
2021-10-18 15:38:12 +02:00
parent d738104c4b
commit 1187cfbd62
61 changed files with 1690 additions and 997 deletions

View File

@@ -5,15 +5,17 @@
#include "mlir-c/Registration.h"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/ExecutionArgument.h"
#include "zamalang/Support/Jit.h"
#include "zamalang/Support/JitCompilerEngine.h"
#ifdef __cplusplus
extern "C" {
#endif
struct compilerEngine {
mlir::zamalang::CompilerEngine *ptr;
struct lambda {
mlir::zamalang::JitCompilerEngine::Lambda *ptr;
};
typedef struct compilerEngine compilerEngine;
typedef struct lambda lambda;
struct executionArguments {
mlir::zamalang::ExecutionArgument *data;
@@ -21,13 +23,12 @@ struct executionArguments {
};
typedef struct executionArguments exectuionArguments;
// Compile an MLIR module
MLIR_CAPI_EXPORTED void compilerEngineCompile(compilerEngine engine,
const char *module);
MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName);
// Run the compiled module
MLIR_CAPI_EXPORTED uint64_t compilerEngineRun(compilerEngine e,
executionArguments args);
MLIR_CAPI_EXPORTED uint64_t invokeLambda(lambda l, executionArguments args);
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
#ifdef __cplusplus
}

View File

@@ -1,49 +1,138 @@
#ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H
#define ZAMALANG_SUPPORT_COMPILER_ENGINE_H
#include "Jit.h"
#include <llvm/IR/Module.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/MLIRContext.h>
#include <zamalang/Conversion/Utils/GlobalFHEContext.h>
#include <zamalang/Support/ClientParameters.h>
#include <zamalang/Support/KeySet.h>
namespace mlir {
namespace zamalang {
/// CompilerEngine is an tools that provides tools to implements the compilation
/// flow and manage the compilation flow state.
// Compilation context that acts as the root owner of LLVM and MLIR
// data structures directly and indirectly referenced by artefacts
// produced by the `CompilerEngine`.
class CompilationContext {
public:
CompilationContext();
~CompilationContext();
mlir::MLIRContext *getMLIRContext();
llvm::LLVMContext *getLLVMContext();
static std::shared_ptr<CompilationContext> createShared();
protected:
mlir::MLIRContext *mlirContext;
llvm::LLVMContext *llvmContext;
};
class CompilerEngine {
public:
CompilerEngine() {
context = new mlir::MLIRContext();
loadDialects();
}
~CompilerEngine() {
if (context != nullptr)
delete context;
}
// Result of an invocation of the `CompilerEngine` with optional
// fields for the results produced by different stages.
class CompilationResult {
public:
CompilationResult(std::shared_ptr<CompilationContext> compilationContext =
CompilationContext::createShared())
: compilationContext(compilationContext) {}
// Compile an mlir programs from it's textual representation.
llvm::Error compile(
std::string mlirStr,
llvm::Optional<mlir::zamalang::V0FHEConstraint> overrideConstraints = {});
llvm::Optional<mlir::OwningModuleRef> mlirModuleRef;
llvm::Optional<mlir::zamalang::ClientParameters> clientParameters;
std::unique_ptr<mlir::zamalang::KeySet> keySet;
std::unique_ptr<llvm::Module> llvmModule;
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
// Build the jit lambda argument.
llvm::Expected<std::unique_ptr<JITLambda::Argument>> buildArgument();
protected:
std::shared_ptr<CompilationContext> compilationContext;
};
// Call the compiled function with and argument object.
llvm::Error invoke(JITLambda::Argument &arg);
// Specification of the exit stage of the compilation pipeline
enum class Target {
// Only read sources and produce corresponding MLIR module
ROUND_TRIP,
// Call the compiled function with a list of integer arguments.
llvm::Expected<uint64_t> run(std::vector<uint64_t> args);
// Read sources and exit before any lowering
HLFHE,
// Get a printable representation of the compiled module
std::string getCompiledModule();
// Read sources and attempt to run the Minimal Arithmetic Noise
// Padding pass
HLFHE_MANP,
// Read sources and lower all HLFHE operations to MidLFHE
// operations
MIDLFHE,
// Read sources and lower all HLFHE and MidLFHE operations to LowLFHE
// operations
LOWLFHE,
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
// operations to canonical MLIR dialects. Cryptographic operations
// are lowered to invocations of the concrete library.
STD,
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
// operations to operations from the LLVM dialect. Cryptographic
// operations are lowered to invocations of the concrete library.
LLVM,
// Same as `LLVM`, but lowers to actual LLVM IR instead of the
// LLVM dialect
LLVM_IR,
// Same as `LLVM_IR`, but invokes the LLVM optimization pipeline
// to produce optimized LLVM IR
OPTIMIZED_LLVM_IR
};
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
: overrideMaxEintPrecision(), overrideMaxMANP(),
clientParametersFuncName(), verifyDiagnostics(false),
generateKeySet(false), generateClientParameters(false),
parametrizeMidLFHE(true), compilationContext(compilationContext) {}
llvm::Expected<CompilationResult> compile(llvm::StringRef s, Target target);
llvm::Expected<CompilationResult>
compile(std::unique_ptr<llvm::MemoryBuffer> buffer, Target target);
llvm::Expected<CompilationResult> compile(llvm::SourceMgr &sm, Target target);
void setFHEConstraints(const mlir::zamalang::V0FHEConstraint &c);
void setMaxEintPrecision(size_t v);
void setMaxMANP(size_t v);
void setVerifyDiagnostics(bool v);
void setGenerateKeySet(bool v);
void setGenerateClientParameters(bool v);
void setParametrizeMidLFHE(bool v);
void setClientParametersFuncName(const llvm::StringRef &name);
protected:
llvm::Optional<size_t> overrideMaxEintPrecision;
llvm::Optional<size_t> overrideMaxMANP;
llvm::Optional<std::string> clientParametersFuncName;
bool verifyDiagnostics;
bool generateKeySet;
bool generateClientParameters;
bool parametrizeMidLFHE;
std::shared_ptr<CompilationContext> compilationContext;
// Helper enum identifying an FHE dialect (`HLFHE`, `MIDLFHE`, `LOWLFHE`)
// or indicating that no FHE dialect is used (`NONE`).
enum class FHEDialect { HLFHE, MIDLFHE, LOWLFHE, NONE };
static FHEDialect detectHighestFHEDialect(mlir::ModuleOp module);
private:
// Load the necessary dialects into the engine's context
void loadDialects();
mlir::OwningModuleRef module_ref;
mlir::MLIRContext *context;
std::unique_ptr<mlir::zamalang::KeySet> keySet;
llvm::Error lowerParamDependentHalf(Target target, CompilationResult &res);
llvm::Error determineFHEParameters(CompilationResult &res, bool noOverride);
};
} // namespace zamalang
} // namespace mlir

View File

@@ -9,11 +9,6 @@
namespace mlir {
namespace zamalang {
mlir::LogicalResult
runJit(mlir::ModuleOp module, llvm::StringRef func,
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
std::function<llvm::Error(llvm::Module *)> optPipeline,
llvm::raw_ostream &os);
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
/// of the module.

View File

@@ -0,0 +1,296 @@
#ifndef ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
#define ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <zamalang/Support/CompilerEngine.h>
#include <zamalang/Support/Error.h>
#include <zamalang/Support/Jit.h>
#include <zamalang/Support/LambdaArgument.h>
namespace mlir {
namespace zamalang {
namespace {
// Generic function template as well as specializations of
// `typedResult` must be declared at namespace scope due to return
// type template specialization
// Helper function for `JitCompilerEngine::Lambda::operator()`
// implementing type-dependent preparation of the result.
template <typename ResT>
llvm::Expected<ResT> typedResult(JITLambda::Argument &arguments);
// Specialization of `typedResult()` for scalar results, forwarding
// scalar value to caller
template <>
inline llvm::Expected<uint64_t> typedResult(JITLambda::Argument &arguments) {
uint64_t res = 0;
if (auto err = arguments.getResult(0, res))
return StreamStringError() << "Cannot retrieve result:" << err;
return res;
}
// Specialization of `typedResult()` for vector results, initializing
// an `std::vector` of the right size with the results and forwarding
// it to the caller with move semantics.
template <>
inline llvm::Expected<std::vector<uint64_t>>
typedResult(JITLambda::Argument &arguments) {
llvm::Expected<size_t> n = arguments.getResultVectorSize(0);
if (auto err = n.takeError())
return std::move(err);
std::vector<uint64_t> res(*n);
if (auto err = arguments.getResult(0, res.data(), res.size()))
return StreamStringError() << "Cannot retrieve result:" << err;
return std::move(res);
}
// Adaptor class that adds arguments specified as instances of
// `LambdaArgument` to `JitLambda::Argument`.
class JITLambdaArgumentAdaptor {
public:
// Checks if the argument `arg` is an plaintext / encrypted integer
// argument or a plaintext / encrypted tensor argument with a
// backing integer type `IntT` and adds the argument to `jla` at
// position `pos`.
//
// Returns `true` if `arg` has one of the types above and its value
// was successfully added to `jla`, `false` if none of the types
// matches or an error if a type matched, but adding the argument to
// `jla` failed.
template <typename IntT>
static inline llvm::Expected<bool>
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
if (llvm::Error err = jla.setArg(pos, ila->getValue()))
return std::move(err);
else
return true;
} else if (auto tla = arg.dyn_cast<
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
llvm::Expected<size_t> numElements = tla->getNumElements();
if (!numElements)
return std::move(numElements.takeError());
if (llvm::Error err = jla.setArg(pos, tla->getValue(), *numElements))
return std::move(err);
else
return true;
}
return false;
}
// Recursive case for `tryAddArg<IntT>(...)`
template <typename IntT, typename NextIntT, typename... IntTs>
static inline llvm::Expected<bool>
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
llvm::Expected<bool> successOrError = tryAddArg<IntT>(jla, pos, arg);
if (!successOrError)
return std::move(successOrError.takeError());
if (successOrError.get() == false)
return tryAddArg<NextIntT, IntTs...>(jla, pos, arg);
else
return true;
}
// Attempts to add a single argument `arg` to `jla` at position
// `pos`. Returns an error if either the argument type is
// unsupported or if the argument types is supported, but adding it
// to `jla` failed.
static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos,
const LambdaArgument &arg) {
llvm::Expected<bool> successOrError =
JITLambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t,
uint8_t>(jla, pos, arg);
if (!successOrError)
return std::move(successOrError.takeError());
if (successOrError.get() == false)
return StreamStringError("Unknown argument type");
else
return llvm::Error::success();
}
};
} // namespace
// A compiler engine that JIT-compiles a source and produces a lambda
// object directly invocable through its call operator.
class JitCompilerEngine : public CompilerEngine {
public:
// Wrapper class around `JITLambda` and `JITLambda::Argument` that
// allows for direct invocation of a compiled function through
// `operator ()`.
class Lambda {
public:
Lambda(Lambda &&other)
: innerLambda(std::move(other.innerLambda)),
keySet(std::move(other.keySet)),
compilationContext(other.compilationContext) {}
Lambda(std::shared_ptr<CompilationContext> compilationContext,
std::unique_ptr<JITLambda> lambda, std::unique_ptr<KeySet> keySet)
: innerLambda(std::move(lambda)), keySet(std::move(keySet)),
compilationContext(compilationContext) {}
// Returns the number of arguments required for an invocation of
// the lambda
size_t getNumArguments() { return this->keySet->numInputs(); }
// Returns the number of results an invocation of the lambda
// produces
size_t getNumResults() { return this->keySet->numOutputs(); }
// Invocation with an dynamic list of arguments of different
// types, specified as `LambdaArgument`s
template <typename ResT = uint64_t>
llvm::Expected<ResT>
operator()(llvm::ArrayRef<LambdaArgument *> lambdaArgs) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
for (size_t i = 0; i < lambdaArgs.size(); i++) {
if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument(
*arguments, i, *lambdaArgs[i])) {
return std::move(err);
}
}
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
return std::move(typedResult<ResT>(*arguments));
}
// Invocation with an array of arguments of the same type
template <typename T, typename ResT = uint64_t>
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
for (size_t i = 0; i < args.size(); i++) {
if (auto err = arguments->setArg(i, args[i])) {
return StreamStringError()
<< "Cannot push argument " << i << ": " << err;
}
}
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
return std::move(typedResult<ResT>(*arguments));
}
// Invocation with arguments of different types
template <typename ResT = uint64_t, typename... Ts>
llvm::Expected<ResT> operator()(const Ts... ts) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...))
return std::move(err);
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
return std::move(typedResult<ResT>(*arguments));
}
protected:
template <int pos>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs) {
// base case -- nothing to do
return llvm::Error::success();
}
// Recursive case for scalars: extract first scalar argument from
// parameter pack and forward rest
template <int pos, typename ArgT, typename... Ts>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg,
Ts... remainder) {
if (auto err = jitArgs->setArg(pos, arg)) {
return StreamStringError()
<< "Cannot push scalar argument " << pos << ": " << err;
}
return this->addArgs<pos + 1>(jitArgs, remainder...);
}
// Recursive case for tensors: extract pointer and size from
// parameter pack and forward rest
template <int pos, typename ArgT, typename... Ts>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg,
size_t size, Ts... remainder) {
if (auto err = jitArgs->setArg(pos, arg, size)) {
return StreamStringError()
<< "Cannot push tensor argument " << pos << ": " << err;
}
return this->addArgs<pos + 1>(jitArgs, remainder...);
}
std::unique_ptr<JITLambda> innerLambda;
std::unique_ptr<KeySet> keySet;
std::shared_ptr<CompilationContext> compilationContext;
};
JitCompilerEngine(std::shared_ptr<CompilationContext> compilationContext =
CompilationContext::createShared(),
unsigned int optimizationLevel = 3);
llvm::Expected<Lambda> buildLambda(llvm::StringRef src,
llvm::StringRef funcName = "main");
llvm::Expected<Lambda> buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::StringRef funcName = "main");
llvm::Expected<Lambda> buildLambda(llvm::SourceMgr &sm,
llvm::StringRef funcName = "main");
protected:
llvm::Expected<mlir::LLVM::LLVMFuncOp> findLLVMFuncOp(mlir::ModuleOp module,
llvm::StringRef name);
unsigned int optimizationLevel;
};
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -0,0 +1,157 @@
#ifndef ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H
#define ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H
#include <cstdint>
#include <limits>
#include <llvm/ADT/ArrayRef.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/ExtensibleRTTI.h>
#include <zamalang/Support/Error.h>
namespace mlir {
namespace zamalang {
// Abstract base class for lambda arguments
class LambdaArgument
: public llvm::RTTIExtends<LambdaArgument, llvm::RTTIRoot> {
public:
LambdaArgument(LambdaArgument &) = delete;
template <typename T> bool isa() const { return llvm::isa<T>(*this); }
// Cast functions on constant instances
template <typename T> const T &cast() const { return llvm::cast<T>(*this); }
template <typename T> const T *dyn_cast() const {
return llvm::dyn_cast<T>(this);
}
// Cast functions for mutable instances
template <typename T> T &cast() { return llvm::cast<T>(*this); }
template <typename T> T *dyn_cast() { return llvm::dyn_cast<T>(this); }
static char ID;
protected:
LambdaArgument(){};
};
// Class for integer arguments. `BackingIntType` is used as the data
// type to hold the argument's value. The precision is the actual
// precision of the value, which might be different from the precision
// of the backing integer type.
template <typename BackingIntType = uint64_t>
class IntLambdaArgument
: public llvm::RTTIExtends<IntLambdaArgument<BackingIntType>,
LambdaArgument> {
public:
typedef BackingIntType value_type;
IntLambdaArgument(BackingIntType value,
unsigned int precision = 8 * sizeof(BackingIntType))
: precision(precision) {
if (precision < 8 * sizeof(BackingIntType)) {
this->value = value & (1 << (this->precision - 1));
} else {
this->value = value;
}
}
unsigned int getPrecision() const { return this->precision; }
BackingIntType getValue() const { return this->value; }
static char ID;
protected:
unsigned int precision;
BackingIntType value;
};
template <typename BackingIntType>
char IntLambdaArgument<BackingIntType>::ID = 0;
namespace {
// Calculates `accu *= factor` or returns an error if the result
// would overflow
template <typename AccuT, typename ValT>
llvm::Error safeUnsignedMul(AccuT &accu, ValT factor) {
static_assert(std::numeric_limits<AccuT>::is_integer &&
std::numeric_limits<ValT>::is_integer &&
!std::numeric_limits<AccuT>::is_signed &&
!std::numeric_limits<ValT>::is_signed,
"Only unsigned integers are supported");
const AccuT left = std::numeric_limits<AccuT>::max() / accu;
if (left > factor) {
accu *= factor;
return llvm::Error::success();
}
return StreamStringError("Multiplying value ")
<< accu << " with " << factor << " would cause an overflow";
}
} // namespace
// Class for Tensor arguments. This can either be plaintext tensors
// (for `ScalarArgumentT = IntLambaArgument<T>`) or tensors
// representing encrypted integers (for `ScalarArgumentT =
// EIntLambaArgument<T>`).
template <typename ScalarArgumentT>
class TensorLambdaArgument
: public llvm::RTTIExtends<TensorLambdaArgument<ScalarArgumentT>,
LambdaArgument> {
public:
typedef ScalarArgumentT scalar_type;
// Construct tensor argument from the one-dimensional array `value`,
// but interpreting the array's values as a linearized
// multi-dimensional tensor with the sizes of the dimensions
// specified in `dimensions`.
TensorLambdaArgument(
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value,
llvm::ArrayRef<unsigned int> dimensions)
: value(value), dimensions(dimensions.vec()) {}
// Construct a one-dimensional tensor argument from the
// array `value`.
TensorLambdaArgument(
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value)
: TensorLambdaArgument(value, {(unsigned int)value.size()}) {}
const std::vector<unsigned int> &getDimensions() const {
return this->dimensions;
}
// Returns the total number of elements in the tensor. If the number
// of elements cannot be represented as a `size_t`, the method
// returns an error.
llvm::Expected<size_t> getNumElements() const {
size_t accu = 1;
for (unsigned int dimSize : dimensions)
if (llvm::Error err = safeUnsignedMul(accu, dimSize))
return std::move(err);
return accu;
}
// Returns a bare pointer to the linearized values of the tensor.
typename ScalarArgumentT::value_type *getValue() const {
return this->value.data();
}
static char ID;
protected:
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value;
std::vector<unsigned int> dimensions;
};
template <typename ScalarArgumentT>
char TensorLambdaArgument<ScalarArgumentT>::ID = 0;
} // namespace zamalang
} // namespace mlir
#endif