mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(Testlib): lib for testing libs generated by concretecompiler
Closes #201
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -48,3 +48,6 @@ _build/
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
|
||||
compiler/tests/TestLib/out/
|
||||
|
||||
@@ -69,12 +69,18 @@ test-check: concretecompiler file-check not
|
||||
test-python: python-bindings concretecompiler
|
||||
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so pytest -vs tests/python
|
||||
|
||||
test: test-check test-end-to-end-jit test-python support-unit-test
|
||||
test: test-check test-end-to-end-jit test-python support-unit-test testlib-unit-test
|
||||
|
||||
test-dataflow: test-end-to-end-jit-dfr test-end-to-end-jit-auto-parallelization
|
||||
|
||||
# unit-test
|
||||
|
||||
testlib-unit-test: build-testlib-unit-test
|
||||
$(BUILD_DIR)/bin/testlib_unit_test
|
||||
|
||||
build-testlib-unit-test: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target testlib_unit_test
|
||||
|
||||
support-unit-test: build-support-unit-test
|
||||
$(BUILD_DIR)/bin/support_unit_test
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth);
|
||||
|
||||
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
|
||||
/// of the module.
|
||||
class JITLambda {
|
||||
|
||||
105
compiler/include/concretelang/TestLib/Arguments.h
Normal file
105
compiler/include/concretelang/TestLib/Arguments.h
Normal file
@@ -0,0 +1,105 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#ifndef CONCRETELANG_TESTLIB_ARGUMENTS_H
|
||||
#define CONCRETELANG_TESTLIB_ARGUMENTS_H
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
class DynamicLambda;
|
||||
|
||||
class Arguments {
|
||||
public:
|
||||
Arguments(KeySet &keySet) : currentPos(0), keySet(keySet) {
|
||||
keySet.setRuntimeContext(context);
|
||||
}
|
||||
|
||||
~Arguments();
|
||||
|
||||
// Create EncryptedArgument that use the given KeySet to perform encryption
|
||||
// and decryption operations.
|
||||
static std::shared_ptr<Arguments> create(KeySet &keySet);
|
||||
|
||||
// Add a scalar argument.
|
||||
llvm::Error pushArg(uint64_t arg);
|
||||
|
||||
// Add a vector-tensor argument.
|
||||
llvm::Error pushArg(std::vector<uint8_t> arg);
|
||||
|
||||
template <size_t size> llvm::Error pushArg(std::array<uint8_t, size> arg) {
|
||||
return pushArg(8, (void *)arg.data(), {size});
|
||||
}
|
||||
|
||||
// Add a matrix-tensor argument.
|
||||
template <size_t size0, size_t size1>
|
||||
llvm::Error pushArg(std::array<std::array<uint8_t, size1>, size0> arg) {
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1});
|
||||
}
|
||||
|
||||
// Add a rank3 tensor.
|
||||
template <size_t size0, size_t size1, size_t size2>
|
||||
llvm::Error pushArg(
|
||||
std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg) {
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1, size2});
|
||||
}
|
||||
|
||||
// Generalize by computing shape by template recursion
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T> llvm::Error pushArg(T *data, int64_t dim1) {
|
||||
return pushArg<T>(data, llvm::ArrayRef<int64_t>(&dim1, 1));
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
llvm::Error pushArg(T *data, llvm::ArrayRef<int64_t> shape) {
|
||||
return pushArg(8 * sizeof(T), static_cast<void *>(data), shape);
|
||||
}
|
||||
|
||||
llvm::Error pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
// Push the runtime context to the argument list, this must be called
|
||||
// after each argument was pushed.
|
||||
llvm::Error pushContext();
|
||||
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
llvm::Error pushArgs(Arg0 arg0, OtherArgs... others) {
|
||||
auto err = pushArg(arg0);
|
||||
if (err) {
|
||||
return err;
|
||||
}
|
||||
return pushArgs(others...);
|
||||
}
|
||||
|
||||
llvm::Error pushArgs() { return pushContext(); }
|
||||
|
||||
private:
|
||||
friend DynamicLambda;
|
||||
template <typename Result>
|
||||
friend llvm::Expected<Result> invoke(DynamicLambda &lambda,
|
||||
const Arguments &args);
|
||||
llvm::Error checkPushTooManyArgs();
|
||||
|
||||
// Position of the next pushed argument
|
||||
size_t currentPos;
|
||||
std::vector<void *> preparedArgs;
|
||||
|
||||
// Store allocated lwe ciphertexts (for free)
|
||||
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
RuntimeContext context;
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
123
compiler/include/concretelang/TestLib/DynamicLambda.h
Normal file
123
compiler/include/concretelang/TestLib/DynamicLambda.h
Normal file
@@ -0,0 +1,123 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_LAMBDA_H
|
||||
#define CONCRETELANG_TESTLIB_DYNAMIC_LAMBDA_H
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/TestLib/Arguments.h"
|
||||
#include "concretelang/TestLib/DynamicModule.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
template <size_t N> struct MemRefDescriptor;
|
||||
|
||||
template <typename Result>
|
||||
llvm::Expected<Result> invoke(DynamicLambda &lambda, const Arguments &args) {
|
||||
// compile time error if used
|
||||
using COMPATIBLE_RESULT_TYPE = void;
|
||||
return (Result)(
|
||||
COMPATIBLE_RESULT_TYPE)0; // invoke does not accept this kind of Result
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<u_int64_t> invoke<u_int64_t>(DynamicLambda &lambda,
|
||||
const Arguments &args);
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<uint64_t>>
|
||||
invoke<std::vector<uint64_t>>(DynamicLambda &lambda, const Arguments &args);
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<std::vector<uint64_t>>>
|
||||
invoke<std::vector<std::vector<uint64_t>>>(DynamicLambda &lambda,
|
||||
const Arguments &args);
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<std::vector<std::vector<uint64_t>>>>
|
||||
invoke<std::vector<std::vector<std::vector<uint64_t>>>>(DynamicLambda &lambda,
|
||||
const Arguments &args);
|
||||
|
||||
class DynamicLambda {
|
||||
private:
|
||||
template <typename... Args>
|
||||
llvm::Expected<std::shared_ptr<Arguments>> createArguments(Args... args) {
|
||||
if (keySet == nullptr) {
|
||||
return StreamStringError("keySet was not initialized");
|
||||
}
|
||||
auto arg = Arguments::create(*keySet);
|
||||
auto err = arg->pushArgs(args...);
|
||||
if (err) {
|
||||
return StreamStringError(llvm::toString(std::move(err)));
|
||||
}
|
||||
return arg;
|
||||
}
|
||||
|
||||
public:
|
||||
static llvm::Expected<DynamicLambda> load(std::string funcName,
|
||||
std::string outputLib);
|
||||
|
||||
static llvm::Expected<DynamicLambda>
|
||||
load(std::shared_ptr<DynamicModule> module, std::string funcName);
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
llvm::Expected<Result> call(Args... args) {
|
||||
auto argOrErr = createArguments(args...);
|
||||
if (!argOrErr) {
|
||||
return argOrErr.takeError();
|
||||
}
|
||||
auto arg = argOrErr.get();
|
||||
return invoke<Result>(*this, *arg);
|
||||
}
|
||||
|
||||
llvm::Error generateKeySet(llvm::Optional<KeySetCache> cache = llvm::None,
|
||||
uint64_t seed_msb = 0, uint64_t seed_lsb = 0);
|
||||
|
||||
protected:
|
||||
template <typename Result>
|
||||
friend llvm::Expected<Result> invoke(DynamicLambda &lambda,
|
||||
const Arguments &args);
|
||||
|
||||
template <size_t Rank>
|
||||
llvm::Expected<MemRefDescriptor<Rank>>
|
||||
invokeMemRefDecriptor(const Arguments &args);
|
||||
|
||||
ClientParameters clientParameters;
|
||||
std::shared_ptr<KeySet> keySet;
|
||||
void *(*func)(void *...);
|
||||
// Retain module and open shared lib alive
|
||||
std::shared_ptr<DynamicModule> module;
|
||||
};
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
class TypedDynamicLambda : public DynamicLambda {
|
||||
|
||||
public:
|
||||
static llvm::Expected<TypedDynamicLambda<Result, Args...>>
|
||||
load(std::string funcName, std::string outputLib) {
|
||||
auto lambda = DynamicLambda::load(funcName, outputLib);
|
||||
if (!lambda) {
|
||||
return lambda.takeError();
|
||||
}
|
||||
return TypedDynamicLambda(*lambda);
|
||||
}
|
||||
|
||||
llvm::Expected<Result> call(Args... args) {
|
||||
return DynamicLambda::call<Result>(args...);
|
||||
}
|
||||
|
||||
// TODO: check parameter types
|
||||
TypedDynamicLambda(DynamicLambda &lambda) : DynamicLambda(lambda) {
|
||||
// TODO: add static check on types vs lambda inputs/outpus
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
33
compiler/include/concretelang/TestLib/DynamicModule.h
Normal file
33
compiler/include/concretelang/TestLib/DynamicModule.h
Normal file
@@ -0,0 +1,33 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
|
||||
#define CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
class DynamicModule {
|
||||
public:
|
||||
~DynamicModule();
|
||||
static llvm::Expected<std::shared_ptr<DynamicModule>>
|
||||
open(std::string libraryPath);
|
||||
|
||||
private:
|
||||
llvm::Error loadClientParametersJSON(std::string path);
|
||||
llvm::Error loadSharedLibrary(std::string path);
|
||||
|
||||
private:
|
||||
std::vector<ClientParameters> clientParametersList;
|
||||
void *libraryHandle;
|
||||
|
||||
friend class DynamicLambda;
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
#endif
|
||||
1468
compiler/include/concretelang/TestLib/dynamicArityCall.h
Normal file
1468
compiler/include/concretelang/TestLib/dynamicArityCall.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,6 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
for i in range(128):
|
||||
args = ','.join(f'args[{j}]' for j in range(i))
|
||||
print(f' case {i}: return func({args});')
|
||||
@@ -4,6 +4,7 @@ add_subdirectory(Support)
|
||||
add_subdirectory(Runtime)
|
||||
add_subdirectory(ClientLib)
|
||||
add_subdirectory(Bindings)
|
||||
add_subdirectory(TestLib)
|
||||
|
||||
# CAPI needed only for python bindings
|
||||
if (CONCRETELANG_BINDINGS_PYTHON_ENABLED)
|
||||
|
||||
161
compiler/lib/TestLib/Arguments.cpp
Normal file
161
compiler/lib/TestLib/Arguments.cpp
Normal file
@@ -0,0 +1,161 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#include "concretelang/TestLib/Arguments.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
Arguments::~Arguments() {
|
||||
for (auto ct : allocatedCiphertexts) {
|
||||
int err;
|
||||
free_lwe_ciphertext_u64(&err, ct);
|
||||
}
|
||||
for (auto ctBuffer : ciphertextBuffers) {
|
||||
free(ctBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Arguments> Arguments::create(KeySet &keySet) {
|
||||
auto args = std::make_shared<Arguments>(keySet);
|
||||
return args;
|
||||
}
|
||||
|
||||
llvm::Error Arguments::pushArg(uint64_t arg) {
|
||||
if (auto err = checkPushTooManyArgs()) {
|
||||
return err;
|
||||
}
|
||||
|
||||
auto pos = currentPos++;
|
||||
CircuitGate input = keySet.inputGate(pos);
|
||||
if (input.shape.size != 0) {
|
||||
return StreamStringError("argument #") << pos << " is not a scalar";
|
||||
}
|
||||
if (!input.encryption.hasValue()) {
|
||||
// clear scalar: just push the argument
|
||||
if (input.shape.width != 64) {
|
||||
return StreamStringError(
|
||||
"scalar argument of with != 64 is not supported for DynamicLambda");
|
||||
}
|
||||
preparedArgs.push_back((void *)arg);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// encrypted scalar: allocate, encrypt and push
|
||||
LweCiphertext_u64 *ctArg;
|
||||
if (auto err = keySet.allocate_lwe(pos, &ctArg)) {
|
||||
return err;
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctArg);
|
||||
if (auto err = keySet.encrypt_lwe(pos, ctArg, arg)) {
|
||||
return err;
|
||||
}
|
||||
preparedArgs.push_back((void *)ctArg);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error Arguments::pushArg(std::vector<uint8_t> arg) {
|
||||
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()});
|
||||
}
|
||||
|
||||
llvm::Error Arguments::pushArg(size_t width, void *data,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
if (auto err = checkPushTooManyArgs()) {
|
||||
return err;
|
||||
}
|
||||
auto pos = currentPos;
|
||||
currentPos = currentPos + 1;
|
||||
CircuitGate input = keySet.inputGate(pos);
|
||||
// Check the width of data
|
||||
if (input.shape.width > 64) {
|
||||
return StreamStringError("argument #")
|
||||
<< pos << " width > 64 bits is not supported";
|
||||
}
|
||||
auto roundedSize = bitWidthAsWord(input.shape.width);
|
||||
if (width != roundedSize) {
|
||||
return StreamStringError("argument #")
|
||||
<< pos << "width mismatch, got " << width << " expected "
|
||||
<< roundedSize;
|
||||
}
|
||||
// Check the shape of tensor
|
||||
if (input.shape.dimensions.empty()) {
|
||||
return StreamStringError("argument #") << pos << "is not a tensor";
|
||||
}
|
||||
if (shape.size() != input.shape.dimensions.size()) {
|
||||
return StreamStringError("argument #")
|
||||
<< pos << "has not the expected number of dimension, got "
|
||||
<< shape.size() << " expected " << input.shape.dimensions.size();
|
||||
}
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
if (shape[i] != input.shape.dimensions[i]) {
|
||||
return StreamStringError("argument #")
|
||||
<< pos << " has not the expected dimension #" << i << " , got "
|
||||
<< shape[i] << " expected " << input.shape.dimensions[i];
|
||||
}
|
||||
}
|
||||
if (input.encryption.hasValue()) {
|
||||
// Encrypted tensor: for now we support only 8 bits for encrypted tensor
|
||||
if (width != 8) {
|
||||
return StreamStringError("argument #")
|
||||
<< pos << " width mismatch, expected 8 got " << width;
|
||||
}
|
||||
const uint8_t *data8 = (const uint8_t *)data;
|
||||
|
||||
// Allocate a buffer for ciphertexts of size of tensor
|
||||
auto ctBuffer = (LweCiphertext_u64 **)malloc(input.shape.size *
|
||||
sizeof(LweCiphertext_u64 *));
|
||||
ciphertextBuffers.push_back(ctBuffer);
|
||||
// Allocate ciphertexts and encrypt, for every values in tensor
|
||||
for (size_t i = 0; i < input.shape.size; i++) {
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
|
||||
return err;
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctBuffer[i]);
|
||||
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
// Replace the data by the buffer to ciphertext
|
||||
data = (void *)ctBuffer;
|
||||
}
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back(data);
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
preparedArgs.push_back((void *)shape[i]);
|
||||
}
|
||||
// strides - FIXME make it works
|
||||
// strides is an array of size equals to numDim
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
preparedArgs.push_back((void *)0);
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error Arguments::pushContext() {
|
||||
if (currentPos < keySet.numInputs()) {
|
||||
return StreamStringError("Missing arguments");
|
||||
}
|
||||
preparedArgs.push_back(&context);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error Arguments::checkPushTooManyArgs() {
|
||||
size_t arity = keySet.numInputs();
|
||||
if (currentPos < arity) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
return StreamStringError("function has arity ")
|
||||
<< arity << " but is applied to too many arguments";
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
16
compiler/lib/TestLib/CMakeLists.txt
Normal file
16
compiler/lib/TestLib/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
add_mlir_library(ConcretelangTestLib
|
||||
Arguments.cpp
|
||||
DynamicLambda.cpp
|
||||
DynamicModule.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/TestLib
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
|
||||
ConcretelangSupport
|
||||
ConcretelangClientLib
|
||||
)
|
||||
224
compiler/lib/TestLib/DynamicLambda.cpp
Normal file
224
compiler/lib/TestLib/DynamicLambda.cpp
Normal file
@@ -0,0 +1,224 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/TestLib/DynamicLambda.h"
|
||||
#include "concretelang/TestLib/dynamicArityCall.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
template <size_t N> struct MemRefDescriptor {
|
||||
LweCiphertext_u64 **allocated;
|
||||
LweCiphertext_u64 **aligned;
|
||||
size_t offset;
|
||||
size_t sizes[N];
|
||||
size_t strides[N];
|
||||
};
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> decryptSlice(LweCiphertext_u64 **aligned,
|
||||
KeySet &keySet, size_t start,
|
||||
size_t size,
|
||||
size_t stride = 1) {
|
||||
stride = (stride == 0) ? 1 : stride;
|
||||
std::vector<uint64_t> result(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
size_t offset = start + i * stride;
|
||||
auto err = keySet.decrypt_lwe(0, aligned[offset], result[i]);
|
||||
if (err) {
|
||||
return StreamStringError()
|
||||
<< "cannot decrypt result #" << i << ", err:" << err;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm::Expected<mlir::concretelang::DynamicLambda>
|
||||
DynamicLambda::load(std::string funcName, std::string outputLib) {
|
||||
auto moduleOrErr = mlir::concretelang::DynamicModule::open(outputLib);
|
||||
if (!moduleOrErr) {
|
||||
return moduleOrErr.takeError();
|
||||
}
|
||||
return mlir::concretelang::DynamicLambda::load(*moduleOrErr, funcName);
|
||||
}
|
||||
|
||||
llvm::Expected<DynamicLambda>
|
||||
DynamicLambda::load(std::shared_ptr<DynamicModule> module,
|
||||
std::string funcName) {
|
||||
DynamicLambda lambda;
|
||||
lambda.module =
|
||||
module; // prevent module and library handler from being destroyed
|
||||
lambda.func =
|
||||
(void *(*)(void *, ...))dlsym(module->libraryHandle, funcName.c_str());
|
||||
|
||||
if (auto err = dlerror()) {
|
||||
return StreamStringError("Cannot open lambda: ") << err;
|
||||
}
|
||||
|
||||
auto param =
|
||||
llvm::find_if(module->clientParametersList, [&](ClientParameters param) {
|
||||
return param.functionName == funcName;
|
||||
});
|
||||
|
||||
if (param == module->clientParametersList.end()) {
|
||||
return StreamStringError("cannot find function ")
|
||||
<< funcName << "in client parameters";
|
||||
}
|
||||
|
||||
if (param->outputs.size() != 1) {
|
||||
return StreamStringError("DynamicLambda: output arity (")
|
||||
<< std::to_string(param->outputs.size())
|
||||
<< ") != 1 is not supported";
|
||||
}
|
||||
|
||||
if (!param->outputs[0].encryption.hasValue()) {
|
||||
return StreamStringError(
|
||||
"DynamicLambda: clear output is not yet supported");
|
||||
}
|
||||
|
||||
lambda.clientParameters = *param;
|
||||
return lambda;
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<uint64_t> invoke<uint64_t>(DynamicLambda &lambda,
|
||||
const Arguments &args) {
|
||||
auto output = lambda.clientParameters.outputs[0];
|
||||
if (output.shape.size != 0) {
|
||||
return StreamStringError("the function doesn't return a scalar");
|
||||
}
|
||||
// Scalar encrypted result
|
||||
auto fCasted = (LweCiphertext_u64 * (*)(void *...))(lambda.func);
|
||||
;
|
||||
LweCiphertext_u64 *lweResult =
|
||||
mlir::concretelang::call(fCasted, args.preparedArgs);
|
||||
|
||||
uint64_t decryptedResult;
|
||||
if (auto err = lambda.keySet->decrypt_lwe(0, lweResult, decryptedResult)) {
|
||||
return std::move(err);
|
||||
}
|
||||
return decryptedResult;
|
||||
}
|
||||
|
||||
template <size_t Rank>
|
||||
llvm::Expected<MemRefDescriptor<Rank>>
|
||||
DynamicLambda::invokeMemRefDecriptor(const Arguments &args) {
|
||||
auto output = clientParameters.outputs[0];
|
||||
if (output.shape.size == 0) {
|
||||
return StreamStringError("the function doesn't return a tensor");
|
||||
}
|
||||
if (output.shape.dimensions.size() != Rank) {
|
||||
return StreamStringError("the function doesn't return a tensor of rank ")
|
||||
<< Rank;
|
||||
}
|
||||
// Tensor encrypted result
|
||||
auto fCasted = (MemRefDescriptor<Rank>(*)(void *...))(func);
|
||||
auto encryptedResult = mlir::concretelang::call(fCasted, args.preparedArgs);
|
||||
|
||||
for (size_t dim = 0; dim < Rank; dim++) {
|
||||
size_t actual_size = encryptedResult.sizes[dim];
|
||||
size_t expected_size = output.shape.dimensions[dim];
|
||||
if (actual_size != expected_size) {
|
||||
return StreamStringError("the function returned a vector of size ")
|
||||
<< actual_size << " instead of size " << expected_size;
|
||||
}
|
||||
}
|
||||
return encryptedResult;
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<uint64_t>>
|
||||
invoke<std::vector<uint64_t>>(DynamicLambda &lambda, const Arguments &args) {
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<1>(args);
|
||||
if (!encryptedResultOrErr) {
|
||||
return encryptedResultOrErr.takeError();
|
||||
}
|
||||
auto &encryptedResult = encryptedResultOrErr.get();
|
||||
auto &keySet = lambda.keySet;
|
||||
return decryptSlice(encryptedResult.aligned, *keySet, encryptedResult.offset,
|
||||
encryptedResult.sizes[0], encryptedResult.strides[0]);
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<std::vector<uint64_t>>>
|
||||
invoke<std::vector<std::vector<uint64_t>>>(DynamicLambda &lambda,
|
||||
const Arguments &args) {
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<2>(args);
|
||||
if (!encryptedResultOrErr) {
|
||||
return encryptedResultOrErr.takeError();
|
||||
}
|
||||
auto &encryptedResult = encryptedResultOrErr.get();
|
||||
auto &keySet = lambda.keySet;
|
||||
|
||||
std::vector<std::vector<uint64_t>> result;
|
||||
result.reserve(encryptedResult.sizes[0]);
|
||||
for (size_t i = 0; i < encryptedResult.sizes[0]; i++) {
|
||||
// TODO : strides
|
||||
int offset = encryptedResult.offset + i * encryptedResult.sizes[1];
|
||||
auto slice =
|
||||
decryptSlice(encryptedResult.aligned, *keySet, offset,
|
||||
encryptedResult.sizes[1], encryptedResult.strides[1]);
|
||||
if (!slice) {
|
||||
return StreamStringError(llvm::toString(slice.takeError()));
|
||||
}
|
||||
result.push_back(slice.get());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<std::vector<std::vector<uint64_t>>>>
|
||||
invoke<std::vector<std::vector<std::vector<uint64_t>>>>(DynamicLambda &lambda,
|
||||
const Arguments &args) {
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<3>(args);
|
||||
if (!encryptedResultOrErr) {
|
||||
return encryptedResultOrErr.takeError();
|
||||
}
|
||||
auto &encryptedResult = encryptedResultOrErr.get();
|
||||
auto &keySet = lambda.keySet;
|
||||
|
||||
std::vector<std::vector<std::vector<uint64_t>>> result0;
|
||||
result0.reserve(encryptedResult.sizes[0]);
|
||||
for (size_t i = 0; i < encryptedResult.sizes[0]; i++) {
|
||||
std::vector<std::vector<uint64_t>> result1;
|
||||
result1.reserve(encryptedResult.sizes[1]);
|
||||
for (size_t j = 0; j < encryptedResult.sizes[1]; j++) {
|
||||
// TODO : strides
|
||||
int offset = encryptedResult.offset +
|
||||
i * encryptedResult.sizes[1] * encryptedResult.sizes[2] +
|
||||
j * encryptedResult.sizes[2];
|
||||
auto slice =
|
||||
decryptSlice(encryptedResult.aligned, *keySet, offset,
|
||||
encryptedResult.sizes[2], encryptedResult.strides[2]);
|
||||
if (!slice) {
|
||||
return StreamStringError(llvm::toString(slice.takeError()));
|
||||
}
|
||||
result1.push_back(slice.get());
|
||||
}
|
||||
result0.push_back(result1);
|
||||
}
|
||||
return result0;
|
||||
}
|
||||
|
||||
llvm::Error DynamicLambda::generateKeySet(llvm::Optional<KeySetCache> cache,
|
||||
uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
auto maybeKeySet =
|
||||
cache.hasValue()
|
||||
? cache->tryLoadOrGenerateSave(clientParameters, seed_msb, seed_lsb)
|
||||
: KeySet::generate(clientParameters, seed_msb, seed_lsb);
|
||||
|
||||
if (auto err = maybeKeySet.takeError()) {
|
||||
return err;
|
||||
}
|
||||
keySet = std::move(maybeKeySet.get());
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
61
compiler/lib/TestLib/DynamicModule.cpp
Normal file
61
compiler/lib/TestLib/DynamicModule.cpp
Normal file
@@ -0,0 +1,61 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <fstream>
|
||||
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
#include "concretelang/TestLib/DynamicModule.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
DynamicModule::~DynamicModule() {
|
||||
if (libraryHandle != nullptr) {
|
||||
dlclose(libraryHandle);
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Expected<std::shared_ptr<DynamicModule>>
|
||||
DynamicModule::open(std::string path) {
|
||||
std::shared_ptr<DynamicModule> module = std::make_shared<DynamicModule>();
|
||||
if (auto err = module->loadClientParametersJSON(path)) {
|
||||
return StreamStringError("Cannot load client parameters: ")
|
||||
<< llvm::toString(std::move(err));
|
||||
}
|
||||
if (auto err = module->loadSharedLibrary(path)) {
|
||||
return StreamStringError("Cannot load client parameters: ")
|
||||
<< llvm::toString(std::move(err));
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
llvm::Error DynamicModule::loadSharedLibrary(std::string path) {
|
||||
libraryHandle = dlopen(
|
||||
CompilerEngine::Library::getSharedLibraryPath(path).c_str(), RTLD_LAZY);
|
||||
if (!libraryHandle) {
|
||||
return StreamStringError("Cannot open shared library") << dlerror();
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error DynamicModule::loadClientParametersJSON(std::string path) {
|
||||
|
||||
std::ifstream file(CompilerEngine::Library::getClientParametersPath(path));
|
||||
std::string content((std::istreambuf_iterator<char>(file)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
llvm::Expected<std::vector<ClientParameters>> expectedClientParams =
|
||||
llvm::json::parse<std::vector<ClientParameters>>(content);
|
||||
if (auto err = expectedClientParams.takeError()) {
|
||||
return StreamStringError("Cannot open client parameters: ") << err;
|
||||
}
|
||||
this->clientParametersList = *expectedClientParams;
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -22,6 +22,7 @@ if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED)
|
||||
RTDialect
|
||||
|
||||
ConcretelangSupport
|
||||
ConcretelangTestLib
|
||||
|
||||
-Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime
|
||||
-Wl,-rpath,${HPX_DIR}/../../
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
if (CONCRETELANG_UNIT_TESTS)
|
||||
add_subdirectory(unittest)
|
||||
add_subdirectory(Support)
|
||||
add_subdirectory(TestLib)
|
||||
endif()
|
||||
|
||||
24
compiler/tests/TestLib/CMakeLists.txt
Normal file
24
compiler/tests/TestLib/CMakeLists.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
enable_testing()
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
add_executable(
|
||||
testlib_unit_test
|
||||
testlib_unit_test.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(
|
||||
testlib_unit_test.cpp
|
||||
|
||||
PROPERTIES COMPILE_FLAGS "-fno-rtti"
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
testlib_unit_test
|
||||
gtest_main
|
||||
ConcretelangRuntime
|
||||
ConcretelangTestLib
|
||||
)
|
||||
|
||||
include(GoogleTest)
|
||||
gtest_discover_tests(testlib_unit_test)
|
||||
0
compiler/tests/TestLib/out/.keep
Normal file
0
compiler/tests/TestLib/out/.keep
Normal file
241
compiler/tests/TestLib/testlib_unit_test.cpp
Normal file
241
compiler/tests/TestLib/testlib_unit_test.cpp
Normal file
@@ -0,0 +1,241 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
|
||||
#include "../unittest/end_to_end_jit_test.h"
|
||||
#include "concretelang/TestLib/DynamicLambda.h"
|
||||
|
||||
const std::string FUNCNAME = "main";
|
||||
|
||||
template<typename... Params>
|
||||
using TypedDynamicLambda = mlir::concretelang::TypedDynamicLambda<Params...>;
|
||||
|
||||
using scalar = uint64_t;
|
||||
using tensor1_in = std::vector<uint8_t>;
|
||||
using tensor1_out = std::vector<uint64_t>;
|
||||
using tensor2_out = std::vector<std::vector<uint64_t>>;
|
||||
using tensor3_out = std::vector<std::vector<std::vector<uint64_t>>>;
|
||||
|
||||
std::vector<uint8_t>
|
||||
values_7bits() {
|
||||
return {0, 1, 2, 63, 64, 65, 125, 126};
|
||||
}
|
||||
|
||||
llvm::Expected<mlir::concretelang::CompilerEngine::Library>
|
||||
compile(std::string outputLib, std::string source) {
|
||||
std::vector<std::string> sources = {source};
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
mlir::concretelang::JitCompilerEngine ce {ccx};
|
||||
ce.setClientParametersFuncName(FUNCNAME);
|
||||
return ce.compile(sources, outputLib);
|
||||
}
|
||||
|
||||
template<typename Info>
|
||||
std::string outputLibFromThis(Info *info) {
|
||||
return "tests/TestLib/out/" + std::string(info->name());
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_1s_1s) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
return %arg0: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<scalar, scalar>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
for(auto a: values_7bits()) {
|
||||
auto res = lambda->call(a);
|
||||
ASSERT_EXPECTED_VALUE(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_2s_1s) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<scalar, scalar, scalar>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
for(auto a: values_7bits()) for(auto b: values_7bits()) {
|
||||
auto res = lambda->call(a, b);
|
||||
ASSERT_EXPECTED_VALUE(res, a + b);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_1s_1t) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
|
||||
%1 = tensor.from_elements %arg0 : tensor<1x!FHE.eint<7>>
|
||||
return %1: tensor<1x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<tensor1_out, scalar>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
for(auto a: values_7bits()) {
|
||||
auto res = lambda->call(a);
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
tensor1_out v = res.get();
|
||||
EXPECT_EQ(v[0], a);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_2s_1t) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> tensor<2x!FHE.eint<7>> {
|
||||
%1 = tensor.from_elements %arg0, %arg1 : tensor<2x!FHE.eint<7>>
|
||||
return %1: tensor<2x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<tensor1_out, scalar, scalar>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
for(auto a : values_7bits()) {
|
||||
auto res = lambda->call(a, a+1);
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
tensor1_out v = res.get();
|
||||
EXPECT_EQ((scalar)v[0], a);
|
||||
EXPECT_EQ((scalar)v[1], a + 1u);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_1t_1s) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: tensor<1x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%1 = tensor.extract %arg0[%c0] : tensor<1x!FHE.eint<7>>
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<scalar, tensor1_in>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
for(uint8_t a : values_7bits()) {
|
||||
tensor1_in ta = {a};
|
||||
auto res = lambda->call(ta);
|
||||
ASSERT_EXPECTED_VALUE(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_1t_1t) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> {
|
||||
return %arg0: tensor<3x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<tensor1_out, tensor1_in>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
tensor1_in ta = {1, 2, 3};
|
||||
auto res = lambda->call(ta);
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
tensor1_out v = res.get();
|
||||
for(size_t i = 0; i < v.size(); i++) {
|
||||
EXPECT_EQ(v[i], ta[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_2t_1s) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%1 = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>>
|
||||
%c1 = arith.constant 1 : i8
|
||||
%2 = tensor.from_elements %c1, %c1, %c1 : tensor<3xi8>
|
||||
%3 = "FHELinalg.dot_eint_int"(%1, %2) : (tensor<3x!FHE.eint<7>>, tensor<3xi8>) -> !FHE.eint<7>
|
||||
return %3: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<scalar, tensor1_in, std::array<uint8_t, 3>>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
tensor1_in ta {1, 2, 3};
|
||||
std::array<uint8_t, 3> tb {5, 7, 9};
|
||||
auto res = lambda->call(ta, tb);
|
||||
auto expected = std::accumulate(ta.begin(), ta.end(), 0u) +
|
||||
std::accumulate(tb.begin(), tb.end(), 0u);
|
||||
ASSERT_EXPECTED_VALUE(res, expected);
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_1tr2_1tr2) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> {
|
||||
return %arg0: tensor<2x3x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
using tensor2_in = std::array<std::array<uint8_t, 3>, 2>;
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<tensor2_out, tensor2_in>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
tensor2_in ta = {{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6}
|
||||
}};
|
||||
auto res = lambda->call(ta);
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
tensor2_out v = res.get();
|
||||
for(size_t i = 0; i < v.size(); i++) {
|
||||
for(size_t j = 0; j < v.size(); j++) {
|
||||
EXPECT_EQ(v[i][j], ta[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(CompiledModule, call_1tr3_1tr3) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
|
||||
return %arg0: tensor<2x3x1x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
using tensor3_in = std::array<std::array<std::array<uint8_t, 1>, 3>, 2>;
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<tensor3_out, tensor3_in>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
tensor3_in ta = {{
|
||||
{{ {1}, {2}, {3} }},
|
||||
{{ {4}, {5}, {6} }}
|
||||
}};
|
||||
auto res = lambda->call(ta);
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
tensor3_out v = res.get();
|
||||
for(size_t i = 0; i < v.size(); i++) {
|
||||
for(size_t j = 0; j < v[i].size(); j++) {
|
||||
for(size_t k = 0; k < v[i][j].size(); k++) {
|
||||
EXPECT_EQ(v[i][j][k], ta[i][j][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,6 +93,17 @@ static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
static inline llvm::Optional<mlir::concretelang::KeySetCache> getTestKeySetCache() {
|
||||
|
||||
llvm::SmallString<0> cachePath;
|
||||
llvm::sys::path::system_temp_directory(true, cachePath);
|
||||
llvm::sys::path::append(cachePath, "KeySetCache");
|
||||
|
||||
auto cachePathStr = std::string(cachePath);
|
||||
return llvm::Optional<mlir::concretelang::KeySetCache>(
|
||||
mlir::concretelang::KeySetCache(cachePathStr));
|
||||
}
|
||||
|
||||
// Jit-compiles the function specified by `func` from `src` and
|
||||
// returns the corresponding lambda. Any compilation errors are caught
|
||||
// and reult in abnormal termination.
|
||||
@@ -103,16 +114,6 @@ internalCheckedJit(F checkFunc, llvm::StringRef src,
|
||||
bool useDefaultFHEConstraints = false,
|
||||
bool autoParallelize = false) {
|
||||
|
||||
llvm::SmallString<0> cachePath;
|
||||
|
||||
llvm::sys::path::system_temp_directory(true, cachePath);
|
||||
|
||||
llvm::sys::path::append(cachePath, "KeySetCache");
|
||||
|
||||
auto cachePathStr = std::string(cachePath);
|
||||
auto optCache = llvm::Optional<mlir::concretelang::KeySetCache>(
|
||||
mlir::concretelang::KeySetCache(cachePathStr));
|
||||
|
||||
mlir::concretelang::JitCompilerEngine engine;
|
||||
|
||||
if (useDefaultFHEConstraints)
|
||||
@@ -124,7 +125,7 @@ internalCheckedJit(F checkFunc, llvm::StringRef src,
|
||||
#endif
|
||||
|
||||
llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
engine.buildLambda(src, func, optCache);
|
||||
engine.buildLambda(src, func, getTestKeySetCache());
|
||||
|
||||
if (!lambdaOrErr) {
|
||||
std::cout << llvm::toString(lambdaOrErr.takeError()) << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user