mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compiler/hlfhe): More verification on dot_eint_int
This commit is contained in:
@@ -97,14 +97,7 @@ def Dot : HLFHE_Op<"dot_eint_int"> {
|
||||
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$rhs);
|
||||
let results = (outs EncryptedIntegerType:$out);
|
||||
let verifier = [{
|
||||
if(::mlir::failed(
|
||||
mlir::verifyCompatibleShape(
|
||||
lhs().getType(),
|
||||
rhs().getType()))) {
|
||||
return this->emitOpError("arguments have incompatible shapes");
|
||||
}
|
||||
|
||||
return ::mlir::success();
|
||||
return ::mlir::zamalang::HLFHE::verifyDotEintInt(*this);
|
||||
}];
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -118,6 +118,33 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
::mlir::LogicalResult verifyDotEintInt(Dot &op) {
|
||||
if (::mlir::failed(mlir::verifyCompatibleShape(op.lhs().getType(),
|
||||
op.rhs().getType()))) {
|
||||
return op.emitOpError("arguments have incompatible shapes");
|
||||
}
|
||||
auto lhsEltType = op.lhs()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.cast<EncryptedIntegerType>();
|
||||
auto rhsEltType = op.rhs()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::IntegerType>();
|
||||
auto resultType = op.getResult().getType().cast<EncryptedIntegerType>();
|
||||
if (!verifyEncryptedIntegerAndIntegerInputsConsistency(op, lhsEltType,
|
||||
rhsEltType)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(op, lhsEltType,
|
||||
resultType)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
} // namespace HLFHE
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -21,7 +21,7 @@ func @dot_incompatible_input_types(
|
||||
{
|
||||
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #1 must}}
|
||||
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<5x!HLFHE.eint<2>>, tensor<4xf32>) -> !HLFHE.eint<0>
|
||||
(tensor<5x!HLFHE.eint<2>>, tensor<4xf32>) -> !HLFHE.eint<2>
|
||||
|
||||
return %ret : !HLFHE.eint<2>
|
||||
}
|
||||
@@ -31,11 +31,39 @@ func @dot_incompatible_input_types(
|
||||
// Wrong number of dimensions
|
||||
func @dot_num_dims(
|
||||
%arg0: tensor<2x4x!HLFHE.eint<2>>,
|
||||
%arg1: tensor<2x4xi32>) -> !HLFHE.eint<2>
|
||||
%arg1: tensor<2x4xi3>) -> !HLFHE.eint<2>
|
||||
{
|
||||
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}}
|
||||
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi32>) -> !HLFHE.eint<2>
|
||||
(tensor<2x4x!HLFHE.eint<2>>, tensor<2x4xi3>) -> !HLFHE.eint<2>
|
||||
|
||||
return %ret : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Wrong returns type
|
||||
func @dot_incompatible_return(
|
||||
%arg0: tensor<4x!HLFHE.eint<2>>,
|
||||
%arg1: tensor<4xi3>) -> !HLFHE.eint<3>
|
||||
{
|
||||
// expected-error @+1 {{'HLFHE.dot_eint_int' op should have the width of encrypted inputs and result equals}}
|
||||
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<3>
|
||||
|
||||
return %ret : !HLFHE.eint<3>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Wrong integer size
|
||||
func @dot_incompatible_int(
|
||||
%arg0: tensor<4x!HLFHE.eint<2>>,
|
||||
%arg1: tensor<4xi4>) -> !HLFHE.eint<2>
|
||||
{
|
||||
// expected-error @+1 {{'HLFHE.dot_eint_int' op should have the width of plain input equals to width of encrypted input + 1}}
|
||||
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!HLFHE.eint<2>>, tensor<4xi4>) -> !HLFHE.eint<2>
|
||||
|
||||
return %ret : !HLFHE.eint<2>
|
||||
}
|
||||
@@ -60,13 +60,13 @@ func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.e
|
||||
return %1: !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi32>) -> !HLFHE.eint<2>
|
||||
// CHECK-LABEL: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2>
|
||||
func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>,
|
||||
%arg1: tensor<2xi32>) -> !HLFHE.eint<2>
|
||||
%arg1: tensor<2xi3>) -> !HLFHE.eint<2>
|
||||
{
|
||||
// CHECK-NEXT: %[[RET:.*]] = "HLFHE.dot_eint_int"(%arg0, %arg1) : (tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: %[[RET:.*]] = "HLFHE.dot_eint_int"(%arg0, %arg1) : (tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2>
|
||||
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2>
|
||||
(tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2>
|
||||
|
||||
//CHECK-NEXT: return %[[RET]] : !HLFHE.eint<2>
|
||||
return %ret : !HLFHE.eint<2>
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
// RUN: zamacompiler %s --convert-hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
// RUN: zamacompiler %s --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
//CHECK: #map0 = affine_map<(d0) -> (d0)>
|
||||
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
|
||||
//CHECK-NEXT: module {
|
||||
//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi32>) -> !HLFHE.eint<2> {
|
||||
//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2> {
|
||||
//CHECK-NEXT: %0 = "HLFHE.zero"() : () -> !HLFHE.eint<2>
|
||||
//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<2>>
|
||||
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) outs(%1 : tensor<1x!HLFHE.eint<2>>) {
|
||||
//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i32, %arg4: !HLFHE.eint<2>): // no predecessors
|
||||
//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2>
|
||||
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) outs(%1 : tensor<1x!HLFHE.eint<2>>) {
|
||||
//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i3, %arg4: !HLFHE.eint<2>): // no predecessors
|
||||
//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
//CHECK-NEXT: %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
//CHECK-NEXT: linalg.yield %5 : !HLFHE.eint<2>
|
||||
//CHECK-NEXT: } -> tensor<1x!HLFHE.eint<2>>
|
||||
@@ -18,9 +18,9 @@
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: }
|
||||
func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>,
|
||||
%arg1: tensor<2xi32>) -> !HLFHE.eint<2>
|
||||
%arg1: tensor<2xi3>) -> !HLFHE.eint<2>
|
||||
{
|
||||
%o = "HLFHE.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<2x!HLFHE.eint<2>>, tensor<2xi32>) -> !HLFHE.eint<2>
|
||||
(tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2>
|
||||
return %o : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user