From fa1b2dc056d96737c5bdda416f171d8ce76446f1 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Fri, 2 Dec 2022 15:34:30 +0000 Subject: [PATCH] test(SDFG): add unit tests for SDFG and stream emulator. --- .../unit_tests/concretelang/CMakeLists.txt | 1 + .../concretelang/SDFG/CMakeLists.txt | 22 ++ .../concretelang/SDFG/SDFG_unit_tests.cpp | 191 ++++++++++++++++++ 3 files changed, 214 insertions(+) create mode 100644 compiler/tests/unit_tests/concretelang/SDFG/CMakeLists.txt create mode 100644 compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp diff --git a/compiler/tests/unit_tests/concretelang/CMakeLists.txt b/compiler/tests/unit_tests/concretelang/CMakeLists.txt index a8e616428..8ad5ef904 100644 --- a/compiler/tests/unit_tests/concretelang/CMakeLists.txt +++ b/compiler/tests/unit_tests/concretelang/CMakeLists.txt @@ -1,4 +1,5 @@ add_custom_target(ConcretelangUnitTests) add_subdirectory(ClientLib) +add_subdirectory(SDFG) add_subdirectory(TestLib) diff --git a/compiler/tests/unit_tests/concretelang/SDFG/CMakeLists.txt b/compiler/tests/unit_tests/concretelang/SDFG/CMakeLists.txt new file mode 100644 index 000000000..1da230275 --- /dev/null +++ b/compiler/tests/unit_tests/concretelang/SDFG/CMakeLists.txt @@ -0,0 +1,22 @@ +add_custom_target(SDFGUnitTests) + +add_dependencies(ConcretelangUnitTests SDFGUnitTests) + +function(add_concretecompiler_lib_test test_name) + add_unittest(SDFGUnitTests ${test_name} ${ARGN}) + target_link_libraries(${test_name} PRIVATE ConcretelangSupport) + set_source_files_properties(${ARGN} PROPERTIES COMPILE_FLAGS "-fno-rtti") +endfunction() + +if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + link_libraries( + # usefull for old gcc versions + -Wl,--allow-multiple-definition # static concrete-optimizer and concrete shares some code + ) +endif() + +if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED) + add_compile_options(-DCONCRETELANG_DATAFLOW_TESTING_ENABLED) +endif() + +add_concretecompiler_lib_test(unit_tests_concretelang_SDFG SDFG_unit_tests.cpp) diff --git a/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp b/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp new file mode 100644 index 000000000..010f20227 --- /dev/null +++ b/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp @@ -0,0 +1,191 @@ +#include + +#include +#include +#include +#include + +#include "boost/outcome.h" + +#include "concretelang/ClientLib/ClientLambda.h" +#include "concretelang/Common/Error.h" +#include "concretelang/Support/CompilerEngine.h" +#include "concretelang/TestLib/TestTypedLambda.h" + +#include "tests_tools/GtestEnvironment.h" +#include "tests_tools/assert.h" +#include "tests_tools/keySetCache.h" + +testing::Environment *const dfr_env = + testing::AddGlobalTestEnvironment(new DFREnvironment); + +const std::string FUNCNAME = "main"; + +using namespace concretelang::testlib; + +using concretelang::clientlib::scalar_in; +using concretelang::clientlib::scalar_out; +using concretelang::clientlib::tensor1_in; +using concretelang::clientlib::tensor1_out; +using concretelang::clientlib::tensor2_in; +using concretelang::clientlib::tensor2_out; +using concretelang::clientlib::tensor3_out; + +std::vector values_3bits() { return {0, 1, 2, 5, 7}; } +std::vector values_6bits() { return {0, 1, 2, 13, 22, 59, 62, 63}; } +std::vector values_7bits() { return {0, 1, 2, 63, 64, 65, 125, 126}; } + +mlir::concretelang::CompilerEngine::Library +compile(std::string outputLib, std::string source, + std::string funcname = FUNCNAME) { + std::vector sources = {source}; + std::shared_ptr ccx = + mlir::concretelang::CompilationContext::createShared(); + mlir::concretelang::CompilerEngine ce{ccx}; + mlir::concretelang::CompilationOptions options(funcname); + options.emitSDFGOps = true; +#ifdef CONCRETELANG_DATAFLOW_TESTING_ENABLED + // options.dataflowParallelize = true; +#endif + ce.setCompilationOptions(options); + auto result = ce.compile(sources, outputLib); + if (!result) { + llvm::errs() << result.takeError(); + assert(false); + } + assert(result); + return result.get(); +} + +static const std::string CURRENT_FILE = __FILE__; +static const std::string THIS_TEST_DIRECTORY = + CURRENT_FILE.substr(0, CURRENT_FILE.find_last_of("/\\")); +static const std::string OUT_DIRECTORY = "/tmp"; + +template std::string outputLibFromThis(Info *info) { + return OUT_DIRECTORY + "/" + std::string(info->name()); +} + +template Lambda load(std::string outputLib) { + auto l = Lambda::load(FUNCNAME, outputLib, 0, 0, getTestKeySetCachePtr()); + assert(l.has_value()); + return l.value(); +} + +TEST(SDFG_unit_tests, add_eint) { + std::string source = R"( +func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> +} +)"; + std::string outputLib = outputLibFromThis(this->test_info_); + auto compiled = compile(outputLib, source); + auto lambda = + load>(outputLib); + for (auto a : values_7bits()) + for (auto b : values_7bits()) { + if (a > b) { + continue; + } + auto res = lambda.call(a, b); + ASSERT_EQ_OUTCOME(res, (scalar_out)a + b); + } +} + +TEST(SDFG_unit_tests, add_eint_int) { + std::string source = R"( +func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> +} +)"; + std::string outputLib = outputLibFromThis(this->test_info_); + auto compiled = compile(outputLib, source); + auto lambda = + load>(outputLib); + for (auto a : values_7bits()) + for (auto b : values_7bits()) { + if (a > b) { + continue; + } + auto res = lambda.call(a, b); + ASSERT_EQ_OUTCOME(res, (scalar_out)a + b); + } +} + +TEST(SDFG_unit_tests, mul_eint_int) { + std::string source = R"( +func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { + %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> +} +)"; + std::string outputLib = outputLibFromThis(this->test_info_); + auto compiled = compile(outputLib, source); + auto lambda = + load>(outputLib); + for (auto a : values_3bits()) + for (auto b : values_3bits()) { + if (a > b) { + continue; + } + auto res = lambda.call(a, b); + ASSERT_EQ_OUTCOME(res, (scalar_out)a * b); + } +} + +TEST(SDFG_unit_tests, neg_eint) { + std::string source = R"( +func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> +} +)"; + std::string outputLib = outputLibFromThis(this->test_info_); + auto compiled = compile(outputLib, source); + auto lambda = load>(outputLib); + for (auto a : values_7bits()) { + auto res = lambda.call(a); + ASSERT_EQ_OUTCOME(res, (scalar_out)((a == 0) ? 0 : 256 - a)); + } +} + +TEST(SDFG_unit_tests, add_eint_tree) { + std::string source = R"( +func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, %arg3: !FHE.eint<7>) -> !FHE.eint<7> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + %2 = "FHE.add_eint"(%arg2, %arg3): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + %3 = "FHE.add_eint"(%1, %2): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) + return %3: !FHE.eint<7> +} +)"; + std::string outputLib = outputLibFromThis(this->test_info_); + auto compiled = compile(outputLib, source); + auto lambda = load< + TestTypedLambda>( + outputLib); + for (auto a : values_3bits()) { + for (auto b : values_3bits()) { + auto res = lambda.call(a, a, b, b); + ASSERT_EQ_OUTCOME(res, (scalar_out)a + a + b + b); + } + } +} + +TEST(SDFG_unit_tests, tlu) { + std::string source = R"( +func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { + %tlu_3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64> + %1 = "FHE.apply_lookup_table"(%arg0, %tlu_3): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>) + return %1: !FHE.eint<3> +} +)"; + std::string outputLib = outputLibFromThis(this->test_info_); + auto compiled = compile(outputLib, source); + auto lambda = load>(outputLib); + for (auto a : values_3bits()) { + auto res = lambda.call(a); + ASSERT_EQ_OUTCOME(res, (scalar_out)a); + } +}