mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
Merge branch 'master' into hlfhelinalg-binary-op-lowering
This commit is contained in:
@@ -5,15 +5,17 @@
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/ExecutionArgument.h"
|
||||
#include "zamalang/Support/Jit.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct compilerEngine {
|
||||
mlir::zamalang::CompilerEngine *ptr;
|
||||
struct lambda {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda *ptr;
|
||||
};
|
||||
typedef struct compilerEngine compilerEngine;
|
||||
typedef struct lambda lambda;
|
||||
|
||||
struct executionArguments {
|
||||
mlir::zamalang::ExecutionArgument *data;
|
||||
@@ -21,13 +23,12 @@ struct executionArguments {
|
||||
};
|
||||
typedef struct executionArguments exectuionArguments;
|
||||
|
||||
// Compile an MLIR module
|
||||
MLIR_CAPI_EXPORTED void compilerEngineCompile(compilerEngine engine,
|
||||
const char *module);
|
||||
MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName);
|
||||
|
||||
// Run the compiled module
|
||||
MLIR_CAPI_EXPORTED uint64_t compilerEngineRun(compilerEngine e,
|
||||
executionArguments args);
|
||||
MLIR_CAPI_EXPORTED uint64_t invokeLambda(lambda l, executionArguments args);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -1,49 +1,128 @@
|
||||
#ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H
|
||||
#define ZAMALANG_SUPPORT_COMPILER_ENGINE_H
|
||||
|
||||
#include "Jit.h"
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <llvm/Support/SourceMgr.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/MLIRContext.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <zamalang/Conversion/Utils/GlobalFHEContext.h>
|
||||
#include <zamalang/Support/ClientParameters.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
/// CompilerEngine is an tools that provides tools to implements the compilation
|
||||
/// flow and manage the compilation flow state.
|
||||
// Compilation context that acts as the root owner of LLVM and MLIR
|
||||
// data structures directly and indirectly referenced by artefacts
|
||||
// produced by the `CompilerEngine`.
|
||||
class CompilationContext {
|
||||
public:
|
||||
CompilationContext();
|
||||
~CompilationContext();
|
||||
|
||||
mlir::MLIRContext *getMLIRContext();
|
||||
llvm::LLVMContext *getLLVMContext();
|
||||
|
||||
static std::shared_ptr<CompilationContext> createShared();
|
||||
|
||||
protected:
|
||||
mlir::MLIRContext *mlirContext;
|
||||
llvm::LLVMContext *llvmContext;
|
||||
};
|
||||
|
||||
class CompilerEngine {
|
||||
public:
|
||||
CompilerEngine() {
|
||||
context = new mlir::MLIRContext();
|
||||
loadDialects();
|
||||
}
|
||||
~CompilerEngine() {
|
||||
if (context != nullptr)
|
||||
delete context;
|
||||
}
|
||||
// Result of an invocation of the `CompilerEngine` with optional
|
||||
// fields for the results produced by different stages.
|
||||
class CompilationResult {
|
||||
public:
|
||||
CompilationResult(std::shared_ptr<CompilationContext> compilationContext =
|
||||
CompilationContext::createShared())
|
||||
: compilationContext(compilationContext) {}
|
||||
|
||||
// Compile an mlir programs from it's textual representation.
|
||||
llvm::Error compile(
|
||||
std::string mlirStr,
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> overrideConstraints = {});
|
||||
llvm::Optional<mlir::OwningModuleRef> mlirModuleRef;
|
||||
llvm::Optional<mlir::zamalang::ClientParameters> clientParameters;
|
||||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
|
||||
|
||||
// Build the jit lambda argument.
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> buildArgument();
|
||||
protected:
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
};
|
||||
|
||||
// Call the compiled function with and argument object.
|
||||
llvm::Error invoke(JITLambda::Argument &arg);
|
||||
// Specification of the exit stage of the compilation pipeline
|
||||
enum class Target {
|
||||
// Only read sources and produce corresponding MLIR module
|
||||
ROUND_TRIP,
|
||||
|
||||
// Call the compiled function with a list of integer arguments.
|
||||
llvm::Expected<uint64_t> run(std::vector<uint64_t> args);
|
||||
// Read sources and exit before any lowering
|
||||
HLFHE,
|
||||
|
||||
// Get a printable representation of the compiled module
|
||||
std::string getCompiledModule();
|
||||
// Read sources and lower all HLFHE operations to MidLFHE
|
||||
// operations
|
||||
MIDLFHE,
|
||||
|
||||
// Read sources and lower all HLFHE and MidLFHE operations to LowLFHE
|
||||
// operations
|
||||
LOWLFHE,
|
||||
|
||||
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
|
||||
// operations to canonical MLIR dialects. Cryptographic operations
|
||||
// are lowered to invocations of the concrete library.
|
||||
STD,
|
||||
|
||||
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
|
||||
// operations to operations from the LLVM dialect. Cryptographic
|
||||
// operations are lowered to invocations of the concrete library.
|
||||
LLVM,
|
||||
|
||||
// Same as `LLVM`, but lowers to actual LLVM IR instead of the
|
||||
// LLVM dialect
|
||||
LLVM_IR,
|
||||
|
||||
// Same as `LLVM_IR`, but invokes the LLVM optimization pipeline
|
||||
// to produce optimized LLVM IR
|
||||
OPTIMIZED_LLVM_IR
|
||||
};
|
||||
|
||||
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
|
||||
: overrideMaxEintPrecision(), overrideMaxMANP(),
|
||||
clientParametersFuncName(), verifyDiagnostics(false),
|
||||
generateClientParameters(false),
|
||||
enablePass([](mlir::Pass *pass) { return true; }),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
llvm::Expected<CompilationResult> compile(llvm::StringRef s, Target target);
|
||||
|
||||
llvm::Expected<CompilationResult>
|
||||
compile(std::unique_ptr<llvm::MemoryBuffer> buffer, Target target);
|
||||
|
||||
llvm::Expected<CompilationResult> compile(llvm::SourceMgr &sm, Target target);
|
||||
|
||||
void setFHEConstraints(const mlir::zamalang::V0FHEConstraint &c);
|
||||
void setMaxEintPrecision(size_t v);
|
||||
void setMaxMANP(size_t v);
|
||||
void setVerifyDiagnostics(bool v);
|
||||
void setGenerateClientParameters(bool v);
|
||||
void setClientParametersFuncName(const llvm::StringRef &name);
|
||||
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
protected:
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision;
|
||||
llvm::Optional<size_t> overrideMaxMANP;
|
||||
llvm::Optional<std::string> clientParametersFuncName;
|
||||
bool verifyDiagnostics;
|
||||
bool generateClientParameters;
|
||||
std::function<bool(mlir::Pass *)> enablePass;
|
||||
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
|
||||
private:
|
||||
// Load the necessary dialects into the engine's context
|
||||
void loadDialects();
|
||||
|
||||
mlir::OwningModuleRef module_ref;
|
||||
mlir::MLIRContext *context;
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet;
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
getV0FHEConstraint(CompilationResult &res);
|
||||
llvm::Error determineFHEParameters(CompilationResult &res);
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
53
compiler/include/zamalang/Support/Error.h
Normal file
53
compiler/include/zamalang/Support/Error.h
Normal file
@@ -0,0 +1,53 @@
|
||||
#ifndef ZAMALANG_SUPPORT_STRING_ERROR_H
|
||||
#define ZAMALANG_SUPPORT_STRING_ERROR_H
|
||||
|
||||
#include <llvm/Support/Error.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// Internal error class that allows for composing `llvm::Error`s
|
||||
// similar to `llvm::createStringError()`, but using stream-like
|
||||
// composition with `operator<<`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// llvm::Error foo(int i, size_t s, ...) {
|
||||
// ...
|
||||
// if(...) {
|
||||
// return StreamStringError()
|
||||
// << "Some error message with an integer: "
|
||||
// << i << " and a size_t: " << s;
|
||||
// }
|
||||
// ...
|
||||
// }
|
||||
class StreamStringError {
|
||||
public:
|
||||
StreamStringError(const llvm::StringRef &s) : buffer(s.str()), os(buffer){};
|
||||
StreamStringError() : buffer(""), os(buffer){};
|
||||
|
||||
template <typename T> StreamStringError &operator<<(const T &v) {
|
||||
this->os << v;
|
||||
return *this;
|
||||
}
|
||||
|
||||
operator llvm::Error() {
|
||||
return llvm::make_error<llvm::StringError>(os.str(),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
template <typename T> operator llvm::Expected<T>() {
|
||||
return this->operator llvm::Error();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::string buffer;
|
||||
llvm::raw_string_ostream os;
|
||||
};
|
||||
|
||||
StreamStringError &operator<<(StreamStringError &se, llvm::Error &err);
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -9,11 +9,6 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
mlir::LogicalResult
|
||||
runJit(mlir::ModuleOp module, llvm::StringRef func,
|
||||
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
|
||||
std::function<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::raw_ostream &os);
|
||||
|
||||
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
|
||||
/// of the module.
|
||||
@@ -53,6 +48,10 @@ public:
|
||||
// - or the size of the `res` buffser doesn't match the size of the tensor.
|
||||
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
|
||||
|
||||
// Returns the number of elements of the result vector at position
|
||||
// `pos` or an error if the result is a scalar value
|
||||
llvm::Expected<size_t> getResultVectorSize(size_t pos);
|
||||
|
||||
private:
|
||||
llvm::Error setArg(size_t pos, size_t width, void *data,
|
||||
llvm::ArrayRef<int64_t> shape);
|
||||
@@ -97,7 +96,7 @@ public:
|
||||
|
||||
private:
|
||||
mlir::LLVM::LLVMFunctionType type;
|
||||
llvm::StringRef name;
|
||||
std::string name;
|
||||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||
};
|
||||
|
||||
|
||||
292
compiler/include/zamalang/Support/JitCompilerEngine.h
Normal file
292
compiler/include/zamalang/Support/JitCompilerEngine.h
Normal file
@@ -0,0 +1,292 @@
|
||||
#ifndef ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
|
||||
#define ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <zamalang/Support/CompilerEngine.h>
|
||||
#include <zamalang/Support/Error.h>
|
||||
#include <zamalang/Support/Jit.h>
|
||||
#include <zamalang/Support/LambdaArgument.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
namespace {
|
||||
// Generic function template as well as specializations of
|
||||
// `typedResult` must be declared at namespace scope due to return
|
||||
// type template specialization
|
||||
|
||||
// Helper function for `JitCompilerEngine::Lambda::operator()`
|
||||
// implementing type-dependent preparation of the result.
|
||||
template <typename ResT>
|
||||
llvm::Expected<ResT> typedResult(JITLambda::Argument &arguments);
|
||||
|
||||
// Specialization of `typedResult()` for scalar results, forwarding
|
||||
// scalar value to caller
|
||||
template <>
|
||||
inline llvm::Expected<uint64_t> typedResult(JITLambda::Argument &arguments) {
|
||||
uint64_t res = 0;
|
||||
|
||||
if (auto err = arguments.getResult(0, res))
|
||||
return StreamStringError() << "Cannot retrieve result:" << err;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// Specialization of `typedResult()` for vector results, initializing
|
||||
// an `std::vector` of the right size with the results and forwarding
|
||||
// it to the caller with move semantics.
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint64_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
llvm::Expected<size_t> n = arguments.getResultVectorSize(0);
|
||||
|
||||
if (auto err = n.takeError())
|
||||
return std::move(err);
|
||||
|
||||
std::vector<uint64_t> res(*n);
|
||||
|
||||
if (auto err = arguments.getResult(0, res.data(), res.size()))
|
||||
return StreamStringError() << "Cannot retrieve result:" << err;
|
||||
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// Adaptor class that adds arguments specified as instances of
|
||||
// `LambdaArgument` to `JitLambda::Argument`.
|
||||
class JITLambdaArgumentAdaptor {
|
||||
public:
|
||||
// Checks if the argument `arg` is an plaintext / encrypted integer
|
||||
// argument or a plaintext / encrypted tensor argument with a
|
||||
// backing integer type `IntT` and adds the argument to `jla` at
|
||||
// position `pos`.
|
||||
//
|
||||
// Returns `true` if `arg` has one of the types above and its value
|
||||
// was successfully added to `jla`, `false` if none of the types
|
||||
// matches or an error if a type matched, but adding the argument to
|
||||
// `jla` failed.
|
||||
template <typename IntT>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
|
||||
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
|
||||
if (llvm::Error err = jla.setArg(pos, ila->getValue()))
|
||||
return std::move(err);
|
||||
else
|
||||
return true;
|
||||
} else if (auto tla = arg.dyn_cast<
|
||||
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
|
||||
if (llvm::Error err =
|
||||
jla.setArg(pos, tla->getValue(), tla->getDimensions()))
|
||||
return std::move(err);
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Recursive case for `tryAddArg<IntT>(...)`
|
||||
template <typename IntT, typename NextIntT, typename... IntTs>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
|
||||
llvm::Expected<bool> successOrError = tryAddArg<IntT>(jla, pos, arg);
|
||||
|
||||
if (!successOrError)
|
||||
return std::move(successOrError.takeError());
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return tryAddArg<NextIntT, IntTs...>(jla, pos, arg);
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
// Attempts to add a single argument `arg` to `jla` at position
|
||||
// `pos`. Returns an error if either the argument type is
|
||||
// unsupported or if the argument types is supported, but adding it
|
||||
// to `jla` failed.
|
||||
static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos,
|
||||
const LambdaArgument &arg) {
|
||||
llvm::Expected<bool> successOrError =
|
||||
JITLambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t,
|
||||
uint8_t>(jla, pos, arg);
|
||||
|
||||
if (!successOrError)
|
||||
return std::move(successOrError.takeError());
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return StreamStringError("Unknown argument type");
|
||||
else
|
||||
return llvm::Error::success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// A compiler engine that JIT-compiles a source and produces a lambda
|
||||
// object directly invocable through its call operator.
|
||||
class JitCompilerEngine : public CompilerEngine {
|
||||
public:
|
||||
// Wrapper class around `JITLambda` and `JITLambda::Argument` that
|
||||
// allows for direct invocation of a compiled function through
|
||||
// `operator ()`.
|
||||
class Lambda {
|
||||
public:
|
||||
Lambda(Lambda &&other)
|
||||
: innerLambda(std::move(other.innerLambda)),
|
||||
keySet(std::move(other.keySet)),
|
||||
compilationContext(other.compilationContext) {}
|
||||
|
||||
Lambda(std::shared_ptr<CompilationContext> compilationContext,
|
||||
std::unique_ptr<JITLambda> lambda, std::unique_ptr<KeySet> keySet)
|
||||
: innerLambda(std::move(lambda)), keySet(std::move(keySet)),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
// Returns the number of arguments required for an invocation of
|
||||
// the lambda
|
||||
size_t getNumArguments() { return this->keySet->numInputs(); }
|
||||
|
||||
// Returns the number of results an invocation of the lambda
|
||||
// produces
|
||||
size_t getNumResults() { return this->keySet->numOutputs(); }
|
||||
|
||||
// Invocation with an dynamic list of arguments of different
|
||||
// types, specified as `LambdaArgument`s
|
||||
template <typename ResT = uint64_t>
|
||||
llvm::Expected<ResT>
|
||||
operator()(llvm::ArrayRef<LambdaArgument *> lambdaArgs) {
|
||||
// Create the arguments of the JIT lambda
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
|
||||
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
|
||||
|
||||
if (llvm::Error err = argsOrErr.takeError())
|
||||
return StreamStringError("Could not create lambda arguments");
|
||||
|
||||
// Set the arguments
|
||||
std::unique_ptr<JITLambda::Argument> arguments =
|
||||
std::move(argsOrErr.get());
|
||||
|
||||
for (size_t i = 0; i < lambdaArgs.size(); i++) {
|
||||
if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument(
|
||||
*arguments, i, *lambdaArgs[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke the lambda
|
||||
if (auto err = this->innerLambda->invoke(*arguments))
|
||||
return StreamStringError() << "Cannot invoke lambda:" << err;
|
||||
|
||||
return std::move(typedResult<ResT>(*arguments));
|
||||
}
|
||||
|
||||
// Invocation with an array of arguments of the same type
|
||||
template <typename T, typename ResT = uint64_t>
|
||||
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
|
||||
// Create the arguments of the JIT lambda
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
|
||||
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
|
||||
|
||||
if (llvm::Error err = argsOrErr.takeError())
|
||||
return StreamStringError("Could not create lambda arguments");
|
||||
|
||||
// Set the arguments
|
||||
std::unique_ptr<JITLambda::Argument> arguments =
|
||||
std::move(argsOrErr.get());
|
||||
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, args[i])) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push argument " << i << ": " << err;
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke the lambda
|
||||
if (auto err = this->innerLambda->invoke(*arguments))
|
||||
return StreamStringError() << "Cannot invoke lambda:" << err;
|
||||
|
||||
return std::move(typedResult<ResT>(*arguments));
|
||||
}
|
||||
|
||||
// Invocation with arguments of different types
|
||||
template <typename ResT = uint64_t, typename... Ts>
|
||||
llvm::Expected<ResT> operator()(const Ts... ts) {
|
||||
// Create the arguments of the JIT lambda
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
|
||||
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
|
||||
|
||||
if (llvm::Error err = argsOrErr.takeError())
|
||||
return StreamStringError("Could not create lambda arguments");
|
||||
|
||||
// Set the arguments
|
||||
std::unique_ptr<JITLambda::Argument> arguments =
|
||||
std::move(argsOrErr.get());
|
||||
|
||||
if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...))
|
||||
return std::move(err);
|
||||
|
||||
// Invoke the lambda
|
||||
if (auto err = this->innerLambda->invoke(*arguments))
|
||||
return StreamStringError() << "Cannot invoke lambda:" << err;
|
||||
|
||||
return std::move(typedResult<ResT>(*arguments));
|
||||
}
|
||||
|
||||
protected:
|
||||
template <int pos>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs) {
|
||||
// base case -- nothing to do
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
// Recursive case for scalars: extract first scalar argument from
|
||||
// parameter pack and forward rest
|
||||
template <int pos, typename ArgT, typename... Ts>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg,
|
||||
Ts... remainder) {
|
||||
if (auto err = jitArgs->setArg(pos, arg)) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push scalar argument " << pos << ": " << err;
|
||||
}
|
||||
|
||||
return this->addArgs<pos + 1>(jitArgs, remainder...);
|
||||
}
|
||||
|
||||
// Recursive case for tensors: extract pointer and size from
|
||||
// parameter pack and forward rest
|
||||
template <int pos, typename ArgT, typename... Ts>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg,
|
||||
size_t size, Ts... remainder) {
|
||||
if (auto err = jitArgs->setArg(pos, arg, size)) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push tensor argument " << pos << ": " << err;
|
||||
}
|
||||
|
||||
return this->addArgs<pos + 1>(jitArgs, remainder...);
|
||||
}
|
||||
|
||||
std::unique_ptr<JITLambda> innerLambda;
|
||||
std::unique_ptr<KeySet> keySet;
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
};
|
||||
|
||||
JitCompilerEngine(std::shared_ptr<CompilationContext> compilationContext =
|
||||
CompilationContext::createShared(),
|
||||
unsigned int optimizationLevel = 3);
|
||||
|
||||
llvm::Expected<Lambda> buildLambda(llvm::StringRef src,
|
||||
llvm::StringRef funcName = "main");
|
||||
|
||||
llvm::Expected<Lambda> buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::StringRef funcName = "main");
|
||||
|
||||
llvm::Expected<Lambda> buildLambda(llvm::SourceMgr &sm,
|
||||
llvm::StringRef funcName = "main");
|
||||
|
||||
protected:
|
||||
llvm::Expected<mlir::LLVM::LLVMFuncOp> findLLVMFuncOp(mlir::ModuleOp module,
|
||||
llvm::StringRef name);
|
||||
unsigned int optimizationLevel;
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
170
compiler/include/zamalang/Support/LambdaArgument.h
Normal file
170
compiler/include/zamalang/Support/LambdaArgument.h
Normal file
@@ -0,0 +1,170 @@
|
||||
#ifndef ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H
|
||||
#define ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
#include <llvm/ADT/ArrayRef.h>
|
||||
#include <llvm/Support/Casting.h>
|
||||
#include <llvm/Support/ExtensibleRTTI.h>
|
||||
#include <zamalang/Support/Error.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// Abstract base class for lambda arguments
|
||||
class LambdaArgument
|
||||
: public llvm::RTTIExtends<LambdaArgument, llvm::RTTIRoot> {
|
||||
public:
|
||||
LambdaArgument(LambdaArgument &) = delete;
|
||||
|
||||
template <typename T> bool isa() const { return llvm::isa<T>(*this); }
|
||||
|
||||
// Cast functions on constant instances
|
||||
template <typename T> const T &cast() const { return llvm::cast<T>(*this); }
|
||||
template <typename T> const T *dyn_cast() const {
|
||||
return llvm::dyn_cast<T>(this);
|
||||
}
|
||||
|
||||
// Cast functions for mutable instances
|
||||
template <typename T> T &cast() { return llvm::cast<T>(*this); }
|
||||
template <typename T> T *dyn_cast() { return llvm::dyn_cast<T>(this); }
|
||||
|
||||
static char ID;
|
||||
|
||||
protected:
|
||||
LambdaArgument(){};
|
||||
};
|
||||
|
||||
// Class for integer arguments. `BackingIntType` is used as the data
|
||||
// type to hold the argument's value. The precision is the actual
|
||||
// precision of the value, which might be different from the precision
|
||||
// of the backing integer type.
|
||||
template <typename BackingIntType = uint64_t>
|
||||
class IntLambdaArgument
|
||||
: public llvm::RTTIExtends<IntLambdaArgument<BackingIntType>,
|
||||
LambdaArgument> {
|
||||
public:
|
||||
typedef BackingIntType value_type;
|
||||
|
||||
IntLambdaArgument(BackingIntType value,
|
||||
unsigned int precision = 8 * sizeof(BackingIntType))
|
||||
: precision(precision) {
|
||||
if (precision < 8 * sizeof(BackingIntType)) {
|
||||
this->value = value & (1 << (this->precision - 1));
|
||||
} else {
|
||||
this->value = value;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned int getPrecision() const { return this->precision; }
|
||||
BackingIntType getValue() const { return this->value; }
|
||||
|
||||
static char ID;
|
||||
|
||||
protected:
|
||||
unsigned int precision;
|
||||
BackingIntType value;
|
||||
};
|
||||
|
||||
template <typename BackingIntType>
|
||||
char IntLambdaArgument<BackingIntType>::ID = 0;
|
||||
|
||||
// Class for encrypted integer arguments. `BackingIntType` is used as
|
||||
// the data type to hold the argument's plaintext value. The precision
|
||||
// is the actual precision of the value, which might be different from
|
||||
// the precision of the backing integer type.
|
||||
template <typename BackingIntType = uint64_t>
|
||||
class EIntLambdaArgument
|
||||
: public llvm::RTTIExtends<EIntLambdaArgument<BackingIntType>,
|
||||
IntLambdaArgument<BackingIntType>> {
|
||||
public:
|
||||
static char ID;
|
||||
};
|
||||
|
||||
template <typename BackingIntType>
|
||||
char EIntLambdaArgument<BackingIntType>::ID = 0;
|
||||
|
||||
namespace {
|
||||
// Calculates `accu *= factor` or returns an error if the result
|
||||
// would overflow
|
||||
template <typename AccuT, typename ValT>
|
||||
llvm::Error safeUnsignedMul(AccuT &accu, ValT factor) {
|
||||
static_assert(std::numeric_limits<AccuT>::is_integer &&
|
||||
std::numeric_limits<ValT>::is_integer &&
|
||||
!std::numeric_limits<AccuT>::is_signed &&
|
||||
!std::numeric_limits<ValT>::is_signed,
|
||||
"Only unsigned integers are supported");
|
||||
|
||||
const AccuT left = std::numeric_limits<AccuT>::max() / accu;
|
||||
|
||||
if (left > factor) {
|
||||
accu *= factor;
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
return StreamStringError("Multiplying value ")
|
||||
<< accu << " with " << factor << " would cause an overflow";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Class for Tensor arguments. This can either be plaintext tensors
|
||||
// (for `ScalarArgumentT = IntLambaArgument<T>`) or tensors
|
||||
// representing encrypted integers (for `ScalarArgumentT =
|
||||
// EIntLambaArgument<T>`).
|
||||
template <typename ScalarArgumentT>
|
||||
class TensorLambdaArgument
|
||||
: public llvm::RTTIExtends<TensorLambdaArgument<ScalarArgumentT>,
|
||||
LambdaArgument> {
|
||||
public:
|
||||
typedef ScalarArgumentT scalar_type;
|
||||
|
||||
// Construct tensor argument from the one-dimensional array `value`,
|
||||
// but interpreting the array's values as a linearized
|
||||
// multi-dimensional tensor with the sizes of the dimensions
|
||||
// specified in `dimensions`.
|
||||
TensorLambdaArgument(
|
||||
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value,
|
||||
llvm::ArrayRef<int64_t> dimensions)
|
||||
: value(value), dimensions(dimensions.vec()) {}
|
||||
|
||||
// Construct a one-dimensional tensor argument from the
|
||||
// array `value`.
|
||||
TensorLambdaArgument(
|
||||
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value)
|
||||
: TensorLambdaArgument(value, {(int64_t)value.size()}) {}
|
||||
|
||||
const std::vector<int64_t> &getDimensions() const { return this->dimensions; }
|
||||
|
||||
// Returns the total number of elements in the tensor. If the number
|
||||
// of elements cannot be represented as a `size_t`, the method
|
||||
// returns an error.
|
||||
llvm::Expected<size_t> getNumElements() const {
|
||||
size_t accu = 1;
|
||||
|
||||
for (unsigned int dimSize : dimensions)
|
||||
if (llvm::Error err = safeUnsignedMul(accu, dimSize))
|
||||
return std::move(err);
|
||||
|
||||
return accu;
|
||||
}
|
||||
|
||||
// Returns a bare pointer to the linearized values of the tensor.
|
||||
typename ScalarArgumentT::value_type *getValue() const {
|
||||
return this->value.data();
|
||||
}
|
||||
|
||||
static char ID;
|
||||
|
||||
protected:
|
||||
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value;
|
||||
std::vector<int64_t> dimensions;
|
||||
};
|
||||
|
||||
template <typename ScalarArgumentT>
|
||||
char TensorLambdaArgument<ScalarArgumentT>::ID = 0;
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -4,43 +4,43 @@
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
|
||||
#include <zamalang/Support/V0Parameters.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace pipeline {
|
||||
|
||||
mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool debug);
|
||||
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module);
|
||||
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool verbose);
|
||||
mlir::LogicalResult
|
||||
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
V0FHEContext &fheContext,
|
||||
bool parametrize);
|
||||
mlir::LogicalResult
|
||||
lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module);
|
||||
mlir::LogicalResult
|
||||
lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool verbose);
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
|
||||
llvm::Module &module);
|
||||
|
||||
mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
V0FHEContext &fheContext, bool verbose);
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context,
|
||||
llvm::LLVMContext &llvmContext,
|
||||
mlir::ModuleOp &module);
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -32,6 +32,7 @@ private:
|
||||
StreamWrap<llvm::raw_ostream> &log_error(void);
|
||||
StreamWrap<llvm::raw_ostream> &log_verbose(void);
|
||||
void setupLogging(bool verbose);
|
||||
bool isVerbose();
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
Reference in New Issue
Block a user