enhance(compiler): Add verify for HLFHE.eint, precision ]0;7] (#54)

This commit is contained in:
Quentin Bourgerie
2021-07-15 10:25:19 +02:00
parent 3c326c09d6
commit ea77ee696a
8 changed files with 81 additions and 48 deletions

View File

@@ -33,8 +33,11 @@ def EncryptedIntegerType : HLFHE_Type<"EncryptedInteger",
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, width);
Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
return getChecked(loc, loc.getContext(), width);
}];
let genVerifyDecl = true;
}
#endif

View File

@@ -43,3 +43,12 @@ void HLFHEDialect::printType(::mlir::Type type,
// Calling default printer if failed to print HLFHE type
printer.printType(type);
}
mlir::LogicalResult EncryptedIntegerType::verify(
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) {
if (p == 0 || p > 7) {
emitError() << "HLFHE.eint support only precision in ]0;7]";
return mlir::failure();
}
return mlir::success();
}

View File

@@ -2,13 +2,13 @@
// Unranked types
func @dot_unranked(
%arg0: memref<?x!HLFHE.eint<0>>,
%arg0: memref<?x!HLFHE.eint<2>>,
%arg1: memref<?xi32>,
%arg2: memref<!HLFHE.eint<0>>)
%arg2: memref<!HLFHE.eint<2>>)
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}}
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<?x!HLFHE.eint<0>>, memref<?xi32>, memref<!HLFHE.eint<0>>) -> ()
(memref<?x!HLFHE.eint<2>>, memref<?xi32>, memref<!HLFHE.eint<2>>) -> ()
return
}
@@ -17,13 +17,13 @@ func @dot_unranked(
// Incompatible shapes
func @dot_incompatible_shapes(
%arg0: memref<5x!HLFHE.eint<0>>,
%arg0: memref<5x!HLFHE.eint<2>>,
%arg1: memref<4xi32>,
%arg2: memref<!HLFHE.eint<0>>)
%arg2: memref<!HLFHE.eint<2>>)
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op arguments have incompatible shapes}}
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<5x!HLFHE.eint<0>>, memref<4xi32>, memref<!HLFHE.eint<0>>) -> ()
(memref<5x!HLFHE.eint<2>>, memref<4xi32>, memref<!HLFHE.eint<2>>) -> ()
return
}
@@ -32,13 +32,13 @@ func @dot_incompatible_shapes(
// Incompatible input types
func @dot_incompatible_input_types(
%arg0: memref<4x!HLFHE.eint<0>>,
%arg0: memref<4x!HLFHE.eint<2>>,
%arg1: memref<4xf32>,
%arg2: memref<!HLFHE.eint<0>>)
%arg2: memref<!HLFHE.eint<2>>)
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #1 must}}
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<4x!HLFHE.eint<0>>, memref<4xf32>, memref<!HLFHE.eint<0>>) -> ()
(memref<4x!HLFHE.eint<2>>, memref<4xf32>, memref<!HLFHE.eint<2>>) -> ()
return
}
@@ -47,13 +47,13 @@ func @dot_incompatible_input_types(
// Wrong number of dimensions
func @dot_num_dims(
%arg0: memref<2x4x!HLFHE.eint<0>>,
%arg0: memref<2x4x!HLFHE.eint<2>>,
%arg1: memref<2x4xi32>,
%arg2: memref<!HLFHE.eint<0>>)
%arg2: memref<!HLFHE.eint<2>>)
{
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}}
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<2x4x!HLFHE.eint<0>>, memref<2x4xi32>, memref<!HLFHE.eint<0>>) -> ()
(memref<2x4x!HLFHE.eint<2>>, memref<2x4xi32>, memref<!HLFHE.eint<2>>) -> ()
return
}

View File

@@ -0,0 +1,6 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: eint support only precision in ]0;7]
func @test(%arg0: !HLFHE.eint<8>) {
return
}

View File

@@ -0,0 +1,6 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: eint support only precision in ]0;7]
func @test(%arg0: !HLFHE.eint<0>) {
return
}

View File

