refactor(compiler): Refactor CompilerEngine and related classes

This commit contains several incremental improvements towards a clear
interface for lambdas:

  - Unification of static and JIT compilation by using the static
    compilation path of `CompilerEngine` within a new subclass
    `JitCompilerEngine`.

  - Clear ownership for compilation artefacts through
    `CompilationContext`, making it impossible to destroy objects used
    directly or indirectly before destruction of their users.

  - Clear interface for lambdas generated by the compiler through
    `JitCompilerEngine::Lambda` with a templated call operator,
    encapsulating otherwise manual orchestration of `CompilerEngine`,
    `JITLambda`, and `CompilerEngine::Argument`.

  - Improved error handling through `llvm::Expected<T>` and proper
    error checking following the conventions for `llvm::Expected<T>`
    and `llvm::Error`.

Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
This commit is contained in:
Andi Drebes
2021-10-18 15:38:12 +02:00
parent d738104c4b
commit 1187cfbd62
61 changed files with 1690 additions and 997 deletions

View File

@@ -5,15 +5,17 @@
#include "mlir-c/Registration.h"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/ExecutionArgument.h"
#include "zamalang/Support/Jit.h"
#include "zamalang/Support/JitCompilerEngine.h"
#ifdef __cplusplus
extern "C" {
#endif
struct compilerEngine {
mlir::zamalang::CompilerEngine *ptr;
struct lambda {
mlir::zamalang::JitCompilerEngine::Lambda *ptr;
};
typedef struct compilerEngine compilerEngine;
typedef struct lambda lambda;
struct executionArguments {
mlir::zamalang::ExecutionArgument *data;
@@ -21,13 +23,12 @@ struct executionArguments {
};
typedef struct executionArguments exectuionArguments;
// Compile an MLIR module
MLIR_CAPI_EXPORTED void compilerEngineCompile(compilerEngine engine,
const char *module);
MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName);
// Run the compiled module
MLIR_CAPI_EXPORTED uint64_t compilerEngineRun(compilerEngine e,
executionArguments args);
MLIR_CAPI_EXPORTED uint64_t invokeLambda(lambda l, executionArguments args);
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
#ifdef __cplusplus
}

View File

@@ -1,49 +1,138 @@
#ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H
#define ZAMALANG_SUPPORT_COMPILER_ENGINE_H
#include "Jit.h"
#include <llvm/IR/Module.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/MLIRContext.h>
#include <zamalang/Conversion/Utils/GlobalFHEContext.h>
#include <zamalang/Support/ClientParameters.h>
#include <zamalang/Support/KeySet.h>
namespace mlir {
namespace zamalang {
/// CompilerEngine is an tools that provides tools to implements the compilation
/// flow and manage the compilation flow state.
// Compilation context that acts as the root owner of LLVM and MLIR
// data structures directly and indirectly referenced by artefacts
// produced by the `CompilerEngine`.
class CompilationContext {
public:
CompilationContext();
~CompilationContext();
mlir::MLIRContext *getMLIRContext();
llvm::LLVMContext *getLLVMContext();
static std::shared_ptr<CompilationContext> createShared();
protected:
mlir::MLIRContext *mlirContext;
llvm::LLVMContext *llvmContext;
};
class CompilerEngine {
public:
CompilerEngine() {
context = new mlir::MLIRContext();
loadDialects();
}
~CompilerEngine() {
if (context != nullptr)
delete context;
}
// Result of an invocation of the `CompilerEngine` with optional
// fields for the results produced by different stages.
class CompilationResult {
public:
CompilationResult(std::shared_ptr<CompilationContext> compilationContext =
CompilationContext::createShared())
: compilationContext(compilationContext) {}
// Compile an mlir programs from it's textual representation.
llvm::Error compile(
std::string mlirStr,
llvm::Optional<mlir::zamalang::V0FHEConstraint> overrideConstraints = {});
llvm::Optional<mlir::OwningModuleRef> mlirModuleRef;
llvm::Optional<mlir::zamalang::ClientParameters> clientParameters;
std::unique_ptr<mlir::zamalang::KeySet> keySet;
std::unique_ptr<llvm::Module> llvmModule;
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
// Build the jit lambda argument.
llvm::Expected<std::unique_ptr<JITLambda::Argument>> buildArgument();
protected:
std::shared_ptr<CompilationContext> compilationContext;
};
// Call the compiled function with and argument object.
llvm::Error invoke(JITLambda::Argument &arg);
// Specification of the exit stage of the compilation pipeline
enum class Target {
// Only read sources and produce corresponding MLIR module
ROUND_TRIP,
// Call the compiled function with a list of integer arguments.
llvm::Expected<uint64_t> run(std::vector<uint64_t> args);
// Read sources and exit before any lowering
HLFHE,
// Get a printable representation of the compiled module
std::string getCompiledModule();
// Read sources and attempt to run the Minimal Arithmetic Noise
// Padding pass
HLFHE_MANP,
// Read sources and lower all HLFHE operations to MidLFHE
// operations
MIDLFHE,
// Read sources and lower all HLFHE and MidLFHE operations to LowLFHE
// operations
LOWLFHE,
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
// operations to canonical MLIR dialects. Cryptographic operations
// are lowered to invocations of the concrete library.
STD,
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
// operations to operations from the LLVM dialect. Cryptographic
// operations are lowered to invocations of the concrete library.
LLVM,
// Same as `LLVM`, but lowers to actual LLVM IR instead of the
// LLVM dialect
LLVM_IR,
// Same as `LLVM_IR`, but invokes the LLVM optimization pipeline
// to produce optimized LLVM IR
OPTIMIZED_LLVM_IR
};
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
: overrideMaxEintPrecision(), overrideMaxMANP(),
clientParametersFuncName(), verifyDiagnostics(false),
generateKeySet(false), generateClientParameters(false),
parametrizeMidLFHE(true), compilationContext(compilationContext) {}
llvm::Expected<CompilationResult> compile(llvm::StringRef s, Target target);
llvm::Expected<CompilationResult>
compile(std::unique_ptr<llvm::MemoryBuffer> buffer, Target target);
llvm::Expected<CompilationResult> compile(llvm::SourceMgr &sm, Target target);
void setFHEConstraints(const mlir::zamalang::V0FHEConstraint &c);
void setMaxEintPrecision(size_t v);
void setMaxMANP(size_t v);
void setVerifyDiagnostics(bool v);
void setGenerateKeySet(bool v);
void setGenerateClientParameters(bool v);
void setParametrizeMidLFHE(bool v);
void setClientParametersFuncName(const llvm::StringRef &name);
protected:
llvm::Optional<size_t> overrideMaxEintPrecision;
llvm::Optional<size_t> overrideMaxMANP;
llvm::Optional<std::string> clientParametersFuncName;
bool verifyDiagnostics;
bool generateKeySet;
bool generateClientParameters;
bool parametrizeMidLFHE;
std::shared_ptr<CompilationContext> compilationContext;
// Helper enum identifying an FHE dialect (`HLFHE`, `MIDLFHE`, `LOWLFHE`)
// or indicating that no FHE dialect is used (`NONE`).
enum class FHEDialect { HLFHE, MIDLFHE, LOWLFHE, NONE };
static FHEDialect detectHighestFHEDialect(mlir::ModuleOp module);
private:
// Load the necessary dialects into the engine's context
void loadDialects();
mlir::OwningModuleRef module_ref;
mlir::MLIRContext *context;
std::unique_ptr<mlir::zamalang::KeySet> keySet;
llvm::Error lowerParamDependentHalf(Target target, CompilationResult &res);
llvm::Error determineFHEParameters(CompilationResult &res, bool noOverride);
};
} // namespace zamalang
} // namespace mlir

View File

@@ -9,11 +9,6 @@
namespace mlir {
namespace zamalang {
mlir::LogicalResult
runJit(mlir::ModuleOp module, llvm::StringRef func,
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
std::function<llvm::Error(llvm::Module *)> optPipeline,
llvm::raw_ostream &os);
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
/// of the module.

View File

@@ -0,0 +1,296 @@
#ifndef ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
#define ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <zamalang/Support/CompilerEngine.h>
#include <zamalang/Support/Error.h>
#include <zamalang/Support/Jit.h>
#include <zamalang/Support/LambdaArgument.h>
namespace mlir {
namespace zamalang {
namespace {
// Generic function template as well as specializations of
// `typedResult` must be declared at namespace scope due to return
// type template specialization
// Helper function for `JitCompilerEngine::Lambda::operator()`
// implementing type-dependent preparation of the result.
template <typename ResT>
llvm::Expected<ResT> typedResult(JITLambda::Argument &arguments);
// Specialization of `typedResult()` for scalar results, forwarding
// scalar value to caller
template <>
inline llvm::Expected<uint64_t> typedResult(JITLambda::Argument &arguments) {
uint64_t res = 0;
if (auto err = arguments.getResult(0, res))
return StreamStringError() << "Cannot retrieve result:" << err;
return res;
}
// Specialization of `typedResult()` for vector results, initializing
// an `std::vector` of the right size with the results and forwarding
// it to the caller with move semantics.
template <>
inline llvm::Expected<std::vector<uint64_t>>
typedResult(JITLambda::Argument &arguments) {
llvm::Expected<size_t> n = arguments.getResultVectorSize(0);
if (auto err = n.takeError())
return std::move(err);
std::vector<uint64_t> res(*n);
if (auto err = arguments.getResult(0, res.data(), res.size()))
return StreamStringError() << "Cannot retrieve result:" << err;
return std::move(res);
}
// Adaptor class that adds arguments specified as instances of
// `LambdaArgument` to `JitLambda::Argument`.
class JITLambdaArgumentAdaptor {
public:
// Checks if the argument `arg` is an plaintext / encrypted integer
// argument or a plaintext / encrypted tensor argument with a
// backing integer type `IntT` and adds the argument to `jla` at
// position `pos`.
//
// Returns `true` if `arg` has one of the types above and its value
// was successfully added to `jla`, `false` if none of the types
// matches or an error if a type matched, but adding the argument to
// `jla` failed.
template <typename IntT>
static inline llvm::Expected<bool>
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
if (llvm::Error err = jla.setArg(pos, ila->getValue()))
return std::move(err);
else
return true;
} else if (auto tla = arg.dyn_cast<
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
llvm::Expected<size_t> numElements = tla->getNumElements();
if (!numElements)
return std::move(numElements.takeError());
if (llvm::Error err = jla.setArg(pos, tla->getValue(), *numElements))
return std::move(err);
else
return true;
}
return false;
}
// Recursive case for `tryAddArg<IntT>(...)`
template <typename IntT, typename NextIntT, typename... IntTs>
static inline llvm::Expected<bool>
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
llvm::Expected<bool> successOrError = tryAddArg<IntT>(jla, pos, arg);
if (!successOrError)
return std::move(successOrError.takeError());
if (successOrError.get() == false)
return tryAddArg<NextIntT, IntTs...>(jla, pos, arg);
else
return true;
}
// Attempts to add a single argument `arg` to `jla` at position
// `pos`. Returns an error if either the argument type is
// unsupported or if the argument types is supported, but adding it
// to `jla` failed.
static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos,
const LambdaArgument &arg) {
llvm::Expected<bool> successOrError =
JITLambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t,
uint8_t>(jla, pos, arg);
if (!successOrError)
return std::move(successOrError.takeError());
if (successOrError.get() == false)
return StreamStringError("Unknown argument type");
else
return llvm::Error::success();
}
};
} // namespace
// A compiler engine that JIT-compiles a source and produces a lambda
// object directly invocable through its call operator.
class JitCompilerEngine : public CompilerEngine {
public:
// Wrapper class around `JITLambda` and `JITLambda::Argument` that
// allows for direct invocation of a compiled function through
// `operator ()`.
class Lambda {
public:
Lambda(Lambda &&other)
: innerLambda(std::move(other.innerLambda)),
keySet(std::move(other.keySet)),
compilationContext(other.compilationContext) {}
Lambda(std::shared_ptr<CompilationContext> compilationContext,
std::unique_ptr<JITLambda> lambda, std::unique_ptr<KeySet> keySet)
: innerLambda(std::move(lambda)), keySet(std::move(keySet)),
compilationContext(compilationContext) {}
// Returns the number of arguments required for an invocation of
// the lambda
size_t getNumArguments() { return this->keySet->numInputs(); }
// Returns the number of results an invocation of the lambda
// produces
size_t getNumResults() { return this->keySet->numOutputs(); }
// Invocation with an dynamic list of arguments of different
// types, specified as `LambdaArgument`s
template <typename ResT = uint64_t>
llvm::Expected<ResT>
operator()(llvm::ArrayRef<LambdaArgument *> lambdaArgs) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
for (size_t i = 0; i < lambdaArgs.size(); i++) {
if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument(
*arguments, i, *lambdaArgs[i])) {
return std::move(err);
}
}
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
return std::move(typedResult<ResT>(*arguments));
}
// Invocation with an array of arguments of the same type
template <typename T, typename ResT = uint64_t>
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
for (size_t i = 0; i < args.size(); i++) {
if (auto err = arguments->setArg(i, args[i])) {
return StreamStringError()
<< "Cannot push argument " << i << ": " << err;
}
}
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
return std::move(typedResult<ResT>(*arguments));
}
// Invocation with arguments of different types
template <typename ResT = uint64_t, typename... Ts>
llvm::Expected<ResT> operator()(const Ts... ts) {
// Create the arguments of the JIT lambda
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
if (llvm::Error err = argsOrErr.takeError())
return StreamStringError("Could not create lambda arguments");
// Set the arguments
std::unique_ptr<JITLambda::Argument> arguments =
std::move(argsOrErr.get());
if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...))
return std::move(err);
// Invoke the lambda
if (auto err = this->innerLambda->invoke(*arguments))
return StreamStringError() << "Cannot invoke lambda:" << err;
return std::move(typedResult<ResT>(*arguments));
}
protected:
template <int pos>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs) {
// base case -- nothing to do
return llvm::Error::success();
}
// Recursive case for scalars: extract first scalar argument from
// parameter pack and forward rest
template <int pos, typename ArgT, typename... Ts>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg,
Ts... remainder) {
if (auto err = jitArgs->setArg(pos, arg)) {
return StreamStringError()
<< "Cannot push scalar argument " << pos << ": " << err;
}
return this->addArgs<pos + 1>(jitArgs, remainder...);
}
// Recursive case for tensors: extract pointer and size from
// parameter pack and forward rest
template <int pos, typename ArgT, typename... Ts>
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg,
size_t size, Ts... remainder) {
if (auto err = jitArgs->setArg(pos, arg, size)) {
return StreamStringError()
<< "Cannot push tensor argument " << pos << ": " << err;
}
return this->addArgs<pos + 1>(jitArgs, remainder...);
}
std::unique_ptr<JITLambda> innerLambda;
std::unique_ptr<KeySet> keySet;
std::shared_ptr<CompilationContext> compilationContext;
};
JitCompilerEngine(std::shared_ptr<CompilationContext> compilationContext =
CompilationContext::createShared(),
unsigned int optimizationLevel = 3);
llvm::Expected<Lambda> buildLambda(llvm::StringRef src,
llvm::StringRef funcName = "main");
llvm::Expected<Lambda> buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::StringRef funcName = "main");
llvm::Expected<Lambda> buildLambda(llvm::SourceMgr &sm,
llvm::StringRef funcName = "main");
protected:
llvm::Expected<mlir::LLVM::LLVMFuncOp> findLLVMFuncOp(mlir::ModuleOp module,
llvm::StringRef name);
unsigned int optimizationLevel;
};
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -0,0 +1,157 @@
#ifndef ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H
#define ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H
#include <cstdint>
#include <limits>
#include <llvm/ADT/ArrayRef.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/ExtensibleRTTI.h>
#include <zamalang/Support/Error.h>
namespace mlir {
namespace zamalang {
// Abstract base class for lambda arguments
class LambdaArgument
: public llvm::RTTIExtends<LambdaArgument, llvm::RTTIRoot> {
public:
LambdaArgument(LambdaArgument &) = delete;
template <typename T> bool isa() const { return llvm::isa<T>(*this); }
// Cast functions on constant instances
template <typename T> const T &cast() const { return llvm::cast<T>(*this); }
template <typename T> const T *dyn_cast() const {
return llvm::dyn_cast<T>(this);
}
// Cast functions for mutable instances
template <typename T> T &cast() { return llvm::cast<T>(*this); }
template <typename T> T *dyn_cast() { return llvm::dyn_cast<T>(this); }
static char ID;
protected:
LambdaArgument(){};
};
// Class for integer arguments. `BackingIntType` is used as the data
// type to hold the argument's value. The precision is the actual
// precision of the value, which might be different from the precision
// of the backing integer type.
template <typename BackingIntType = uint64_t>
class IntLambdaArgument
: public llvm::RTTIExtends<IntLambdaArgument<BackingIntType>,
LambdaArgument> {
public:
typedef BackingIntType value_type;
IntLambdaArgument(BackingIntType value,
unsigned int precision = 8 * sizeof(BackingIntType))
: precision(precision) {
if (precision < 8 * sizeof(BackingIntType)) {
this->value = value & (1 << (this->precision - 1));
} else {
this->value = value;
}
}
unsigned int getPrecision() const { return this->precision; }
BackingIntType getValue() const { return this->value; }
static char ID;
protected:
unsigned int precision;
BackingIntType value;
};
template <typename BackingIntType>
char IntLambdaArgument<BackingIntType>::ID = 0;
namespace {
// Calculates `accu *= factor` or returns an error if the result
// would overflow
template <typename AccuT, typename ValT>
llvm::Error safeUnsignedMul(AccuT &accu, ValT factor) {
static_assert(std::numeric_limits<AccuT>::is_integer &&
std::numeric_limits<ValT>::is_integer &&
!std::numeric_limits<AccuT>::is_signed &&
!std::numeric_limits<ValT>::is_signed,
"Only unsigned integers are supported");
const AccuT left = std::numeric_limits<AccuT>::max() / accu;
if (left > factor) {
accu *= factor;
return llvm::Error::success();
}
return StreamStringError("Multiplying value ")
<< accu << " with " << factor << " would cause an overflow";
}
} // namespace
// Class for Tensor arguments. This can either be plaintext tensors
// (for `ScalarArgumentT = IntLambaArgument<T>`) or tensors
// representing encrypted integers (for `ScalarArgumentT =
// EIntLambaArgument<T>`).
template <typename ScalarArgumentT>
class TensorLambdaArgument
: public llvm::RTTIExtends<TensorLambdaArgument<ScalarArgumentT>,
LambdaArgument> {
public:
typedef ScalarArgumentT scalar_type;
// Construct tensor argument from the one-dimensional array `value`,
// but interpreting the array's values as a linearized
// multi-dimensional tensor with the sizes of the dimensions
// specified in `dimensions`.
TensorLambdaArgument(
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value,
llvm::ArrayRef<unsigned int> dimensions)
: value(value), dimensions(dimensions.vec()) {}
// Construct a one-dimensional tensor argument from the
// array `value`.
TensorLambdaArgument(
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value)
: TensorLambdaArgument(value, {(unsigned int)value.size()}) {}
const std::vector<unsigned int> &getDimensions() const {
return this->dimensions;
}
// Returns the total number of elements in the tensor. If the number
// of elements cannot be represented as a `size_t`, the method
// returns an error.
llvm::Expected<size_t> getNumElements() const {
size_t accu = 1;
for (unsigned int dimSize : dimensions)
if (llvm::Error err = safeUnsignedMul(accu, dimSize))
return std::move(err);
return accu;
}
// Returns a bare pointer to the linearized values of the tensor.
typename ScalarArgumentT::value_type *getValue() const {
return this->value.data();
}
static char ID;
protected:
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value;
std::vector<unsigned int> dimensions;
};
template <typename ScalarArgumentT>
char TensorLambdaArgument<ScalarArgumentT>::ID = 0;
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -1,8 +1,9 @@
#include "CompilerAPIModule.h"
#include "zamalang-c/Support/CompilerEngine.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEOpsDialect.h.inc"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/ExecutionArgument.h"
#include "zamalang/Support/Jit.h"
#include "zamalang/Support/JitCompilerEngine.h"
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/ExecutionEngine/OptUtils.h>
@@ -14,27 +15,15 @@
#include <stdexcept>
#include <string>
using mlir::zamalang::CompilerEngine;
using mlir::zamalang::ExecutionArgument;
using mlir::zamalang::JitCompilerEngine;
/// Populate the compiler API python module.
void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
m.doc() = "Zamalang compiler python API";
m.def("round_trip", [](std::string mlir_input) {
mlir::MLIRContext context;
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
auto module_ref = mlir::parseSourceString(mlir_input, &context);
if (!module_ref) {
throw std::logic_error("mlir parsing failed");
}
std::string result;
llvm::raw_string_ostream os(result);
module_ref->print(os);
return os.str();
});
m.def("round_trip",
[](std::string mlir_input) { return roundTrip(mlir_input.c_str()); });
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
m, "ExecutionArgument")
@@ -45,20 +34,19 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
.def("is_tensor", &ExecutionArgument::isTensor)
.def("is_int", &ExecutionArgument::isInt);
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
pybind11::class_<JitCompilerEngine>(m, "JitCompilerEngine")
.def(pybind11::init())
.def("run",
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
// wrap and call CAPI
compilerEngine e{&engine};
exectuionArguments a{args.data(), args.size()};
return compilerEngineRun(e, a);
})
.def("compile_fhe",
[](CompilerEngine &engine, std::string mlir_input) {
// wrap and call CAPI
compilerEngine e{&engine};
compilerEngineCompile(e, mlir_input.c_str());
})
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
.def_static("build_lambda",
[](std::string mlir_input, std::string func_name) {
return buildLambda(mlir_input.c_str(), func_name.c_str());
});
pybind11::class_<JitCompilerEngine::Lambda>(m, "Lambda")
.def("invoke", [](JitCompilerEngine::Lambda &py_lambda,
std::vector<ExecutionArgument> args) {
// wrap and call CAPI
lambda c_lambda{&py_lambda};
exectuionArguments a{args.data(), args.size()};
return invokeLambda(c_lambda, a);
});
}

