fix(compiler/hlfhe): More verification on dot_eint_int

This commit is contained in:
Quentin Bourgerie
2021-08-17 14:21:38 +02:00
parent 8b9c9f2da1
commit 7372cd3d0a
5 changed files with 70 additions and 22 deletions

View File

@@ -97,14 +97,7 @@ def Dot : HLFHE_Op<"dot_eint_int"> {
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$rhs);
let results = (outs EncryptedIntegerType:$out);
let verifier = [{
if(::mlir::failed(
mlir::verifyCompatibleShape(
lhs().getType(),
rhs().getType()))) {
return this->emitOpError("arguments have incompatible shapes");
}
return ::mlir::success();
return ::mlir::zamalang::HLFHE::verifyDotEintInt(*this);
}];
}
#endif

View File

@@ -118,6 +118,33 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
return mlir::success();
}
::mlir::LogicalResult verifyDotEintInt(Dot &op) {
if (::mlir::failed(mlir::verifyCompatibleShape(op.lhs().getType(),
op.rhs().getType()))) {
return op.emitOpError("arguments have incompatible shapes");
}
auto lhsEltType = op.lhs()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.cast<EncryptedIntegerType>();
auto rhsEltType = op.rhs()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.cast<mlir::IntegerType>();
auto resultType = op.getResult().getType().cast<EncryptedIntegerType>();
if (!verifyEncryptedIntegerAndIntegerInputsConsistency(op, lhsEltType,
rhsEltType)) {
return ::mlir::failure();
}
if (!verifyEncryptedIntegerInputAndResultConsistency(op, lhsEltType,
resultType)) {
return ::mlir::failure();
}
return ::mlir::success();
}
} // namespace HLFHE
} // namespace zamalang
} // namespace mlir

View File

@@ -21,7 +21,7 @@ func @dot_incompatible_input_types(
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #1 must}}
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<5x!HLFHE.eint<2>>, tensor<4xf32>) -> !HLFHE.eint<0>
(tensor<5x!HLFHE.eint<2>>, tensor<4xf32>) -> !HLFHE.eint<2>
return %ret : !HLFHE.eint<2>
}
@@ -31,11 +31,39 @@ func @dot_incompatible_input_types(
// Wrong number of dimensions
func @dot_num_dims(
%arg0: tensor<2x4x!HLFHE.eint<2>>,
%arg1: tensor<2x4xi32>) -> !HLFHE.eint<2>
%arg1: tensor<2x4xi3>) -> !HLFHE.eint<2>
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}}
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi32>) -> !HLFHE.eint<2>
(tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi3>) -> !HLFHE.eint<2>
return %ret : !HLFHE.eint<2>
}
// -----
// Wrong returns type
func @dot_incompatible_return(
%arg0: tensor<4x!HLFHE.eint<2>>,
%arg1: tensor<4xi3>) -> !HLFHE.eint<3>
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op should have the width of encrypted inputs and result equals}}
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<3>
return %ret : !HLFHE.eint<3>
}
// -----
// Wrong integer size
func @dot_incompatible_int(
%arg0: tensor<4x!HLFHE.eint<2>>,
%arg1: tensor<4xi4>) -> !HLFHE.eint<2>
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op should have the width of plain input equals to width of encrypted input + 1}}
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!HLFHE.eint<2>>, tensor<4xi4>) -> !HLFHE.eint<2>
return %ret : !HLFHE.eint<2>
}

View File

@@ -60,13 +60,13 @@ func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.e
return %1: !HLFHE.eint<2>
}
// CHECK-LABEL: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi32>) -> !HLFHE.eint<2>
// CHECK-LABEL: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2>
func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>,
%arg1: tensor<2xi32>) -> !HLFHE.eint<2>
%arg1: tensor<2xi3>) -> !HLFHE.eint<2>
{
// CHECK-NEXT: %[[RET:.*]] = "HLFHE.dot_eint_int"(%arg0, %arg1) : (tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2>
// CHECK-NEXT: %[[RET:.*]] = "HLFHE.dot_eint_int"(%arg0, %arg1) : (tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2>
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2>
(tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2>
//CHECK-NEXT: return %[[RET]] : !HLFHE.eint<2>
return %ret : !HLFHE.eint<2>

View File

@@ -1,14 +1,14 @@
// RUN: zamacompiler %s --convert-hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
//CHECK: #map0 = affine_map<(d0) -> (d0)>
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
//CHECK-NEXT: module {
//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi32>) -> !HLFHE.eint<2> {
//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2> {
//CHECK-NEXT: %0 = "HLFHE.zero"() : () -> !HLFHE.eint<2>
//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<2>>
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) outs(%1 : tensor<1x!HLFHE.eint<2>>) {
//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i32, %arg4: !HLFHE.eint<2>): // no predecessors
//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2>
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) outs(%1 : tensor<1x!HLFHE.eint<2>>) {
//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i3, %arg4: !HLFHE.eint<2>): // no predecessors
//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
//CHECK-NEXT: %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
//CHECK-NEXT: linalg.yield %5 : !HLFHE.eint<2>
//CHECK-NEXT: } -> tensor<1x!HLFHE.eint<2>>
@@ -18,9 +18,9 @@
//CHECK-NEXT: }
//CHECK-NEXT: }
func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>,
%arg1: tensor<2xi32>) -> !HLFHE.eint<2>
%arg1: tensor<2xi3>) -> !HLFHE.eint<2>
{
%o = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2>
(tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2>
return %o : !HLFHE.eint<2>
}