From 5bb58453aa6413fd18a166ccd3fe4ae75ad5441f Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Wed, 27 Oct 2021 14:10:04 +0200 Subject: [PATCH] feat(compiler): MANP Analysis of HLFHELinalg.apply_lookup_table (close #175) --- compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 3 +++ .../Dialect/HLFHE/Analysis/MANP_linalg.mlir | 20 +++++++++++++++++++ .../unittest/end_to_end_jit_hlfhelinalg.cc | 3 +-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index 2c071919c..b8f4335d7 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -680,6 +680,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); + } else if (llvm::isa( + op)) { + norm2SqEquiv = llvm::APInt{1, 1, false}; } // Tensor Operators // ExtractOp diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir index 926041558..e4a0fdda9 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir @@ -92,3 +92,23 @@ func @chain_add_eint_int(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2 %3 = "HLFHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> return %3 : tensor<8x!HLFHE.eint<2>> } + +// ----- + +func @apply_lookup_table(%t: tensor<3x3x!HLFHE.eint<2>>) -> tensor<3x3x!HLFHE.eint<3>> { + %lut = std.constant dense<[1,3,5,7]> : tensor<4xi64> + // CHECK: %[[RES:.*]] = "HLFHELinalg.apply_lookup_table"(%[[T:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>> + %res = "HLFHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>> + return %res : tensor<3x3x!HLFHE.eint<3>> +} + +// ----- + +func @apply_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<3>> { + %lut = std.constant dense<[1,3,5,7]> : tensor<4xi64> + // CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + // CHECK-NEXT: %[[RES:.*]] = "HLFHELinalg.apply_lookup_table"(%[[V0:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<8x!HLFHE.eint<3>> + %res = "HLFHELinalg.apply_lookup_table"(%0, %lut) : (tensor<8x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<8x!HLFHE.eint<3>> + return %res : tensor<8x!HLFHE.eint<3>> +} \ No newline at end of file diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index be82813f2..89d644969 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -885,8 +885,7 @@ TEST(End2EndJit_HLFHELinalg, apply_lookup_table) { %res = "HLFHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>> return %res : tensor<3x3x!HLFHE.eint<3>> } -)XXX", - "main", true); +)XXX"); const uint8_t t[3][3]{ {0, 1, 2}, {3, 0, 1},