diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td index 763a92882..d40332f61 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td @@ -175,4 +175,52 @@ def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tens ]; } +def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> { + let summary = "Returns a tensor that contains the multiplication of a tensor of encrypted integers and a tensor of clear integers."; + + let description = [{ + Performs a multiplication following the broadcasting rules between a tensor of encrypted integers and a tensor of clear integers. + The width of the clear integers should be less or equals than the witdh of encrypted integers. + + Examples: + ```mlir + // Returns the term to term multiplication of `%a0` with `%a1` + "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>> + + // Returns the term to term multiplication of `%a0` with `%a1`, where dimensions equal to one are stretched. + "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> + + // Returns the multiplication of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of integers. + // + // [1,2,3] [1] [1,2,3] + // [4,5,6] * [2] = [8,10,18] + // [7,8,9] [3] [21,24,27] + // + // The dimension #1 of operand #2 is stretched as it is equals to 1. + "HLFHELinalg.mul_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>> + + // Returns the multiplication of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of integers. + // + // [1,2,3] [2,4,6] + // [4,5,6] * [1,2,3] = [5,7,9] + // [7,8,9] [8,10,12] + // + // The dimension #2 of operand #2 is stretched as it is equals to 1. + "HLFHELinalg.mul_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>> + + // Same behavior than the previous one, but as the dimension #2 is missing of operand #2. + "HLFHELinalg.mul_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> + + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$rhs + ); + + let results = (outs Type.predicate, HasStaticShapePred]>>); + +} + #endif diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir index 2649dfc37..761dd8ca4 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir @@ -76,4 +76,46 @@ func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4x!HLFHE.eint<3>>) // expected-error @+1 {{'HLFHELinalg.add_eint' op should have the width of encrypted equals, got 3 expect 2}} %1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<2x3x4x!HLFHE.eint<3>>) -> tensor<2x3x4x!HLFHE.eint<2>> return %1 : tensor<2x3x4x!HLFHE.eint<2>> -} \ No newline at end of file +} + +// ----- + +///////////////////////////////////////////////// +// HLFHELinalg.mul_eint_int +///////////////////////////////////////////////// + +// Incompatible dimension of operands +func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4xi3>) -> tensor<2x2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.mul_eint_int' op has the dimension #2 of the operand #1 incompatible with other operands, got 2 expect 1 or 3}} + %1 = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4xi3>) -> tensor<2x2x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x2x3x4x!HLFHE.eint<2>> +} + +// ----- + +// Incompatible dimension of result +func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4xi3>) -> tensor<2x10x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.mul_eint_int' op has the dimension #3 of the result incompatible with operands dimension, got 10 expect 2}} + %1 = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4xi3>) -> tensor<2x10x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x10x3x4x!HLFHE.eint<2>> +} + +// ----- + +// Incompatible number of dimension between operands and result +func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4xi3>) -> tensor<2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.mul_eint_int' op should have the number of dimensions of the result equal to the highest number of dimensions of operands, got 3 expect 4}} + %1 = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4xi3>) -> tensor<2x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x3x4x!HLFHE.eint<2>> +} + +// ----- + +// Incompatible width between clear and encrypted witdh +func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4xi4>) -> tensor<2x3x4x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.mul_eint_int' op should have the width of integer values less or equals than the width of encrypted values + 1}} + %1 = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<2x3x4xi4>) -> tensor<2x3x4x!HLFHE.eint<2>> + return %1 : tensor<2x3x4x!HLFHE.eint<2>> +} + +// ----- \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.mlir index e793816d1..81ebd6dbb 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.mlir @@ -161,4 +161,59 @@ func @sub_int_eint_broadcast_1(%a0: tensor<3x4x1xi3>, %a1: tensor<1x4x5x!HLFHE.e func @sub_int_eint_broadcast_2(%a0: tensor<3x4xi3>, %a1: tensor<4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> { %1 ="HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x4xi3>, tensor<4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> return %1: tensor<3x4x!HLFHE.eint<2>> +} + + +///////////////////////////////////////////////// +// HLFHELinalg.mul_eint_int +///////////////////////////////////////////////// + +// 1D tensor +// CHECK: func @mul_eint_int_1D(%[[a0:.*]]: tensor<4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"(%[[a0]], %[[a1]]) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @mul_eint_int_1D(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> + return %1: tensor<4x!HLFHE.eint<2>> +} + +// 2D tensor +// CHECK: func @mul_eint_int_2D(%[[a0:.*]]: tensor<2x4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"(%[[a0]], %[[a1]]) : (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<2x4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @mul_eint_int_2D(%a0: tensor<2x4x!HLFHE.eint<2>>, %a1: tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!HLFHE.eint<2>> + return %1: tensor<2x4x!HLFHE.eint<2>> +} + +// 10D tensor +// CHECK: func @mul_eint_int_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"(%[[a0]], %[[a1]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @mul_eint_int_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, %a1: tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> + return %1: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> +} + +// Broadcasting with tensor with dimensions equals to one +// CHECK: func @mul_eint_int_broadcast_1(%[[a0:.*]]: tensor<1x4x5x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"(%[[a0]], %[[a1]]) : (tensor<1x4x5x!HLFHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @mul_eint_int_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<1x4x5x!HLFHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!HLFHE.eint<2>> + return %1: tensor<3x4x5x!HLFHE.eint<2>> +} + +// Broadcasting with a tensor less dimensions of another +// CHECK: func @mul_eint_int_broadcast_2(%[[a0:.*]]: tensor<4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"(%[[a0]], %[[a1]]) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x!HLFHE.eint<2>> +// CHECK-NEXT: } +func @mul_eint_int_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> { + %1 ="HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> + return %1: tensor<3x4x!HLFHE.eint<2>> } \ No newline at end of file