View File

@@ -1,10 +1,9 @@
"""Compiler submodule"""
from typing import List, Union
from mlir._mlir_libs._zamalang._compiler import CompilerEngine as _CompilerEngine
from mlir._mlir_libs._zamalang._compiler import JitCompilerEngine as _JitCompilerEngine
from mlir._mlir_libs._zamalang._compiler import ExecutionArgument as _ExecutionArgument
from mlir._mlir_libs._zamalang._compiler import round_trip as _round_trip
def round_trip(mlir_str: str) -> str:
"""Parse the MLIR input, then return it back.
@@ -49,25 +48,24 @@ def create_execution_argument(value: Union[int, List[int]]) -> "_ExecutionArgume
class CompilerEngine:
def __init__(self, mlir_str: str = None):
self._engine = _CompilerEngine()
self._engine = _JitCompilerEngine()
self._lambda = None
if mlir_str is not None:
self.compile_fhe(mlir_str)
def compile_fhe(self, mlir_str: str) -> "CompilerEngine":
"""Compile the MLIR input and build a CompilerEngine.
def compile_fhe(self, mlir_str: str, func_name: str = "main"):
"""Compile the MLIR input.
Args:
mlir_str (str): MLIR to compile.
func_name (str): name of the function to set as entrypoint.
Raises:
TypeError: if the argument is not an str.
Returns:
CompilerEngine: engine used for execution.
"""
if not isinstance(mlir_str, str):
raise TypeError("input must be an `str`")
return self._engine.compile_fhe(mlir_str)
self._lambda = self._engine.build_lambda(mlir_str, func_name)
def run(self, *args: List[Union[int, List[int]]]) -> int:
"""Run the compiled code.
@@ -77,17 +75,12 @@ class CompilerEngine:
Raises:
TypeError: if execution arguments can't be constructed
RuntimeError: if the engine has not compiled any code yet
Returns:
int: result of execution.
"""
if self._lambda is None:
raise RuntimeError("need to compile an MLIR code first")
execution_arguments = [create_execution_argument(arg) for arg in args]
return self._engine.run(execution_arguments)
def get_compiled_module(self) -> str:
"""Compiled module in printable form.
Returns:
str: Compiled module in printable form.
"""
return self._engine.get_compiled_module()
return self._lambda.invoke(execution_arguments)

