diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHETypes.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHETypes.td index ed4c045b0..903d820d3 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHETypes.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHETypes.td @@ -33,8 +33,11 @@ def EncryptedIntegerType : HLFHE_Type<"EncryptedInteger", return Type(); if ($_parser.parseGreater()) return Type(); - return get($_ctxt, width); + Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc()); + return getChecked(loc, loc.getContext(), width); }]; + + let genVerifyDecl = true; } #endif diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp index 5453a5302..09f18d9bd 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp @@ -43,3 +43,12 @@ void HLFHEDialect::printType(::mlir::Type type, // Calling default printer if failed to print HLFHE type printer.printType(type); } + +mlir::LogicalResult EncryptedIntegerType::verify( + llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) { + if (p == 0 || p > 7) { + emitError() << "HLFHE.eint support only precision in ]0;7]"; + return mlir::failure(); + } + return mlir::success(); +} diff --git a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir index 8ea2b2c26..7127999c2 100644 --- a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir +++ b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir @@ -2,13 +2,13 @@ // Unranked types func @dot_unranked( - %arg0: memref>, + %arg0: memref>, %arg1: memref, - %arg2: memref>) + %arg2: memref>) { // expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}} "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref>, memref, memref>) -> () + (memref>, memref, memref>) -> () return } @@ -17,13 +17,13 @@ func @dot_unranked( // Incompatible shapes func @dot_incompatible_shapes( - %arg0: memref<5x!HLFHE.eint<0>>, + %arg0: memref<5x!HLFHE.eint<2>>, %arg1: memref<4xi32>, - %arg2: memref>) + %arg2: memref>) { // expected-error @+1 {{'HLFHE.dot_eint_int' op arguments have incompatible shapes}} "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<5x!HLFHE.eint<0>>, memref<4xi32>, memref>) -> () + (memref<5x!HLFHE.eint<2>>, memref<4xi32>, memref>) -> () return } @@ -32,13 +32,13 @@ func @dot_incompatible_shapes( // Incompatible input types func @dot_incompatible_input_types( - %arg0: memref<4x!HLFHE.eint<0>>, + %arg0: memref<4x!HLFHE.eint<2>>, %arg1: memref<4xf32>, - %arg2: memref>) + %arg2: memref>) { // expected-error @+1 {{'HLFHE.dot_eint_int' op operand #1 must}} "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<4x!HLFHE.eint<0>>, memref<4xf32>, memref>) -> () + (memref<4x!HLFHE.eint<2>>, memref<4xf32>, memref>) -> () return } @@ -47,13 +47,13 @@ func @dot_incompatible_input_types( // Wrong number of dimensions func @dot_num_dims( - %arg0: memref<2x4x!HLFHE.eint<0>>, + %arg0: memref<2x4x!HLFHE.eint<2>>, %arg1: memref<2x4xi32>, - %arg2: memref>) + %arg2: memref>) { // expected-error @+1 {{'HLFHE.dot_eint_int' op operand #0}} "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<2x4x!HLFHE.eint<0>>, memref<2x4xi32>, memref>) -> () + (memref<2x4x!HLFHE.eint<2>>, memref<2x4xi32>, memref>) -> () return } diff --git a/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir b/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir new file mode 100644 index 000000000..de76e776f --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir @@ -0,0 +1,6 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: eint support only precision in ]0;7] +func @test(%arg0: !HLFHE.eint<8>) { + return +} diff --git a/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir b/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir new file mode 100644 index 000000000..300d49685 --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir @@ -0,0 +1,6 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: eint support only precision in ]0;7] +func @test(%arg0: !HLFHE.eint<0>) { + return +} diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index f9fae3714..990d96950 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -1,44 +1,53 @@ // RUN: zamacompiler %s 2>&1| FileCheck %s -// CHECK-LABEL: func @add_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> -func @add_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> { +// CHECK-LABEL: func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> +func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "HLFHE.add_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0> - // CHECK-NEXT: return %[[V2]] : !HLFHE.eint<0> + // CHECK-NEXT: %[[V2:.*]] = "HLFHE.add_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2> + // CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2> %0 = constant 1 : i32 - %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<0>, i32) -> (!HLFHE.eint<0>) - return %1: !HLFHE.eint<0> + %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i32) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> } -// CHECK-LABEL: func @mul_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> -func @mul_eint_int(%arg0: !HLFHE.eint<0>) -> !HLFHE.eint<0> { +// CHECK-LABEL: func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> +func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "HLFHE.mul_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0> - // CHECK-NEXT: return %[[V2]] : !HLFHE.eint<0> + // CHECK-NEXT: %[[V2:.*]] = "HLFHE.mul_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2> + // CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2> %0 = constant 1 : i32 - %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<0>, i32) -> (!HLFHE.eint<0>) - return %1: !HLFHE.eint<0> + %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i32) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> } -// CHECK-LABEL: func @add_eint(%arg0: !HLFHE.eint<0>, %arg1: !HLFHE.eint<0>) -> !HLFHE.eint<0> -func @add_eint(%arg0: !HLFHE.eint<0>, %arg1: !HLFHE.eint<0>) -> !HLFHE.eint<0> { - // CHECK-NEXT: %[[V1:.*]] = "HLFHE.add_eint"(%arg0, %arg1) : (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0> - // CHECK-NEXT: return %[[V1]] : !HLFHE.eint<0> +// CHECK-LABEL: func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<2> +func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<2> { + // CHECK-NEXT: %[[V1:.*]] = "HLFHE.add_eint"(%arg0, %arg1) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2> + // CHECK-NEXT: return %[[V1]] : !HLFHE.eint<2> - %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<0>, !HLFHE.eint<0>) -> (!HLFHE.eint<0>) - return %1: !HLFHE.eint<0> + %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<2>) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> } -// CHECK-LABEL: func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>, %arg1: memref<2xi32>, %arg2: memref>) -func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>, +// CHECK-LABEL: func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2> +func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2> { + // CHECK-NEXT: %[[V1:.*]] = "HLFHE.apply_lookup_table"(%arg0, %arg1) : (!HLFHE.eint<2>, memref<4xi2>) -> !HLFHE.eint<2> + // CHECK-NEXT: return %[[V1]] : !HLFHE.eint<2> + + %1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, memref<4xi2>) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> +} + +// CHECK-LABEL: func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, %arg1: memref<2xi32>, %arg2: memref>) +func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, %arg1: memref<2xi32>, - %arg2: memref>) + %arg2: memref>) { - // CHECK-NEXT: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref>) -> () + // CHECK-NEXT: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref<2x!HLFHE.eint<2>>, memref<2xi32>, memref>) -> () "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref>) -> () + (memref<2x!HLFHE.eint<2>>, memref<2xi32>, memref>) -> () //CHECK-NEXT: return return diff --git a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir index f117126f2..2fde5e0ba 100644 --- a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir @@ -3,20 +3,20 @@ // CHECK: #map0 = affine_map<(d0) -> (d0)> // CHECK-NEXT: #map1 = affine_map<(d0) -> ()> // CHECK-NEXT: module { -// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<0>>, %[[A1:.*]]: memref<2xi32>, %[[A2:.*]]: memref>) -// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<0>>, memref<2xi32>) outs(%arg2 : memref>) { -// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<0>, %[[A4:.*]]: i32, %[[A5:.*]]: !HLFHE.eint<0>): // no predecessors -// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0> -// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0> -// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<0> +// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<2>>, %[[A1:.*]]: memref<2xi3>, %[[A2:.*]]: memref>) +// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<2>>, memref<2xi3>) outs(%arg2 : memref>) { +// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<2>, %[[A4:.*]]: i3, %[[A5:.*]]: !HLFHE.eint<2>): // no predecessors +// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> +// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2> +// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<2> // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } -func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>, - %arg1: memref<2xi32>, - %arg2: memref>) +func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, + %arg1: memref<2xi3>, + %arg2: memref>) { "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : - (memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref>) -> () + (memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref>) -> () return } diff --git a/compiler/tests/Dialect/HLFHE/types.mlir b/compiler/tests/Dialect/HLFHE/types.mlir index cd2646a72..ecc07bcdc 100644 --- a/compiler/tests/Dialect/HLFHE/types.mlir +++ b/compiler/tests/Dialect/HLFHE/types.mlir @@ -1,6 +1,6 @@ // RUN: zamacompiler %s 2>&1| FileCheck %s -// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<0>> -func @memref_arg(%arg0: memref<2x!HLFHE.eint<0>>) { +// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>> +func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) { return }