diff --git a/compiler/tests/Dialect/HLFHELinalg/tiling.mlir b/compiler/tests/Dialect/HLFHELinalg/tiling.mlir new file mode 100644 index 000000000..61affb0b0 --- /dev/null +++ b/compiler/tests/Dialect/HLFHELinalg/tiling.mlir @@ -0,0 +1,60 @@ +// RUN: zamacompiler --action=dump-hlfhe %s 2>&1 --split-input-file | FileCheck %s + +// CHECK: func @tiled_2x2(%[[Varg0:.*]]: tensor<8x4x!HLFHE.eint<6>>, %[[Varg1:.*]]: tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> { +// CHECK-NEXT: %[[Vc2:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[Vc8:.*]] = arith.constant 8 : index +// CHECK-NEXT: %[[Vc4:.*]] = arith.constant 4 : index +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.zero"() : () -> tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc8]] step %[[Vc2]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<8x2x!HLFHE.eint<6>>) { +// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg4:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc2]] iter_args(%[[Varg5:.*]] = %[[Varg3]]) -> (tensor<8x2x!HLFHE.eint<6>>) { +// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg6:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc2]] iter_args(%[[Varg7:.*]] = %[[Varg5]]) -> (tensor<8x2x!HLFHE.eint<6>>) { +// CHECK-NEXT: %[[V4:.*]] = tensor.extract_slice %[[Varg0]][%[[Varg2]], %[[Varg4]]] [2, 2] [1, 1] : tensor<8x4x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>> +// CHECK-NEXT: %[[V5:.*]] = tensor.extract_slice %[[Varg1]][%[[Varg4]], %[[Varg6]]] [2, 2] [1, 1] : tensor<4x2xi7> to tensor<2x2xi7> +// CHECK-NEXT: %[[V6:.*]] = tensor.extract_slice %[[Varg7]][%[[Varg2]], %[[Varg6]]] [2, 2] [1, 1] : tensor<8x2x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>> +// CHECK-NEXT: %[[V7:.*]] = "HLFHELinalg.matmul_eint_int"(%[[V4]], %[[V5]]) : (tensor<2x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<2x2x!HLFHE.eint<6>> +// CHECK-NEXT: %[[V8:.*]] = "HLFHELinalg.add_eint"(%[[V6]], %[[V7]]) : (tensor<2x2x!HLFHE.eint<6>>, tensor<2x2x!HLFHE.eint<6>>) -> tensor<2x2x!HLFHE.eint<6>> +// CHECK-NEXT: %[[V9:.*]] = tensor.insert_slice %[[V8]] into %[[Varg7]][%[[Varg2]], %[[Varg6]]] [2, 2] [1, 1] : tensor<2x2x!HLFHE.eint<6>> into tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: scf.yield %[[V9]] : tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[V3]] : tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[V2]] : tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V1]] : tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: } +func @tiled_2x2(%a: tensor<8x4x!HLFHE.eint<6>>, %b: tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> { + %0 = "HLFHELinalg.matmul_eint_int"(%a, %b) { "tile-sizes" = [2,2,2] } : (tensor<8x4x!HLFHE.eint<6>>, tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> + return %0 : tensor<8x2x!HLFHE.eint<6>> +} + +// ----- + +// CHECK: func @tiled_one_big_tile(%[[Varg0:.*]]: tensor<8x4x!HLFHE.eint<6>>, %[[Varg1:.*]]: tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> { +// CHECK-NEXT: %[[Vc8:.*]] = arith.constant 8 : index +// CHECK-NEXT: %[[Vc4:.*]] = arith.constant 4 : index +// CHECK-NEXT: %[[Vc2:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.zero"() : () -> tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc8]] step %[[Vc8]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<8x2x!HLFHE.eint<6>>) { +// CHECK-NEXT: %[[V2:.*]] = scf.for %[[Varg4:.*]] = %[[Vc0]] to %[[Vc4]] step %[[Vc4]] iter_args(%[[Varg5:.*]] = %[[Varg3]]) -> (tensor<8x2x!HLFHE.eint<6>>) { +// CHECK-NEXT: %[[V3:.*]] = scf.for %[[Varg6:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc2]] iter_args(%[[Varg7:.*]] = %[[Varg5]]) -> (tensor<8x2x!HLFHE.eint<6>>) { +// CHECK-NEXT: %[[V4:.*]] = tensor.extract_slice %[[Varg0]][%[[Varg2]], %[[Varg4]]] [8, 4] [1, 1] : tensor<8x4x!HLFHE.eint<6>> to tensor<8x4x!HLFHE.eint<6>> +// CHECK-NEXT: %[[V5:.*]] = tensor.extract_slice %[[Varg1]][%[[Varg4]], %[[Varg6]]] [4, 2] [1, 1] : tensor<4x2xi7> to tensor<4x2xi7> +// CHECK-NEXT: %[[V6:.*]] = tensor.extract_slice %[[Varg7]][%[[Varg2]], %[[Varg6]]] [8, 2] [1, 1] : tensor<8x2x!HLFHE.eint<6>> to tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: %[[V7:.*]] = "HLFHELinalg.matmul_eint_int"(%[[V4]], %[[V5]]) : (tensor<8x4x!HLFHE.eint<6>>, tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: %[[V8:.*]] = "HLFHELinalg.add_eint"(%[[V6]], %[[V7]]) : (tensor<8x2x!HLFHE.eint<6>>, tensor<8x2x!HLFHE.eint<6>>) -> tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: %[[V9:.*]] = tensor.insert_slice %[[V8]] into %[[Varg7]][%[[Varg2]], %[[Varg6]]] [8, 2] [1, 1] : tensor<8x2x!HLFHE.eint<6>> into tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: scf.yield %[[V9]] : tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[V3]] : tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[V2]] : tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V1]] : tensor<8x2x!HLFHE.eint<6>> +// CHECK-NEXT: } +func @tiled_one_big_tile(%a: tensor<8x4x!HLFHE.eint<6>>, %b: tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> { + %0 = "HLFHELinalg.matmul_eint_int"(%a, %b) { "tile-sizes" = [8,4,2] } : (tensor<8x4x!HLFHE.eint<6>>, tensor<4x2xi7>) -> tensor<8x2x!HLFHE.eint<6>> + return %0 : tensor<8x2x!HLFHE.eint<6>> +} + diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index d460e080a..ea63465c7 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -1470,3 +1470,78 @@ func @main() -> tensor<2x2x4x!HLFHE.eint<6>> { } } } + +class TiledMatMulParametric + : public ::testing::TestWithParam> {}; + +TEST_P(TiledMatMulParametric, tiled_matmul_eint_int) { + std::vector tiling = GetParam(); + std::ostringstream mlirProgram; + + mlirProgram + << "func @main(%a: tensor<8x4x!HLFHE.eint<6>>, %b: tensor<4x2xi7>) ->\n" + << " tensor<8x2x!HLFHE.eint<6>> {\n" + << " %0 = \"HLFHELinalg.matmul_eint_int\"(%a, %b) { \"tile-sizes\" = " + << "[" << tiling[0] << ", " << tiling[1] << ", " << tiling[2] << "]" + << "} :\n" + << " (tensor<8x4x!HLFHE.eint<6>>, tensor<4x2xi7>) ->\n" + << " tensor<8x2x!HLFHE.eint<6>>\n" + << " return %0 : tensor<8x2x!HLFHE.eint<6>>\n" + << " }"; + + mlir::zamalang::JitCompilerEngine::Lambda lambda = + checkedJit(mlirProgram.str()); + + const size_t rowsA = 8; + const size_t colsA = 4; + const uint8_t A[rowsA][colsA] = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 0, 1, 2}, + {3, 4, 5, 6}, {7, 8, 9, 0}, {1, 2, 3, 4}, + {5, 6, 7, 8}, {9, 0, 1, 2}}; + + const size_t rowsB = 4; + const size_t colsB = 2; + const uint8_t B[rowsB][colsB]{{1, 2}, {3, 4}, {3, 1}, {0, 2}}; + + const size_t rowsC = rowsA; + const size_t colsC = colsB; + const uint8_t expected[rowsC][colsC]{ + {16, 21}, {44, 57}, {12, 23}, {30, 39}, + {58, 55}, {16, 21}, {44, 57}, {12, 23}, + }; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + aArg(llvm::ArrayRef((const uint8_t *)A, rowsA * colsA), + {rowsA, colsA}); + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + bArg(llvm::ArrayRef((const uint8_t *)B, rowsB * colsB), + {rowsB, colsB}); + + llvm::Expected> res = + lambda.operator()>({&aArg, &bArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)(rowsC * colsC)); + + for (size_t i = 0; i < rowsC; i++) { + for (size_t j = 0; j < colsC; j++) { + EXPECT_EQ((*res)[i * colsC + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +INSTANTIATE_TEST_SUITE_P(TiledMatMul, TiledMatMulParametric, + ::testing::Values( + // Element-sized tiles + std::vector{1, 1, 1}, + + // Mixed tiles + std::vector{2, 2, 2}, + std::vector{4, 4, 2}, + std::vector{2, 4, 2}, + + // Single, big tile + std::vector{8, 4, 2}));