feat(Testlib): lib for testing libs generated by concretecompiler

Closes #201
This commit is contained in:
rudy
2022-01-03 11:05:20 +01:00
committed by rudy-6-4
parent d1fef75aea
commit 58e02fd035
19 changed files with 2489 additions and 12 deletions

3
.gitignore vendored
View File

@@ -48,3 +48,6 @@ _build/
# macOS
.DS_Store
compiler/tests/TestLib/out/

View File

@@ -69,12 +69,18 @@ test-check: concretecompiler file-check not
test-python: python-bindings concretecompiler
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so pytest -vs tests/python
test: test-check test-end-to-end-jit test-python support-unit-test
test: test-check test-end-to-end-jit test-python support-unit-test testlib-unit-test
test-dataflow: test-end-to-end-jit-dfr test-end-to-end-jit-auto-parallelization
# unit-test
testlib-unit-test: build-testlib-unit-test
$(BUILD_DIR)/bin/testlib_unit_test
build-testlib-unit-test: build-initialized
cmake --build $(BUILD_DIR) --target testlib_unit_test
support-unit-test: build-support-unit-test
$(BUILD_DIR)/bin/support_unit_test

View File

@@ -15,6 +15,8 @@
namespace mlir {
namespace concretelang {
size_t bitWidthAsWord(size_t exactBitWidth);
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
/// of the module.
class JITLambda {

View File

@@ -0,0 +1,105 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#ifndef CONCRETELANG_TESTLIB_ARGUMENTS_H
#define CONCRETELANG_TESTLIB_ARGUMENTS_H
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySet.h"
namespace mlir {
namespace concretelang {
class DynamicLambda;
class Arguments {
public:
Arguments(KeySet &keySet) : currentPos(0), keySet(keySet) {
keySet.setRuntimeContext(context);
}
~Arguments();
// Create EncryptedArgument that use the given KeySet to perform encryption
// and decryption operations.
static std::shared_ptr<Arguments> create(KeySet &keySet);
// Add a scalar argument.
llvm::Error pushArg(uint64_t arg);
// Add a vector-tensor argument.
llvm::Error pushArg(std::vector<uint8_t> arg);
template <size_t size> llvm::Error pushArg(std::array<uint8_t, size> arg) {
return pushArg(8, (void *)arg.data(), {size});
}
// Add a matrix-tensor argument.
template <size_t size0, size_t size1>
llvm::Error pushArg(std::array<std::array<uint8_t, size1>, size0> arg) {
return pushArg(8, (void *)arg.data(), {size0, size1});
}
// Add a rank3 tensor.
template <size_t size0, size_t size1, size_t size2>
llvm::Error pushArg(
std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg) {
return pushArg(8, (void *)arg.data(), {size0, size1, size2});
}
// Generalize by computing shape by template recursion
// Set a argument at the given pos as a 1D tensor of T.
template <typename T> llvm::Error pushArg(T *data, int64_t dim1) {
return pushArg<T>(data, llvm::ArrayRef<int64_t>(&dim1, 1));
}
// Set a argument at the given pos as a tensor of T.
template <typename T>
llvm::Error pushArg(T *data, llvm::ArrayRef<int64_t> shape) {
return pushArg(8 * sizeof(T), static_cast<void *>(data), shape);
}
llvm::Error pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape);
// Push the runtime context to the argument list, this must be called
// after each argument was pushed.
llvm::Error pushContext();
template <typename Arg0, typename... OtherArgs>
llvm::Error pushArgs(Arg0 arg0, OtherArgs... others) {
auto err = pushArg(arg0);
if (err) {
return err;
}
return pushArgs(others...);
}
llvm::Error pushArgs() { return pushContext(); }
private:
friend DynamicLambda;
template <typename Result>
friend llvm::Expected<Result> invoke(DynamicLambda &lambda,
const Arguments &args);
llvm::Error checkPushTooManyArgs();
// Position of the next pushed argument
size_t currentPos;
std::vector<void *> preparedArgs;
// Store allocated lwe ciphertexts (for free)
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
// Store buffers of ciphertexts
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
KeySet &keySet;
RuntimeContext context;
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,123 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_LAMBDA_H
#define CONCRETELANG_TESTLIB_DYNAMIC_LAMBDA_H
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/TestLib/Arguments.h"
#include "concretelang/TestLib/DynamicModule.h"
namespace mlir {
namespace concretelang {
template <size_t N> struct MemRefDescriptor;
template <typename Result>
llvm::Expected<Result> invoke(DynamicLambda &lambda, const Arguments &args) {
// compile time error if used
using COMPATIBLE_RESULT_TYPE = void;
return (Result)(
COMPATIBLE_RESULT_TYPE)0; // invoke does not accept this kind of Result
}
template <>
llvm::Expected<u_int64_t> invoke<u_int64_t>(DynamicLambda &lambda,
const Arguments &args);
template <>
llvm::Expected<std::vector<uint64_t>>
invoke<std::vector<uint64_t>>(DynamicLambda &lambda, const Arguments &args);
template <>
llvm::Expected<std::vector<std::vector<uint64_t>>>
invoke<std::vector<std::vector<uint64_t>>>(DynamicLambda &lambda,
const Arguments &args);
template <>
llvm::Expected<std::vector<std::vector<std::vector<uint64_t>>>>
invoke<std::vector<std::vector<std::vector<uint64_t>>>>(DynamicLambda &lambda,
const Arguments &args);
class DynamicLambda {
private:
template <typename... Args>
llvm::Expected<std::shared_ptr<Arguments>> createArguments(Args... args) {
if (keySet == nullptr) {
return StreamStringError("keySet was not initialized");
}
auto arg = Arguments::create(*keySet);
auto err = arg->pushArgs(args...);
if (err) {
return StreamStringError(llvm::toString(std::move(err)));
}
return arg;
}
public:
static llvm::Expected<DynamicLambda> load(std::string funcName,
std::string outputLib);
static llvm::Expected<DynamicLambda>
load(std::shared_ptr<DynamicModule> module, std::string funcName);
template <typename Result, typename... Args>
llvm::Expected<Result> call(Args... args) {
auto argOrErr = createArguments(args...);
if (!argOrErr) {
return argOrErr.takeError();
}
auto arg = argOrErr.get();
return invoke<Result>(*this, *arg);
}
llvm::Error generateKeySet(llvm::Optional<KeySetCache> cache = llvm::None,
uint64_t seed_msb = 0, uint64_t seed_lsb = 0);
protected:
template <typename Result>
friend llvm::Expected<Result> invoke(DynamicLambda &lambda,
const Arguments &args);
template <size_t Rank>
llvm::Expected<MemRefDescriptor<Rank>>
invokeMemRefDecriptor(const Arguments &args);
ClientParameters clientParameters;
std::shared_ptr<KeySet> keySet;
void *(*func)(void *...);
// Retain module and open shared lib alive
std::shared_ptr<DynamicModule> module;
};
template <typename Result, typename... Args>
class TypedDynamicLambda : public DynamicLambda {
public:
static llvm::Expected<TypedDynamicLambda<Result, Args...>>
load(std::string funcName, std::string outputLib) {
auto lambda = DynamicLambda::load(funcName, outputLib);
if (!lambda) {
return lambda.takeError();
}
return TypedDynamicLambda(*lambda);
}
llvm::Expected<Result> call(Args... args) {
return DynamicLambda::call<Result>(args...);
}
// TODO: check parameter types
TypedDynamicLambda(DynamicLambda &lambda) : DynamicLambda(lambda) {
// TODO: add static check on types vs lambda inputs/outpus
}
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,33 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
#define CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
#include "concretelang/ClientLib/ClientParameters.h"
namespace mlir {
namespace concretelang {
class DynamicModule {
public:
~DynamicModule();
static llvm::Expected<std::shared_ptr<DynamicModule>>
open(std::string libraryPath);
private:
llvm::Error loadClientParametersJSON(std::string path);
llvm::Error loadSharedLibrary(std::string path);
private:
std::vector<ClientParameters> clientParametersList;
void *libraryHandle;
friend class DynamicLambda;
};
} // namespace concretelang
} // namespace mlir
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
for i in range(128):
args = ','.join(f'args[{j}]' for j in range(i))
print(f' case {i}: return func({args});')

View File

@@ -4,6 +4,7 @@ add_subdirectory(Support)
add_subdirectory(Runtime)
add_subdirectory(ClientLib)
add_subdirectory(Bindings)
add_subdirectory(TestLib)
# CAPI needed only for python bindings
if (CONCRETELANG_BINDINGS_PYTHON_ENABLED)

View File

@@ -0,0 +1,161 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#include "concretelang/TestLib/Arguments.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/Jit.h"
namespace mlir {
namespace concretelang {
Arguments::~Arguments() {
for (auto ct : allocatedCiphertexts) {
int err;
free_lwe_ciphertext_u64(&err, ct);
}
for (auto ctBuffer : ciphertextBuffers) {
free(ctBuffer);
}
}
std::shared_ptr<Arguments> Arguments::create(KeySet &keySet) {
auto args = std::make_shared<Arguments>(keySet);
return args;
}
llvm::Error Arguments::pushArg(uint64_t arg) {
if (auto err = checkPushTooManyArgs()) {
return err;
}
auto pos = currentPos++;
CircuitGate input = keySet.inputGate(pos);
if (input.shape.size != 0) {
return StreamStringError("argument #") << pos << " is not a scalar";
}
if (!input.encryption.hasValue()) {
// clear scalar: just push the argument
if (input.shape.width != 64) {
return StreamStringError(
"scalar argument of with != 64 is not supported for DynamicLambda");
}
preparedArgs.push_back((void *)arg);
return llvm::Error::success();
}
// encrypted scalar: allocate, encrypt and push
LweCiphertext_u64 *ctArg;
if (auto err = keySet.allocate_lwe(pos, &ctArg)) {
return err;
}
allocatedCiphertexts.push_back(ctArg);
if (auto err = keySet.encrypt_lwe(pos, ctArg, arg)) {
return err;
}
preparedArgs.push_back((void *)ctArg);
return llvm::Error::success();
}
llvm::Error Arguments::pushArg(std::vector<uint8_t> arg) {
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()});
}
llvm::Error Arguments::pushArg(size_t width, void *data,
llvm::ArrayRef<int64_t> shape) {
if (auto err = checkPushTooManyArgs()) {
return err;
}
auto pos = currentPos;
currentPos = currentPos + 1;
CircuitGate input = keySet.inputGate(pos);
// Check the width of data
if (input.shape.width > 64) {
return StreamStringError("argument #")
<< pos << " width > 64 bits is not supported";
}
auto roundedSize = bitWidthAsWord(input.shape.width);
if (width != roundedSize) {
return StreamStringError("argument #")
<< pos << "width mismatch, got " << width << " expected "
<< roundedSize;
}
// Check the shape of tensor
if (input.shape.dimensions.empty()) {
return StreamStringError("argument #") << pos << "is not a tensor";
}
if (shape.size() != input.shape.dimensions.size()) {
return StreamStringError("argument #")
<< pos << "has not the expected number of dimension, got "
<< shape.size() << " expected " << input.shape.dimensions.size();
}
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i] != input.shape.dimensions[i]) {
return StreamStringError("argument #")
<< pos << " has not the expected dimension #" << i << " , got "
<< shape[i] << " expected " << input.shape.dimensions[i];
}
}
if (input.encryption.hasValue()) {
// Encrypted tensor: for now we support only 8 bits for encrypted tensor
if (width != 8) {
return StreamStringError("argument #")
<< pos << " width mismatch, expected 8 got " << width;
}
const uint8_t *data8 = (const uint8_t *)data;
// Allocate a buffer for ciphertexts of size of tensor
auto ctBuffer = (LweCiphertext_u64 **)malloc(input.shape.size *
sizeof(LweCiphertext_u64 *));
ciphertextBuffers.push_back(ctBuffer);
// Allocate ciphertexts and encrypt, for every values in tensor
for (size_t i = 0; i < input.shape.size; i++) {
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
return err;
}
allocatedCiphertexts.push_back(ctBuffer[i]);
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
return err;
}
}
// Replace the data by the buffer to ciphertext
data = (void *)ctBuffer;
}
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back(data);
// offset
preparedArgs.push_back((void *)0);
// sizes
for (size_t i = 0; i < shape.size(); i++) {
preparedArgs.push_back((void *)shape[i]);
}
// strides - FIXME make it works
// strides is an array of size equals to numDim
for (size_t i = 0; i < shape.size(); i++) {
preparedArgs.push_back((void *)0);
}
return llvm::Error::success();
}
llvm::Error Arguments::pushContext() {
if (currentPos < keySet.numInputs()) {
return StreamStringError("Missing arguments");
}
preparedArgs.push_back(&context);
return llvm::Error::success();
}
llvm::Error Arguments::checkPushTooManyArgs() {
size_t arity = keySet.numInputs();
if (currentPos < arity) {
return llvm::Error::success();
}
return StreamStringError("function has arity ")
<< arity << " but is applied to too many arguments";
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,16 @@
add_mlir_library(ConcretelangTestLib
Arguments.cpp
DynamicLambda.cpp
DynamicModule.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/TestLib
DEPENDS
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
ConcretelangSupport
ConcretelangClientLib
)

View File

@@ -0,0 +1,224 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#include <dlfcn.h>
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/TestLib/DynamicLambda.h"
#include "concretelang/TestLib/dynamicArityCall.h"
namespace mlir {
namespace concretelang {
template <size_t N> struct MemRefDescriptor {
LweCiphertext_u64 **allocated;
LweCiphertext_u64 **aligned;
size_t offset;
size_t sizes[N];
size_t strides[N];
};
llvm::Expected<std::vector<uint64_t>> decryptSlice(LweCiphertext_u64 **aligned,
KeySet &keySet, size_t start,
size_t size,
size_t stride = 1) {
stride = (stride == 0) ? 1 : stride;
std::vector<uint64_t> result(size);
for (size_t i = 0; i < size; i++) {
size_t offset = start + i * stride;
auto err = keySet.decrypt_lwe(0, aligned[offset], result[i]);
if (err) {
return StreamStringError()
<< "cannot decrypt result #" << i << ", err:" << err;
}
}
return result;
}
llvm::Expected<mlir::concretelang::DynamicLambda>
DynamicLambda::load(std::string funcName, std::string outputLib) {
auto moduleOrErr = mlir::concretelang::DynamicModule::open(outputLib);
if (!moduleOrErr) {
return moduleOrErr.takeError();
}
return mlir::concretelang::DynamicLambda::load(*moduleOrErr, funcName);
}
llvm::Expected<DynamicLambda>
DynamicLambda::load(std::shared_ptr<DynamicModule> module,
std::string funcName) {
DynamicLambda lambda;
lambda.module =
module; // prevent module and library handler from being destroyed
lambda.func =
(void *(*)(void *, ...))dlsym(module->libraryHandle, funcName.c_str());
if (auto err = dlerror()) {
return StreamStringError("Cannot open lambda: ") << err;
}
auto param =
llvm::find_if(module->clientParametersList, [&](ClientParameters param) {
return param.functionName == funcName;
});
if (param == module->clientParametersList.end()) {
return StreamStringError("cannot find function ")
<< funcName << "in client parameters";
}
if (param->outputs.size() != 1) {
return StreamStringError("DynamicLambda: output arity (")
<< std::to_string(param->outputs.size())
<< ") != 1 is not supported";
}
if (!param->outputs[0].encryption.hasValue()) {
return StreamStringError(
"DynamicLambda: clear output is not yet supported");
}
lambda.clientParameters = *param;
return lambda;
}
template <>
llvm::Expected<uint64_t> invoke<uint64_t>(DynamicLambda &lambda,
const Arguments &args) {
auto output = lambda.clientParameters.outputs[0];
if (output.shape.size != 0) {
return StreamStringError("the function doesn't return a scalar");
}
// Scalar encrypted result
auto fCasted = (LweCiphertext_u64 * (*)(void *...))(lambda.func);
;
LweCiphertext_u64 *lweResult =
mlir::concretelang::call(fCasted, args.preparedArgs);
uint64_t decryptedResult;
if (auto err = lambda.keySet->decrypt_lwe(0, lweResult, decryptedResult)) {
return std::move(err);
}
return decryptedResult;
}
template <size_t Rank>
llvm::Expected<MemRefDescriptor<Rank>>
DynamicLambda::invokeMemRefDecriptor(const Arguments &args) {
auto output = clientParameters.outputs[0];
if (output.shape.size == 0) {
return StreamStringError("the function doesn't return a tensor");
}
if (output.shape.dimensions.size() != Rank) {
return StreamStringError("the function doesn't return a tensor of rank ")
<< Rank;
}
// Tensor encrypted result
auto fCasted = (MemRefDescriptor<Rank>(*)(void *...))(func);
auto encryptedResult = mlir::concretelang::call(fCasted, args.preparedArgs);
for (size_t dim = 0; dim < Rank; dim++) {
size_t actual_size = encryptedResult.sizes[dim];
size_t expected_size = output.shape.dimensions[dim];
if (actual_size != expected_size) {
return StreamStringError("the function returned a vector of size ")
<< actual_size << " instead of size " << expected_size;
}
}
return encryptedResult;
}
template <>
llvm::Expected<std::vector<uint64_t>>
invoke<std::vector<uint64_t>>(DynamicLambda &lambda, const Arguments &args) {
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<1>(args);
if (!encryptedResultOrErr) {
return encryptedResultOrErr.takeError();
}
auto &encryptedResult = encryptedResultOrErr.get();
auto &keySet = lambda.keySet;
return decryptSlice(encryptedResult.aligned, *keySet, encryptedResult.offset,
encryptedResult.sizes[0], encryptedResult.strides[0]);
}
template <>
llvm::Expected<std::vector<std::vector<uint64_t>>>
invoke<std::vector<std::vector<uint64_t>>>(DynamicLambda &lambda,
const Arguments &args) {
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<2>(args);
if (!encryptedResultOrErr) {
return encryptedResultOrErr.takeError();
}
auto &encryptedResult = encryptedResultOrErr.get();
auto &keySet = lambda.keySet;
std::vector<std::vector<uint64_t>> result;
result.reserve(encryptedResult.sizes[0]);
for (size_t i = 0; i < encryptedResult.sizes[0]; i++) {
// TODO : strides
int offset = encryptedResult.offset + i * encryptedResult.sizes[1];
auto slice =
decryptSlice(encryptedResult.aligned, *keySet, offset,
encryptedResult.sizes[1], encryptedResult.strides[1]);
if (!slice) {
return StreamStringError(llvm::toString(slice.takeError()));
}
result.push_back(slice.get());
}
return result;
}
template <>
llvm::Expected<std::vector<std::vector<std::vector<uint64_t>>>>
invoke<std::vector<std::vector<std::vector<uint64_t>>>>(DynamicLambda &lambda,
const Arguments &args) {
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<3>(args);
if (!encryptedResultOrErr) {
return encryptedResultOrErr.takeError();
}
auto &encryptedResult = encryptedResultOrErr.get();
auto &keySet = lambda.keySet;
std::vector<std::vector<std::vector<uint64_t>>> result0;
result0.reserve(encryptedResult.sizes[0]);
for (size_t i = 0; i < encryptedResult.sizes[0]; i++) {
std::vector<std::vector<uint64_t>> result1;
result1.reserve(encryptedResult.sizes[1]);
for (size_t j = 0; j < encryptedResult.sizes[1]; j++) {
// TODO : strides
int offset = encryptedResult.offset +
i * encryptedResult.sizes[1] * encryptedResult.sizes[2] +
j * encryptedResult.sizes[2];
auto slice =
decryptSlice(encryptedResult.aligned, *keySet, offset,
encryptedResult.sizes[2], encryptedResult.strides[2]);
if (!slice) {
return StreamStringError(llvm::toString(slice.takeError()));
}
result1.push_back(slice.get());
}
result0.push_back(result1);
}
return result0;
}
llvm::Error DynamicLambda::generateKeySet(llvm::Optional<KeySetCache> cache,
uint64_t seed_msb,
uint64_t seed_lsb) {
auto maybeKeySet =
cache.hasValue()
? cache->tryLoadOrGenerateSave(clientParameters, seed_msb, seed_lsb)
: KeySet::generate(clientParameters, seed_msb, seed_lsb);
if (auto err = maybeKeySet.takeError()) {
return err;
}
keySet = std::move(maybeKeySet.get());
return llvm::Error::success();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,61 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#include <dlfcn.h>
#include <fstream>
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/TestLib/DynamicModule.h"
namespace mlir {
namespace concretelang {
DynamicModule::~DynamicModule() {
if (libraryHandle != nullptr) {
dlclose(libraryHandle);
}
}
llvm::Expected<std::shared_ptr<DynamicModule>>
DynamicModule::open(std::string path) {
std::shared_ptr<DynamicModule> module = std::make_shared<DynamicModule>();
if (auto err = module->loadClientParametersJSON(path)) {
return StreamStringError("Cannot load client parameters: ")
<< llvm::toString(std::move(err));
}
if (auto err = module->loadSharedLibrary(path)) {
return StreamStringError("Cannot load client parameters: ")
<< llvm::toString(std::move(err));
}
return module;
}
llvm::Error DynamicModule::loadSharedLibrary(std::string path) {
libraryHandle = dlopen(
CompilerEngine::Library::getSharedLibraryPath(path).c_str(), RTLD_LAZY);
if (!libraryHandle) {
return StreamStringError("Cannot open shared library") << dlerror();
}
return llvm::Error::success();
}
llvm::Error DynamicModule::loadClientParametersJSON(std::string path) {
std::ifstream file(CompilerEngine::Library::getClientParametersPath(path));
std::string content((std::istreambuf_iterator<char>(file)),
(std::istreambuf_iterator<char>()));
llvm::Expected<std::vector<ClientParameters>> expectedClientParams =
llvm::json::parse<std::vector<ClientParameters>>(content);
if (auto err = expectedClientParams.takeError()) {
return StreamStringError("Cannot open client parameters: ") << err;
}
this->clientParametersList = *expectedClientParams;
return llvm::Error::success();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -22,6 +22,7 @@ if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED)
RTDialect
ConcretelangSupport
ConcretelangTestLib
-Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime
-Wl,-rpath,${HPX_DIR}/../../

View File

@@ -1,4 +1,5 @@
if (CONCRETELANG_UNIT_TESTS)
add_subdirectory(unittest)
add_subdirectory(Support)
add_subdirectory(TestLib)
endif()

View File

@@ -0,0 +1,24 @@
enable_testing()
include_directories(${PROJECT_SOURCE_DIR}/include)
add_executable(
testlib_unit_test
testlib_unit_test.cpp
)
set_source_files_properties(
testlib_unit_test.cpp
PROPERTIES COMPILE_FLAGS "-fno-rtti"
)
target_link_libraries(
testlib_unit_test
gtest_main
ConcretelangRuntime
ConcretelangTestLib
)
include(GoogleTest)
gtest_discover_tests(testlib_unit_test)

View File

View File

@@ -0,0 +1,241 @@
#include <gtest/gtest.h>
#include <numeric>
#include <cassert>
#include "../unittest/end_to_end_jit_test.h"
#include "concretelang/TestLib/DynamicLambda.h"
const std::string FUNCNAME = "main";
template<typename... Params>
using TypedDynamicLambda = mlir::concretelang::TypedDynamicLambda<Params...>;
using scalar = uint64_t;
using tensor1_in = std::vector<uint8_t>;
using tensor1_out = std::vector<uint64_t>;
using tensor2_out = std::vector<std::vector<uint64_t>>;
using tensor3_out = std::vector<std::vector<std::vector<uint64_t>>>;
std::vector<uint8_t>
values_7bits() {
return {0, 1, 2, 63, 64, 65, 125, 126};
}
llvm::Expected<mlir::concretelang::CompilerEngine::Library>
compile(std::string outputLib, std::string source) {
std::vector<std::string> sources = {source};
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
mlir::concretelang::JitCompilerEngine ce {ccx};
ce.setClientParametersFuncName(FUNCNAME);
return ce.compile(sources, outputLib);
}
template<typename Info>
std::string outputLibFromThis(Info *info) {
return "tests/TestLib/out/" + std::string(info->name());
}
TEST(CompiledModule, call_1s_1s) {
std::string source = R"(
func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
return %arg0: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<scalar, scalar>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
for(auto a: values_7bits()) {
auto res = lambda->call(a);
ASSERT_EXPECTED_VALUE(res, a);
}
}
TEST(CompiledModule, call_2s_1s) {
std::string source = R"(
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<scalar, scalar, scalar>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
for(auto a: values_7bits()) for(auto b: values_7bits()) {
auto res = lambda->call(a, b);
ASSERT_EXPECTED_VALUE(res, a + b);
}
}
TEST(CompiledModule, call_1s_1t) {
std::string source = R"(
func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
%1 = tensor.from_elements %arg0 : tensor<1x!FHE.eint<7>>
return %1: tensor<1x!FHE.eint<7>>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<tensor1_out, scalar>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
for(auto a: values_7bits()) {
auto res = lambda->call(a);
ASSERT_EXPECTED_SUCCESS(res);
tensor1_out v = res.get();
EXPECT_EQ(v[0], a);
}
}
TEST(CompiledModule, call_2s_1t) {
std::string source = R"(
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> tensor<2x!FHE.eint<7>> {
%1 = tensor.from_elements %arg0, %arg1 : tensor<2x!FHE.eint<7>>
return %1: tensor<2x!FHE.eint<7>>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<tensor1_out, scalar, scalar>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
for(auto a : values_7bits()) {
auto res = lambda->call(a, a+1);
ASSERT_EXPECTED_SUCCESS(res);
tensor1_out v = res.get();
EXPECT_EQ((scalar)v[0], a);
EXPECT_EQ((scalar)v[1], a + 1u);
}
}
TEST(CompiledModule, call_1t_1s) {
std::string source = R"(
func @main(%arg0: tensor<1x!FHE.eint<7>>) -> !FHE.eint<7> {
%c0 = arith.constant 0 : index
%1 = tensor.extract %arg0[%c0] : tensor<1x!FHE.eint<7>>
return %1: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<scalar, tensor1_in>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
for(uint8_t a : values_7bits()) {
tensor1_in ta = {a};
auto res = lambda->call(ta);
ASSERT_EXPECTED_VALUE(res, a);
}
}
TEST(CompiledModule, call_1t_1t) {
std::string source = R"(
func @main(%arg0: tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> {
return %arg0: tensor<3x!FHE.eint<7>>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<tensor1_out, tensor1_in>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
tensor1_in ta = {1, 2, 3};
auto res = lambda->call(ta);
ASSERT_EXPECTED_SUCCESS(res);
tensor1_out v = res.get();
for(size_t i = 0; i < v.size(); i++) {
EXPECT_EQ(v[i], ta[i]);
}
}
TEST(CompiledModule, call_2t_1s) {
std::string source = R"(
func @main(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) -> !FHE.eint<7> {
%1 = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>>
%c1 = arith.constant 1 : i8
%2 = tensor.from_elements %c1, %c1, %c1 : tensor<3xi8>
%3 = "FHELinalg.dot_eint_int"(%1, %2) : (tensor<3x!FHE.eint<7>>, tensor<3xi8>) -> !FHE.eint<7>
return %3: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<scalar, tensor1_in, std::array<uint8_t, 3>>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
tensor1_in ta {1, 2, 3};
std::array<uint8_t, 3> tb {5, 7, 9};
auto res = lambda->call(ta, tb);
auto expected = std::accumulate(ta.begin(), ta.end(), 0u) +
std::accumulate(tb.begin(), tb.end(), 0u);
ASSERT_EXPECTED_VALUE(res, expected);
}
TEST(CompiledModule, call_1tr2_1tr2) {
std::string source = R"(
func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> {
return %arg0: tensor<2x3x!FHE.eint<7>>
}
)";
using tensor2_in = std::array<std::array<uint8_t, 3>, 2>;
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<tensor2_out, tensor2_in>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
tensor2_in ta = {{
{1, 2, 3},
{4, 5, 6}
}};
auto res = lambda->call(ta);
ASSERT_EXPECTED_SUCCESS(res);
tensor2_out v = res.get();
for(size_t i = 0; i < v.size(); i++) {
for(size_t j = 0; j < v.size(); j++) {
EXPECT_EQ(v[i][j], ta[i][j]);
}
}
}
TEST(CompiledModule, call_1tr3_1tr3) {
std::string source = R"(
func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
return %arg0: tensor<2x3x1x!FHE.eint<7>>
}
)";
using tensor3_in = std::array<std::array<std::array<uint8_t, 1>, 3>, 2>;
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<tensor3_out, tensor3_in>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
tensor3_in ta = {{
{{ {1}, {2}, {3} }},
{{ {4}, {5}, {6} }}
}};
auto res = lambda->call(ta);
ASSERT_EXPECTED_SUCCESS(res);
tensor3_out v = res.get();
for(size_t i = 0; i < v.size(); i++) {
for(size_t j = 0; j < v[i].size(); j++) {
for(size_t k = 0; k < v[i][j].size(); k++) {
EXPECT_EQ(v[i][j][k], ta[i][j][k]);
}
}
}
}

View File

@@ -93,6 +93,17 @@ static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
} \
} while (0)
static inline llvm::Optional<mlir::concretelang::KeySetCache> getTestKeySetCache() {
llvm::SmallString<0> cachePath;
llvm::sys::path::system_temp_directory(true, cachePath);
llvm::sys::path::append(cachePath, "KeySetCache");
auto cachePathStr = std::string(cachePath);
return llvm::Optional<mlir::concretelang::KeySetCache>(
mlir::concretelang::KeySetCache(cachePathStr));
}
// Jit-compiles the function specified by `func` from `src` and
// returns the corresponding lambda. Any compilation errors are caught
// and reult in abnormal termination.
@@ -103,16 +114,6 @@ internalCheckedJit(F checkFunc, llvm::StringRef src,
bool useDefaultFHEConstraints = false,
bool autoParallelize = false) {
llvm::SmallString<0> cachePath;
llvm::sys::path::system_temp_directory(true, cachePath);
llvm::sys::path::append(cachePath, "KeySetCache");
auto cachePathStr = std::string(cachePath);
auto optCache = llvm::Optional<mlir::concretelang::KeySetCache>(
mlir::concretelang::KeySetCache(cachePathStr));
mlir::concretelang::JitCompilerEngine engine;
if (useDefaultFHEConstraints)
@@ -124,7 +125,7 @@ internalCheckedJit(F checkFunc, llvm::StringRef src,
#endif
llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> lambdaOrErr =
engine.buildLambda(src, func, optCache);
engine.buildLambda(src, func, getTestKeySetCache());
if (!lambdaOrErr) {
std::cout << llvm::toString(lambdaOrErr.takeError()) << std::endl;