From 3b02a16f7babda4b9697c8446f4d4b56458bab62 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Thu, 21 Oct 2021 16:40:47 +0200 Subject: [PATCH] feat(compiler): HLFHELinalg.add_eint definition --- .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.h | 13 +++++ .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.td | 54 +++++++++++++++++++ .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp | 37 +++++++++++++ .../Dialect/HLFHELinalg/ops.invalid.mlir | 40 ++++++++++++++ compiler/tests/Dialect/HLFHELinalg/ops.mlir | 54 +++++++++++++++++++ 5 files changed, 198 insertions(+) diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h index 692a7c2de..4be869c13 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h @@ -14,6 +14,7 @@ namespace OpTrait { namespace impl { LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op); LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op); +LogicalResult verifyTensorBinaryEint(mlir::Operation *op); } // namespace impl /// TensorBroadcastingRules is a trait for operators that should respect the @@ -47,6 +48,18 @@ public: } }; +/// TensorBinary verify the operation match the following signature +/// `(tensor<...x!HLFHE.eint<$p>>, tensor<...x!HLFHE.eint<$p>>) -> +/// tensor<...x!HLFHE.eint<$p>>` +template +class TensorBinaryEint + : public mlir::OpTrait::TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorBinaryEint(op); + } +}; + } // namespace OpTrait } // namespace mlir diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td index e4fd26256..12f86ae7d 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td @@ -13,6 +13,8 @@ class HLFHELinalg_Op traits = []> : // TensorBroadcastingRules verify that the operands and result verify the broadcasting rules def TensorBroadcastingRules : NativeOpTrait<"TensorBroadcastingRules">; def TensorBinaryEintInt : NativeOpTrait<"TensorBinaryEintInt">; +def TensorBinaryEint : NativeOpTrait<"TensorBinaryEint">; + def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> { let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers."; @@ -67,4 +69,56 @@ def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tens ]; } +def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinaryEint]> { + let summary = "Returns a tensor that contains the addition of two tensor of encrypted integers."; + + let description = [{ + Performs an addition follwing the broadcasting rules between two tensors of encrypted integers. + The width of the encrypted integers should be equals. + + Examples: + ```mlir + // Returns the term to term addition of `%a0` with `%a1` + "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> + + // Returns the term to term addition of `%a0` with `%a1`, where dimensions equal to one are stretched. + "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> + + // Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of encrypted integers. + // + // [1,2,3] [1] [2,3,4] + // [4,5,6] + [2] = [6,7,8] + // [7,8,9] [3] [10,11,12] + // + // The dimension #1 of operand #2 is stretched as it is equals to 1. + "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> + + // Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of encrypted integers. + // + // [1,2,3] [2,4,6] + // [4,5,6] + [1,2,3] = [5,7,9] + // [7,8,9] [8,10,12] + // + // The dimension #2 of operand #2 is stretched as it is equals to 1. + "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> + + // Same behavior than the previous one, but as the dimension #2 of operand #2 is missing. + "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$rhs + ); + + let results = (outs Type.predicate, HasStaticShapePred]>>); + + let builders = [ + OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{ + build($_builder, $_state, rhs.getType(), rhs, lhs); + }]> + ]; +} + #endif diff --git a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp index ba38f1a5c..ec76ae450 100644 --- a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp +++ b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp @@ -127,6 +127,43 @@ LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op) { } return mlir::success(); } + +LogicalResult verifyTensorBinaryEint(mlir::Operation *op) { + if (op->getNumOperands() != 2) { + op->emitOpError() << "should have exactly 2 operands"; + return mlir::failure(); + } + auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); + auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); + if (op0Ty == nullptr || op1Ty == nullptr) { + op->emitOpError() << "should have both operands as tensor"; + return mlir::failure(); + } + auto el0Ty = + op0Ty.getElementType() + .dyn_cast_or_null(); + if (el0Ty == nullptr) { + op->emitOpError() << "should have a !HLFHE.eint as the element type of the " + "tensor of operand #0"; + return mlir::failure(); + } + auto el1Ty = + op1Ty.getElementType() + .dyn_cast_or_null(); + if (el1Ty == nullptr) { + op->emitOpError() << "should have a !HLFHE.eint as the element type of the " + "tensor of operand #1"; + return mlir::failure(); + } + if (el1Ty.getWidth() != el0Ty.getWidth()) { + op->emitOpError() << "should have the width of encrypted equals" + ", got " + << el1Ty.getWidth() << " expect " << el0Ty.getWidth(); + return mlir::failure(); + } + return mlir::success(); +} + } // namespace impl } // namespace OpTrait diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir index 17ce71261..2649dfc37 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir @@ -36,4 +36,44 @@ func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4xi4>) -> tensor<2 // expected-error @+1 {{'HLFHELinalg.add_eint_int' op should have the width of integer values less or equals than the width of encrypted values + 1}} %1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<2x3x4xi4>) -> tensor<2x3x4x!HLFHE.eint<2>> return %1 : tensor<2x3x4x!HLFHE.eint<2>> +} + +// ----- + +///////////////////////////////////////////////// +// HLFHELinalg.add_eint +///////////////////////////////////////////////// + +// Incompatible dimension of operands +func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.add_eint' op has the dimension #2 of the operand #1 incompatible with other operands, got 2 expect 1 or 3}} + %1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x2x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x2x3x4x!HLFHE.eint<2>> +} + +// ----- + +// Incompatible dimension of result +func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x10x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.add_eint' op has the dimension #3 of the result incompatible with operands dimension, got 10 expect 2}} + %1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x10x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x10x3x4x!HLFHE.eint<2>> +} + +// ----- + +// Incompatible number of dimension between operands and result +func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.add_eint' op should have the number of dimensions of the result equal to the highest number of dimensions of operands, got 3 expect 4}} + %1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x3x4x!HLFHE.eint<2>> +} + +// ----- + +// Incompatible width between clear and encrypted witdh +func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4x!HLFHE.eint<3>>) -> tensor<2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.add_eint' op should have the width of encrypted equals, got 3 expect 2}} + %1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<2x3x4x!HLFHE.eint<3>>) -> tensor<2x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x3x4x!HLFHE.eint<2>> } \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.mlir index ec1a765e9..7f0d8b1e1 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.mlir @@ -52,4 +52,58 @@ func @add_eint_int_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x func @add_eint_int_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> { %1 ="HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> return %1: tensor<3x4x!HLFHE.eint<2>> +} + +///////////////////////////////////////////////// +// HLFHELinalg.add_eint +///////////////////////////////////////////////// + +// 1D tensor +// CHECK: func @add_eint_1D(%[[a0:.*]]: tensor<4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint"(%[[a0]], %[[a1]]) : (tensor<4x!HLFHE.eint<2>>, tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_1D(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>> + return %1: tensor<4x!HLFHE.eint<2>> +} + +// 2D tensor +// CHECK: func @add_eint_2D(%[[a0:.*]]: tensor<2x4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint"(%[[a0]], %[[a1]]) : (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<2x4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_2D(%a0: tensor<2x4x!HLFHE.eint<2>>, %a1: tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>> + return %1: tensor<2x4x!HLFHE.eint<2>> +} + +// 10D tensor +// CHECK: func @add_eint_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint"(%[[a0]], %[[a1]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, %a1: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> + return %1: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +} + +// Broadcasting with tensor with dimensions equals to one +// CHECK: func @add_eint_broadcast_1(%[[a0:.*]]: tensor<1x4x5x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<3x4x1x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint"(%[[a0]], %[[a1]]) : (tensor<1x4x5x!HLFHE.eint<2>>, tensor<3x4x1x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x4x1x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<1x4x5x!HLFHE.eint<2>>, tensor<3x4x1x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>> + return %1: tensor<3x4x5x!HLFHE.eint<2>> +} + +// Broadcasting with a tensor less dimensions of another +// CHECK: func @add_eint_broadcast_2(%[[a0:.*]]: tensor<4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<3x4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint"(%[[a0]], %[[a1]]) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> { + %1 ="HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> + return %1: tensor<3x4x!HLFHE.eint<2>> } \ No newline at end of file