test(SDFG): add unit tests for SDFG and stream emulator.

This commit is contained in:
Antoniu Pop
2022-12-02 15:34:30 +00:00
committed by Andi Drebes
parent 0dbb86bb36
commit fa1b2dc056
3 changed files with 214 additions and 0 deletions

View File

@@ -1,4 +1,5 @@
add_custom_target(ConcretelangUnitTests)
add_subdirectory(ClientLib)
add_subdirectory(SDFG)
add_subdirectory(TestLib)

View File

@@ -0,0 +1,22 @@
add_custom_target(SDFGUnitTests)
add_dependencies(ConcretelangUnitTests SDFGUnitTests)
function(add_concretecompiler_lib_test test_name)
add_unittest(SDFGUnitTests ${test_name} ${ARGN})
target_link_libraries(${test_name} PRIVATE ConcretelangSupport)
set_source_files_properties(${ARGN} PROPERTIES COMPILE_FLAGS "-fno-rtti")
endfunction()
if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
link_libraries(
# usefull for old gcc versions
-Wl,--allow-multiple-definition # static concrete-optimizer and concrete shares some code
)
endif()
if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED)
add_compile_options(-DCONCRETELANG_DATAFLOW_TESTING_ENABLED)
endif()
add_concretecompiler_lib_test(unit_tests_concretelang_SDFG SDFG_unit_tests.cpp)

View File

@@ -0,0 +1,191 @@
#include <gtest/gtest.h>
#include <cassert>
#include <chrono>
#include <iostream>
#include <thread>
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientLambda.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/TestLib/TestTypedLambda.h"
#include "tests_tools/GtestEnvironment.h"
#include "tests_tools/assert.h"
#include "tests_tools/keySetCache.h"
testing::Environment *const dfr_env =
testing::AddGlobalTestEnvironment(new DFREnvironment);
const std::string FUNCNAME = "main";
using namespace concretelang::testlib;
using concretelang::clientlib::scalar_in;
using concretelang::clientlib::scalar_out;
using concretelang::clientlib::tensor1_in;
using concretelang::clientlib::tensor1_out;
using concretelang::clientlib::tensor2_in;
using concretelang::clientlib::tensor2_out;
using concretelang::clientlib::tensor3_out;
std::vector<uint8_t> values_3bits() { return {0, 1, 2, 5, 7}; }
std::vector<uint8_t> values_6bits() { return {0, 1, 2, 13, 22, 59, 62, 63}; }
std::vector<uint8_t> values_7bits() { return {0, 1, 2, 63, 64, 65, 125, 126}; }
mlir::concretelang::CompilerEngine::Library
compile(std::string outputLib, std::string source,
std::string funcname = FUNCNAME) {
std::vector<std::string> sources = {source};
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
mlir::concretelang::CompilerEngine ce{ccx};
mlir::concretelang::CompilationOptions options(funcname);
options.emitSDFGOps = true;
#ifdef CONCRETELANG_DATAFLOW_TESTING_ENABLED
// options.dataflowParallelize = true;
#endif
ce.setCompilationOptions(options);
auto result = ce.compile(sources, outputLib);
if (!result) {
llvm::errs() << result.takeError();
assert(false);
}
assert(result);
return result.get();
}
static const std::string CURRENT_FILE = __FILE__;
static const std::string THIS_TEST_DIRECTORY =
CURRENT_FILE.substr(0, CURRENT_FILE.find_last_of("/\\"));
static const std::string OUT_DIRECTORY = "/tmp";
template <typename Info> std::string outputLibFromThis(Info *info) {
return OUT_DIRECTORY + "/" + std::string(info->name());
}
template <typename Lambda> Lambda load(std::string outputLib) {
auto l = Lambda::load(FUNCNAME, outputLib, 0, 0, getTestKeySetCachePtr());
assert(l.has_value());
return l.value();
}
TEST(SDFG_unit_tests, add_eint) {
std::string source = R"(
func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda =
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for (auto a : values_7bits())
for (auto b : values_7bits()) {
if (a > b) {
continue;
}
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, (scalar_out)a + b);
}
}
TEST(SDFG_unit_tests, add_eint_int) {
std::string source = R"(
func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda =
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for (auto a : values_7bits())
for (auto b : values_7bits()) {
if (a > b) {
continue;
}
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, (scalar_out)a + b);
}
}
TEST(SDFG_unit_tests, mul_eint_int) {
std::string source = R"(
func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
%1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda =
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for (auto a : values_3bits())
for (auto b : values_3bits()) {
if (a > b) {
continue;
}
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, (scalar_out)a * b);
}
}
TEST(SDFG_unit_tests, neg_eint) {
std::string source = R"(
func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<scalar_out, scalar_in>>(outputLib);
for (auto a : values_7bits()) {
auto res = lambda.call(a);
ASSERT_EQ_OUTCOME(res, (scalar_out)((a == 0) ? 0 : 256 - a));
}
}
TEST(SDFG_unit_tests, add_eint_tree) {
std::string source = R"(
func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, %arg3: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
%2 = "FHE.add_eint"(%arg2, %arg3): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
%3 = "FHE.add_eint"(%1, %2): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %3: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<
TestTypedLambda<scalar_out, scalar_in, scalar_in, scalar_in, scalar_in>>(
outputLib);
for (auto a : values_3bits()) {
for (auto b : values_3bits()) {
auto res = lambda.call(a, a, b, b);
ASSERT_EQ_OUTCOME(res, (scalar_out)a + a + b + b);
}
}
}
TEST(SDFG_unit_tests, tlu) {
std::string source = R"(
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
%tlu_3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
%1 = "FHE.apply_lookup_table"(%arg0, %tlu_3): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<scalar_out, scalar_in>>(outputLib);
for (auto a : values_3bits()) {
auto res = lambda.call(a);
ASSERT_EQ_OUTCOME(res, (scalar_out)a);
}
}