diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td index 8281ad7a7..a68ee12f3 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td @@ -27,6 +27,10 @@ def AddEintIntOp : HLFHE_Op<"add_eint_int"> { build($_builder, $_state, a.getType(), a, b); }]> ]; + + let verifier = [{ + return ::mlir::zamalang::HLFHE::verifyAddEintIntOp(*this); + }]; } def AddEintOp : HLFHE_Op<"add_eint"> { @@ -38,6 +42,10 @@ def AddEintOp : HLFHE_Op<"add_eint"> { build($_builder, $_state, a.getType(), a, b); }]> ]; + + let verifier = [{ + return ::mlir::zamalang::HLFHE::verifyAddEintOp(*this); + }]; } def MulEintIntOp : HLFHE_Op<"mul_eint_int"> { @@ -49,6 +57,10 @@ def MulEintIntOp : HLFHE_Op<"mul_eint_int"> { build($_builder, $_state, a.getType(), a, b); }]> ]; + + let verifier = [{ + return ::mlir::zamalang::HLFHE::verifyMulEintIntOp(*this); + }]; } def ApplyLookupTable : HLFHE_Op<"apply_lookup_table"> { diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp index d563cc961..4df577034 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp @@ -8,9 +8,81 @@ namespace mlir { namespace zamalang { namespace HLFHE { +using mlir::zamalang::HLFHE::AddEintOp; using mlir::zamalang::HLFHE::ApplyLookupTable; using mlir::zamalang::HLFHE::EncryptedIntegerType; +bool verifyEncryptedIntegerInputAndResultConsistency( + ::mlir::OpState &op, EncryptedIntegerType &input, + EncryptedIntegerType &result) { + if (input.getWidth() != result.getWidth()) { + op.emitOpError( + " should have the width of encrypted inputs and result equals"); + return false; + } + return true; +} + +bool verifyEncryptedIntegerAndIntegerInputsConsistency(::mlir::OpState &op, + EncryptedIntegerType &a, + IntegerType &b) { + if (a.getWidth() + 1 != b.getWidth()) { + op.emitOpError(" should have the width of plain input equals to width of " + "encrypted input + 1"); + return false; + } + return true; +} + +bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op, + EncryptedIntegerType &a, + EncryptedIntegerType &b) { + if (a.getWidth() != b.getWidth()) { + op.emitOpError(" should have the width of encrypted inputs equals"); + return false; + } + return true; +} + +::mlir::LogicalResult verifyAddEintIntOp(AddEintIntOp &op) { + auto a = op.a().getType().cast(); + auto b = op.b().getType().cast(); + auto out = op.getResult().getType().cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(op, a, out)) { + return ::mlir::failure(); + } + if (!verifyEncryptedIntegerAndIntegerInputsConsistency(op, a, b)) { + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult verifyAddEintOp(AddEintOp &op) { + auto a = op.a().getType().cast(); + auto b = op.b().getType().cast(); + auto out = op.getResult().getType().cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(op, a, out)) { + return ::mlir::failure(); + } + if (!verifyEncryptedIntegerInputsConsistency(op, a, b)) { + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult verifyMulEintIntOp(MulEintIntOp &op) { + auto a = op.a().getType().cast(); + auto b = op.b().getType().cast(); + auto out = op.getResult().getType().cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(op, a, out)) { + return ::mlir::failure(); + } + if (!verifyEncryptedIntegerAndIntegerInputsConsistency(op, a, b)) { + return ::mlir::failure(); + } + return ::mlir::success(); +} + ::mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) { auto ct = op.ct().getType().cast(); auto l_cst = op.l_cst().getType().cast(); diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir new file mode 100644 index 000000000..01bd02a02 --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir @@ -0,0 +1,7 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals +func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> { + %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<3>) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir new file mode 100644 index 000000000..3c9e7b0cb --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir @@ -0,0 +1,7 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals +func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> { + %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<2>) -> (!HLFHE.eint<3>) + return %1: !HLFHE.eint<3> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir new file mode 100644 index 000000000..680c79b57 --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir @@ -0,0 +1,8 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1 +func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { + %0 = constant 1 : i4 + %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i4) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir new file mode 100644 index 000000000..299d2771a --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir @@ -0,0 +1,8 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals +func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { + %0 = constant 1 : i2 + %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i2) -> (!HLFHE.eint<3>) + return %1: !HLFHE.eint<3> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir new file mode 100644 index 000000000..6ea308618 --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir @@ -0,0 +1,8 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1 +func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { + %0 = constant 1 : i4 + %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i4) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir new file mode 100644 index 000000000..bc7ef7637 --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir @@ -0,0 +1,8 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals +func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { + %0 = constant 1 : i2 + %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i2) -> (!HLFHE.eint<3>) + return %1: !HLFHE.eint<3> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index 990d96950..41508ee58 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -2,23 +2,23 @@ // CHECK-LABEL: func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { - // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "HLFHE.add_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2> + // CHECK-NEXT: %[[V1:.*]] = constant 1 : i3 + // CHECK-NEXT: %[[V2:.*]] = "HLFHE.add_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> // CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2> - %0 = constant 1 : i32 - %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i32) -> (!HLFHE.eint<2>) + %0 = constant 1 : i3 + %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i3) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> } // CHECK-LABEL: func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { - // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "HLFHE.mul_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2> + // CHECK-NEXT: %[[V1:.*]] = constant 1 : i3 + // CHECK-NEXT: %[[V2:.*]] = "HLFHE.mul_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> // CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2> - %0 = constant 1 : i32 - %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i32) -> (!HLFHE.eint<2>) + %0 = constant 1 : i3 + %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i3) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> }