From dea1be9d524a12d2552ffa587c472031c804c446 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 22 Oct 2021 17:22:31 +0200 Subject: [PATCH] feat(compiler): HLFHELinalg.apply_lookup_table definition --- .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.td | 38 +++++++++++++++++++ .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp | 33 ++++++++++++++++ .../Dialect/HLFHELinalg/ops.invalid.mlir | 28 +++++++++++++- compiler/tests/Dialect/HLFHELinalg/ops.mlir | 13 +++++++ 4 files changed, 111 insertions(+), 1 deletion(-) diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td index d40332f61..48ea3b907 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td @@ -220,7 +220,45 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens ); let results = (outs Type.predicate, HasStaticShapePred]>>); +} +def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> { + let summary = "Returns a tensor that contains the result of the lookup on a table."; + + let description = [{ + Performs for each encrypted indices a lookup on a table of clear integers. + + ```mlir + // The result of this operation, is a tensor that contains the result of the lookup on a table. + // i.e. %res[i, ..., k] = %lut[%t[i, ..., k]] + %res = HLFHELinalg.apply_lookup_table(%t, %lut): tensor>, tensor -> tensor> + ``` + + The `%lut` argument should be a tensor with one dimension, where its dimension is equals to `2^p` where `p` is the width of the encrypted integers. + + Examples: + ```mlir + + // Returns the lookup of 3x3 matrix of encrypted indices of with 2 on a table of size 4=2² of clear integers. + // + // [0,1,2] [2,4,6] + // [3,0,1] lut [2,4,6,8] = [8,2,4] + // [2,3,0] [6,8,0] + // + "HLFHELinalg.apply_lookup_table(%t, %lut)" : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi4>) -> tensor<3x3x!HLFHE.eint<3>> + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$t, + Type.predicate, HasStaticShapePred]>>:$lut + ); + + let results = (outs Type.predicate, HasStaticShapePred]>>); + + let verifier = [{ + return ::mlir::zamalang::HLFHELinalg::verifyApplyLookupTable(*this); + }]; } #endif diff --git a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp index 1bcce2b8b..222cc1c7b 100644 --- a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp +++ b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp @@ -202,5 +202,38 @@ LogicalResult verifyTensorBinaryEint(mlir::Operation *op) { } // namespace OpTrait } // namespace mlir +namespace mlir { +namespace zamalang { +namespace HLFHELinalg { + +mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTableEintOp &op) { + auto tTy = op.t().getType().cast(); + auto tEltTy = + tTy.getElementType().cast(); + auto lutTy = op.lut().getType().cast(); + auto lutEltTy = lutTy.getElementType().cast(); + auto resultTy = op.getResult().getType().cast(); + + // Check the shape of lut argument + auto tEltwidth = tEltTy.getWidth(); + mlir::SmallVector expectedShape{1 << tEltwidth}; + if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isInteger(64)) { + op.emitOpError() + << "should have as operand #2 a tensor<2^pxi64>, where p is the width " + "of the encrypted integer of the operand #1," + << "expect tensor <" << expectedShape[0] << "xi64>"; + return mlir::failure(); + } + if (!resultTy.hasStaticShape(tTy.getShape())) { + op.emitOpError() + << " should have same shapes for operand #1 and the result"; + } + return mlir::success(); +} + +} // namespace HLFHELinalg +} // namespace zamalang +} // namespace mlir + #define GET_OP_CLASSES #include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp.inc" diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir index 761dd8ca4..7eff1be0e 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir @@ -118,4 +118,30 @@ func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4xi4>) -> tensor<2 return %1 : tensor<2x3x4x!HLFHE.eint<2>> } -// ----- \ No newline at end of file +// ----- + +///////////////////////////////////////////////// +// HLFHELinalg.apply_lookup_table +///////////////////////////////////////////////// + +func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi32>) -> tensor<2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}} + %1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi32>) -> (tensor<2x3x4x!HLFHE.eint<2>>) + return %1: tensor<2x3x4x!HLFHE.eint<2>> +} + +// ----- + +func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<12xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}} + %1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<12xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>) + return %1: tensor<2x3x4x!HLFHE.eint<2>> +} + +// ----- + +func @apply_lookup_table(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.apply_lookup_table' op should have same shapes for operand #1 and the result}} + %1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>) + return %1: tensor<2x3x4x!HLFHE.eint<2>> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.mlir index 81ebd6dbb..b7bf7fdc1 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.mlir @@ -216,4 +216,17 @@ func @mul_eint_int_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x func @mul_eint_int_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> { %1 ="HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> return %1: tensor<3x4x!HLFHE.eint<2>> +} + +///////////////////////////////////////////////// +// HLFHELinalg.apply_lookup_table +///////////////////////////////////////////////// + +// CHECK-LABEL: func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> +func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> { + // CHECK-NEXT: %[[V1:.*]] = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> + // CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!HLFHE.eint<2>> + + %1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>) + return %1: tensor<2x3x4x!HLFHE.eint<2>> } \ No newline at end of file