mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(compiler): Add verify for HLFHE.eint, precision ]0;7] (#54)
This commit is contained in:
@@ -33,8 +33,11 @@ def EncryptedIntegerType : HLFHE_Type<"EncryptedInteger",
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
return get($_ctxt, width);
|
||||
Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
|
||||
return getChecked(loc, loc.getContext(), width);
|
||||
}];
|
||||
|
||||
let genVerifyDecl = true;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -43,3 +43,12 @@ void HLFHEDialect::printType(::mlir::Type type,
|
||||
// Calling default printer if failed to print HLFHE type
|
||||
printer.printType(type);
|
||||
}
|
||||
|
||||
mlir::LogicalResult EncryptedIntegerType::verify(
|
||||
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) {
|
||||
if (p == 0 || p > 7) {
|
||||
emitError() << "HLFHE.eint support only precision in ]0;7]";
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
|
||||
// Unranked types
|
||||
func @dot_unranked(
|
||||
%arg0: memref<?x!HLFHE.eint<0>>,
|
||||
%arg0: memref<?x!HLFHE.eint<2>>,
|
||||
%arg1: memref<?xi32>,
|
||||
%arg2: memref<!HLFHE.eint<0>>)
|
||||
%arg2: memref<!HLFHE.eint<2>>)
|
||||
{
|
||||
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}}
|
||||
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
|
||||
(memref<?x!HLFHE.eint<0>>, memref<?xi32>, memref<!HLFHE.eint<0>>) -> ()
|
||||
(memref<?x!HLFHE.eint<2>>, memref<?xi32>, memref<!HLFHE.eint<2>>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
@@ -17,13 +17,13 @@ func @dot_unranked(
|
||||
|
||||
// Incompatible shapes
|
||||
func @dot_incompatible_shapes(
|
||||
%arg0: memref<5x!HLFHE.eint<0>>,
|
||||
%arg0: memref<5x!HLFHE.eint<2>>,
|
||||
%arg1: memref<4xi32>,
|
||||
%arg2: memref<!HLFHE.eint<0>>)
|
||||
%arg2: memref<!HLFHE.eint<2>>)
|
||||
{
|
||||
// expected-error @+1 {{'HLFHE.dot_eint_int' op arguments have incompatible shapes}}
|
||||
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
|
||||
(memref<5x!HLFHE.eint<0>>, memref<4xi32>, memref<!HLFHE.eint<0>>) -> ()
|
||||
(memref<5x!HLFHE.eint<2>>, memref<4xi32>, memref<!HLFHE.eint<2>>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
@@ -32,13 +32,13 @@ func @dot_incompatible_shapes(
|
||||
|
||||
// Incompatible input types
|
||||
func @dot_incompatible_input_types(
|
||||
%arg0: memref<4x!HLFHE.eint<0>>,
|
||||
%arg0: memref<4x!HLFHE.eint<2>>,
|
||||
%arg1: memref<4xf32>,
|
||||
%arg2: memref<!HLFHE.eint<0>>)
|
||||
%arg2: memref<!HLFHE.eint<2>>)
|
||||
{
|
||||
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #1 must}}
|
||||
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
|
||||
(memref<4x!HLFHE.eint<0>>, memref<4xf32>, memref<!HLFHE.eint<0>>) -> ()
|
||||
(memref<4x!HLFHE.eint<2>>, memref<4xf32>, memref<!HLFHE.eint<2>>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
@@ -47,13 +47,13 @@ func @dot_incompatible_input_types(
|
||||
|
||||
// Wrong number of dimensions
|
||||
func @dot_num_dims(
|
||||
%arg0: memref<2x4x!HLFHE.eint<0>>,
|
||||
%arg0: memref<2x4x!HLFHE.eint<2>>,
|
||||
%arg1: memref<2x4xi32>,
|
||||
%arg2: memref<!HLFHE.eint<0>>)
|
||||
%arg2: memref<!HLFHE.eint<2>>)
|
||||
{
|
||||
// expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}}
|
||||
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
|
||||
(memref<2x4x!HLFHE.eint<0>>, memref<2x4xi32>, memref<!HLFHE.eint<0>>) -> ()
|
||||
(memref<2x4x!HLFHE.eint<2>>, memref<2x4xi32>, memref<!HLFHE.eint<2>>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
6
compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir
Normal file
6
compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir
Normal file
@@ -0,0 +1,6 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: eint support only precision in ]0;7]
|
||||
func @test(%arg0: !HLFHE.eint<8>) {
|
||||
return
|
||||
}
|
||||
6
compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir
Normal file
6
compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir
Normal file
@@ -0,0 +1,6 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: eint support only precision in ]0;7]
|
||||
func @test(%arg0: !HLFHE.eint<0>) {
|
||||
return
|
||||
}
|
||||
@@ -1,44 +1,53 @@
|
||||
// RUN: zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0>
|
||||
func @add_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> {
|
||||
// CHECK-LABEL: func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i32
|
||||
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.add_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
|
||||
// CHECK-NEXT: return %[[V2]] : !HLFHE.eint<0>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.add_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2>
|
||||
|
||||
%0 = constant 1 : i32
|
||||
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<0>, i32) -> (!HLFHE.eint<0>)
|
||||
return %1: !HLFHE.eint<0>
|
||||
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i32) -> (!HLFHE.eint<2>)
|
||||
return %1: !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mul_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0>
|
||||
func @mul_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> {
|
||||
// CHECK-LABEL: func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i32
|
||||
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.mul_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
|
||||
// CHECK-NEXT: return %[[V2]] : !HLFHE.eint<0>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.mul_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2>
|
||||
|
||||
%0 = constant 1 : i32
|
||||
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<0>, i32) -> (!HLFHE.eint<0>)
|
||||
return %1: !HLFHE.eint<0>
|
||||
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i32) -> (!HLFHE.eint<2>)
|
||||
return %1: !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add_eint(%arg0: !HLFHE.eint<0>, %arg1: !HLFHE.eint<0>) -> !HLFHE.eint<0>
|
||||
func @add_eint(%arg0: !HLFHE.eint<0>, %arg1: !HLFHE.eint<0>) -> !HLFHE.eint<0> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "HLFHE.add_eint"(%arg0, %arg1) : (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0>
|
||||
// CHECK-NEXT: return %[[V1]] : !HLFHE.eint<0>
|
||||
// CHECK-LABEL: func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "HLFHE.add_eint"(%arg0, %arg1) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: return %[[V1]] : !HLFHE.eint<2>
|
||||
|
||||
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<0>, !HLFHE.eint<0>) -> (!HLFHE.eint<0>)
|
||||
return %1: !HLFHE.eint<0>
|
||||
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<2>) -> (!HLFHE.eint<2>)
|
||||
return %1: !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>, %arg1: memref<2xi32>, %arg2: memref<!HLFHE.eint<0>>)
|
||||
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>,
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2>
|
||||
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "HLFHE.apply_lookup_table"(%arg0, %arg1) : (!HLFHE.eint<2>, memref<4xi2>) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: return %[[V1]] : !HLFHE.eint<2>
|
||||
|
||||
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, memref<4xi2>) -> (!HLFHE.eint<2>)
|
||||
return %1: !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, %arg1: memref<2xi32>, %arg2: memref<!HLFHE.eint<2>>)
|
||||
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>,
|
||||
%arg1: memref<2xi32>,
|
||||
%arg2: memref<!HLFHE.eint<0>>)
|
||||
%arg2: memref<!HLFHE.eint<2>>)
|
||||
{
|
||||
// CHECK-NEXT: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref<!HLFHE.eint<0>>) -> ()
|
||||
// CHECK-NEXT: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref<2x!HLFHE.eint<2>>, memref<2xi32>, memref<!HLFHE.eint<2>>) -> ()
|
||||
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
|
||||
(memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref<!HLFHE.eint<0>>) -> ()
|
||||
(memref<2x!HLFHE.eint<2>>, memref<2xi32>, memref<!HLFHE.eint<2>>) -> ()
|
||||
|
||||
//CHECK-NEXT: return
|
||||
return
|
||||
|
||||
@@ -3,20 +3,20 @@
|
||||
// CHECK: #map0 = affine_map<(d0) -> (d0)>
|
||||
// CHECK-NEXT: #map1 = affine_map<(d0) -> ()>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<0>>, %[[A1:.*]]: memref<2xi32>, %[[A2:.*]]: memref<!HLFHE.eint<0>>)
|
||||
// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<0>>, memref<2xi32>) outs(%arg2 : memref<!HLFHE.eint<0>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<0>, %[[A4:.*]]: i32, %[[A5:.*]]: !HLFHE.eint<0>): // no predecessors
|
||||
// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
|
||||
// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0>
|
||||
// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<0>
|
||||
// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<2>>, %[[A1:.*]]: memref<2xi3>, %[[A2:.*]]: memref<!HLFHE.eint<2>>)
|
||||
// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<2>>, memref<2xi3>) outs(%arg2 : memref<!HLFHE.eint<2>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<2>, %[[A4:.*]]: i3, %[[A5:.*]]: !HLFHE.eint<2>): // no predecessors
|
||||
// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<2>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>,
|
||||
%arg1: memref<2xi32>,
|
||||
%arg2: memref<!HLFHE.eint<0>>)
|
||||
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>,
|
||||
%arg1: memref<2xi3>,
|
||||
%arg2: memref<!HLFHE.eint<2>>)
|
||||
{
|
||||
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
|
||||
(memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref<!HLFHE.eint<0>>) -> ()
|
||||
(memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref<!HLFHE.eint<2>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<0>>
|
||||
func @memref_arg(%arg0: memref<2x!HLFHE.eint<0>>) {
|
||||
// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>
|
||||
func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) {
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user