feat(compiler): HLFHELinalg.add_eint definition

This commit is contained in:
Quentin Bourgerie
2021-10-21 16:40:47 +02:00
committed by Andi Drebes
parent 0d4e10169b
commit 3b02a16f7b
5 changed files with 198 additions and 0 deletions

View File

@@ -14,6 +14,7 @@ namespace OpTrait {
namespace impl {
LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op);
LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op);
LogicalResult verifyTensorBinaryEint(mlir::Operation *op);
} // namespace impl
/// TensorBroadcastingRules is a trait for operators that should respect the
@@ -47,6 +48,18 @@ public:
}
};
/// TensorBinary verify the operation match the following signature
/// `(tensor<...x!HLFHE.eint<$p>>, tensor<...x!HLFHE.eint<$p>>) ->
/// tensor<...x!HLFHE.eint<$p>>`
template <typename ConcreteType>
class TensorBinaryEint
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBinaryEint> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyTensorBinaryEint(op);
}
};
} // namespace OpTrait
} // namespace mlir

View File

@@ -13,6 +13,8 @@ class HLFHELinalg_Op<string mnemonic, list<OpTrait> traits = []> :
// TensorBroadcastingRules verify that the operands and result verify the broadcasting rules
def TensorBroadcastingRules : NativeOpTrait<"TensorBroadcastingRules">;
def TensorBinaryEintInt : NativeOpTrait<"TensorBinaryEintInt">;
def TensorBinaryEint : NativeOpTrait<"TensorBinaryEint">;
def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers.";
@@ -67,4 +69,56 @@ def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tens
];
}
def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinaryEint]> {
let summary = "Returns a tensor that contains the addition of two tensor of encrypted integers.";
let description = [{
Performs an addition follwing the broadcasting rules between two tensors of encrypted integers.
The width of the encrypted integers should be equals.
Examples:
```mlir
// Returns the term to term addition of `%a0` with `%a1`
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
// Returns the term to term addition of `%a0` with `%a1`, where dimensions equal to one are stretched.
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>>
// Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of encrypted integers.
//
// [1,2,3] [1] [2,3,4]
// [4,5,6] + [2] = [6,7,8]
// [7,8,9] [3] [10,11,12]
//
// The dimension #1 of operand #2 is stretched as it is equals to 1.
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
// Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of encrypted integers.
//
// [1,2,3] [2,4,6]
// [4,5,6] + [1,2,3] = [5,7,9]
// [7,8,9] [8,10,12]
//
// The dimension #2 of operand #2 is stretched as it is equals to 1.
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
// Same behavior than the previous one, but as the dimension #2 of operand #2 is missing.
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
```
}];
let arguments = (ins
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let builders = [
OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{
build($_builder, $_state, rhs.getType(), rhs, lhs);
}]>
];
}
#endif

View File

@@ -127,6 +127,43 @@ LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op) {
}
return mlir::success();
}
LogicalResult verifyTensorBinaryEint(mlir::Operation *op) {
if (op->getNumOperands() != 2) {
op->emitOpError() << "should have exactly 2 operands";
return mlir::failure();
}
auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null<mlir::TensorType>();
if (op0Ty == nullptr || op1Ty == nullptr) {
op->emitOpError() << "should have both operands as tensor";
return mlir::failure();
}
auto el0Ty =
op0Ty.getElementType()
.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
if (el0Ty == nullptr) {
op->emitOpError() << "should have a !HLFHE.eint as the element type of the "
"tensor of operand #0";
return mlir::failure();
}
auto el1Ty =
op1Ty.getElementType()
.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
if (el1Ty == nullptr) {
op->emitOpError() << "should have a !HLFHE.eint as the element type of the "
"tensor of operand #1";
return mlir::failure();
}
if (el1Ty.getWidth() != el0Ty.getWidth()) {
op->emitOpError() << "should have the width of encrypted equals"
", got "
<< el1Ty.getWidth() << " expect " << el0Ty.getWidth();
return mlir::failure();
}
return mlir::success();
}
} // namespace impl
} // namespace OpTrait

View File