View File

@@ -1,62 +1,83 @@
#include "zamalang-c/Support/CompilerEngine.h"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/ExecutionArgument.h"
#include "zamalang/Support/Jit.h"
#include "zamalang/Support/JitCompilerEngine.h"
#include "zamalang/Support/logging.h"
using mlir::zamalang::CompilerEngine;
// using mlir::zamalang::CompilerEngine;
using mlir::zamalang::ExecutionArgument;
using mlir::zamalang::JitCompilerEngine;
void compilerEngineCompile(compilerEngine engine, const char *module) {
auto error = engine.ptr->compile(module);
if (error) {
llvm::errs() << "Compilation failed: " << error << "\n";
llvm::consumeError(std::move(error));
mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module,
const char *funcName) {
mlir::zamalang::JitCompilerEngine engine;
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
engine.buildLambda(module, funcName);
if (!lambdaOrErr) {
mlir::zamalang::log_error()
<< "Compilation failed: "
<< llvm::toString(std::move(lambdaOrErr.takeError())) << "\n";
throw std::runtime_error(
"failed compiling, see previous logs for more info");
}
return std::move(*lambdaOrErr);
}
uint64_t compilerEngineRun(compilerEngine engine, exectuionArguments args) {
auto args_size = args.size;
auto maybeArgument = engine.ptr->buildArgument();
if (auto err = maybeArgument.takeError()) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error(
"failed building arguments, see previous logs for more info");
uint64_t invokeLambda(lambda l, executionArguments args) {
mlir::zamalang::JitCompilerEngine::Lambda *lambda_ptr =
(mlir::zamalang::JitCompilerEngine::Lambda *)l.ptr;
if (args.size != lambda_ptr->getNumArguments()) {
throw std::invalid_argument("wrong number of arguments");
}
// Set the integer/tensor arguments
auto arguments = std::move(maybeArgument.get());
for (auto i = 0; i < args_size; i++) {
std::vector<mlir::zamalang::LambdaArgument *> lambdaArgumentsRef;
for (auto i = 0; i < args.size; i++) {
if (args.data[i].isInt()) { // integer argument
if (auto err = arguments->setArg(i, args.data[i].getIntegerArgument())) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error("failed pushing integer argument, see "
"previous logs for more info");
}
lambdaArgumentsRef.push_back(new mlir::zamalang::IntLambdaArgument<>(
args.data[i].getIntegerArgument()));
} else { // tensor argument
assert(args.data[i].isTensor() && "should be tensor argument");
if (auto err = arguments->setArg(i, args.data[i].getTensorArgument(),
args.data[i].getTensorSize())) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error("failed pushing tensor argument, see "
"previous logs for more info");
}
llvm::MutableArrayRef<uint8_t> tensor(args.data[i].getTensorArgument(),
args.data[i].getTensorSize());
lambdaArgumentsRef.push_back(
new mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>(tensor));
}
}
// Invoke the lambda
if (auto err = engine.ptr->invoke(*arguments)) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error("failed running, see previous logs for more info");
}
uint64_t result = 0;
if (auto err = arguments->getResult(0, result)) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
// Run lambda
llvm::Expected<uint64_t> resOrError = (*lambda_ptr)(
llvm::ArrayRef<mlir::zamalang::LambdaArgument *>(lambdaArgumentsRef));
// Free heap
for (size_t i = 0; i < lambdaArgumentsRef.size(); i++)
delete lambdaArgumentsRef[i];
if (!resOrError) {
mlir::zamalang::log_error()
<< "Lambda invokation failed: "
<< llvm::toString(std::move(resOrError.takeError())) << "\n";
throw std::runtime_error(
"failed getting result, see previous logs for more info");
"failed invoking lambda, see previous logs for more info");
}
return result;
}
return *resOrError;
}
std::string roundTrip(const char *module) {
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
mlir::zamalang::CompilationContext::createShared();
mlir::zamalang::JitCompilerEngine ce{ccx};
llvm::Expected<mlir::zamalang::CompilerEngine::CompilationResult> retOrErr =
ce.compile(module, mlir::zamalang::CompilerEngine::Target::ROUND_TRIP);
if (!retOrErr) {
mlir::zamalang::log_error()
<< llvm::toString(std::move(retOrErr.takeError())) << "\n";
throw std::runtime_error(
"mlir parsing failed, see previous logs for more info");
}
std::string result;
llvm::raw_string_ostream os(result);
retOrErr->mlirModuleRef->get().print(os);
return os.str();
}

View File

