mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test(parallelization): add small NN test for auto parallelization.
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <cstdint>
|
||||
#include <gtest/gtest.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "end_to_end_jit_test.h"
|
||||
@@ -55,10 +57,108 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, %
|
||||
return %35: !FHE.eint<7>
|
||||
}
|
||||
)XXX",
|
||||
"main", false, true);
|
||||
"main", false, true, false);
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64, 3_u64, 4_u64), 150);
|
||||
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64, 6_u64, 7_u64), 74);
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64, 1_u64, 1_u64), 60);
|
||||
ASSERT_EXPECTED_VALUE(lambda(5_u64, 7_u64, 11_u64, 13_u64), 28);
|
||||
if (_dfr_is_root_node()) {
|
||||
llvm::Expected<uint64_t> res_1 = lambda(1_u64, 2_u64, 3_u64, 4_u64);
|
||||
llvm::Expected<uint64_t> res_2 = lambda(4_u64, 5_u64, 6_u64, 7_u64);
|
||||
llvm::Expected<uint64_t> res_3 = lambda(1_u64, 1_u64, 1_u64, 1_u64);
|
||||
llvm::Expected<uint64_t> res_4 = lambda(5_u64, 7_u64, 11_u64, 13_u64);
|
||||
ASSERT_EXPECTED_SUCCESS(res_1);
|
||||
ASSERT_EXPECTED_SUCCESS(res_2);
|
||||
ASSERT_EXPECTED_SUCCESS(res_3);
|
||||
ASSERT_EXPECTED_SUCCESS(res_4);
|
||||
ASSERT_EXPECTED_VALUE(res_1, 150);
|
||||
ASSERT_EXPECTED_VALUE(res_2, 74);
|
||||
ASSERT_EXPECTED_VALUE(res_3, 60);
|
||||
ASSERT_EXPECTED_VALUE(res_4, 28);
|
||||
} else {
|
||||
ASSERT_EXPECTED_FAILURE(lambda(1_u64, 2_u64, 3_u64, 4_u64));
|
||||
ASSERT_EXPECTED_FAILURE(lambda(4_u64, 5_u64, 6_u64, 7_u64));
|
||||
ASSERT_EXPECTED_FAILURE(lambda(1_u64, 1_u64, 1_u64, 1_u64));
|
||||
ASSERT_EXPECTED_FAILURE(lambda(5_u64, 7_u64, 11_u64, 13_u64));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint64_t> parallel_results;
|
||||
|
||||
TEST(ParallelizeAndRunFHE, nn_small_parallel) {
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%arg0: tensor<4x5x!FHE.eint<5>>) -> tensor<4x7x!FHE.eint<5>> {
|
||||
%cst = arith.constant dense<[[0, 0, 1, 0, 1, 1, 0], [1, 1, 1, 0, 1, 0, 0], [1, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 1, 1]]> : tensor<4x7xi6>
|
||||
%cst_0 = arith.constant dense<[[1, 0, 1, 1, 0, 1, 1], [0, 1, 0, 0, 0, 0, 1], [0, 1, 1, 1, 1, 0, 0], [0, 1, 1, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 1]]> : tensor<5x7xi6>
|
||||
%0 = "FHELinalg.matmul_eint_int"(%arg0, %cst_0) : (tensor<4x5x!FHE.eint<5>>, tensor<5x7xi6>) -> tensor<4x7x!FHE.eint<5>>
|
||||
%1 = "FHELinalg.add_eint_int"(%0, %cst) : (tensor<4x7x!FHE.eint<5>>, tensor<4x7xi6>) -> tensor<4x7x!FHE.eint<5>>
|
||||
%cst_1 = arith.constant dense<[0, 3, 7, 10, 14, 17, 21, 24, 28, 31, 35, 38, 42, 45, 49, 52, 56, 59, 63, 66, 70, 73, 77, 80, 84, 87, 91, 94, 98, 101, 105, 108]> : tensor<32xi64>
|
||||
%2 = "FHELinalg.apply_lookup_table"(%1, %cst_1) : (tensor<4x7x!FHE.eint<5>>, tensor<32xi64>) -> tensor<4x7x!FHE.eint<5>>
|
||||
return %2 : tensor<4x7x!FHE.eint<5>>
|
||||
}
|
||||
)XXX",
|
||||
"main", false, true, true);
|
||||
|
||||
const size_t numDim = 2;
|
||||
const size_t dim0 = 4;
|
||||
const size_t dim1 = 5;
|
||||
const size_t dim2 = 7;
|
||||
const int64_t dims[numDim]{dim0, dim1};
|
||||
const llvm::ArrayRef<int64_t> shape2D(dims, numDim);
|
||||
std::vector<uint8_t> input;
|
||||
input.reserve(dim0 * dim1);
|
||||
|
||||
for (int i = 0; i < dim0 * dim1; ++i)
|
||||
input.push_back(i % 17 % 4);
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg(input, shape2D);
|
||||
|
||||
if (_dfr_is_root_node()) {
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg});
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
ASSERT_EQ(res->size(), dim0 * dim2);
|
||||
parallel_results = *res;
|
||||
} else {
|
||||
ASSERT_EXPECTED_FAILURE(lambda.operator()<std::vector<uint64_t>>({&arg}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ParallelizeAndRunFHE, nn_small_sequential) {
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%arg0: tensor<4x5x!FHE.eint<5>>) -> tensor<4x7x!FHE.eint<5>> {
|
||||
%cst = arith.constant dense<[[0, 0, 1, 0, 1, 1, 0], [1, 1, 1, 0, 1, 0, 0], [1, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 1, 1]]> : tensor<4x7xi6>
|
||||
%cst_0 = arith.constant dense<[[1, 0, 1, 1, 0, 1, 1], [0, 1, 0, 0, 0, 0, 1], [0, 1, 1, 1, 1, 0, 0], [0, 1, 1, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 1]]> : tensor<5x7xi6>
|
||||
%0 = "FHELinalg.matmul_eint_int"(%arg0, %cst_0) : (tensor<4x5x!FHE.eint<5>>, tensor<5x7xi6>) -> tensor<4x7x!FHE.eint<5>>
|
||||
%1 = "FHELinalg.add_eint_int"(%0, %cst) : (tensor<4x7x!FHE.eint<5>>, tensor<4x7xi6>) -> tensor<4x7x!FHE.eint<5>>
|
||||
%cst_1 = arith.constant dense<[0, 3, 7, 10, 14, 17, 21, 24, 28, 31, 35, 38, 42, 45, 49, 52, 56, 59, 63, 66, 70, 73, 77, 80, 84, 87, 91, 94, 98, 101, 105, 108]> : tensor<32xi64>
|
||||
%2 = "FHELinalg.apply_lookup_table"(%1, %cst_1) : (tensor<4x7x!FHE.eint<5>>, tensor<32xi64>) -> tensor<4x7x!FHE.eint<5>>
|
||||
return %2 : tensor<4x7x!FHE.eint<5>>
|
||||
}
|
||||
)XXX",
|
||||
"main", false, false, false);
|
||||
|
||||
const size_t numDim = 2;
|
||||
const size_t dim0 = 4;
|
||||
const size_t dim1 = 5;
|
||||
const size_t dim2 = 7;
|
||||
const int64_t dims[numDim]{dim0, dim1};
|
||||
const llvm::ArrayRef<int64_t> shape2D(dims, numDim);
|
||||
std::vector<uint8_t> input;
|
||||
input.reserve(dim0 * dim1);
|
||||
|
||||
for (int i = 0; i < dim0 * dim1; ++i)
|
||||
input.push_back(i % 17 % 4);
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg(input, shape2D);
|
||||
|
||||
// This is sequential: only execute on root node.
|
||||
if (_dfr_is_root_node()) {
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg});
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
for (size_t i = 0; i < dim0 * dim2; i++)
|
||||
EXPECT_EQ(parallel_results[i], (*res)[i]) << "result differ at pos " << i;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user