Files
concrete/compiler/include/zamalang/Support/CompilerEngine.h
Andi Drebes 1187cfbd62 refactor(compiler): Refactor CompilerEngine and related classes
This commit contains several incremental improvements towards a clear
interface for lambdas:

  - Unification of static and JIT compilation by using the static
    compilation path of `CompilerEngine` within a new subclass
    `JitCompilerEngine`.

  - Clear ownership for compilation artefacts through
    `CompilationContext`, making it impossible to destroy objects used
    directly or indirectly before destruction of their users.

  - Clear interface for lambdas generated by the compiler through
    `JitCompilerEngine::Lambda` with a templated call operator,
    encapsulating otherwise manual orchestration of `CompilerEngine`,
    `JITLambda`, and `CompilerEngine::Argument`.

  - Improved error handling through `llvm::Expected<T>` and proper
    error checking following the conventions for `llvm::Expected<T>`
    and `llvm::Error`.

Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
2021-10-29 13:44:34 +02:00

140 lines
4.5 KiB
C++

#ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H
#define ZAMALANG_SUPPORT_COMPILER_ENGINE_H
#include <llvm/IR/Module.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/MLIRContext.h>
#include <zamalang/Conversion/Utils/GlobalFHEContext.h>
#include <zamalang/Support/ClientParameters.h>
#include <zamalang/Support/KeySet.h>
namespace mlir {
namespace zamalang {
// Compilation context that acts as the root owner of LLVM and MLIR
// data structures directly and indirectly referenced by artefacts
// produced by the `CompilerEngine`.
class CompilationContext {
public:
CompilationContext();
~CompilationContext();
mlir::MLIRContext *getMLIRContext();
llvm::LLVMContext *getLLVMContext();
static std::shared_ptr<CompilationContext> createShared();
protected:
mlir::MLIRContext *mlirContext;
llvm::LLVMContext *llvmContext;
};
class CompilerEngine {
public:
// Result of an invocation of the `CompilerEngine` with optional
// fields for the results produced by different stages.
class CompilationResult {
public:
CompilationResult(std::shared_ptr<CompilationContext> compilationContext =
CompilationContext::createShared())
: compilationContext(compilationContext) {}
llvm::Optional<mlir::OwningModuleRef> mlirModuleRef;
llvm::Optional<mlir::zamalang::ClientParameters> clientParameters;
std::unique_ptr<mlir::zamalang::KeySet> keySet;
std::unique_ptr<llvm::Module> llvmModule;
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
protected:
std::shared_ptr<CompilationContext> compilationContext;
};
// Specification of the exit stage of the compilation pipeline
enum class Target {
// Only read sources and produce corresponding MLIR module
ROUND_TRIP,
// Read sources and exit before any lowering
HLFHE,
// Read sources and attempt to run the Minimal Arithmetic Noise
// Padding pass
HLFHE_MANP,
// Read sources and lower all HLFHE operations to MidLFHE
// operations
MIDLFHE,
// Read sources and lower all HLFHE and MidLFHE operations to LowLFHE
// operations
LOWLFHE,
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
// operations to canonical MLIR dialects. Cryptographic operations
// are lowered to invocations of the concrete library.
STD,
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
// operations to operations from the LLVM dialect. Cryptographic
// operations are lowered to invocations of the concrete library.
LLVM,
// Same as `LLVM`, but lowers to actual LLVM IR instead of the
// LLVM dialect
LLVM_IR,
// Same as `LLVM_IR`, but invokes the LLVM optimization pipeline
// to produce optimized LLVM IR
OPTIMIZED_LLVM_IR
};
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
: overrideMaxEintPrecision(), overrideMaxMANP(),
clientParametersFuncName(), verifyDiagnostics(false),
generateKeySet(false), generateClientParameters(false),
parametrizeMidLFHE(true), compilationContext(compilationContext) {}
llvm::Expected<CompilationResult> compile(llvm::StringRef s, Target target);
llvm::Expected<CompilationResult>
compile(std::unique_ptr<llvm::MemoryBuffer> buffer, Target target);
llvm::Expected<CompilationResult> compile(llvm::SourceMgr &sm, Target target);
void setFHEConstraints(const mlir::zamalang::V0FHEConstraint &c);
void setMaxEintPrecision(size_t v);
void setMaxMANP(size_t v);
void setVerifyDiagnostics(bool v);
void setGenerateKeySet(bool v);
void setGenerateClientParameters(bool v);
void setParametrizeMidLFHE(bool v);
void setClientParametersFuncName(const llvm::StringRef &name);
protected:
llvm::Optional<size_t> overrideMaxEintPrecision;
llvm::Optional<size_t> overrideMaxMANP;
llvm::Optional<std::string> clientParametersFuncName;
bool verifyDiagnostics;
bool generateKeySet;
bool generateClientParameters;
bool parametrizeMidLFHE;
std::shared_ptr<CompilationContext> compilationContext;
// Helper enum identifying an FHE dialect (`HLFHE`, `MIDLFHE`, `LOWLFHE`)
// or indicating that no FHE dialect is used (`NONE`).
enum class FHEDialect { HLFHE, MIDLFHE, LOWLFHE, NONE };
static FHEDialect detectHighestFHEDialect(mlir::ModuleOp module);
private:
llvm::Error lowerParamDependentHalf(Target target, CompilationResult &res);
llvm::Error determineFHEParameters(CompilationResult &res, bool noOverride);
};
} // namespace zamalang
} // namespace mlir
#endif