enhance(compiler/runtime): Add runtime tools to handle tensor inputs and outputs

This commit is contained in:
Quentin Bourgerie
2021-08-24 15:02:45 +02:00
parent dba76a1e1b
commit af0789f128
11 changed files with 701 additions and 73 deletions

View File

@@ -1,5 +1,5 @@
#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_H_
#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_H_
#include <cstddef>
namespace mlir {

View File

@@ -56,7 +56,10 @@ struct EncryptionGate {
};
struct CircuitGateShape {
uint64_t size;
// Width of the scalar value
size_t width;
// Size of the buffer
size_t size;
};
struct CircuitGate {

View File

@@ -15,6 +15,9 @@
namespace mlir {
namespace zamalang {
/// CompilerEngine is an tools that provides tools to implements the compilation
/// flow and manage the compilation flow state.
class CompilerEngine {
public:
CompilerEngine() {
@@ -26,10 +29,16 @@ public:
delete context;
}
// Compile an MLIR input
llvm::Expected<mlir::LogicalResult> compileFHE(std::string mlir_input);
// Compile an mlir programs from it's textual representation.
llvm::Error compile(std::string mlirStr);
// Run the compiled module
// Build the jit lambda argument.
llvm::Expected<std::unique_ptr<JITLambda::Argument>> buildArgument();
// Call the compiled function with and argument object.
llvm::Error invoke(JITLambda::Argument &arg);
// Call the compiled function with a list of integer arguments.
llvm::Expected<uint64_t> run(std::vector<uint64_t> args);
// Get a printable representation of the compiled module

View File

@@ -51,17 +51,54 @@ public:
// and decryption operations.
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
// Set the argument at the given pos as a uint64_t.
// Set a scalar argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint64_t arg);
// Set a argument at the given pos as a tensor of int64.
llvm::Error setArg(size_t pos, uint64_t *data, size_t size) {
return setArg(pos, 64, (void *)data, size);
}
// Set a argument at the given pos as a tensor of int32.
llvm::Error setArg(size_t pos, uint32_t *data, size_t size) {
return setArg(pos, 32, (void *)data, size);
}
// Set a argument at the given pos as a tensor of int32.
llvm::Error setArg(size_t pos, uint16_t *data, size_t size) {
return setArg(pos, 16, (void *)data, size);
}
// Set a tensor argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint8_t *data, size_t size) {
return setArg(pos, 8, (void *)data, size);
}
// Get the result at the given pos as an uint64_t.
llvm::Error getResult(size_t pos, uint64_t &res);
// Fill the result.
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
private:
llvm::Error setArg(size_t pos, size_t width, void *data, size_t size);
friend JITLambda;
// Store the pointer on inputs values and outputs values
std::vector<void *> rawArg;
// Store the values of inputs
std::vector<void *> inputs;
std::vector<void *> results;
// Store the values of outputs
std::vector<void *> outputs;
// Store the input gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> inputGates;
// Store the outputs gates description and the offset of the argument.
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> outputGates;
// Store allocated lwe ciphertexts (for free)
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
// Store buffers of ciphertexts
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
KeySet &keySet;
};
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)

View File

@@ -37,6 +37,9 @@ public:
size_t numInputs() { return inputs.size(); }
size_t numOutputs() { return outputs.size(); }
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
protected:
llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
SecretRandomGenerator *generator);

View File

@@ -14,12 +14,25 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
Precision precision,
mlir::Type type) {
if (type.isIntOrIndex()) {
// TODO - The index type is dependant of the target architecture, so
// actually we assume we target only 64 bits, we need to have some the size
// of the word of the target system.
size_t width = 64;
if (!type.isIndex()) {
width = type.getIntOrFloatBitWidth();
}
return CircuitGate{
.encryption = llvm::None,
.shape = {.size = 0},
.shape =
{
.width = width,
.size = 0,
},
};
}
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
// TODO - Get the width from the LWECiphertextType instead of global
// precision (could be possible after merge lowlfhe-ciphertext-parameter)
return CircuitGate{
.encryption = llvm::Optional<EncryptionGate>({
.secretKeyID = secretKeyID,
@@ -27,17 +40,17 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
.variance = 0.,
.encoding = {.precision = precision},
}),
.shape = {.size = 0},
.shape = {.width = precision, .size = 0},
};
}
auto memref = type.dyn_cast_or_null<mlir::MemRefType>();
if (memref != nullptr) {
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensor != nullptr) {
auto gate =
gateFromMLIRType(secretKeyID, precision, memref.getElementType());
gateFromMLIRType(secretKeyID, precision, tensor.getElementType());
if (auto err = gate.takeError()) {
return std::move(err);
}
gate->shape.size = memref.getDimSize(0);
gate->shape.size = tensor.getDimSize(0);
return gate;
}
return llvm::make_error<llvm::StringError>(

View File

@@ -23,9 +23,8 @@ std::string CompilerEngine::getCompiledModule() {
return os.str();
}
llvm::Expected<mlir::LogicalResult>
CompilerEngine::compileFHE(std::string mlir_input) {
module_ref = mlir::parseSourceString(mlir_input, context);
llvm::Error CompilerEngine::compile(std::string mlirStr) {
module_ref = mlir::parseSourceString(mlirStr, context);
if (!module_ref) {
return llvm::make_error<llvm::StringError>("mlir parsing failed",
llvm::inconvertibleErrorCode());
@@ -60,29 +59,44 @@ CompilerEngine::compileFHE(std::string mlir_input) {
return llvm::make_error<llvm::StringError>(
"failed to lower to LLVM dialect", llvm::inconvertibleErrorCode());
}
return mlir::success();
return llvm::Error::success();
}
llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
llvm::Expected<std::unique_ptr<JITLambda::Argument>>
CompilerEngine::buildArgument() {
if (keySet.get() == nullptr) {
return llvm::make_error<llvm::StringError>(
"CompilerEngine::buildArgument: invalid engine state, the keySet has "
"not be generated",
llvm::inconvertibleErrorCode());
}
return JITLambda::Argument::create(*keySet);
}
llvm::Error CompilerEngine::invoke(JITLambda::Argument &arg) {
// Create the JIT lambda
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
auto module = module_ref.get();
auto maybeLambda =
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
if (!maybeLambda) {
return llvm::make_error<llvm::StringError>("couldn't create lambda",
llvm::inconvertibleErrorCode());
if (auto err = maybeLambda.takeError()) {
return std::move(err);
}
auto lambda = std::move(maybeLambda.get());
// Invoke the lambda
if (auto err = maybeLambda.get()->invoke(arg)) {
return std::move(err);
}
return llvm::Error::success();
}
// Create the arguments of the JIT lambda
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(*keySet);
if (auto err = maybeArguments.takeError()) {
return llvm::make_error<llvm::StringError>("cannot create lambda args",
llvm::inconvertibleErrorCode());
llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
// Build the argument of the JIT lambda.
auto maybeArgument = buildArgument();
if (auto err = maybeArgument.takeError()) {
return std::move(err);
}
// Set the arguments
auto arguments = std::move(maybeArguments.get());
// Set the integer arguments
auto arguments = std::move(maybeArgument.get());
for (auto i = 0; i < args.size(); i++) {
if (auto err = arguments->setArg(i, args[i])) {
return llvm::make_error<llvm::StringError>(
@@ -90,14 +104,12 @@ llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
}
}
// Invoke the lambda
if (lambda->invoke(*arguments)) {
return llvm::make_error<llvm::StringError>("failed execution",
llvm::inconvertibleErrorCode());
if (auto err = invoke(*arguments)) {
return std::move(err);
}
uint64_t res = 0;
if (auto err = arguments->getResult(0, res)) {
return llvm::make_error<llvm::StringError>("cannot get result",
llvm::inconvertibleErrorCode());
return std::move(err);
}
return res;
}

View File

@@ -141,10 +141,16 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
}
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
if (this->type.getNumParams() != args.size() - 1) {
return llvm::make_error<llvm::StringError>(
"invokeRaw: wrong number of argument", llvm::inconvertibleErrorCode());
}
size_t nbReturn = 0;
// TODO - This check break with memref as we have 5 returns args.
// if (!this->type.getReturnType().isa<mlir::LLVM::LLVMVoidType>()) {
// nbReturn = 1;
// }
// if (this->type.getNumParams() != args.size() - nbReturn) {
// return llvm::make_error<llvm::StringError>(
// "invokeRaw: wrong number of argument",
// llvm::inconvertibleErrorCode());
// }
if (llvm::find(args, nullptr) != args.end()) {
return llvm::make_error<llvm::StringError>(
"invoke: some arguments are null", llvm::inconvertibleErrorCode());
@@ -157,24 +163,58 @@ llvm::Error JITLambda::invoke(Argument &args) {
}
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
inputs = std::vector<void *>(keySet.numInputs());
results = std::vector<void *>(keySet.numOutputs());
// Setting the inputs
{
auto numInputs = 0;
for (size_t i = 0; i < keySet.numInputs(); i++) {
auto offset = numInputs;
auto gate = keySet.inputGate(i);
inputGates.push_back({gate, offset});
if (keySet.inputGate(i).shape.size == 0) {
// scalar gate
numInputs = numInputs + 1;
continue;
}
// memref gate, as we follow the standard calling convention
numInputs = numInputs + 5;
}
inputs = std::vector<void *>(numInputs);
}
// Setting the outputs
{
auto numOutputs = 0;
for (size_t i = 0; i < keySet.numOutputs(); i++) {
auto offset = numOutputs;
auto gate = keySet.outputGate(i);
outputGates.push_back({gate, offset});
if (gate.shape.size == 0) {
// scalar gate
numOutputs = numOutputs + 1;
continue;
}
// memref gate, as we follow the standard calling convention
numOutputs = numOutputs + 5;
}
outputs = std::vector<void *>(numOutputs);
}
// The raw argument contains pointers to inputs and pointers to store the
// results
rawArg =
std::vector<void *>(keySet.numInputs() + keySet.numOutputs(), nullptr);
// Set the results pointer on the rawArg
for (auto i = keySet.numInputs(); i < rawArg.size(); i++) {
rawArg[i] = &results[i - keySet.numInputs()];
rawArg = std::vector<void *>(inputs.size() + outputs.size(), nullptr);
// Set the pointer on outputs on rawArg
for (auto i = inputs.size(); i < rawArg.size(); i++) {
rawArg[i] = &outputs[i - inputs.size()];
}
}
JITLambda::Argument::~Argument() {
int err;
for (auto i = 0; i < keySet.numInputs(); i++) {
if (keySet.isInputEncrypted(i)) {
free_lwe_ciphertext_u64(&err, (LweCiphertext_u64 *)(inputs[i]));
}
for (auto ct : allocatedCiphertexts) {
free_lwe_ciphertext_u64(&err, ct);
}
for (auto buffer : ciphertextBuffers) {
free(buffer);
}
}
@@ -185,38 +225,206 @@ JITLambda::Argument::create(KeySet &keySet) {
}
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
if (pos >= inputGates.size()) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument index out of bound: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
auto gate = inputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check is the argument is a scalar
if (info.shape.size != 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// If argument is not encrypted, just save.
if (!keySet.isInputEncrypted(pos)) {
inputs[pos] = (void *)arg;
rawArg[pos] = &inputs[pos];
if (!info.encryption.hasValue()) {
inputs[offset] = (void *)arg;
rawArg[offset] = &inputs[offset];
return llvm::Error::success();
}
// Else if is encryted, allocate ciphertext.
// Else if is encryted, allocate ciphertext and encrypt.
LweCiphertext_u64 *ctArg;
if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) {
return std::move(err);
}
allocatedCiphertexts.push_back(ctArg);
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
return std::move(err);
}
inputs[pos] = ctArg;
rawArg[pos] = &inputs[pos];
inputs[offset] = ctArg;
rawArg[offset] = &inputs[offset];
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
size_t size) {
auto gate = inputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check if the width is compatible
// TODO - I found this rules empirically, they are a spec somewhere?
if (info.shape.width <= 8 && width != 8) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 8: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 16: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 32: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 64: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.width > 64) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width not supported: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// Check the size
if (info.shape.size == 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (info.shape.size != size) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("vector argument has not the expected size")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// If argument is not encrypted, just save with the right calling convention.
if (info.encryption.hasValue()) {
// Else if is encrypted
// For moment we support only 8 bits inputs
uint8_t *data8 = (uint8_t *)data;
if (width != 8) {
return llvm::make_error<llvm::StringError>(
llvm::Twine(
"argument width > 8 for encrypted gates are not supported: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// Allocate a buffer for ciphertexts.
auto ctBuffer =
(LweCiphertext_u64 **)malloc(size * sizeof(LweCiphertext_u64 *));
ciphertextBuffers.push_back(ctBuffer);
// Allocate ciphertexts and encrypt
for (auto i = 0; i < size; i++) {
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
return std::move(err);
}
allocatedCiphertexts.push_back(ctBuffer[i]);
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
return std::move(err);
}
}
// Replace the data by the buffer to ciphertext
data = (void *)ctBuffer;
}
// Set the buffer as the memref calling convention expect.
// allocated
inputs[offset] = (void *)0; // TODO - Better understand how it is used.
rawArg[offset] = &inputs[offset];
// aligned
inputs[offset + 1] = data;
rawArg[offset + 1] = &inputs[offset + 1];
// offset
inputs[offset + 2] = (void *)0;
rawArg[offset + 2] = &inputs[offset + 2];
// size
inputs[offset + 3] = (void *)size;
rawArg[offset + 3] = &inputs[offset + 3];
// stride
inputs[offset + 4] = (void *)0;
rawArg[offset + 4] = &inputs[offset + 4];
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
auto gate = outputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check is the argument is a scalar
if (info.shape.size != 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
// If result is not encrypted, just set the result
if (!keySet.isOutputEncrypted(pos)) {
res = (uint64_t)(results[pos]);
if (!info.encryption.hasValue()) {
res = (uint64_t)(outputs[offset]);
return llvm::Error::success();
}
// Else if is encryted, decrypt
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(results[pos]);
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]);
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
return std::move(err);
}
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
size_t size) {
auto gate = outputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check is the argument is a scalar
if (info.shape.size == 0) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (!info.encryption.hasValue()) {
return llvm::make_error<llvm::StringError>(
"unencrypted result as tensor output NYI",
llvm::inconvertibleErrorCode());
}
// Get the values as the memref calling convention expect.
void *allocated = outputs[offset]; // TODO - Better understand how it is used.
// aligned
void *aligned = outputs[offset + 1];
// offset
size_t offset_r = (size_t)outputs[offset + 2];
// size
size_t size_r = (size_t)outputs[offset + 3];
// stride
size_t stride = (size_t)outputs[offset + 4];
// Check the sizes
if (info.shape.size != size || size_r != size) {
return llvm::make_error<llvm::StringError>("output bad result buffer size",
llvm::inconvertibleErrorCode());
}
// decrypt and fill the result buffer
for (auto i = 0; i < size_r; i++) {
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i];
if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) {
return std::move(err);
}
}
return llvm::Error::success();
}
} // namespace zamalang
} // namespace mlir
} // namespace mlir

View File

@@ -42,7 +42,7 @@ KeySet::generate(ClientParameters &params, uint64_t seed_msb,
auto e = keySet->generateSecretKey(secretKeyParam.first,
secretKeyParam.second, generator);
if (e) {
return e;
return std::move(e);
}
}
CAPI_ERR_TO_LLVM_ERROR(free_secret_generator(&err, generator),
@@ -60,7 +60,7 @@ KeySet::generate(ClientParameters &params, uint64_t seed_msb,
bootstrapKeyParam.second,
keySet->encryptionRandomGenerator);
if (e) {
return e;
return std::move(e);
}
}
for (auto keyswitchParam : params.keyswitchKeys) {
@@ -68,7 +68,7 @@ KeySet::generate(ClientParameters &params, uint64_t seed_msb,
keyswitchParam.second,
keySet->encryptionRandomGenerator);
if (e) {
return e;
return std::move(e);
}
}
}
@@ -112,9 +112,8 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
LweSecretKeyParam param,
SecretRandomGenerator *generator) {
LweSecretKey_u64 *sk;
CAPI_ERR_TO_LLVM_ERROR(
sk = allocate_lwe_secret_key_u64(&err, {_0 : param.size}),
"cannot allocate secret key");
CAPI_ERR_TO_LLVM_ERROR(sk = allocate_lwe_secret_key_u64(&err, {param.size}),
"cannot allocate secret key");
CAPI_ERR_TO_LLVM_ERROR(fill_lwe_secret_key_u64(&err, sk, generator),
"cannot fill secret key with random generator")
secretKeys[id] = {param, sk};
@@ -250,6 +249,7 @@ llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
uint64_t &output) {
if (argPos >= outputs.size()) {
return llvm::make_error<llvm::StringError>(
"decrypt_lwe: position of argument is too high",
@@ -262,13 +262,14 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
llvm::inconvertibleErrorCode());
}
// Decrypt
Plaintext_u64 plaintext;
Plaintext_u64 plaintext = {0};
CAPI_ERR_TO_LLVM_ERROR(
decrypt_lwe_u64(&err, std::get<2>(outputSk), ciphertext, &plaintext),
"cannot decrypt");
// Decode
output = plaintext._0 >>
(64 - (std::get<0>(outputSk).encryption->encoding.precision + 1));
return llvm::Error::success();
}

