diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td index 69fa3558d..a54b82be9 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td @@ -97,14 +97,7 @@ def Dot : HLFHE_Op<"dot_eint_int"> { Type.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$rhs); let results = (outs EncryptedIntegerType:$out); let verifier = [{ - if(::mlir::failed( - mlir::verifyCompatibleShape( - lhs().getType(), - rhs().getType()))) { - return this->emitOpError("arguments have incompatible shapes"); - } - - return ::mlir::success(); + return ::mlir::zamalang::HLFHE::verifyDotEintInt(*this); }]; } #endif diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp index b2a3e0de5..122cd8004 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp @@ -118,6 +118,33 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op, return mlir::success(); } +::mlir::LogicalResult verifyDotEintInt(Dot &op) { + if (::mlir::failed(mlir::verifyCompatibleShape(op.lhs().getType(), + op.rhs().getType()))) { + return op.emitOpError("arguments have incompatible shapes"); + } + auto lhsEltType = op.lhs() + .getType() + .cast() + .getElementType() + .cast(); + auto rhsEltType = op.rhs() + .getType() + .cast() + .getElementType() + .cast(); + auto resultType = op.getResult().getType().cast(); + if (!verifyEncryptedIntegerAndIntegerInputsConsistency(op, lhsEltType, + rhsEltType)) { + return ::mlir::failure(); + } + if (!verifyEncryptedIntegerInputAndResultConsistency(op, lhsEltType, + resultType)) { + return ::mlir::failure(); + } + return ::mlir::success(); +} + } // namespace HLFHE } // namespace zamalang } // namespace mlir diff --git a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir index 9345c8b91..2e2f31a65 100644 --- a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir +++ b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir @@ -21,7 +21,7 @@ func @dot_incompatible_input_types( { // expected-error @+1 {{'HLFHE.dot_eint_int' op operand #1 must}} %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : - (tensor<5x!HLFHE.eint<2>>, tensor<4xf32>) -> !HLFHE.eint<0> + (tensor<5x!HLFHE.eint<2>>, tensor<4xf32>) -> !HLFHE.eint<2> return %ret : !HLFHE.eint<2> } @@ -31,11 +31,39 @@ func @dot_incompatible_input_types( // Wrong number of dimensions func @dot_num_dims( %arg0: tensor<2x4x!HLFHE.eint<2>>, - %arg1: tensor<2x4xi32>) -> !HLFHE.eint<2> + %arg1: tensor<2x4xi3>) -> !HLFHE.eint<2> { // expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}} %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : - (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi32>) -> !HLFHE.eint<2> + (tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi3>) -> !HLFHE.eint<2> return %ret : !HLFHE.eint<2> } + +// ----- + +// Wrong returns type +func @dot_incompatible_return( + %arg0: tensor<4x!HLFHE.eint<2>>, + %arg1: tensor<4xi3>) -> !HLFHE.eint<3> +{ + // expected-error @+1 {{'HLFHE.dot_eint_int' op should have the width of encrypted inputs and result equals}} + %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<3> + + return %ret : !HLFHE.eint<3> +} + +// ----- + +// Wrong integer size +func @dot_incompatible_int( + %arg0: tensor<4x!HLFHE.eint<2>>, + %arg1: tensor<4xi4>) -> !HLFHE.eint<2> +{ + // expected-error @+1 {{'HLFHE.dot_eint_int' op should have the width of plain input equals to width of encrypted input + 1}} + %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!HLFHE.eint<2>>, tensor<4xi4>) -> !HLFHE.eint<2> + + return %ret : !HLFHE.eint<2> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index 31fda9dc9..9636116d1 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -60,13 +60,13 @@ func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.e return %1: !HLFHE.eint<2> } -// CHECK-LABEL: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi32>) -> !HLFHE.eint<2> +// CHECK-LABEL: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2> func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, - %arg1: tensor<2xi32>) -> !HLFHE.eint<2> + %arg1: tensor<2xi3>) -> !HLFHE.eint<2> { - // CHECK-NEXT: %[[RET:.*]] = "HLFHE.dot_eint_int"(%arg0, %arg1) : (tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2> + // CHECK-NEXT: %[[RET:.*]] = "HLFHE.dot_eint_int"(%arg0, %arg1) : (tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2> %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : - (tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2> + (tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2> //CHECK-NEXT: return %[[RET]] : !HLFHE.eint<2> return %ret : !HLFHE.eint<2> diff --git a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir index 752200b71..f313d7501 100644 --- a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir @@ -1,14 +1,14 @@ -// RUN: zamacompiler %s --convert-hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s +// RUN: zamacompiler %s --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s //CHECK: #map0 = affine_map<(d0) -> (d0)> //CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> //CHECK-NEXT: module { -//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi32>) -> !HLFHE.eint<2> { +//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2> { //CHECK-NEXT: %0 = "HLFHE.zero"() : () -> !HLFHE.eint<2> //CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<2>> -//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) outs(%1 : tensor<1x!HLFHE.eint<2>>) { -//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i32, %arg4: !HLFHE.eint<2>): // no predecessors -//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2> +//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) outs(%1 : tensor<1x!HLFHE.eint<2>>) { +//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i3, %arg4: !HLFHE.eint<2>): // no predecessors +//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> //CHECK-NEXT: %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2> //CHECK-NEXT: linalg.yield %5 : !HLFHE.eint<2> //CHECK-NEXT: } -> tensor<1x!HLFHE.eint<2>> @@ -18,9 +18,9 @@ //CHECK-NEXT: } //CHECK-NEXT: } func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, - %arg1: tensor<2xi32>) -> !HLFHE.eint<2> + %arg1: tensor<2xi3>) -> !HLFHE.eint<2> { %o = "HLFHE.dot_eint_int"(%arg0, %arg1) : - (tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2> + (tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2> return %o : !HLFHE.eint<2> }