@@ -3,6 +3,8 @@ add_mlir_library(ZamalangSupport
Pipeline.cpp
Jit.cpp
CompilerEngine.cpp
JitCompilerEngine.cpp
LambdaArgument.cpp
V0Parameters.cpp
V0Curves.cpp
ClientParameters.cpp

View File

@@ -1,3 +1,5 @@
#include <llvm/Support/Error.h>
#include <llvm/Support/SMLoc.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
@@ -9,155 +11,419 @@
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h>
#include <zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h>
#include <zamalang/Support/CompilerEngine.h>
#include <zamalang/Support/Error.h>
#include <zamalang/Support/Jit.h>
#include <zamalang/Support/Pipeline.h>
namespace mlir {
namespace zamalang {
void CompilerEngine::loadDialects() {
context->getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context->getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
context->getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
context->getOrLoadDialect<mlir::StandardOpsDialect>();
context->getOrLoadDialect<mlir::memref::MemRefDialect>();
context->getOrLoadDialect<mlir::linalg::LinalgDialect>();
context->getOrLoadDialect<mlir::LLVM::LLVMDialect>();
// Creates a new compilation context that can be shared across
// compilation engines and results
std::shared_ptr<CompilationContext> CompilationContext::createShared() {
return std::make_shared<CompilationContext>();
}
std::string CompilerEngine::getCompiledModule() {
std::string compiledModule;
llvm::raw_string_ostream os(compiledModule);
module_ref->print(os);
return os.str();
CompilationContext::CompilationContext()
: mlirContext(nullptr), llvmContext(nullptr) {}
CompilationContext::~CompilationContext() {
delete this->mlirContext;
delete this->llvmContext;
}
llvm::Error CompilerEngine::compile(
std::string mlirStr,
llvm::Optional<mlir::zamalang::V0FHEConstraint> overrideConstraints) {
module_ref = mlir::parseSourceString(mlirStr, context);
if (!module_ref) {
return llvm::make_error<llvm::StringError>("mlir parsing failed",
llvm::inconvertibleErrorCode());
// Returns the MLIR context for a compilation context. Creates and
// initializes a new MLIR context if necessary.
mlir::MLIRContext *CompilationContext::getMLIRContext() {
if (this->mlirContext == nullptr) {
this->mlirContext = new mlir::MLIRContext();
this->mlirContext->getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
this->mlirContext
->getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
this->mlirContext
->getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
this->mlirContext->getOrLoadDialect<mlir::StandardOpsDialect>();
this->mlirContext->getOrLoadDialect<mlir::memref::MemRefDialect>();
this->mlirContext->getOrLoadDialect<mlir::linalg::LinalgDialect>();
this->mlirContext->getOrLoadDialect<mlir::LLVM::LLVMDialect>();
}
mlir::ModuleOp module = module_ref.get();
return this->mlirContext;
}
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraintsOpt =
overrideConstraints;
// Returns the LLVM context for a compilation context. Creates and
// initializes a new LLVM context if necessary.
llvm::LLVMContext *CompilationContext::getLLVMContext() {
if (this->llvmContext == nullptr)
this->llvmContext = new llvm::LLVMContext();
if (!fheConstraintsOpt.hasValue()) {
return this->llvmContext;
}
// Sets the FHE constraints for the compilation. Overrides any
// automatically detected configuration and prevents the autodetection
// pass from running.
void CompilerEngine::setFHEConstraints(
const mlir::zamalang::V0FHEConstraint &c) {
this->overrideMaxEintPrecision = c.p;
this->overrideMaxMANP = c.norm2;
}
void CompilerEngine::setVerifyDiagnostics(bool v) {
this->verifyDiagnostics = v;
}
void CompilerEngine::setGenerateKeySet(bool v) { this->generateKeySet = v; }
void CompilerEngine::setGenerateClientParameters(bool v) {
this->generateClientParameters = v;
}
void CompilerEngine::setMaxEintPrecision(size_t v) {
this->overrideMaxEintPrecision = v;
}
void CompilerEngine::setParametrizeMidLFHE(bool v) {
this->parametrizeMidLFHE = v;
}
void CompilerEngine::setMaxMANP(size_t v) { this->overrideMaxMANP = v; }
void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) {
this->clientParametersFuncName = name.str();
}
// Helper function detecting the FHE dialect with the highest level of
// abstraction used in `module`. If no FHE dialect is used, the
// function returns `CompilerEngine::FHEDialect::NONE`.
CompilerEngine::FHEDialect
CompilerEngine::detectHighestFHEDialect(mlir::ModuleOp module) {
CompilerEngine::FHEDialect highestDialect = CompilerEngine::FHEDialect::NONE;
mlir::TypeID hlfheID =
mlir::TypeID::get<mlir::zamalang::HLFHE::HLFHEDialect>();
mlir::TypeID midlfheID =
mlir::TypeID::get<mlir::zamalang::MidLFHE::MidLFHEDialect>();
mlir::TypeID lowlfheID =
mlir::TypeID::get<mlir::zamalang::LowLFHE::LowLFHEDialect>();
// Helper lambda updating the currently highest dialect if necessary
// by dialect type ID
auto updateDialectFromDialectID = [&](mlir::TypeID dialectID) {
if (dialectID == hlfheID) {
highestDialect = CompilerEngine::FHEDialect::HLFHE;
return true;
} else if (dialectID == lowlfheID &&
highestDialect == CompilerEngine::FHEDialect::NONE) {
highestDialect = CompilerEngine::FHEDialect::LOWLFHE;
} else if (dialectID == midlfheID &&
(highestDialect == CompilerEngine::FHEDialect::NONE ||
highestDialect == CompilerEngine::FHEDialect::LOWLFHE)) {
highestDialect = CompilerEngine::FHEDialect::MIDLFHE;
}
return false;
};
// Helper lambda updating the currently highest dialect if necessary
// by value type
std::function<bool(mlir::Type)> updateDialectFromType =
[&](mlir::Type ty) -> bool {
if (updateDialectFromDialectID(ty.getDialect().getTypeID()))
return true;
if (mlir::TensorType tensorTy = ty.dyn_cast_or_null<mlir::TensorType>())
return updateDialectFromType(tensorTy.getElementType());
return false;
};
module.walk([&](mlir::Operation *op) {
// Check operation itself
if (updateDialectFromDialectID(op->getDialect()->getTypeID()))
return mlir::WalkResult::interrupt();
// Check types of operands
for (mlir::Value operand : op->getOperands()) {
if (updateDialectFromType(operand.getType()))
return mlir::WalkResult::interrupt();
}
// Check types of results
for (mlir::Value res : op->getResults()) {
if (updateDialectFromType(res.getType())) {
return mlir::WalkResult::interrupt();
}
}
return mlir::WalkResult::advance();
});
return highestDialect;
}
// Sets the FHE parameters of `res` either through autodetection or
// fixed constraints provided in
// `CompilerEngine::overrideMaxEintPrecision` and
// `CompilerEngine::overrideMaxMANP`.
//
// Autodetected values can be partially or fully overridden through
// `CompilerEngine::overrideMaxEintPrecision` and
// `CompilerEngine::overrideMaxMANP`.
//
// If `noOverrideAutodetected` is true, autodetected values are not
// overriden and used directly for `res`.
//
// Return an error if autodetection fails.
llvm::Error
CompilerEngine::determineFHEParameters(CompilationResult &res,
bool noOverrideAutodetected) {
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
mlir::ModuleOp module = res.mlirModuleRef->get();
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraints;
// Determine FHE constraints either through autodetection or through
// overridden values
if (this->overrideMaxEintPrecision.hasValue() &&
this->overrideMaxMANP.hasValue() && !noOverrideAutodetected) {
fheConstraints.emplace(mlir::zamalang::V0FHEConstraint{
this->overrideMaxMANP.getValue(),
this->overrideMaxEintPrecision.getValue()});
} else {
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
fheConstraintsOrErr =
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(*context,
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(mlirContext,
module);
if (auto err = fheConstraintsOrErr.takeError())
return std::move(err);
if (!fheConstraintsOrErr.get().hasValue()) {
return llvm::make_error<llvm::StringError>(
"Could not determine maximum required precision for encrypted "
"integers "
"and maximum value for the Minimal Arithmetic Noise Padding",
llvm::inconvertibleErrorCode());
return StreamStringError("Could not determine maximum required precision "
"for encrypted integers and maximum value for "
"the Minimal Arithmetic Noise Padding");
}
fheConstraintsOpt = fheConstraintsOrErr.get();
if (noOverrideAutodetected)
return llvm::Error::success();
fheConstraints = fheConstraintsOrErr.get();
// Override individual values if requested
if (this->overrideMaxEintPrecision.hasValue())
fheConstraints->p = this->overrideMaxEintPrecision.getValue();
if (this->overrideMaxMANP.hasValue())
fheConstraints->norm2 = this->overrideMaxMANP.getValue();
}
mlir::zamalang::V0FHEConstraint fheConstraints = fheConstraintsOpt.getValue();
const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints);
const mlir::zamalang::V0Parameter *fheParams =
getV0Parameter(fheConstraints.getValue());
if (!parameter) {
std::string buffer;
llvm::raw_string_ostream strs(buffer);
strs << "Could not determine V0 parameters for 2-norm of "
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
return llvm::make_error<llvm::StringError>(strs.str(),
llvm::inconvertibleErrorCode());
if (!fheParams) {
return StreamStringError()
<< "Could not determine V0 parameters for 2-norm of "
<< fheConstraints->norm2 << " and p of " << fheConstraints->p;
}
mlir::zamalang::V0FHEContext fheContext{fheConstraints, *parameter};
res.fheContext.emplace(
mlir::zamalang::V0FHEContext{*fheConstraints, *fheParams});
// Lower to MLIR Std
if (mlir::zamalang::pipeline::lowerHLFHEToStd(*context, module, fheContext,
false)
.failed()) {
return llvm::make_error<llvm::StringError>("failed to lower to MLIR Std",
llvm::inconvertibleErrorCode());
}
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheContext, "main", module_ref.get());
if (auto err = clientParameter.takeError()) {
return std::move(err);
}
auto maybeKeySet =
mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0);
if (auto err = maybeKeySet.takeError()) {
return std::move(err);
}
keySet = std::move(maybeKeySet.get());
// Lower to MLIR LLVM Dialect
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(*context, module, false)
.failed()) {
return llvm::make_error<llvm::StringError>(
"failed to lower to LLVM dialect", llvm::inconvertibleErrorCode());
}
return llvm::Error::success();
}
llvm::Expected<std::unique_ptr<JITLambda::Argument>>
CompilerEngine::buildArgument() {
if (keySet.get() == nullptr) {
return llvm::make_error<llvm::StringError>(
"CompilerEngine::buildArgument: invalid engine state, the keySet has "
"not been generated",
llvm::inconvertibleErrorCode());
}
return JITLambda::Argument::create(*keySet);
}
// Performs all lowering from HLFHE to the FHE dialect with the lwoest
// level of abstraction that requires FHE parameters.
//
// Returns an error if any of the lowerings fails.
llvm::Error CompilerEngine::lowerParamDependentHalf(Target target,
CompilationResult &res) {
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
mlir::ModuleOp module = res.mlirModuleRef->get();
llvm::Error CompilerEngine::invoke(JITLambda::Argument &arg) {
// Create the JIT lambda
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
auto module = module_ref.get();
auto maybeLambda =
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
if (auto err = maybeLambda.takeError()) {
return std::move(err);
// HLFHE -> MidLFHE
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(mlirContext, module, false)
.failed()) {
return StreamStringError("Lowering from HLFHE to MidLFHE failed");
}
// Invoke the lambda
if (auto err = maybeLambda.get()->invoke(arg)) {
return std::move(err);
if (target == Target::MIDLFHE)
return llvm::Error::success();
// MidLFHE -> LowLFHE
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
mlirContext, module, *res.fheContext, this->parametrizeMidLFHE)
.failed()) {
return StreamStringError("Lowering from MidLFHE to LowLFHE failed");
}
return llvm::Error::success();
}
llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
// Build the argument of the JIT lambda.
auto maybeArgument = buildArgument();
if (auto err = maybeArgument.takeError()) {
return std::move(err);
// Compile the sources managed by the source manager `sm` to the
// target dialect `target`. If successful, the result can be retrieved
// using `getModule()` and `getLLVMModule()`, respectively depending
// on the target dialect.
llvm::Expected<CompilerEngine::CompilationResult>
CompilerEngine::compile(llvm::SourceMgr &sm, Target target) {
CompilationResult res(this->compilationContext);
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
mlir::SourceMgrDiagnosticVerifierHandler smHandler(sm, &mlirContext);
mlirContext.printOpOnDiagnostic(false);
mlir::OwningModuleRef mlirModuleRef =
mlir::parseSourceFile<mlir::ModuleOp>(sm, &mlirContext);
if (this->verifyDiagnostics) {
if (smHandler.verify().failed())
return StreamStringError("Verification of diagnostics failed");
else
return res;
}
// Set the integer arguments
auto arguments = std::move(maybeArgument.get());
for (auto i = 0; i < args.size(); i++) {
if (auto err = arguments->setArg(i, args[i])) {
if (!mlirModuleRef)
return StreamStringError("Could not parse source");
res.mlirModuleRef = std::move(mlirModuleRef);
mlir::ModuleOp module = res.mlirModuleRef->get();
if (target == Target::HLFHE || target == Target::ROUND_TRIP)
return res;
// Detect highest FHE dialect and check if FHE parameter
// autodetection / lowering of parameter-dependent dialects can be
// skipped
FHEDialect highestFHEDialect = this->detectHighestFHEDialect(module);
if (highestFHEDialect == FHEDialect::HLFHE ||
highestFHEDialect == FHEDialect::MIDLFHE ||
this->generateClientParameters) {
bool noOverrideAutoDetected = (target == Target::HLFHE_MANP);
if (auto err = this->determineFHEParameters(res, noOverrideAutoDetected))
return std::move(err);
}
// return early if only the MANP pass was requested
if (target == Target::HLFHE_MANP)
return res;
if (highestFHEDialect == FHEDialect::HLFHE ||
highestFHEDialect == FHEDialect::MIDLFHE) {
if (llvm::Error err = this->lowerParamDependentHalf(target, res))
return std::move(err);
}
if (target == Target::HLFHE_MANP || target == Target::MIDLFHE ||
target == Target::LOWLFHE)
return res;
// LowLFHE -> Canonical dialects
if (mlir::zamalang::pipeline::lowerLowLFHEToStd(mlirContext, module)
.failed()) {
return StreamStringError(
"Lowering from LowLFHE to canonical MLIR dialects failed");
}
if (target == Target::STD)
return res;
// Generate client parameters if requested
if (this->generateClientParameters) {
if (!this->clientParametersFuncName.hasValue()) {
return StreamStringError(
"Generation of client parameters requested, but no function name "
"specified");
}
llvm::Expected<mlir::zamalang::ClientParameters> clientParametersOrErr =
mlir::zamalang::createClientParametersForV0(
*res.fheContext, *this->clientParametersFuncName, module);
if (llvm::Error err = clientParametersOrErr.takeError())
return std::move(err);
res.clientParameters = clientParametersOrErr.get();
}
// Invoke the lambda
if (auto err = invoke(*arguments)) {
return std::move(err);
// Generate Key set if requested
if (this->generateKeySet) {
if (!res.clientParameters.hasValue()) {
return StreamStringError("Generation of keyset requested without request "
"for generation of client parameters");
}
llvm::Expected<std::unique_ptr<mlir::zamalang::KeySet>> keySetOrErr =
mlir::zamalang::KeySet::generate(*res.clientParameters, 0, 0);
if (auto err = keySetOrErr.takeError())
return std::move(err);
res.keySet = std::move(*keySetOrErr);
}
uint64_t res = 0;
if (auto err = arguments->getResult(0, res)) {
return std::move(err);
// MLIR canonical dialects -> LLVM Dialect
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(mlirContext, module,
false)
.failed()) {
return StreamStringError("Failed to lower to LLVM dialect");
}
if (target == Target::LLVM)
return res;
// Lowering to actual LLVM IR (i.e., not the LLVM dialect)
llvm::LLVMContext &llvmContext = *this->compilationContext->getLLVMContext();
res.llvmModule = mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(
mlirContext, llvmContext, module);
if (!res.llvmModule)
return StreamStringError("Failed to convert from LLVM dialect to LLVM IR");
if (target == Target::LLVM_IR)
return res;
if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *res.llvmModule)
.failed()) {
return StreamStringError("Failed to optimize LLVM IR");
}
if (target == Target::OPTIMIZED_LLVM_IR)
return res;
return res;
} // namespace zamalang
// Compile the source `s` to the target dialect `target`. If successful, the
// result can be retrieved using `getModule()` and `getLLVMModule()`,
// respectively depending on the target dialect.
llvm::Expected<CompilerEngine::CompilationResult>
CompilerEngine::compile(llvm::StringRef s, Target target) {
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
llvm::Expected<CompilationResult> res = this->compile(std::move(mb), target);
return std::move(res);
}
// Compile the contained in `buffer` to the target dialect
// `target`. If successful, the result can be retrieved using
// `getModule()` and `getLLVMModule()`, respectively depending on the
// target dialect.
llvm::Expected<CompilerEngine::CompilationResult>
CompilerEngine::compile(std::unique_ptr<llvm::MemoryBuffer> buffer,
Target target) {
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
llvm::Expected<CompilationResult> res = this->compile(sm, target);
return std::move(res);
}
} // namespace zamalang
} // namespace mlir

View File

@@ -1,3 +1,4 @@
#include "llvm/Support/Error.h"
#include <llvm/ADT/ArrayRef.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
@@ -12,56 +13,6 @@
namespace mlir {
namespace zamalang {
// JIT-compiles `module` invokes `func` with the arguments passed in
// `jitArguments` and `keySet`
mlir::LogicalResult
runJit(mlir::ModuleOp module, llvm::StringRef func,
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
std::function<llvm::Error(llvm::Module *)> optPipeline,
llvm::raw_ostream &os) {
// Create the JIT lambda
auto maybeLambda =
mlir::zamalang::JITLambda::create(func, module, optPipeline);
if (!maybeLambda) {
return mlir::failure();
}
auto lambda = std::move(maybeLambda.get());
// Create the arguments of the JIT lambda
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
if (auto err = maybeArguments.takeError()) {
::mlir::zamalang::log_error()
<< "Cannot create lambda arguments: " << err << "\n";
llvm::consumeError(std::move(err));
return mlir::failure();
}
// Set the arguments
auto arguments = std::move(maybeArguments.get());
for (size_t i = 0; i < funcArgs.size(); i++) {
if (auto err = arguments->setArg(i, funcArgs[i])) {
::mlir::zamalang::log_error()
<< "Cannot push argument " << i << ": " << err << "\n";
llvm::consumeError(std::move(err));
return mlir::failure();
}
}
// Invoke the lambda
if (auto err = lambda->invoke(*arguments)) {
::mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n";
llvm::consumeError(std::move(err));
return mlir::failure();
}
uint64_t res = 0;
if (auto err = arguments->getResult(0, res)) {
::mlir::zamalang::log_error() << "Cannot get result : " << err << "\n";
llvm::consumeError(std::move(err));
return mlir::failure();
}
llvm::errs() << res << "\n";
return mlir::success();
}
llvm::Expected<std::unique_ptr<JITLambda>>
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {

View File

@@ -0,0 +1,105 @@
#include "llvm/Support/Error.h"
#include <llvm/ADT/STLExtras.h>
#include <llvm/Support/TargetSelect.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <zamalang/Support/JitCompilerEngine.h>
namespace mlir {
namespace zamalang {
JitCompilerEngine::JitCompilerEngine(
std::shared_ptr<CompilationContext> compilationContext,
unsigned int optimizationLevel)
: CompilerEngine(compilationContext), optimizationLevel(optimizationLevel) {
}
// Returns the `LLVMFuncOp` operation in the compiled module with the
// specified name. If no LLVMFuncOp with that name exists or if there
// was no prior call to `compile()` resulting in an MLIR module in the
// LLVM dialect, an error is returned.
llvm::Expected<mlir::LLVM::LLVMFuncOp>
JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) {
auto funcOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
auto funcOp = llvm::find_if(
funcOps, [&](mlir::LLVM::LLVMFuncOp op) { return op.getName() == name; });
if (funcOp == funcOps.end()) {
return StreamStringError()
<< "Module does not contain function named '" << name.str() << "'";
}
return *funcOp;
}
// Build a lambda from the function with the name given in
// `funcName` from the sources in `buffer`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::StringRef funcName) {
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
llvm::Expected<JitCompilerEngine::Lambda> res =
this->buildLambda(sm, funcName);
return std::move(res);
}
// Build a lambda from the function with the name given in `funcName`
// from the source string `s`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName) {
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
llvm::Expected<JitCompilerEngine::Lambda> res =
this->buildLambda(std::move(mb), funcName);
return std::move(res);
}
// Build a lambda from the function with the name given in
// `funcName` from the sources managed by the source manager `sm`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
this->setGenerateKeySet(true);
this->setGenerateClientParameters(true);
this->setClientParametersFuncName(funcName);
// First, compile to LLVM Dialect
llvm::Expected<CompilerEngine::CompilationResult> compResOrErr =
this->compile(sm, Target::LLVM_IR);
if (!compResOrErr)
return std::move(compResOrErr.takeError());
mlir::ModuleOp module = compResOrErr->mlirModuleRef->get();
// Locate function to JIT-compile
llvm::Expected<mlir::LLVM::LLVMFuncOp> funcOrError =
this->findLLVMFuncOp(compResOrErr->mlirModuleRef->get(), funcName);
if (!funcOrError)
return std::move(funcOrError.takeError());
// Prepare LLVM infrastructure for JIT compilation
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
mlir::registerLLVMDialectTranslation(mlirContext);
std::function<llvm::Error(llvm::Module *)> optPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =
mlir::zamalang::JITLambda::create(funcName, module, optPipeline);
if (!lambdaOrErr)
return std::move(lambdaOrErr.takeError());
return Lambda{this->compilationContext, std::move(lambdaOrErr.get()),
std::move(compResOrErr->keySet)};
}
} // namespace zamalang
} // namespace mlir

View File

@@ -0,0 +1,7 @@
#include <zamalang/Support/LambdaArgument.h>
namespace mlir {
namespace zamalang {
char LambdaArgument::ID = 0;
} // namespace zamalang
} // namespace mlir

View File

@@ -1,3 +1,4 @@
#include <cstdint>
#include <iostream>
#include <llvm/Support/CommandLine.h>
@@ -22,15 +23,15 @@
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/Jit.h"
#include "zamalang/Support/Error.h"
#include "zamalang/Support/JitCompilerEngine.h"
#include "zamalang/Support/KeySet.h"
#include "zamalang/Support/Pipeline.h"
#include "zamalang/Support/logging.h"
enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM };
enum Action {
ROUND_TRIP,
DUMP_HLFHE,
DUMP_HLFHE_MANP,
DUMP_MIDLFHE,
DUMP_LOWLFHE,
@@ -80,26 +81,6 @@ llvm::cl::opt<bool> parametrizeMidLFHE(
llvm::cl::desc("Perform MidLFHE global parametrization pass"),
llvm::cl::init<bool>(true));
static llvm::cl::opt<enum EntryDialect> entryDialect(
"e", "entry-dialect", llvm::cl::desc("Entry dialect"),
llvm::cl::init<enum EntryDialect>(EntryDialect::HLFHE),
llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required,
llvm::cl::values(
clEnumValN(EntryDialect::HLFHE, "hlfhe",
"Input module is composed of HLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::MIDLFHE, "midlfhe",
"Input module is composed of MidLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::LOWLFHE, "lowlfhe",
"Input module is composed of LowLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::STD, "std",
"Input module is composed of operations from std")),
llvm::cl::values(
clEnumValN(EntryDialect::LLVM, "llvm",
"Input module is composed of operations from llvm")));
static llvm::cl::opt<enum Action> action(
"a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired,
llvm::cl::NumOccurrencesFlag::Required,
@@ -109,6 +90,8 @@ static llvm::cl::opt<enum Action> action(
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp",
"Dump HLFHE module after running the Minimal "
"Arithmetic Noise Padding pass")),
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE, "dump-hlfhe",
"Dump HLFHE module")),
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
"Lower to MidLFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe",
@@ -158,50 +141,7 @@ llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser> assumeMaxMANP(
llvm::cl::desc(
"Assume a maximum for the Minimum Arithmetic Noise Padding"));
}; // namespace cmdline
std::function<llvm::Error(llvm::Module *)> defaultOptPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
std::unique_ptr<mlir::zamalang::KeySet>
generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext,
const std::string &jitFuncName) {
std::unique_ptr<mlir::zamalang::KeySet> keySet;
mlir::zamalang::log_verbose()
<< "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2
<< ", p:" << fheContext.constraint.p << "}\n";
mlir::zamalang::log_verbose()
<< "### FHE parameters for the atomic pattern: {k: "
<< fheContext.parameter.k
<< ", polynomialSize: " << fheContext.parameter.polynomialSize
<< ", nSmall: " << fheContext.parameter.nSmall
<< ", brLevel: " << fheContext.parameter.brLevel
<< ", brLogBase: " << fheContext.parameter.brLogBase
<< ", ksLevel: " << fheContext.parameter.ksLevel
<< ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n";
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheContext, jitFuncName, module);
if (auto err = clientParameter.takeError()) {
mlir::zamalang::log_error()
<< "cannot generate client parameters: " << err << "\n";
return nullptr;
}
mlir::zamalang::log_verbose() << "### Generate the key set\n";
auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
0); // TODO: seed
if (auto err = maybeKeySet.takeError()) {
llvm::errs() << err;
return nullptr;
}
return std::move(maybeKeySet.get());
}
} // namespace cmdline
llvm::Expected<mlir::zamalang::V0FHEContext> buildFHEContext(
llvm::Optional<mlir::zamalang::V0FHEConstraint> autoFHEConstraints,
@@ -209,65 +149,48 @@ llvm::Expected<mlir::zamalang::V0FHEContext> buildFHEContext(
llvm::Optional<size_t> overrideMaxMANP) {
if (!autoFHEConstraints.hasValue() &&
(!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) {
return llvm::make_error<llvm::StringError>(
return mlir::zamalang::StreamStringError(
"Maximum encrypted integer precision and maximum for the Minimal"
"Arithmetic Noise Passing are required, but were neither specified"
"explicitly nor determined automatically",
llvm::inconvertibleErrorCode());
"explicitly nor determined automatically");
}
mlir::zamalang::V0FHEConstraint fheConstraints{
.norm2 = overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue()
: autoFHEConstraints.getValue().norm2,
.p = overrideMaxEintPrecision.hasValue()
? overrideMaxEintPrecision.getValue()
: autoFHEConstraints.getValue().p};
overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue()
: autoFHEConstraints.getValue().norm2,
overrideMaxEintPrecision.hasValue() ? overrideMaxEintPrecision.getValue()
: autoFHEConstraints.getValue().p};
const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints);
if (!parameter) {
std::string buffer;
llvm::raw_string_ostream strs(buffer);
strs << "Could not determine V0 parameters for 2-norm of "
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
return llvm::make_error<llvm::StringError>(strs.str(),
llvm::inconvertibleErrorCode());
return mlir::zamalang::StreamStringError()
<< "Could not determine V0 parameters for 2-norm of "
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
}
return mlir::zamalang::V0FHEContext{fheConstraints, *parameter};
}
mlir::LogicalResult buildAssignFHEContext(
llvm::Optional<mlir::zamalang::V0FHEContext> &fheContext,
llvm::Optional<mlir::zamalang::V0FHEConstraint> autoFHEConstraints,
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP) {
namespace llvm {
// This needs to be wrapped into the llvm namespace for proper
// operator lookup
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const llvm::ArrayRef<uint64_t> arr) {
os << "(";
for (size_t i = 0; i < arr.size(); i++) {
os << arr[i];
if (fheContext.hasValue())
return mlir::success();
llvm::Expected<mlir::zamalang::V0FHEContext> fheContextOrErr =
buildFHEContext(autoFHEConstraints, overrideMaxEintPrecision,
overrideMaxMANP);
if (auto err = fheContextOrErr.takeError()) {
mlir::zamalang::log_error() << err;
return mlir::failure();
if (i != arr.size() - 1)
os << ", ";
}
fheContext.emplace(fheContextOrErr.get());
return mlir::success();
return os;
}
} // namespace llvm
// Process a single source buffer
//
// The parameter `entryDialect` must specify the FHE dialect to which
// belong all FHE operations used in the source buffer. The input
// program must only contain FHE operations from that single FHE
// dialect, otherwise processing might fail.
//
// The parameter `action` specifies how the buffer should be processed
// and thus defines the output.
//
@@ -276,15 +199,14 @@ mlir::LogicalResult buildAssignFHEContext(
// using the parameters given in `jitArgs`.
//
// The parameter `parametrizeMidLFHE` defines, whether the
// parametrization pass for MidLFHE is executed. If the pair of
// `entryDialect` and `action` does not involve any MidlFHE
// manipulation, this parameter does not have any effect.
// parametrization pass for MidLFHE is executed. If the `action` does
// not involve any MidlFHE manipulation, this parameter does not have
// any effect.
//
// The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if
// set, override the values for the maximum required precision of
// encrypted integers and the maximum value for the Minimum Arithmetic
// Noise Padding otherwise determined automatically if the entry
// dialect is HLFHE..
// Noise Padding otherwise determined automatically.
//
// If `verifyDiagnostics` is `true`, the procedure only checks if the
// diagnostic messages provided in the source buffer using
@@ -292,164 +214,103 @@ mlir::LogicalResult buildAssignFHEContext(
// the procedure checks if the parsed module is valid and if all
// requested transformations succeeded.
//
// If `verbose` is true, debug messages are displayed throughout the
// compilation process.
//
// Compilation output is written to the stream specified by `os`.
mlir::LogicalResult processInputBuffer(
mlir::MLIRContext &context, std::unique_ptr<llvm::MemoryBuffer> buffer,
enum EntryDialect entryDialect, enum Action action,
const std::string &jitFuncName, llvm::ArrayRef<uint64_t> jitArgs,
bool parametrizeMidlHFE, llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
bool verbose, llvm::raw_ostream &os) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
mlir::LogicalResult
processInputBuffer(std::unique_ptr<llvm::MemoryBuffer> buffer,
enum Action action, const std::string &jitFuncName,
llvm::ArrayRef<uint64_t> jitArgs, bool parametrizeMidlHFE,
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP,
bool verifyDiagnostics, llvm::raw_ostream &os) {
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
mlir::zamalang::CompilationContext::createShared();
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
&context);
mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context);
mlir::zamalang::JitCompilerEngine ce{ccx};
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraints;
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
ce.setVerifyDiagnostics(verifyDiagnostics);
ce.setParametrizeMidLFHE(parametrizeMidlHFE);
std::unique_ptr<mlir::zamalang::KeySet> keySet = nullptr;
if (overrideMaxEintPrecision.hasValue())
ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue());
if (verbose)
context.disableMultithreading();
if (overrideMaxMANP.hasValue())
ce.setMaxMANP(overrideMaxMANP.getValue());
if (verifyDiagnostics)
return sourceMgrHandler.verify();
if (action == Action::JIT_INVOKE) {
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
ce.buildLambda(std::move(buffer), jitFuncName);
if (!moduleRef)
return mlir::failure();
mlir::ModuleOp module = moduleRef.get();
if (action == Action::ROUND_TRIP) {
module->print(os);
return mlir::success();
}
// Lowering pipeline. Each stage is represented as a label in the
// switch statement, from the most abstract dialect to the lowest
// level. Every labels acts as an entry point into the pipeline with
// a fallthrough mechanism to the next stage. Actions act as exit
// points from the pipeline.
switch (entryDialect) {
case EntryDialect::HLFHE:
if (action == Action::DUMP_HLFHE_MANP) {
if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false)
.failed()) {
return mlir::failure();
}
module.print(os);
return mlir::success();
} else {
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
fheConstraintsOrErr =
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(context,
module);
if (auto err = fheConstraintsOrErr.takeError()) {
mlir::zamalang::log_error() << err;
return mlir::failure();
} else {
fheConstraints = fheConstraintsOrErr.get();
}
}
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::MIDLFHE:
if (action == Action::DUMP_MIDLFHE) {
module.print(os);
return mlir::success();
}
if (buildAssignFHEContext(fheContext, fheConstraints,
overrideMaxEintPrecision, overrideMaxMANP)
.failed()) {
return mlir::failure();
}
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
context, module, fheContext.getValue(), parametrizeMidlHFE)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::LOWLFHE:
if (action == Action::DUMP_LOWLFHE) {
module.print(os);
return mlir::success();
}
if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed())
return mlir::failure();
// fallthrough
case EntryDialect::STD:
if (action == Action::DUMP_STD) {
module.print(os);
return mlir::success();
} else if (action == Action::JIT_INVOKE) {
if (buildAssignFHEContext(fheContext, fheConstraints,
overrideMaxEintPrecision, overrideMaxMANP)
.failed()) {
return mlir::failure();
}
keySet = generateKeySet(module, fheContext.getValue(), jitFuncName);
}
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module,
verbose)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::LLVM: {
if (action == Action::DUMP_LLVM_DIALECT) {
module.print(os);
return mlir::success();
} else if (action == Action::JIT_INVOKE) {
return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet,
defaultOptPipeline, os);
}
llvm::LLVMContext llvmContext;
std::unique_ptr<llvm::Module> llvmModule =
mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext,
module);
if (!llvmModule) {
if (!lambdaOrErr) {
mlir::zamalang::log_error()
<< "Failed to translate LLVM dialect to LLVM IR\n";
<< "Failed to JIT-compile " << jitFuncName << ": "
<< llvm::toString(std::move(lambdaOrErr.takeError()));
return mlir::failure();
}
if (action == Action::DUMP_LLVM_IR) {
llvmModule->dump();
return mlir::success();
}
llvm::Expected<uint64_t> resOrErr = (*lambdaOrErr)(jitArgs);
if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule)
.failed()) {
mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n";
if (!resOrErr) {
mlir::zamalang::log_error()
<< "Failed to JIT-invoke " << jitFuncName << " with arguments "
<< jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError()));
return mlir::failure();
}
if (action == Action::DUMP_OPTIMIZED_LLVM_IR) {
llvmModule->dump();
return mlir::success();
os << *resOrErr << "\n";
} else {
enum mlir::zamalang::CompilerEngine::Target target;
switch (action) {
case Action::ROUND_TRIP:
target = mlir::zamalang::CompilerEngine::Target::ROUND_TRIP;
break;
case Action::DUMP_HLFHE:
target = mlir::zamalang::CompilerEngine::Target::HLFHE;
break;
case Action::DUMP_HLFHE_MANP:
target = mlir::zamalang::CompilerEngine::Target::HLFHE_MANP;
break;
case Action::DUMP_MIDLFHE:
target = mlir::zamalang::CompilerEngine::Target::MIDLFHE;
break;
case Action::DUMP_LOWLFHE:
target = mlir::zamalang::CompilerEngine::Target::LOWLFHE;
break;
case Action::DUMP_STD:
target = mlir::zamalang::CompilerEngine::Target::STD;
break;
case Action::DUMP_LLVM_DIALECT:
target = mlir::zamalang::CompilerEngine::Target::LLVM;
break;
case Action::DUMP_LLVM_IR:
target = mlir::zamalang::CompilerEngine::Target::LLVM_IR;
break;
case Action::DUMP_OPTIMIZED_LLVM_IR:
target = mlir::zamalang::CompilerEngine::Target::OPTIMIZED_LLVM_IR;
break;
case JIT_INVOKE:
// Case just here to satisfy the compiler; already handled above
break;
}
break;
}
llvm::Expected<mlir::zamalang::CompilerEngine::CompilationResult> retOrErr =
ce.compile(std::move(buffer), target);
if (!retOrErr) {
mlir::zamalang::log_error()
<< llvm::toString(std::move(retOrErr.takeError())) << "\n";
return mlir::failure();
}
if (verifyDiagnostics) {
return mlir::success();
} else if (action == Action::DUMP_LLVM_IR ||
action == Action::DUMP_OPTIMIZED_LLVM_IR) {
retOrErr->llvmModule->print(os, nullptr);
} else {
retOrErr->mlirModuleRef->get().print(os);
}
}
return mlir::success();
@@ -459,44 +320,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
// Parse command line arguments
llvm::cl::ParseCommandLineOptions(argc, argv);
// Initialize the MLIR context
mlir::MLIRContext context;
mlir::zamalang::setupLogging(cmdline::verbose);
// String for error messages from library functions
std::string errorMessage;
if (cmdline::action == Action::DUMP_HLFHE_MANP &&
cmdline::entryDialect != EntryDialect::HLFHE) {
mlir::zamalang::log_error()
<< "Can only invoke Minimal Arithmetic Noise pass on HLFHE programs";
return mlir::failure();
}
if (cmdline::action == Action::JIT_INVOKE &&
cmdline::entryDialect != EntryDialect::HLFHE &&
cmdline::entryDialect != EntryDialect::MIDLFHE &&
cmdline::entryDialect != EntryDialect::LOWLFHE &&
cmdline::entryDialect != EntryDialect::STD) {
mlir::zamalang::log_error()
<< "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs";
return mlir::failure();
}
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
context.getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
context.getOrLoadDialect<mlir::tensor::TensorDialect>();
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
if (cmdline::verifyDiagnostics)
context.printOpOnDiagnostic(false);
auto output = mlir::openOutputFile(cmdline::output, &errorMessage);
if (!output) {
@@ -523,20 +351,20 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(
context, std::move(inputBuffer), cmdline::entryDialect,
cmdline::action, cmdline::jitFuncName, cmdline::jitArgs,
std::move(inputBuffer), cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs,
cmdline::parametrizeMidLFHE,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, cmdline::verbose, os);
cmdline::verifyDiagnostics, os);
},
output->os())))
return mlir::failure();
} else {
return processInputBuffer(
context, std::move(file), cmdline::entryDialect, cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE,
std::move(file), cmdline::action, cmdline::jitFuncName,
cmdline::jitArgs, cmdline::parametrizeMidLFHE,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, cmdline::verbose, output->os());
cmdline::verifyDiagnostics, output->os());
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi64>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe --assume-max-manp=10 --assume-max-eint-precision=2 2>&1| FileCheck %s
// CHECK: #map0 = affine_map<(d0) -> (d0)>
// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !LowLFHE.lwe_ciphertext<1024,4>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !MidLFHE.glwe<{1024,1,64}{4}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4>
func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --entry-dialect=hlfhe --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
// RUN: zamacompiler --split-input-file --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
func @single_zero() -> !HLFHE.eint<2>
{

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// Incompatible shapes
func @dot_incompatible_shapes(

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: eint support only precision in ]0;7]
func @test(%arg0: !HLFHE.eint<8>) {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: eint support only precision in ]0;7]
func @test(%arg0: !HLFHE.eint<0>) {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument.
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1
func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of encrypted inputs and result equals
func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @zero() -> !HLFHE.eint<2>
func @zero() -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1 | FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s
//CHECK: #map0 = affine_map<(d0) -> (d0)>
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>
func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen
func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// GLWE p parameter result
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// GLWE p parameter
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// Bad dimension of the lookup table
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// GLWE p parameter
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// GLWE p parameter
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=midlfhe --action=roundtrip 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=roundtrip 2>&1| FileCheck %s
// CHECK-LABEL: func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -56,7 +56,7 @@ def test_compile_and_run(mlir_input, args, expected_result):
def test_compile_and_run_invalid_arg_number(mlir_input, args):
engine = CompilerEngine()
engine.compile_fhe(mlir_input)
with pytest.raises(RuntimeError, match=r"failed pushing integer argument"):
with pytest.raises(ValueError, match=r"wrong number of arguments"):
engine.run(*args)

View File

@@ -1,8 +1,11 @@
#include <cstdint>
#include <gtest/gtest.h>
#include <type_traits>
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/JitCompilerEngine.h"
mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7};
mlir::zamalang::V0FHEConstraint defaultV0Constraints = {10, 7};
#define ASSERT_LLVM_ERROR(err) \
if (err) { \
@@ -10,384 +13,405 @@ mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7};
ASSERT_TRUE(false); \
}
// Checks that the value `val` is not in an error state. Returns
// `true` if the test passes, otherwise `false`.
template <typename T>
static bool assert_expected_success(llvm::Expected<T> &val) {
if (!((bool)val)) {
llvm::errs() << llvm::toString(std::move(val.takeError()));
return false;
}
return true;
}
// Checks that the value `val` is not in an error state. Returns
// `true` if the test passes, otherwise `false`.
template <typename T>
static bool assert_expected_success(llvm::Expected<T> &&val) {
return assert_expected_success(val);
}
// Checks that the value `val` of type `llvm::Expected<T>` is not in
// an error state.
#define ASSERT_EXPECTED_SUCCESS(val) \
do { \
if (!assert_expected_success(val)) \
GTEST_FATAL_FAILURE_("Expected<T> contained in error state"); \
} while (0)
// Checks that the value `val` is not in an error state and is equal
// to the value given in `exp`. Returns `true` if the test passes,
// otherwise `false`.
template <typename T, typename V>
static bool assert_expected_value(llvm::Expected<T> &val, const V &exp) {
if (!assert_expected_success(val))
return false;
if (!(val.get() == static_cast<T>(exp))) {
llvm::errs() << "Expected value " << exp << ", but got " << val.get()
<< "\n";
return false;
}
return true;
}
// Checks that the value `val` is not in an error state and is equal
// to the value given in `exp`. Returns `true` if the test passes,
// otherwise `false`.
template <typename T, typename V>
static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
return assert_expected_value(val, exp);
}
// Checks that the value `val` of type `llvm::Expected<T>` is not in
// an error state and is equal to the value of type `T` given in
// `exp`.
#define ASSERT_EXPECTED_VALUE(val, exp) \
do { \
if (!assert_expected_value(val, exp)) { \
GTEST_FATAL_FAILURE_("Expected<T> with wrong value"); \
} \
} while (0)
// Jit-compiles the function specified by `func` from `src` and
// returns the corresponding lambda. Any compilation errors are caught
// and reult in abnormal termination.
template <typename F>
mlir::zamalang::JitCompilerEngine::Lambda
internalCheckedJit(F checkfunc, llvm::StringRef src,
llvm::StringRef func = "main",
bool useDefaultFHEConstraints = false) {
mlir::zamalang::JitCompilerEngine engine;
if (useDefaultFHEConstraints)
engine.setFHEConstraints(defaultV0Constraints);
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
engine.buildLambda(src, func);
checkfunc(lambdaOrErr);
return std::move(*lambdaOrErr);
}
// Shorthands to create integer literals of a specific type
uint8_t operator"" _u8(unsigned long long int v) { return v; }
uint16_t operator"" _u16(unsigned long long int v) { return v; }
uint32_t operator"" _u32(unsigned long long int v) { return v; }
uint64_t operator"" _u64(unsigned long long int v) { return v; }
// Evaluates to the number of elements of a statically initialized
// array
#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof(arr[0]))
// Wrapper around `internalCheckedJit` that causes
// `ASSERT_EXPECTED_SUCCESS` to use the file and line number of the
// caller instead of `internalCheckedJit`.
#define checkedJit(...) \
internalCheckedJit( \
[](llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> &lambda) { \
ASSERT_EXPECTED_SUCCESS(lambda); \
}, \
__VA_ARGS__)
TEST(CompileAndRunHLFHE, add_eint) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
)XXX";
ASSERT_FALSE(engine.compile(mlirStr));
auto maybeResult = engine.run({1, 2});
ASSERT_TRUE((bool)maybeResult);
uint64_t result = maybeResult.get();
ASSERT_EQ(result, 3);
)XXX");
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), 3);
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), 9);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), 2);
}
// Same as CompileAndRunHLFHE::add_eint above, but using
// `LambdaArgument` instances
TEST(CompileAndRunHLFHE, add_eint_lambda_argument) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
)XXX");
mlir::zamalang::IntLambdaArgument<> ila1(1);
mlir::zamalang::IntLambdaArgument<> ila2(2);
mlir::zamalang::IntLambdaArgument<> ila7(7);
mlir::zamalang::IntLambdaArgument<> ila9(9);
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila2}), 3);
ASSERT_EXPECTED_VALUE(lambda({&ila7, &ila9}), 16);
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila7}), 8);
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila9}), 10);
ASSERT_EXPECTED_VALUE(lambda({&ila2, &ila7}), 9);
}
TEST(CompileAndRunHLFHE, add_u64) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: i64, %arg1: i64) -> i64 {
%1 = addi %arg0, %arg1 : i64
return %1: i64
}
)XXX",
"main", true);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), (uint64_t)3);
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), (uint64_t)9);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), (uint64_t)2);
}
TEST(CompileAndRunTensorStd, extract_64) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi64>, %i: index) -> i64{
%c = tensor.extract %t[%i] : tensor<10xi64>
return %c : i64
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF,
0,
8978,
2587490,
90,
197864,
698735,
72132,
87474,
42};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX",
"main", "true");
static uint64_t t_arg[] = {0xFFFFFFFFFFFFFFFF,
0,
8978,
2587490,
90,
197864,
698735,
72132,
87474,
42};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorStd, extract_32) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi32>, %i: index) -> i32{
%c = tensor.extract %t[%i] : tensor<10xi32>
return %c : i32
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90,
197864, 698735, 72132, 87474, 42};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
)XXX",
"main", "true");
static uint32_t t_arg[] = {0xFFFFFFFF, 0, 8978, 2587490, 90,
197864, 698735, 72132, 87474, 42};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
// Same as `CompileAndRunTensorStd::extract_32` above, but using
// `LambdaArgument` instances
TEST(CompileAndRunTensorStd, extract_32_lambda_argument) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi32>, %i: index) -> i32{
%c = tensor.extract %t[%i] : tensor<10xi32>
return %c : i32
}
)XXX",
"main", "true");
static std::vector<uint32_t> t_arg{0xFFFFFFFF, 0, 8978, 2587490, 90,
197864, 698735, 72132, 87474, 42};
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint32_t>>
tla(t_arg);
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) {
mlir::zamalang::IntLambdaArgument<size_t> idx(i);
ASSERT_EXPECTED_VALUE(lambda({&tla, &idx}), t_arg[i]);
}
}
TEST(CompileAndRunTensorStd, extract_16) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi16>, %i: index) -> i16{
%c = tensor.extract %t[%i] : tensor<10xi16>
return %c : i16
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227,
63269, 36435, 52380, 7401, 13313};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX",
"main", "true");
uint16_t t_arg[] = {0xFFFF, 0, 59589, 47826, 16227,
63269, 36435, 52380, 7401, 13313};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorStd, extract_8) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi8>, %i: index) -> i8{
%c = tensor.extract %t[%i] : tensor<10xi8>
return %c : i8
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX",
"main", "true");
static uint8_t t_arg[] = {0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorStd, extract_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi5>, %i: index) -> i5{
%c = tensor.extract %t[%i] : tensor<10xi5>
return %c : i5
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX",
"main", "true");
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorStd, extract_1) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi1>, %i: index) -> i1{
%c = tensor.extract %t[%i] : tensor<10xi1>
return %c : i1
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX",
"main", "true");
static uint8_t t_arg[] = {0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorEncrypted, extract_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{
%c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
return %c : !HLFHE.eint<5>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX");
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorEncrypted, extract_twice_and_add_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5>{
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) ->
!HLFHE.eint<5>{
%ti = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
%tj = tensor.extract %t[%j] : tensor<10x!HLFHE.eint<5>>
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> !HLFHE.eint<5>
return %c : !HLFHE.eint<5>
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) ->
!HLFHE.eint<5> return %c : !HLFHE.eint<5>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
for (size_t j = 0; j < size; j++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Set the %j argument
ASSERT_LLVM_ERROR(argument->setArg(2, j));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i] + t_arg[j]);
}
}
)XXX");
static uint8_t t_arg[] = {3, 0, 7, 12, 14, 6, 5, 4, 1, 2};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
for (size_t j = 0; j < ARRAY_SIZE(t_arg); j++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i, j),
t_arg[i] + t_arg[j]);
}
TEST(CompileAndRunTensorEncrypted, dim_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{
%c0 = constant 0 : index
%c = tensor.dim %t, %c0 : tensor<10x!HLFHE.eint<5>>
return %c : index
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, size);
)XXX");
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg)), ARRAY_SIZE(t_arg));
}
TEST(CompileAndRunTensorEncrypted, from_elements_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
%t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>>
return %t: tensor<1x!HLFHE.eint<5>>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, 10));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
size_t size_res = 1;
uint64_t t_res[size_res];
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
ASSERT_EQ(t_res[0], 10);
)XXX");
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>(10_u64);
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (size_t)1);
ASSERT_EQ(res->at(0), 10_u64);
}
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {
%c_0 = constant 0 : index
%c_1 = constant 1 : index
%a = tensor.extract %in[%c_0] : tensor<2x!HLFHE.eint<5>>
%b = tensor.extract %in[%c_1] : tensor<2x!HLFHE.eint<5>>
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
%aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
%bplusb = "HLFHE.add_eint"(%b, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
%out = tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) ->
(!HLFHE.eint<5>) %aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>,
!HLFHE.eint<5>) -> (!HLFHE.eint<5>) %bplusb = "HLFHE.add_eint"(%b, %b):
(!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) %out =
tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
return %out: tensor<3x!HLFHE.eint<5>>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the argument
const size_t in_size = 2;
uint8_t in[in_size] = {2, 16};
ASSERT_LLVM_ERROR(argument->setArg(0, in, in_size));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
const size_t size_res = 3;
uint64_t t_res[size_res];
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
ASSERT_EQ(t_res[0], in[0] + in[0]);
ASSERT_EQ(t_res[1], in[0] + in[1]);
ASSERT_EQ(t_res[2], in[1] + in[1]);
)XXX");
static uint8_t in[] = {2, 16};
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>(in, ARRAY_SIZE(in));
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (size_t)3);
ASSERT_EQ(res->at(0), (uint64_t)(in[0] + in[0]));
ASSERT_EQ(res->at(1), (uint64_t)(in[0] + in[1]));
ASSERT_EQ(res->at(2), (uint64_t)(in[1] + in[1]));
}
TEST(CompileAndRunTensorEncrypted, linalg_generic) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (0)>
func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc:
!HLFHE.eint<7>) -> !HLFHE.eint<7> {
%tacc = tensor.from_elements %acc : tensor<1x!HLFHE.eint<7>>
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>) outs(%tacc : tensor<1x!HLFHE.eint<7>>) {
^bb0(%arg2: !HLFHE.eint<7>, %arg3: i8, %arg4: !HLFHE.eint<7>): // no predecessors
%4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) -> !HLFHE.eint<7>
%5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>, !HLFHE.eint<7>) -> !HLFHE.eint<7>
linalg.yield %5 : !HLFHE.eint<7>
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types
= ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>)
outs(%tacc : tensor<1x!HLFHE.eint<7>>) { ^bb0(%arg2: !HLFHE.eint<7>, %arg3:
i8, %arg4: !HLFHE.eint<7>): // no predecessors
%4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) ->
!HLFHE.eint<7> %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>,
!HLFHE.eint<7>) -> !HLFHE.eint<7> linalg.yield %5 : !HLFHE.eint<7>
} -> tensor<1x!HLFHE.eint<7>>
%c0 = constant 0 : index
%ret = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<7>>
return %ret : !HLFHE.eint<7>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set arg0, arg1, acc
const size_t in_size = 2;
uint8_t arg0[in_size] = {2, 8};
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
uint8_t arg1[in_size] = {6, 8};
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
ASSERT_LLVM_ERROR(argument->setArg(2, 0));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, 76);
)XXX",
"main", "true");
static uint8_t arg0[] = {2, 8};
static uint8_t arg1[] = {6, 8};
llvm::Expected<uint64_t> res =
lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1), 0_u64);
ASSERT_EXPECTED_VALUE(res, 76);
}
TEST(CompileAndRunTensorEncrypted, dot_eint_int_7) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
%arg1: tensor<4xi8>) -> !HLFHE.eint<7>
{
@@ -395,77 +419,70 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
return %ret : !HLFHE.eint<7>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set arg0, arg1, acc
const size_t in_size = 4;
uint8_t arg0[in_size] = {0, 1, 2, 3};
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
uint8_t arg1[in_size] = {0, 1, 2, 3};
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, 14);
)XXX");
static uint8_t arg0[] = {0, 1, 2, 3};
static uint8_t arg1[] = {0, 1, 2, 3};
llvm::Expected<uint64_t> res =
lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1));
ASSERT_EXPECTED_VALUE(res, 14);
}
class CompileAndRunWithPrecision : public ::testing::TestWithParam<int> {
protected:
mlir::zamalang::CompilerEngine engine;
void compile(std::string mlirStr) { ASSERT_FALSE(engine.compile(mlirStr)); }
void run(std::vector<uint64_t> args, uint64_t expected) {
auto maybeResult = engine.run(args);
ASSERT_TRUE((bool)maybeResult);
uint64_t result = maybeResult.get();
if (result == expected) {
ASSERT_TRUE(true);
} else {
// TODO: Better way to test the probability of exactness
llvm::errs() << "one fail retry\n";
maybeResult = engine.run(args);
ASSERT_TRUE((bool)maybeResult);
result = maybeResult.get();
ASSERT_EQ(result, expected);
}
}
};
class CompileAndRunWithPrecision : public ::testing::TestWithParam<int> {};
TEST_P(CompileAndRunWithPrecision, identity_func) {
int precision = GetParam();
uint64_t precision = GetParam();
std::ostringstream mlirProgram;
auto sizeOfTLU = 1 << precision;
mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision
<< ">) -> !HLFHE.eint<" << precision << "> { \n";
mlirProgram << " %tlu = std.constant dense<[0";
for (auto i = 1; i < sizeOfTLU; i++) {
mlirProgram << ", " << i;
}
mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n";
mlirProgram << " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): "
"(!HLFHE.eint<"
<< precision << ">, tensor<" << sizeOfTLU
<< "xi64>) -> (!HLFHE.eint<" << precision << ">)\n ";
mlirProgram << "return %1: !HLFHE.eint<" << precision << ">\n";
uint64_t sizeOfTLU = 1 << precision;
mlirProgram << "}\n";
llvm::errs() << mlirProgram.str();
compile(mlirProgram.str());
for (auto i = 0; i < sizeOfTLU; i++) {
run({(uint64_t)i}, i);
mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision
<< ">) -> !HLFHE.eint<" << precision << "> { \n"
<< " %tlu = std.constant dense<[0";
for (uint64_t i = 1; i < sizeOfTLU; i++)
mlirProgram << ", " << i;
mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n"
<< " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): "
<< "(!HLFHE.eint<" << precision << ">, tensor<" << sizeOfTLU
<< "xi64>) -> (!HLFHE.eint<" << precision << ">)\n "
<< "return %1: !HLFHE.eint<" << precision << ">\n"
<< "}\n";
mlir::zamalang::JitCompilerEngine::Lambda lambda =
checkedJit(mlirProgram.str());
if (precision == 7) {
// Test fails with a probability of 5% for a precision of 7. The
// probability of the test failing 5 times in a row is .05^5,
// which is less than 1:10,000 and comparable to the probability
// of failure for the other values.
static const int max_tries = 3;
for (uint64_t i = 0; i < sizeOfTLU; i++) {
for (int retry = 0; retry <= max_tries; retry++) {
if (retry == max_tries)
GTEST_FATAL_FAILURE_("Maximum number of tries exceeded");
llvm::Expected<uint64_t> val = lambda(i);
ASSERT_EXPECTED_SUCCESS(val);
if (*val == i)
break;
}
}
} else {
for (uint64_t i = 0; i < sizeOfTLU; i++)
ASSERT_EXPECTED_VALUE(lambda(i), i);
}
}
INSTANTIATE_TEST_CASE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision,
::testing::Values(1, 2, 3, 4, 5, 6, 7));
INSTANTIATE_TEST_SUITE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision,
::testing::Values(1, 2, 3, 4, 5, 6, 7));
TEST(TestHLFHEApplyLookupTable, multiple_precision) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> {
%tlu_7 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64>
%tlu_3 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
@@ -474,45 +491,22 @@ func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> {
%a_plus_b = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<6>, !HLFHE.eint<6>) -> (!HLFHE.eint<6>)
return %a_plus_b: !HLFHE.eint<6>
}
)XXX";
ASSERT_FALSE(engine.compile(mlirStr));
uint64_t arg0 = 23;
uint64_t arg1 = 7;
uint64_t expected = 30;
auto maybeResult = engine.run({arg0, arg1});
ASSERT_TRUE((bool)maybeResult);
uint64_t result = maybeResult.get();
ASSERT_EQ(result, expected);
)XXX");
ASSERT_EXPECTED_VALUE(lambda(23_u64, 7_u64), 30);
}
TEST(CompileAndRunTLU, random_func) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<6>) -> !HLFHE.eint<6> {
%tlu = std.constant dense<[16, 91, 16, 83, 80, 74, 21, 96, 1, 63, 49, 122, 76, 89, 74, 55, 109, 110, 103, 54, 105, 14, 66, 47, 52, 89, 7, 10, 73, 44, 119, 92, 25, 104, 123, 100, 108, 86, 29, 121, 118, 52, 107, 48, 34, 37, 13, 122, 107, 48, 74, 59, 96, 36, 50, 55, 120, 72, 27, 45, 12, 5, 96, 12]> : tensor<64xi64>
%1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<6>, tensor<64xi64>) -> (!HLFHE.eint<6>)
return %1: !HLFHE.eint<6>
}
)XXX";
ASSERT_FALSE(engine.compile(mlirStr));
// first value
auto maybeResult = engine.run({5});
ASSERT_TRUE((bool)maybeResult);
uint64_t result = maybeResult.get();
ASSERT_EQ(result, 74);
// second value
maybeResult = engine.run({62});
ASSERT_TRUE((bool)maybeResult);
result = maybeResult.get();
ASSERT_EQ(result, 96);
// edge value low
maybeResult = engine.run({0});
ASSERT_TRUE((bool)maybeResult);
result = maybeResult.get();
ASSERT_EQ(result, 16);
// edge value high
maybeResult = engine.run({63});
ASSERT_TRUE((bool)maybeResult);
result = maybeResult.get();
ASSERT_EQ(result, 12);
)XXX");
ASSERT_EXPECTED_VALUE(lambda(5_u64), 74);
ASSERT_EXPECTED_VALUE(lambda(62_u64), 96);
ASSERT_EXPECTED_VALUE(lambda(0_u64), 16);
ASSERT_EXPECTED_VALUE(lambda(63_u64), 12);
}