mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
test(compiler): Add unit tests for HLFHELinalg.zero
Add unit tests for `HLFHELinalg.zero`, including a test for the integration into the MANP pass.
This commit is contained in:
@@ -323,3 +323,13 @@ func @matmul_int_eint_cst_p_2_n_1(%arg0: tensor<2x3x!HLFHE.eint<2>>) -> tensor<2
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x2xi3>, tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>>
|
||||
return %1 : tensor<2x3x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @zero() -> tensor<8x!HLFHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "HLFHELinalg.zero"() {MANP = 1 : ui{{[0-9]+}}} : () -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.zero"() : () -> tensor<8x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
@@ -229,3 +229,23 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>)
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<4x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<4x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.zero
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
func @zero_1D_scalar() -> tensor<4x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.zero' op}}
|
||||
%0 = "HLFHELinalg.zero"() : () -> !HLFHE.eint<2>
|
||||
return %0 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @zero_plaintext() -> tensor<4x9xi32> {
|
||||
// expected-error @+1 {{'HLFHELinalg.zero' op}}
|
||||
%0 = "HLFHELinalg.zero"() : () -> tensor<4x9xi32>
|
||||
return %0 : tensor<4x9xi32>
|
||||
}
|
||||
|
||||
@@ -328,4 +328,26 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>)
|
||||
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.zero
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// CHECK: func @zero_1D() -> tensor<4x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "HLFHELinalg.zero"() : () -> tensor<4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @zero_1D() -> tensor<4x!HLFHE.eint<2>> {
|
||||
%0 = "HLFHELinalg.zero"() : () -> tensor<4x!HLFHE.eint<2>>
|
||||
return %0 : tensor<4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @zero_2D() -> tensor<4x9x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "HLFHELinalg.zero"() : () -> tensor<4x9x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<4x9x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @zero_2D() -> tensor<4x9x!HLFHE.eint<2>> {
|
||||
%0 = "HLFHELinalg.zero"() : () -> tensor<4x9x!HLFHE.eint<2>>
|
||||
return %0 : tensor<4x9x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
@@ -1432,3 +1432,41 @@ func @main(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// HLFHELinalg.zero ///////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_Linalg, zero) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main() -> tensor<2x2x4x!HLFHE.eint<6>> {
|
||||
%0 = "HLFHELinalg.zero"() : () -> tensor<2x2x4x!HLFHE.eint<6>>
|
||||
return %0 : tensor<2x2x4x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
|
||||
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>();
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<>>
|
||||
&resp = (*res)
|
||||
->cast<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(resp.getDimensions().size(), (size_t)3);
|
||||
ASSERT_EQ(resp.getDimensions().at(0), 2);
|
||||
ASSERT_EQ(resp.getDimensions().at(1), 2);
|
||||
ASSERT_EQ(resp.getDimensions().at(2), 4);
|
||||
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 2 * 4);
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 2; j++) {
|
||||
for (size_t k = 0; k < 4; k++) {
|
||||
EXPECT_EQ(resp.getValue()[i * 8 + j * 4 + k], 0)
|
||||
<< ", at pos(" << i << "," << j << "," << k << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user