From ed40b1ef91816e12ded330035a54f434bab2b61b Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Mon, 19 Jul 2021 11:41:14 +0200 Subject: [PATCH] feat(compiler): Add HLFHE.sub_int_eint (#54) --- .../zamalang/Dialect/HLFHE/IR/HLFHEOps.td | 9 +++++++++ compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp | 17 +++++++++++++---- .../HLFHE/op_sub_int_eint_err_inputs.mlir | 8 ++++++++ .../HLFHE/op_sub_int_eint_err_result.mlir | 8 ++++++++ compiler/tests/Dialect/HLFHE/ops.mlir | 11 +++++++++++ 5 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir create mode 100644 compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td index a68ee12f3..d3220b4c3 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td @@ -48,6 +48,15 @@ def AddEintOp : HLFHE_Op<"add_eint"> { }]; } +def SubIntEintOp : HLFHE_Op<"sub_int_eint"> { + let arguments = (ins AnyInteger:$a, EncryptedIntegerType:$b); + let results = (outs EncryptedIntegerType); + + let verifier = [{ + return ::mlir::zamalang::HLFHE::verifySubIntEintOp(*this); + }]; +} + def MulEintIntOp : HLFHE_Op<"mul_eint_int"> { let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b); let results = (outs EncryptedIntegerType); diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp index 4df577034..bcdb716ea 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp @@ -8,10 +8,6 @@ namespace mlir { namespace zamalang { namespace HLFHE { -using mlir::zamalang::HLFHE::AddEintOp; -using mlir::zamalang::HLFHE::ApplyLookupTable; -using mlir::zamalang::HLFHE::EncryptedIntegerType; - bool verifyEncryptedIntegerInputAndResultConsistency( ::mlir::OpState &op, EncryptedIntegerType &input, EncryptedIntegerType &result) { @@ -70,6 +66,19 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op, return ::mlir::success(); } +::mlir::LogicalResult verifySubIntEintOp(SubIntEintOp &op) { + auto a = op.a().getType().cast(); + auto b = op.b().getType().cast(); + auto out = op.getResult().getType().cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(op, b, out)) { + return ::mlir::failure(); + } + if (!verifyEncryptedIntegerAndIntegerInputsConsistency(op, b, a)) { + return ::mlir::failure(); + } + return ::mlir::success(); +} + ::mlir::LogicalResult verifyMulEintIntOp(MulEintIntOp &op) { auto a = op.a().getType().cast(); auto b = op.b().getType().cast(); diff --git a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir new file mode 100644 index 000000000..73c036e5e --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir @@ -0,0 +1,8 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1 +func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { + %0 = constant 1 : i4 + %1 = "HLFHE.sub_int_eint"(%0, %arg0): (i4, !HLFHE.eint<2>) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir new file mode 100644 index 000000000..6d2892e8a --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir @@ -0,0 +1,8 @@ +// RUN: not zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of encrypted inputs and result equals +func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { + %0 = constant 1 : i2 + %1 = "HLFHE.sub_int_eint"(%0, %arg0): (i2, !HLFHE.eint<2>) -> (!HLFHE.eint<3>) + return %1: !HLFHE.eint<3> +} \ No newline at end of file diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index 41508ee58..8dcc2ea40 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -11,6 +11,17 @@ func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { return %1: !HLFHE.eint<2> } +// CHECK-LABEL: func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> +func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { + // CHECK-NEXT: %[[V1:.*]] = constant 1 : i3 + // CHECK-NEXT: %[[V2:.*]] = "HLFHE.sub_int_eint"(%[[V1]], %arg0) : (i3, !HLFHE.eint<2>) -> !HLFHE.eint<2> + // CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2> + + %0 = constant 1 : i3 + %1 = "HLFHE.sub_int_eint"(%0, %arg0): (i3, !HLFHE.eint<2>) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> +} + // CHECK-LABEL: func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i3