@@ -1,44 +1,53 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0>
func @add_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> {
// CHECK-LABEL: func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2>
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i32
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.add_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
// CHECK-NEXT: return %[[V2]] : !HLFHE.eint<0>
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.add_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2>
// CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2>
%0 = constant 1 : i32
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<0>, i32) -> (!HLFHE.eint<0>)
return %1: !HLFHE.eint<0>
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i32) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
// CHECK-LABEL: func @mul_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0>
func @mul_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> {
// CHECK-LABEL: func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2>
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i32
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.mul_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
// CHECK-NEXT: return %[[V2]] : !HLFHE.eint<0>
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.mul_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2>
// CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2>
%0 = constant 1 : i32
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<0>, i32) -> (!HLFHE.eint<0>)
return %1: !HLFHE.eint<0>
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i32) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
// CHECK-LABEL: func @add_eint(%arg0: !HLFHE.eint<0>, %arg1: !HLFHE.eint<0>) -> !HLFHE.eint<0>
func @add_eint(%arg0: !HLFHE.eint<0>, %arg1: !HLFHE.eint<0>) -> !HLFHE.eint<0> {
// CHECK-NEXT: %[[V1:.*]] = "HLFHE.add_eint"(%arg0, %arg1) : (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0>
// CHECK-NEXT: return %[[V1]] : !HLFHE.eint<0>
// CHECK-LABEL: func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<2>
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "HLFHE.add_eint"(%arg0, %arg1) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
// CHECK-NEXT: return %[[V1]] : !HLFHE.eint<2>
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<0>, !HLFHE.eint<0>) -> (!HLFHE.eint<0>)
return %1: !HLFHE.eint<0>
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<2>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
// CHECK-LABEL: func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>, %arg1: memref<2xi32>, %arg2: memref<!HLFHE.eint<0>>)
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>,
// CHECK-LABEL: func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2>
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "HLFHE.apply_lookup_table"(%arg0, %arg1) : (!HLFHE.eint<2>, memref<4xi2>) -> !HLFHE.eint<2>
// CHECK-NEXT: return %[[V1]] : !HLFHE.eint<2>
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, memref<4xi2>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
// CHECK-LABEL: func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, %arg1: memref<2xi32>, %arg2: memref<!HLFHE.eint<2>>)
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>,
%arg1: memref<2xi32>,
%arg2: memref<!HLFHE.eint<0>>)
%arg2: memref<!HLFHE.eint<2>>)
{
// CHECK-NEXT: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref<!HLFHE.eint<0>>) -> ()
// CHECK-NEXT: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref<2x!HLFHE.eint<2>>, memref<2xi32>, memref<!HLFHE.eint<2>>) -> ()
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref<!HLFHE.eint<0>>) -> ()
(memref<2x!HLFHE.eint<2>>, memref<2xi32>, memref<!HLFHE.eint<2>>) -> ()
//CHECK-NEXT: return
return

View File

@@ -3,20 +3,20 @@
// CHECK: #map0 = affine_map<(d0) -> (d0)>
// CHECK-NEXT: #map1 = affine_map<(d0) -> ()>
// CHECK-NEXT: module {
// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<0>>, %[[A1:.*]]: memref<2xi32>, %[[A2:.*]]: memref<!HLFHE.eint<0>>)
// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<0>>, memref<2xi32>) outs(%arg2 : memref<!HLFHE.eint<0>>) {
// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<0>, %[[A4:.*]]: i32, %[[A5:.*]]: !HLFHE.eint<0>): // no predecessors
// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0>
// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<0>
// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<2>>, %[[A1:.*]]: memref<2xi3>, %[[A2:.*]]: memref<!HLFHE.eint<2>>)
// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<2>>, memref<2xi3>) outs(%arg2 : memref<!HLFHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<2>, %[[A4:.*]]: i3, %[[A5:.*]]: !HLFHE.eint<2>): // no predecessors
// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<2>
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>,
%arg1: memref<2xi32>,
%arg2: memref<!HLFHE.eint<0>>)
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>,
%arg1: memref<2xi3>,
%arg2: memref<!HLFHE.eint<2>>)
{
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref<!HLFHE.eint<0>>) -> ()
(memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref<!HLFHE.eint<2>>) -> ()
return
}

View File

@@ -1,6 +1,6 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<0>>
func @memref_arg(%arg0: memref<2x!HLFHE.eint<0>>) {
// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>
func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) {
return
}