mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test(compiler): Add tests for the HLFHELinalg tiling passes
This commit is contained in:
60
compiler/tests/Dialect/HLFHELinalg/tiling.mlir
Normal file
60
compiler/tests/Dialect/HLFHELinalg/tiling.mlir
Normal file
@@ -0,0 +1,60 @@
|
||||
// RUN: zamacompiler --action=dump-hlfhe %s 2>&1 --split-input-file | FileCheck %s
|
||||
|
||||
// CHECK: func @tiled_2x2(%[[Varg0:.*]]: tensor<8x4x!HLFHE.eint<6>>, %[[Varg1:.*]]: tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> {
|
||||
// CHECK-NEXT: %[[Vc2:.*]] = arith.constant 2 : index
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[Vc8:.*]] = arith.constant 8 : index
|
||||
// CHECK-NEXT: %[[Vc4:.*]] = arith.constant 4 : index
|
||||
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.zero"() : () -> tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc8]] step %[[Vc2]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<8x2x!HLFHE.eint<6>>) {
|
||||
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg4:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc2]] iter_args(%[[Varg5:.*]] = %[[Varg3]]) -> (tensor<8x2x!HLFHE.eint<6>>) {
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg6:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc2]] iter_args(%[[Varg7:.*]] = %[[Varg5]]) -> (tensor<8x2x!HLFHE.eint<6>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.extract_slice %[[Varg0]][%[[Varg2]], %[[Varg4]]] [2, 2] [1, 1] : tensor<8x4x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: %[[V5:.*]] = tensor.extract_slice %[[Varg1]][%[[Varg4]], %[[Varg6]]] [2, 2] [1, 1] : tensor<4x2xi7> to tensor<2x2xi7>
|
||||
// CHECK-NEXT: %[[V6:.*]] = tensor.extract_slice %[[Varg7]][%[[Varg2]], %[[Varg6]]] [2, 2] [1, 1] : tensor<8x2x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: %[[V7:.*]] = "HLFHELinalg.matmul_eint_int"(%[[V4]], %[[V5]]) : (tensor<2x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<2x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: %[[V8:.*]] = "HLFHELinalg.add_eint"(%[[V6]], %[[V7]]) : (tensor<2x2x!HLFHE.eint<6>>, tensor<2x2x!HLFHE.eint<6>>) -> tensor<2x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: %[[V9:.*]] = tensor.insert_slice %[[V8]] into %[[Varg7]][%[[Varg2]], %[[Varg6]]] [2, 2] [1, 1] : tensor<2x2x!HLFHE.eint<6>> into tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: scf.yield %[[V9]] : tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V3]] : tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V2]] : tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
func @tiled_2x2(%a: tensor<8x4x!HLFHE.eint<6>>, %b: tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> {
|
||||
%0 = "HLFHELinalg.matmul_eint_int"(%a, %b) { "tile-sizes" = [2,2,2] } : (tensor<8x4x!HLFHE.eint<6>>, tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>>
|
||||
return %0 : tensor<8x2x!HLFHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @tiled_one_big_tile(%[[Varg0:.*]]: tensor<8x4x!HLFHE.eint<6>>, %[[Varg1:.*]]: tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> {
|
||||
// CHECK-NEXT: %[[Vc8:.*]] = arith.constant 8 : index
|
||||
// CHECK-NEXT: %[[Vc4:.*]] = arith.constant 4 : index
|
||||
// CHECK-NEXT: %[[Vc2:.*]] = arith.constant 2 : index
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.zero"() : () -> tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc8]] step %[[Vc8]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<8x2x!HLFHE.eint<6>>) {
|
||||
// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg4:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc4]] iter_args(%[[Varg5:.*]] = %[[Varg3]]) -> (tensor<8x2x!HLFHE.eint<6>>) {
|
||||
// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg6:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc2]] iter_args(%[[Varg7:.*]] = %[[Varg5]]) -> (tensor<8x2x!HLFHE.eint<6>>) {
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.extract_slice %[[Varg0]][%[[Varg2]], %[[Varg4]]] [8, 4] [1, 1] : tensor<8x4x!HLFHE.eint<6>> to tensor<8x4x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: %[[V5:.*]] = tensor.extract_slice %[[Varg1]][%[[Varg4]], %[[Varg6]]] [4, 2] [1, 1] : tensor<4x2xi7> to tensor<4x2xi7>
|
||||
// CHECK-NEXT: %[[V6:.*]] = tensor.extract_slice %[[Varg7]][%[[Varg2]], %[[Varg6]]] [8, 2] [1, 1] : tensor<8x2x!HLFHE.eint<6>> to tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: %[[V7:.*]] = "HLFHELinalg.matmul_eint_int"(%[[V4]], %[[V5]]) : (tensor<8x4x!HLFHE.eint<6>>, tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: %[[V8:.*]] = "HLFHELinalg.add_eint"(%[[V6]], %[[V7]]) : (tensor<8x2x!HLFHE.eint<6>>, tensor<8x2x!HLFHE.eint<6>>) -> tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: %[[V9:.*]] = tensor.insert_slice %[[V8]] into %[[Varg7]][%[[Varg2]], %[[Varg6]]] [8, 2] [1, 1] : tensor<8x2x!HLFHE.eint<6>> into tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: scf.yield %[[V9]] : tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V3]] : tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[V2]] : tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<8x2x!HLFHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
func @tiled_one_big_tile(%a: tensor<8x4x!HLFHE.eint<6>>, %b: tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> {
|
||||
%0 = "HLFHELinalg.matmul_eint_int"(%a, %b) { "tile-sizes" = [8,4,2] } : (tensor<8x4x!HLFHE.eint<6>>, tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>>
|
||||
return %0 : tensor<8x2x!HLFHE.eint<6>>
|
||||
}
|
||||
|
||||
@@ -1470,3 +1470,78 @@ func @main() -> tensor<2x2x4x!HLFHE.eint<6>> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TiledMatMulParametric
|
||||
: public ::testing::TestWithParam<std::vector<int64_t>> {};
|
||||
|
||||
TEST_P(TiledMatMulParametric, tiled_matmul_eint_int) {
|
||||
std::vector<int64_t> tiling = GetParam();
|
||||
std::ostringstream mlirProgram;
|
||||
|
||||
mlirProgram
|
||||
<< "func @main(%a: tensor<8x4x!HLFHE.eint<6>>, %b: tensor<4x2xi7>) ->\n"
|
||||
<< " tensor<8x2x!HLFHE.eint<6>> {\n"
|
||||
<< " %0 = \"HLFHELinalg.matmul_eint_int\"(%a, %b) { \"tile-sizes\" = "
|
||||
<< "[" << tiling[0] << ", " << tiling[1] << ", " << tiling[2] << "]"
|
||||
<< "} :\n"
|
||||
<< " (tensor<8x4x!HLFHE.eint<6>>, tensor<4x2xi7>) ->\n"
|
||||
<< " tensor<8x2x!HLFHE.eint<6>>\n"
|
||||
<< " return %0 : tensor<8x2x!HLFHE.eint<6>>\n"
|
||||
<< " }";
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda =
|
||||
checkedJit(mlirProgram.str());
|
||||
|
||||
const size_t rowsA = 8;
|
||||
const size_t colsA = 4;
|
||||
const uint8_t A[rowsA][colsA] = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 0, 1, 2},
|
||||
{3, 4, 5, 6}, {7, 8, 9, 0}, {1, 2, 3, 4},
|
||||
{5, 6, 7, 8}, {9, 0, 1, 2}};
|
||||
|
||||
const size_t rowsB = 4;
|
||||
const size_t colsB = 2;
|
||||
const uint8_t B[rowsB][colsB]{{1, 2}, {3, 4}, {3, 1}, {0, 2}};
|
||||
|
||||
const size_t rowsC = rowsA;
|
||||
const size_t colsC = colsB;
|
||||
const uint8_t expected[rowsC][colsC]{
|
||||
{16, 21}, {44, 57}, {12, 23}, {30, 39},
|
||||
{58, 55}, {16, 21}, {44, 57}, {12, 23},
|
||||
};
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, rowsA * colsA),
|
||||
{rowsA, colsA});
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, rowsB * colsB),
|
||||
{rowsB, colsB});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)(rowsC * colsC));
|
||||
|
||||
for (size_t i = 0; i < rowsC; i++) {
|
||||
for (size_t j = 0; j < colsC; j++) {
|
||||
EXPECT_EQ((*res)[i * colsC + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TiledMatMul, TiledMatMulParametric,
|
||||
::testing::Values(
|
||||
// Element-sized tiles
|
||||
std::vector<int64_t>{1, 1, 1},
|
||||
|
||||
// Mixed tiles
|
||||
std::vector<int64_t>{2, 2, 2},
|
||||
std::vector<int64_t>{4, 4, 2},
|
||||
std::vector<int64_t>{2, 4, 2},
|
||||
|
||||
// Single, big tile
|
||||
std::vector<int64_t>{8, 4, 2}));
|
||||
|
||||
Reference in New Issue
Block a user