@@ -36,4 +36,44 @@ func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4xi4>) -> tensor<2
// expected-error @+1 {{'HLFHELinalg.add_eint_int' op should have the width of integer values less or equals than the width of encrypted values + 1}}
%1 = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<2x3x4xi4>) -> tensor<2x3x4x!HLFHE.eint<2>>
return %1 : tensor<2x3x4x!HLFHE.eint<2>>
}
// -----
/////////////////////////////////////////////////
// HLFHELinalg.add_eint
/////////////////////////////////////////////////
// Incompatible dimension of operands
func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x2x3x4x!HLFHE.eint<2>> {
// expected-error @+1 {{'HLFHELinalg.add_eint' op has the dimension #2 of the operand #1 incompatible with other operands, got 2 expect 1 or 3}}
%1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x2x3x4x!HLFHE.eint<2>>
return %1 : tensor<2x2x3x4x!HLFHE.eint<2>>
}
// -----
// Incompatible dimension of result
func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x10x3x4x!HLFHE.eint<2>> {
// expected-error @+1 {{'HLFHELinalg.add_eint' op has the dimension #3 of the result incompatible with operands dimension, got 10 expect 2}}
%1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x10x3x4x!HLFHE.eint<2>>
return %1 : tensor<2x10x3x4x!HLFHE.eint<2>>
}
// -----
// Incompatible number of dimension between operands and result
func @main(%a0: tensor<2x2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x3x4x!HLFHE.eint<2>> {
// expected-error @+1 {{'HLFHELinalg.add_eint' op should have the number of dimensions of the result equal to the highest number of dimensions of operands, got 3 expect 4}}
%1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x2x3x4x!HLFHE.eint<2>>, tensor<2x2x2x4x!HLFHE.eint<2>>) -> tensor<2x3x4x!HLFHE.eint<2>>
return %1 : tensor<2x3x4x!HLFHE.eint<2>>
}
// -----
// Incompatible width between clear and encrypted witdh
func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4x!HLFHE.eint<3>>) -> tensor<2x3x4x!HLFHE.eint<2>> {
// expected-error @+1 {{'HLFHELinalg.add_eint' op should have the width of encrypted equals, got 3 expect 2}}
%1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<2x3x4x!HLFHE.eint<3>>) -> tensor<2x3x4x!HLFHE.eint<2>>
return %1 : tensor<2x3x4x!HLFHE.eint<2>>
}

View File

@@ -52,4 +52,58 @@ func @add_eint_int_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x
func @add_eint_int_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> {
%1 ="HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>>
return %1: tensor<3x4x!HLFHE.eint<2>>
}
/////////////////////////////////////////////////
// HLFHELinalg.add_eint
/////////////////////////////////////////////////
// 1D tensor
// CHECK: func @add_eint_1D(%[[a0:.*]]: tensor<4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint"(%[[a0]], %[[a1]]) : (tensor<4x!HLFHE.eint<2>>, tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<4x!HLFHE.eint<2>>
// CHECK-NEXT: }
func @add_eint_1D(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<4x!HLFHE.eint<2>>) -> tensor<4x!HLFHE.eint<2>>
return %1: tensor<4x!HLFHE.eint<2>>
}
// 2D tensor
// CHECK: func @add_eint_2D(%[[a0:.*]]: tensor<2x4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint"(%[[a0]], %[[a1]]) : (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<2x4x!HLFHE.eint<2>>
// CHECK-NEXT: }
func @add_eint_2D(%a0: tensor<2x4x!HLFHE.eint<2>>, %a1: tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4x!HLFHE.eint<2>>) -> tensor<2x4x!HLFHE.eint<2>>
return %1: tensor<2x4x!HLFHE.eint<2>>
}
// 10D tensor
// CHECK: func @add_eint_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint"(%[[a0]], %[[a1]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>
// CHECK-NEXT: }
func @add_eint_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, %a1: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>
return %1: tensor<1x2x3x4x5x6x7x8x9x10x!HLFHE.eint<2>>
}
// Broadcasting with tensor with dimensions equals to one
// CHECK: func @add_eint_broadcast_1(%[[a0:.*]]: tensor<1x4x5x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<3x4x1x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint"(%[[a0]], %[[a1]]) : (tensor<1x4x5x!HLFHE.eint<2>>, tensor<3x4x1x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!HLFHE.eint<2>>
// CHECK-NEXT: }
func @add_eint_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x4x1x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<1x4x5x!HLFHE.eint<2>>, tensor<3x4x1x!HLFHE.eint<2>>) -> tensor<3x4x5x!HLFHE.eint<2>>
return %1: tensor<3x4x5x!HLFHE.eint<2>>
}
// Broadcasting with a tensor less dimensions of another
// CHECK: func @add_eint_broadcast_2(%[[a0:.*]]: tensor<4x!HLFHE.eint<2>>, %[[a1:.*]]: tensor<3x4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "HLFHELinalg.add_eint"(%[[a0]], %[[a1]]) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<3x4x!HLFHE.eint<2>>
// CHECK-NEXT: }
func @add_eint_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>> {
%1 ="HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4x!HLFHE.eint<2>>) -> tensor<3x4x!HLFHE.eint<2>>
return %1: tensor<3x4x!HLFHE.eint<2>>
}