diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index 2c071919c..b8f4335d7 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -680,6 +680,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); + } else if (llvm::isa( + op)) { + norm2SqEquiv = llvm::APInt{1, 1, false}; } // Tensor Operators // ExtractOp diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir index 926041558..e4a0fdda9 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir @@ -92,3 +92,23 @@ func @chain_add_eint_int(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2 %3 = "HLFHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> return %3 : tensor<8x!HLFHE.eint<2>> } + +// ----- + +func @apply_lookup_table(%t: tensor<3x3x!HLFHE.eint<2>>) -> tensor<3x3x!HLFHE.eint<3>> { + %lut = std.constant dense<[1,3,5,7]> : tensor<4xi64> + // CHECK: %[[RES:.*]] = "HLFHELinalg.apply_lookup_table"(%[[T:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>> + %res = "HLFHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>> + return %res : tensor<3x3x!HLFHE.eint<3>> +} + +// ----- + +func @apply_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<3>> { + %lut = std.constant dense<[1,3,5,7]> : tensor<4xi64> + // CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + // CHECK-NEXT: %[[RES:.*]] = "HLFHELinalg.apply_lookup_table"(%[[V0:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<8x!HLFHE.eint<3>> + %res = "HLFHELinalg.apply_lookup_table"(%0, %lut) : (tensor<8x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<8x!HLFHE.eint<3>> + return %res : tensor<8x!HLFHE.eint<3>> +} \ No newline at end of file diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index be82813f2..89d644969 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -885,8 +885,7 @@ TEST(End2EndJit_HLFHELinalg, apply_lookup_table) { %res = "HLFHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>> return %res : tensor<3x3x!HLFHE.eint<3>> } -)XXX", - "main", true); +)XXX"); const uint8_t t[3][3]{ {0, 1, 2}, {3, 0, 1},