diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h index 4be869c13..80702d67c 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h @@ -14,6 +14,7 @@ namespace OpTrait { namespace impl { LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op); LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op); +LogicalResult verifyTensorBinaryIntEint(mlir::Operation *op); LogicalResult verifyTensorBinaryEint(mlir::Operation *op); } // namespace impl @@ -48,6 +49,19 @@ public: } }; +/// TensorBinaryEintInt verifies that the operation matches the following +/// signature +/// `(tensor<...xi$p'>, tensor<...x!HLFHE.eint<$p>>) -> +/// tensor<...x!HLFHE.eint<$p>>` where `$p <= $p+1`. +template +class TensorBinaryIntEint + : public mlir::OpTrait::TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorBinaryIntEint(op); + } +}; + /// TensorBinary verify the operation match the following signature /// `(tensor<...x!HLFHE.eint<$p>>, tensor<...x!HLFHE.eint<$p>>) -> /// tensor<...x!HLFHE.eint<$p>>` diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td index 12f86ae7d..763a92882 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td @@ -13,6 +13,7 @@ class HLFHELinalg_Op traits = []> : // TensorBroadcastingRules verify that the operands and result verify the broadcasting rules def TensorBroadcastingRules : NativeOpTrait<"TensorBroadcastingRules">; def TensorBinaryEintInt : NativeOpTrait<"TensorBinaryEintInt">; +def TensorBinaryIntEint : NativeOpTrait<"TensorBinaryIntEint">; def TensorBinaryEint : NativeOpTrait<"TensorBinaryEint">; @@ -121,4 +122,57 @@ def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinar ]; } +def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, TensorBinaryIntEint]> { + let summary = "Returns a tensor that contains the substraction of a tensor of clear integers and a tensor of encrypted integers."; + + let description = [{ + Performs a substraction following the broadcasting rules between a tensor of clear integers and a tensor of encrypted integers. + The width of the clear integers should be less or equals than the witdh of encrypted integers. + + Examples: + ```mlir + // Returns the term to term substraction of `%a0` with `%a1` + "HLFHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4xi5>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> + + // Returns the term to term substraction of `%a0` with `%a1`, where dimensions equal to one are stretched. + "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi5>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> + + // Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers. + // + // [1,2,3] [1] [0,2,3] + // [4,5,6] + [2] = [2,3,4] + // [7,8,9] [3] [4,5,6] + // + // The dimension #1 of operand #2 is stretched as it is equals to 1. + "HLFHELinalg.sub_int_eint(%a0, %a1)" : (tensor<3x1xi5>, tensor<3x4x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> + + // Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers. + // + // [1,2,3] [0,0,0] + // [4,5,6] + [1,2,3] = [3,3,3] + // [7,8,9] [6,6,6] + // + // The dimension #2 of operand #2 is stretched as it is equals to 1. + "HLFHELinalg.sub_int_eint(%a0, %a1)" : (tensor<1x3xi5>, tensor<3x4x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> + + // Same behavior than the previous one, but as the dimension #2 is missing of operand #2. + "HLFHELinalg.sub_int_eint(%a0, %a1)" : (tensor<3xi5>, tensor<3x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> + + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$rhs + ); + + let results = (outs Type.predicate, HasStaticShapePred]>>); + + let builders = [ + OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{ + build($_builder, $_state, lhs.getType(), rhs, lhs); + }]> + ]; +} + #endif diff --git a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp index ec76ae450..1bcce2b8b 100644 --- a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp +++ b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp @@ -118,7 +118,40 @@ LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op) { "tensor of operand #1"; return mlir::failure(); } - // llvm::errs() << width << ""; + if (el1Ty.getWidth() > el0Ty.getWidth() + 1) { + op->emitOpError() + << "should have the width of integer values less or equals " + "than the width of encrypted values + 1"; + return mlir::failure(); + } + return mlir::success(); +} + +LogicalResult verifyTensorBinaryIntEint(mlir::Operation *op) { + if (op->getNumOperands() != 2) { + op->emitOpError() << "should have exactly 2 operands"; + return mlir::failure(); + } + auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); + auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); + if (op0Ty == nullptr || op1Ty == nullptr) { + op->emitOpError() << "should have both operands as tensor"; + return mlir::failure(); + } + auto el0Ty = op0Ty.getElementType().dyn_cast_or_null(); + if (el0Ty == nullptr) { + op->emitOpError() << "should have an integer as the element type of the " + "tensor of operand #0"; + return mlir::failure(); + } + auto el1Ty = + op1Ty.getElementType() + .dyn_cast_or_null(); + if (el1Ty == nullptr) { + op->emitOpError() << "should have a !HLFHE.eint as the element type of the " + "tensor of operand #1"; + return mlir::failure(); + } if (el1Ty.getWidth() > el0Ty.getWidth() + 1) { op->emitOpError() << "should have the width of integer values less or equals " diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.mlir index 7f0d8b1e1..e793816d1 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.mlir @@ -106,4 +106,59 @@ func @add_eint_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x4x1x func @add_eint_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> { %1 ="HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> return %1: tensor<3x4x!HLFHE.eint<2>> +} + + +///////////////////////////////////////////////// +// HLFHELinalg.sub_eint_int +///////////////////////////////////////////////// + +// 1D tensor +// CHECK: func @sub_int_eint_1D(%[[a0:.*]]: tensor<4xi3>, %[[a1:.*]]: tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.sub_int_eint"(%[[a0]], %[[a1]]) : (tensor<4xi3>, tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @sub_int_eint_1D(%a0: tensor<4xi3>, %a1: tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi3>, tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>> + return %1: tensor<4x!HLFHE.eint<2>> +} + +// 2D tensor +// CHECK: func @sub_int_eint_2D(%[[a0:.*]]: tensor<2x4xi3>, %[[a1:.*]]: tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.sub_int_eint"(%[[a0]], %[[a1]]) : (tensor<2x4xi3>, tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<2x4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @sub_int_eint_2D(%a0: tensor<2x4xi3>, %a1: tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<2x4xi3>, tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>> + return %1: tensor<2x4x!HLFHE.eint<2>> +} + +// 10D tensor +// CHECK: func @sub_int_eint_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10xi3>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.sub_int_eint"(%[[a0]], %[[a1]]) : (tensor<1x2x3x4x5x6x7x8x9x10xi3>, tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @sub_int_eint_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10xi3>, %a1: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<1x2x3x4x5x6x7x8x9x10xi3>, tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> + return %1: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +} + +// Broadcasting with tensor with dimensions equals to one +// CHECK: func @sub_int_eint_broadcast_1(%[[a0:.*]]: tensor<3x4x1xi3>, %[[a1:.*]]: tensor<1x4x5x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.sub_int_eint"(%[[a0]], %[[a1]]) : (tensor<3x4x1xi3>, tensor<1x4x5x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @sub_int_eint_broadcast_1(%a0: tensor<3x4x1xi3>, %a1: tensor<1x4x5x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x4x1xi3>, tensor<1x4x5x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>> + return %1: tensor<3x4x5x!HLFHE.eint<2>> +} + +// Broadcasting with a tensor less dimensions of another +// CHECK: func @sub_int_eint_broadcast_2(%[[a0:.*]]: tensor<3x4xi3>, %[[a1:.*]]: tensor<4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.sub_int_eint"(%[[a0]], %[[a1]]) : (tensor<3x4xi3>, tensor<4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @sub_int_eint_broadcast_2(%a0: tensor<3x4xi3>, %a1: tensor<4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> { + %1 ="HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x4xi3>, tensor<4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> + return %1: tensor<3x4x!HLFHE.eint<2>> } \ No newline at end of file