mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test(SDFG): add unit tests for SDFG and stream emulator.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
add_custom_target(ConcretelangUnitTests)
|
||||
|
||||
add_subdirectory(ClientLib)
|
||||
add_subdirectory(SDFG)
|
||||
add_subdirectory(TestLib)
|
||||
|
||||
22
compiler/tests/unit_tests/concretelang/SDFG/CMakeLists.txt
Normal file
22
compiler/tests/unit_tests/concretelang/SDFG/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
add_custom_target(SDFGUnitTests)
|
||||
|
||||
add_dependencies(ConcretelangUnitTests SDFGUnitTests)
|
||||
|
||||
function(add_concretecompiler_lib_test test_name)
|
||||
add_unittest(SDFGUnitTests ${test_name} ${ARGN})
|
||||
target_link_libraries(${test_name} PRIVATE ConcretelangSupport)
|
||||
set_source_files_properties(${ARGN} PROPERTIES COMPILE_FLAGS "-fno-rtti")
|
||||
endfunction()
|
||||
|
||||
if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
link_libraries(
|
||||
# usefull for old gcc versions
|
||||
-Wl,--allow-multiple-definition # static concrete-optimizer and concrete shares some code
|
||||
)
|
||||
endif()
|
||||
|
||||
if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED)
|
||||
add_compile_options(-DCONCRETELANG_DATAFLOW_TESTING_ENABLED)
|
||||
endif()
|
||||
|
||||
add_concretecompiler_lib_test(unit_tests_concretelang_SDFG SDFG_unit_tests.cpp)
|
||||
191
compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp
Normal file
191
compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp
Normal file
@@ -0,0 +1,191 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientLambda.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/TestLib/TestTypedLambda.h"
|
||||
|
||||
#include "tests_tools/GtestEnvironment.h"
|
||||
#include "tests_tools/assert.h"
|
||||
#include "tests_tools/keySetCache.h"
|
||||
|
||||
testing::Environment *const dfr_env =
|
||||
testing::AddGlobalTestEnvironment(new DFREnvironment);
|
||||
|
||||
const std::string FUNCNAME = "main";
|
||||
|
||||
using namespace concretelang::testlib;
|
||||
|
||||
using concretelang::clientlib::scalar_in;
|
||||
using concretelang::clientlib::scalar_out;
|
||||
using concretelang::clientlib::tensor1_in;
|
||||
using concretelang::clientlib::tensor1_out;
|
||||
using concretelang::clientlib::tensor2_in;
|
||||
using concretelang::clientlib::tensor2_out;
|
||||
using concretelang::clientlib::tensor3_out;
|
||||
|
||||
std::vector<uint8_t> values_3bits() { return {0, 1, 2, 5, 7}; }
|
||||
std::vector<uint8_t> values_6bits() { return {0, 1, 2, 13, 22, 59, 62, 63}; }
|
||||
std::vector<uint8_t> values_7bits() { return {0, 1, 2, 63, 64, 65, 125, 126}; }
|
||||
|
||||
mlir::concretelang::CompilerEngine::Library
|
||||
compile(std::string outputLib, std::string source,
|
||||
std::string funcname = FUNCNAME) {
|
||||
std::vector<std::string> sources = {source};
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
mlir::concretelang::CompilerEngine ce{ccx};
|
||||
mlir::concretelang::CompilationOptions options(funcname);
|
||||
options.emitSDFGOps = true;
|
||||
#ifdef CONCRETELANG_DATAFLOW_TESTING_ENABLED
|
||||
// options.dataflowParallelize = true;
|
||||
#endif
|
||||
ce.setCompilationOptions(options);
|
||||
auto result = ce.compile(sources, outputLib);
|
||||
if (!result) {
|
||||
llvm::errs() << result.takeError();
|
||||
assert(false);
|
||||
}
|
||||
assert(result);
|
||||
return result.get();
|
||||
}
|
||||
|
||||
static const std::string CURRENT_FILE = __FILE__;
|
||||
static const std::string THIS_TEST_DIRECTORY =
|
||||
CURRENT_FILE.substr(0, CURRENT_FILE.find_last_of("/\\"));
|
||||
static const std::string OUT_DIRECTORY = "/tmp";
|
||||
|
||||
template <typename Info> std::string outputLibFromThis(Info *info) {
|
||||
return OUT_DIRECTORY + "/" + std::string(info->name());
|
||||
}
|
||||
|
||||
template <typename Lambda> Lambda load(std::string outputLib) {
|
||||
auto l = Lambda::load(FUNCNAME, outputLib, 0, 0, getTestKeySetCachePtr());
|
||||
assert(l.has_value());
|
||||
return l.value();
|
||||
}
|
||||
|
||||
TEST(SDFG_unit_tests, add_eint) {
|
||||
std::string source = R"(
|
||||
func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda =
|
||||
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
|
||||
for (auto a : values_7bits())
|
||||
for (auto b : values_7bits()) {
|
||||
if (a > b) {
|
||||
continue;
|
||||
}
|
||||
auto res = lambda.call(a, b);
|
||||
ASSERT_EQ_OUTCOME(res, (scalar_out)a + b);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SDFG_unit_tests, add_eint_int) {
|
||||
std::string source = R"(
|
||||
func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda =
|
||||
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
|
||||
for (auto a : values_7bits())
|
||||
for (auto b : values_7bits()) {
|
||||
if (a > b) {
|
||||
continue;
|
||||
}
|
||||
auto res = lambda.call(a, b);
|
||||
ASSERT_EQ_OUTCOME(res, (scalar_out)a + b);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SDFG_unit_tests, mul_eint_int) {
|
||||
std::string source = R"(
|
||||
func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda =
|
||||
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
|
||||
for (auto a : values_3bits())
|
||||
for (auto b : values_3bits()) {
|
||||
if (a > b) {
|
||||
continue;
|
||||
}
|
||||
auto res = lambda.call(a, b);
|
||||
ASSERT_EQ_OUTCOME(res, (scalar_out)a * b);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SDFG_unit_tests, neg_eint) {
|
||||
std::string source = R"(
|
||||
func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda = load<TestTypedLambda<scalar_out, scalar_in>>(outputLib);
|
||||
for (auto a : values_7bits()) {
|
||||
auto res = lambda.call(a);
|
||||
ASSERT_EQ_OUTCOME(res, (scalar_out)((a == 0) ? 0 : 256 - a));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SDFG_unit_tests, add_eint_tree) {
|
||||
std::string source = R"(
|
||||
func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, %arg3: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
%2 = "FHE.add_eint"(%arg2, %arg3): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
%3 = "FHE.add_eint"(%1, %2): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %3: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda = load<
|
||||
TestTypedLambda<scalar_out, scalar_in, scalar_in, scalar_in, scalar_in>>(
|
||||
outputLib);
|
||||
for (auto a : values_3bits()) {
|
||||
for (auto b : values_3bits()) {
|
||||
auto res = lambda.call(a, a, b, b);
|
||||
ASSERT_EQ_OUTCOME(res, (scalar_out)a + a + b + b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SDFG_unit_tests, tlu) {
|
||||
std::string source = R"(
|
||||
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%tlu_3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %tlu_3): (!FHE.eint<3>, tensor<8xi64>) -> (!FHE.eint<3>)
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda = load<TestTypedLambda<scalar_out, scalar_in>>(outputLib);
|
||||
for (auto a : values_3bits()) {
|
||||
auto res = lambda.call(a);
|
||||
ASSERT_EQ_OUTCOME(res, (scalar_out)a);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user