diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td index bcbce40aa..c893d7657 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td @@ -69,6 +69,21 @@ def SubIntEintOp : HLFHE_Op<"sub_int_eint"> { }]; } +def NegEintOp : HLFHE_Op<"neg_eint"> { + let arguments = (ins EncryptedIntegerType:$a); + let results = (outs EncryptedIntegerType); + + let builders = [ + OpBuilder<(ins "Value":$a), [{ + build($_builder, $_state, a.getType(), a); + }]> + ]; + + let verifier = [{ + return ::mlir::zamalang::HLFHE::verifyNegEintOp(*this); + }]; +} + def MulEintIntOp : HLFHE_Op<"mul_eint_int"> { let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b); let results = (outs EncryptedIntegerType); diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp index ffcdc47fc..54decddf9 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp @@ -79,6 +79,15 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op, return ::mlir::success(); } +::mlir::LogicalResult verifyNegEintOp(NegEintOp &op) { + auto a = op.a().getType().cast(); + auto out = op.getResult().getType().cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(op, a, out)) { + return ::mlir::failure(); + } + return ::mlir::success(); +} + ::mlir::LogicalResult verifyMulEintIntOp(MulEintIntOp &op) { auto a = op.a().getType().cast(); auto b = op.b().getType().cast(); diff --git a/compiler/tests/Dialect/HLFHE/op_neg_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_neg_eint_err_result.mlir new file mode 100644 index 000000000..84af2b155 --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/op_neg_eint_err_result.mlir @@ -0,0 +1,7 @@ +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'HLFHE.neg_eint' op should have the width of encrypted inputs and result equals +func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { + %1 = "HLFHE.neg_eint"(%arg0): (!HLFHE.eint<2>) -> (!HLFHE.eint<3>) + return %1: !HLFHE.eint<3> +} diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index f5fb96dd3..a8c45b600 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -31,6 +31,15 @@ func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { return %1: !HLFHE.eint<2> } +// CHECK-LABEL: func @neg_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> +func @neg_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { + // CHECK-NEXT: %[[V1:.*]] = "HLFHE.neg_eint"(%arg0) : (!HLFHE.eint<2>) -> !HLFHE.eint<2> + // CHECK-NEXT: return %[[V1]] : !HLFHE.eint<2> + + %1 = "HLFHE.neg_eint"(%arg0): (!HLFHE.eint<2>) -> (!HLFHE.eint<2>) + return %1: !HLFHE.eint<2> +} + // CHECK-LABEL: func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3