View File

@@ -1,5 +1,8 @@
enable_testing()
include_directories(${PROJECT_SOURCE_DIR}/include)
add_executable(
hello_test
hello_test.cc
@@ -7,6 +10,7 @@ add_executable(
target_link_libraries(
hello_test
gtest_main
ZamalangSupport
)
include(GoogleTest)

View File

@@ -1,9 +1,347 @@
#include <gtest/gtest.h>
// Demonstrate some basic assertions.
TEST(HelloTest, BasicAssertions) {
// Expect two strings not to be equal.
EXPECT_STRNE("hello", "world");
// Expect equality.
EXPECT_EQ(7 * 6, 42);
#include "zamalang/Support/CompilerEngine.h"
#define ASSERT_LLVM_ERROR(err) \
if (err) { \
llvm::errs() << "error: " << std::move(err) << "\n"; \
ASSERT_TRUE(false); \
}
TEST(CompileAndRunHLFHE, add_eint) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
)XXX";
ASSERT_FALSE(engine.compile(mlirStr));
auto maybeResult = engine.run({1, 2});
ASSERT_TRUE((bool)maybeResult);
uint64_t result = maybeResult.get();
ASSERT_EQ(result, 3);
}
TEST(CompileAndRunTensorStd, extract_64) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10xi64>, %i: index) -> i64{
%c = tensor.extract %t[%i] : tensor<10xi64>
return %c : i64
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
const size_t size = 10;
uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF,
0,
8978,
2587490,
90,
197864,
698735,
72132,
87474,
42};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
}
TEST(CompileAndRunTensorStd, extract_32) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10xi32>, %i: index) -> i32{
%c = tensor.extract %t[%i] : tensor<10xi32>
return %c : i32
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
const size_t size = 10;
uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90,
197864, 698735, 72132, 87474, 42};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
}
TEST(CompileAndRunTensorStd, extract_16) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10xi16>, %i: index) -> i16{
%c = tensor.extract %t[%i] : tensor<10xi16>
return %c : i16
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
const size_t size = 10;
uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227,
63269, 36435, 52380, 7401, 13313};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
}
TEST(CompileAndRunTensorStd, extract_8) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10xi8>, %i: index) -> i8{
%c = tensor.extract %t[%i] : tensor<10xi8>
return %c : i8
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
const size_t size = 10;
uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
}
TEST(CompileAndRunTensorStd, extract_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10xi5>, %i: index) -> i5{
%c = tensor.extract %t[%i] : tensor<10xi5>
return %c : i5
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
}
TEST(CompileAndRunTensorStd, extract_1) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10xi1>, %i: index) -> i1{
%c = tensor.extract %t[%i] : tensor<10xi1>
return %c : i1
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
const size_t size = 10;
uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
}
TEST(CompileAndRunTensorEncrypted, extract_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{
%c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
return %c : !HLFHE.eint<5>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
}
TEST(CompileAndRunTensorEncrypted, extract_twice_and_add_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5>{
%ti = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
%tj = tensor.extract %t[%j] : tensor<10x!HLFHE.eint<5>>
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> !HLFHE.eint<5>
return %c : !HLFHE.eint<5>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
for (size_t j = 0; j < size; j++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Set the %j argument
ASSERT_LLVM_ERROR(argument->setArg(2, j));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i] + t_arg[j]);
}
}
}
TEST(CompileAndRunTensorEncrypted, dim_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{
%c0 = constant 0 : index
%c = tensor.dim %t, %c0 : tensor<10x!HLFHE.eint<5>>
return %c : index
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, size);
}
TEST(CompileAndRunTensorEncrypted, from_elements_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
%t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>>
return %t: tensor<1x!HLFHE.eint<5>>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, 10));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
size_t size_res = 1;
uint64_t t_res[size_res];
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
ASSERT_EQ(t_res[0], 10);
}
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {
%c_0 = constant 0 : index
%c_1 = constant 1 : index
%a = tensor.extract %in[%c_0] : tensor<2x!HLFHE.eint<5>>
%b = tensor.extract %in[%c_1] : tensor<2x!HLFHE.eint<5>>
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
%aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
%bplusb = "HLFHE.add_eint"(%b, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
%out = tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
return %out: tensor<3x!HLFHE.eint<5>>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the argument
const size_t in_size = 2;
uint8_t in[in_size] = {2, 16};
ASSERT_LLVM_ERROR(argument->setArg(0, in, in_size));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
const size_t size_res = 3;
uint64_t t_res[size_res];
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
ASSERT_EQ(t_res[0], in[0] + in[0]);
ASSERT_EQ(t_res[1], in[0] + in[1]);
ASSERT_EQ(t_res[2], in[1] + in[1]);
}