test(compiler): Add unit tests for HLFHELinalg.zero

Add unit tests for `HLFHELinalg.zero`, including a test for the
integration into the MANP pass.
This commit is contained in:
Andi Drebes
2021-11-23 15:33:15 +01:00
parent 4883eebfa3
commit bf9a831c3d
4 changed files with 91 additions and 1 deletions

View File

@@ -323,3 +323,13 @@ func @matmul_int_eint_cst_p_2_n_1(%arg0: tensor<2x3x!HLFHE.eint<2>>) -> tensor<2
%1 = "HLFHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x2xi3>, tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>>
return %1 : tensor<2x3x!HLFHE.eint<2>>
}
// -----
func @zero() -> tensor<8x!HLFHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "HLFHELinalg.zero"() {MANP = 1 : ui{{[0-9]+}}} : () -> tensor<8x!HLFHE.eint<2>>
%0 = "HLFHELinalg.zero"() : () -> tensor<8x!HLFHE.eint<2>>
return %0 : tensor<8x!HLFHE.eint<2>>
}

View File

@@ -229,3 +229,23 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>)
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<4x2x!HLFHE.eint<2>>
return %1 : tensor<4x2x!HLFHE.eint<2>>
}
// -----
/////////////////////////////////////////////////
// HLFHELinalg.zero
/////////////////////////////////////////////////
func @zero_1D_scalar() -> tensor<4x!HLFHE.eint<2>> {
// expected-error @+1 {{'HLFHELinalg.zero' op}}
%0 = "HLFHELinalg.zero"() : () -> !HLFHE.eint<2>
return %0 : !HLFHE.eint<2>
}
// -----
func @zero_plaintext() -> tensor<4x9xi32> {
// expected-error @+1 {{'HLFHELinalg.zero' op}}
%0 = "HLFHELinalg.zero"() : () -> tensor<4x9xi32>
return %0 : tensor<4x9xi32>
}

View File

@@ -328,4 +328,26 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>)
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>>
return %1 : tensor<3x2x!HLFHE.eint<2>>
}
}
/////////////////////////////////////////////////
// HLFHELinalg.zero
/////////////////////////////////////////////////
// CHECK: func @zero_1D() -> tensor<4x!HLFHE.eint<2>> {
// CHECK-NEXT: %[[v0:.*]] = "HLFHELinalg.zero"() : () -> tensor<4x!HLFHE.eint<2>>
// CHECK-NEXT: return %[[v0]] : tensor<4x!HLFHE.eint<2>>
// CHECK-NEXT: }
func @zero_1D() -> tensor<4x!HLFHE.eint<2>> {
%0 = "HLFHELinalg.zero"() : () -> tensor<4x!HLFHE.eint<2>>
return %0 : tensor<4x!HLFHE.eint<2>>
}
// CHECK: func @zero_2D() -> tensor<4x9x!HLFHE.eint<2>> {
// CHECK-NEXT: %[[v0:.*]] = "HLFHELinalg.zero"() : () -> tensor<4x9x!HLFHE.eint<2>>
// CHECK-NEXT: return %[[v0]] : tensor<4x9x!HLFHE.eint<2>>
// CHECK-NEXT: }
func @zero_2D() -> tensor<4x9x!HLFHE.eint<2>> {
%0 = "HLFHELinalg.zero"() : () -> tensor<4x9x!HLFHE.eint<2>>
return %0 : tensor<4x9x!HLFHE.eint<2>>
}

View File

@@ -1432,3 +1432,41 @@ func @main(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> {
}
}
}
///////////////////////////////////////////////////////////////////////////////
// HLFHELinalg.zero ///////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_Linalg, zero) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main() -> tensor<2x2x4x!HLFHE.eint<6>> {
%0 = "HLFHELinalg.zero"() : () -> tensor<2x2x4x!HLFHE.eint<6>>
return %0 : tensor<2x2x4x!HLFHE.eint<6>>
}
)XXX");
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>();
ASSERT_EXPECTED_SUCCESS(res);
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<>>
&resp = (*res)
->cast<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<>>>();
ASSERT_EQ(resp.getDimensions().size(), (size_t)3);
ASSERT_EQ(resp.getDimensions().at(0), 2);
ASSERT_EQ(resp.getDimensions().at(1), 2);
ASSERT_EQ(resp.getDimensions().at(2), 4);
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 2 * 4);
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 2; j++) {
for (size_t k = 0; k < 4; k++) {
EXPECT_EQ(resp.getValue()[i * 8 + j * 4 + k], 0)
<< ", at pos(" << i << "," << j << "," << k << ")";
}
}
}
}