mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
test(compiler): add unit tests for dataflow auto parallelization.
This commit is contained in:
@@ -50,7 +50,7 @@ test-python: python-bindings zamacompiler
|
||||
|
||||
test: test-check test-end-to-end-jit test-python
|
||||
|
||||
test-dataflow: test-end-to-end-jit-dfr
|
||||
test-dataflow: test-end-to-end-jit-dfr test-end-to-end-jit-auto-parallelization
|
||||
|
||||
# Unittests
|
||||
|
||||
@@ -72,6 +72,9 @@ build-end-to-end-jit-lambda: build-initialized
|
||||
build-end-to-end-jit-dfr: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_dfr
|
||||
|
||||
build-end-to-end-jit-auto-parallelization: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_auto_parallelization
|
||||
|
||||
build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg
|
||||
|
||||
|
||||
@@ -93,8 +96,12 @@ test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lambda
|
||||
test-end-to-end-jit-dfr: build-end-to-end-jit-dfr
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_dfr
|
||||
|
||||
test-end-to-end-jit-auto-parallelization: build-end-to-end-jit-auto-parallelization
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_auto_parallelization
|
||||
|
||||
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg
|
||||
|
||||
|
||||
# LLVM/MLIR dependencies
|
||||
|
||||
all-deps: file-check not
|
||||
|
||||
@@ -2,6 +2,18 @@ enable_testing()
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
if(ZAMALANG_PARALLEL_EXECUTION_ENABLED)
|
||||
add_compile_options(
|
||||
-DZAMALANG_PARALLEL_TESTING_ENABLED
|
||||
)
|
||||
link_libraries(
|
||||
-Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime
|
||||
-Wl,-rpath,${HPX_DIR}/../../
|
||||
-Wl,--no-as-needed
|
||||
DFRuntime
|
||||
)
|
||||
endif()
|
||||
|
||||
add_executable(
|
||||
end_to_end_jit_test
|
||||
end_to_end_jit_test.cc
|
||||
@@ -84,8 +96,15 @@ if(ZAMALANG_PARALLEL_EXECUTION_ENABLED)
|
||||
end_to_end_jit_dfr
|
||||
end_to_end_jit_dfr.cc
|
||||
)
|
||||
add_executable(
|
||||
end_to_end_jit_auto_parallelization
|
||||
end_to_end_jit_auto_parallelization.cc
|
||||
globals.cc
|
||||
)
|
||||
|
||||
set_source_files_properties(
|
||||
end_to_end_jit_dfr.cc
|
||||
end_to_end_jit_auto_parallelization.cc
|
||||
PROPERTIES COMPILE_FLAGS "-fno-rtti"
|
||||
)
|
||||
target_link_libraries(
|
||||
@@ -97,5 +116,17 @@ if(ZAMALANG_PARALLEL_EXECUTION_ENABLED)
|
||||
-Wl,--no-as-needed
|
||||
DFRuntime
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
end_to_end_jit_auto_parallelization
|
||||
gtest_main
|
||||
ZamalangSupport
|
||||
-Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime
|
||||
-Wl,-rpath,${HPX_DIR}/../../
|
||||
-Wl,--no-as-needed
|
||||
DFRuntime
|
||||
)
|
||||
|
||||
gtest_discover_tests(end_to_end_jit_dfr)
|
||||
gtest_discover_tests(end_to_end_jit_auto_parallelization)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "end_to_end_jit_test.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Auto-parallelize independent HLFHE ops /////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(ParallelizeAndRunHLFHE, add_eint_tree) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>, %arg2: !HLFHE.eint<7>, %arg3: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%2 = "HLFHE.add_eint"(%arg0, %arg2): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%3 = "HLFHE.add_eint"(%arg0, %arg3): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%4 = "HLFHE.add_eint"(%arg1, %arg2): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%5 = "HLFHE.add_eint"(%arg1, %arg3): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%6 = "HLFHE.add_eint"(%arg2, %arg3): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
|
||||
%7 = "HLFHE.add_eint"(%1, %2): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%8 = "HLFHE.add_eint"(%1, %3): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%9 = "HLFHE.add_eint"(%1, %4): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%10 = "HLFHE.add_eint"(%1, %5): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%11 = "HLFHE.add_eint"(%1, %6): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%12 = "HLFHE.add_eint"(%2, %3): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%13 = "HLFHE.add_eint"(%2, %4): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%14 = "HLFHE.add_eint"(%2, %5): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%15 = "HLFHE.add_eint"(%2, %6): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%16 = "HLFHE.add_eint"(%3, %4): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%17 = "HLFHE.add_eint"(%3, %5): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%18 = "HLFHE.add_eint"(%3, %6): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%19 = "HLFHE.add_eint"(%4, %5): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%20 = "HLFHE.add_eint"(%4, %6): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%21 = "HLFHE.add_eint"(%5, %6): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
|
||||
%22 = "HLFHE.add_eint"(%7, %8): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%23 = "HLFHE.add_eint"(%9, %10): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%24 = "HLFHE.add_eint"(%11, %12): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%25 = "HLFHE.add_eint"(%13, %14): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%26 = "HLFHE.add_eint"(%15, %16): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%27 = "HLFHE.add_eint"(%17, %18): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%28 = "HLFHE.add_eint"(%19, %20): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
|
||||
%29 = "HLFHE.add_eint"(%22, %23): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%30 = "HLFHE.add_eint"(%24, %25): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%31 = "HLFHE.add_eint"(%26, %27): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%32 = "HLFHE.add_eint"(%21, %28): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
|
||||
%33 = "HLFHE.add_eint"(%29, %30): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
%34 = "HLFHE.add_eint"(%31, %32): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
|
||||
%35 = "HLFHE.add_eint"(%33, %34): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
return %35: !HLFHE.eint<7>
|
||||
}
|
||||
)XXX", "main", false, true);
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64, 3_u64, 4_u64), 150);
|
||||
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64, 6_u64, 7_u64), 74);
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64, 1_u64, 1_u64), 60);
|
||||
ASSERT_EXPECTED_VALUE(lambda(5_u64, 7_u64, 11_u64, 13_u64), 28);
|
||||
}
|
||||
@@ -100,7 +100,8 @@ template <typename F>
|
||||
mlir::zamalang::JitCompilerEngine::Lambda
|
||||
internalCheckedJit(F checkFunc, llvm::StringRef src,
|
||||
llvm::StringRef func = "main",
|
||||
bool useDefaultFHEConstraints = false) {
|
||||
bool useDefaultFHEConstraints = false,
|
||||
bool autoParallelize = false) {
|
||||
|
||||
llvm::SmallString<0> cachePath;
|
||||
|
||||
@@ -116,6 +117,11 @@ internalCheckedJit(F checkFunc, llvm::StringRef src,
|
||||
|
||||
if (useDefaultFHEConstraints)
|
||||
engine.setFHEConstraints(defaultV0Constraints);
|
||||
#ifdef ZAMALANG_PARALLEL_TESTING_ENABLED
|
||||
engine.setAutoParallelize(true);
|
||||
#else
|
||||
engine.setAutoParallelize(autoParallelize);
|
||||
#endif
|
||||
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
engine.buildLambda(src, func, optCache);
|
||||
|
||||
Reference in New Issue
Block a user