diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir index 557803cec..ed15c5c9c 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir @@ -323,3 +323,13 @@ func @matmul_int_eint_cst_p_2_n_1(%arg0: tensor<2x3x!HLFHE.eint<2>>) -> tensor<2 %1 = "HLFHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x2xi3>, tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>> return %1 : tensor<2x3x!HLFHE.eint<2>> } + +// ----- + +func @zero() -> tensor<8x!HLFHE.eint<2>> +{ + // CHECK: %[[ret:.*]] = "HLFHELinalg.zero"() {MANP = 1 : ui{{[0-9]+}}} : () -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.zero"() : () -> tensor<8x!HLFHE.eint<2>> + + return %0 : tensor<8x!HLFHE.eint<2>> +} diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir index 6131e9107..e828353dd 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir @@ -229,3 +229,23 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>) %1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<4x2x!HLFHE.eint<2>> return %1 : tensor<4x2x!HLFHE.eint<2>> } + +// ----- + +///////////////////////////////////////////////// +// HLFHELinalg.zero +///////////////////////////////////////////////// + +func @zero_1D_scalar() -> tensor<4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.zero' op}} + %0 = "HLFHELinalg.zero"() : () -> !HLFHE.eint<2> + return %0 : !HLFHE.eint<2> +} + +// ----- + +func @zero_plaintext() -> tensor<4x9xi32> { + // expected-error @+1 {{'HLFHELinalg.zero' op}} + %0 = "HLFHELinalg.zero"() : () -> tensor<4x9xi32> + return %0 : tensor<4x9xi32> +} diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.mlir index b854b13f1..d9db4af31 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.mlir @@ -328,4 +328,26 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>) %1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> return %1 : tensor<3x2x!HLFHE.eint<2>> -} \ No newline at end of file +} + +///////////////////////////////////////////////// +// HLFHELinalg.zero +///////////////////////////////////////////////// + +// CHECK: func @zero_1D() -> tensor<4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "HLFHELinalg.zero"() : () -> tensor<4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[v0]] : tensor<4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @zero_1D() -> tensor<4x!HLFHE.eint<2>> { + %0 = "HLFHELinalg.zero"() : () -> tensor<4x!HLFHE.eint<2>> + return %0 : tensor<4x!HLFHE.eint<2>> +} + +// CHECK: func @zero_2D() -> tensor<4x9x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "HLFHELinalg.zero"() : () -> tensor<4x9x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[v0]] : tensor<4x9x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @zero_2D() -> tensor<4x9x!HLFHE.eint<2>> { + %0 = "HLFHELinalg.zero"() : () -> tensor<4x9x!HLFHE.eint<2>> + return %0 : tensor<4x9x!HLFHE.eint<2>> +} diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index 8d2b82457..d460e080a 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -1432,3 +1432,41 @@ func @main(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> { } } } + +/////////////////////////////////////////////////////////////////////////////// +// HLFHELinalg.zero /////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_Linalg, zero) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main() -> tensor<2x2x4x!HLFHE.eint<6>> { + %0 = "HLFHELinalg.zero"() : () -> tensor<2x2x4x!HLFHE.eint<6>> + return %0 : tensor<2x2x4x!HLFHE.eint<6>> +} +)XXX"); + + llvm::Expected> res = + lambda.operator()>(); + + ASSERT_EXPECTED_SUCCESS(res); + + mlir::zamalang::TensorLambdaArgument> + &resp = (*res) + ->cast>>(); + + ASSERT_EQ(resp.getDimensions().size(), (size_t)3); + ASSERT_EQ(resp.getDimensions().at(0), 2); + ASSERT_EQ(resp.getDimensions().at(1), 2); + ASSERT_EQ(resp.getDimensions().at(2), 4); + ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 2 * 4); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 2; j++) { + for (size_t k = 0; k < 4; k++) { + EXPECT_EQ(resp.getValue()[i * 8 + j * 4 + k], 0) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +}