test(compiler): add unit tests for dataflow auto parallelization.

This commit is contained in:
Antoniu Pop
2021-12-20 11:14:56 +00:00
committed by Antoniu Pop
parent cdca7ca6f7
commit a1a694a686
4 changed files with 109 additions and 2 deletions

View File

@@ -50,7 +50,7 @@ test-python: python-bindings zamacompiler
test: test-check test-end-to-end-jit test-python
test-dataflow: test-end-to-end-jit-dfr
test-dataflow: test-end-to-end-jit-dfr test-end-to-end-jit-auto-parallelization
# Unittests
@@ -72,6 +72,9 @@ build-end-to-end-jit-lambda: build-initialized
build-end-to-end-jit-dfr: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_dfr
build-end-to-end-jit-auto-parallelization: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_auto_parallelization
build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg
@@ -93,8 +96,12 @@ test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lambda
test-end-to-end-jit-dfr: build-end-to-end-jit-dfr
$(BUILD_DIR)/bin/end_to_end_jit_dfr
test-end-to-end-jit-auto-parallelization: build-end-to-end-jit-auto-parallelization
$(BUILD_DIR)/bin/end_to_end_jit_auto_parallelization
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg
# LLVM/MLIR dependencies
all-deps: file-check not

View File

@@ -2,6 +2,18 @@ enable_testing()
include_directories(${PROJECT_SOURCE_DIR}/include)
if(ZAMALANG_PARALLEL_EXECUTION_ENABLED)
add_compile_options(
-DZAMALANG_PARALLEL_TESTING_ENABLED
)
link_libraries(
-Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime
-Wl,-rpath,${HPX_DIR}/../../
-Wl,--no-as-needed
DFRuntime
)
endif()
add_executable(
end_to_end_jit_test
end_to_end_jit_test.cc
@@ -84,8 +96,15 @@ if(ZAMALANG_PARALLEL_EXECUTION_ENABLED)
end_to_end_jit_dfr
end_to_end_jit_dfr.cc
)
add_executable(
end_to_end_jit_auto_parallelization
end_to_end_jit_auto_parallelization.cc
globals.cc
)
set_source_files_properties(
end_to_end_jit_dfr.cc
end_to_end_jit_auto_parallelization.cc
PROPERTIES COMPILE_FLAGS "-fno-rtti"
)
target_link_libraries(
@@ -97,5 +116,17 @@ if(ZAMALANG_PARALLEL_EXECUTION_ENABLED)
-Wl,--no-as-needed
DFRuntime
)
target_link_libraries(
end_to_end_jit_auto_parallelization
gtest_main
ZamalangSupport
-Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime
-Wl,-rpath,${HPX_DIR}/../../
-Wl,--no-as-needed
DFRuntime
)
gtest_discover_tests(end_to_end_jit_dfr)
gtest_discover_tests(end_to_end_jit_auto_parallelization)
endif()

View File

@@ -0,0 +1,63 @@
#include <cstdint>
#include <gtest/gtest.h>
#include <type_traits>
#include "end_to_end_jit_test.h"
///////////////////////////////////////////////////////////////////////////////
// Auto-parallelize independent HLFHE ops /////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(ParallelizeAndRunHLFHE, add_eint_tree) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>, %arg2: !HLFHE.eint<7>, %arg3: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%2 = "HLFHE.add_eint"(%arg0, %arg2): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%3 = "HLFHE.add_eint"(%arg0, %arg3): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%4 = "HLFHE.add_eint"(%arg1, %arg2): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%5 = "HLFHE.add_eint"(%arg1, %arg3): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%6 = "HLFHE.add_eint"(%arg2, %arg3): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%7 = "HLFHE.add_eint"(%1, %2): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%8 = "HLFHE.add_eint"(%1, %3): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%9 = "HLFHE.add_eint"(%1, %4): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%10 = "HLFHE.add_eint"(%1, %5): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%11 = "HLFHE.add_eint"(%1, %6): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%12 = "HLFHE.add_eint"(%2, %3): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%13 = "HLFHE.add_eint"(%2, %4): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%14 = "HLFHE.add_eint"(%2, %5): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%15 = "HLFHE.add_eint"(%2, %6): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%16 = "HLFHE.add_eint"(%3, %4): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%17 = "HLFHE.add_eint"(%3, %5): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%18 = "HLFHE.add_eint"(%3, %6): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%19 = "HLFHE.add_eint"(%4, %5): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%20 = "HLFHE.add_eint"(%4, %6): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%21 = "HLFHE.add_eint"(%5, %6): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%22 = "HLFHE.add_eint"(%7, %8): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%23 = "HLFHE.add_eint"(%9, %10): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%24 = "HLFHE.add_eint"(%11, %12): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%25 = "HLFHE.add_eint"(%13, %14): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%26 = "HLFHE.add_eint"(%15, %16): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%27 = "HLFHE.add_eint"(%17, %18): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%28 = "HLFHE.add_eint"(%19, %20): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%29 = "HLFHE.add_eint"(%22, %23): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%30 = "HLFHE.add_eint"(%24, %25): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%31 = "HLFHE.add_eint"(%26, %27): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%32 = "HLFHE.add_eint"(%21, %28): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%33 = "HLFHE.add_eint"(%29, %30): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%34 = "HLFHE.add_eint"(%31, %32): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
%35 = "HLFHE.add_eint"(%33, %34): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
return %35: !HLFHE.eint<7>
}
)XXX", "main", false, true);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64, 3_u64, 4_u64), 150);
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64, 6_u64, 7_u64), 74);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64, 1_u64, 1_u64), 60);
ASSERT_EXPECTED_VALUE(lambda(5_u64, 7_u64, 11_u64, 13_u64), 28);
}

View File

@@ -100,7 +100,8 @@ template <typename F>
mlir::zamalang::JitCompilerEngine::Lambda
internalCheckedJit(F checkFunc, llvm::StringRef src,
llvm::StringRef func = "main",
bool useDefaultFHEConstraints = false) {
bool useDefaultFHEConstraints = false,
bool autoParallelize = false) {
llvm::SmallString<0> cachePath;
@@ -116,6 +117,11 @@ internalCheckedJit(F checkFunc, llvm::StringRef src,
if (useDefaultFHEConstraints)
engine.setFHEConstraints(defaultV0Constraints);
#ifdef ZAMALANG_PARALLEL_TESTING_ENABLED
engine.setAutoParallelize(true);
#else
engine.setAutoParallelize(autoParallelize);
#endif
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
engine.buildLambda(src, func, optCache);