diff --git a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td index 76b951d31..dca3d02b8 100644 --- a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td +++ b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td @@ -50,6 +50,15 @@ def SubIntGLWEOp : MidLFHE_Op<"sub_int_glwe"> { }]; } +def NegGLWEOp : MidLFHE_Op<"neg_glwe"> { + let arguments = (ins GLWECipherTextType:$a); + let results = (outs GLWECipherTextType); + + let verifier = [{ + return ::mlir::zamalang::MidLFHE::verifyUnaryGLWEOperator(*this); + }]; +} + def MulGLWEIntOp : MidLFHE_Op<"mul_glwe_int"> { let arguments = (ins GLWECipherTextType:$a, AnyInteger:$b); diff --git a/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp b/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp index c282dfc99..9f3924a3c 100644 --- a/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp +++ b/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp @@ -105,6 +105,36 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) { return mlir::success(); } +// verifyUnaryGLWEOperator verify parameters of operators that has the following +// signature (!MidLFHE.glwe<{dim,poly,bits}{p}>) -> +// (!MidLFHE.glwe<{dim,poly,bits}{p}>)) +template +mlir::LogicalResult verifyUnaryGLWEOperator(Operator &op) { + auto a = ((mlir::Type)(op.a().getType())).cast(); + auto result = + ((mlir::Type)(op.getResult().getType())).cast(); + + // verify consistency of a and result GLWE parameter + if (a.getDimension() != result.getDimension()) { + emitOpErrorForIncompatibleGLWEParameter(op, "dimension"); + return mlir::failure(); + } + if (a.getPolynomialSize() != result.getPolynomialSize()) { + emitOpErrorForIncompatibleGLWEParameter(op, "polynomialSize"); + return mlir::failure(); + } + if (a.getBits() != result.getBits()) { + emitOpErrorForIncompatibleGLWEParameter(op, "bits"); + return mlir::failure(); + } + if (a.getP() != result.getP()) { + emitOpErrorForIncompatibleGLWEParameter(op, "p"); + return mlir::failure(); + } + + return mlir::success(); +} + /// verifyApplyLookupTable verify the GLWE parameters follow the rules: /// - The l_cst argument must be a memref of one dimension of size 2^p /// - The lookup table contains integer values of the same width of the output diff --git a/compiler/tests/Dialect/MidLFHE/op_neg_glwe.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_neg_glwe.invalid.mlir new file mode 100644 index 000000000..00fd720c4 --- /dev/null +++ b/compiler/tests/Dialect/MidLFHE/op_neg_glwe.invalid.mlir @@ -0,0 +1,36 @@ +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s + +// GLWE p parameter +func @neg_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { + // expected-error @+1 {{'MidLFHE.neg_glwe' op should have the same GLWE 'p' parameter}} + %1 = "MidLFHE.neg_glwe"(%arg0): (!MidLFHE.glwe<{1024,12,64}{7}>) -> (!MidLFHE.glwe<{1024,12,64}{6}>) + return %1: !MidLFHE.glwe<{1024,12,64}{6}> +} + +// ----- + +// GLWE dimension parameter +func @neg_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{512,12,64}{7}> { + // expected-error @+1 {{'MidLFHE.neg_glwe' op should have the same GLWE 'dimension' parameter}} + %1 = "MidLFHE.neg_glwe"(%arg0): (!MidLFHE.glwe<{1024,12,64}{7}>) -> (!MidLFHE.glwe<{512,12,64}{7}>) + return %1: !MidLFHE.glwe<{512,12,64}{7}> +} + +// ----- + +// GLWE polynomialSize parameter +func @neg_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,11,64}{7}> { + // expected-error @+1 {{'MidLFHE.neg_glwe' op should have the same GLWE 'polynomialSize' parameter}} + %1 = "MidLFHE.neg_glwe"(%arg0): (!MidLFHE.glwe<{1024,12,64}{7}>) -> (!MidLFHE.glwe<{1024,11,64}{7}>) + return %1: !MidLFHE.glwe<{1024,11,64}{7}> +} + +// ----- + +// integer width doesn't match GLWE parameter +func @neg_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,11,64}{7}> { + // expected-error @+1 {{'MidLFHE.neg_glwe' op should have the same GLWE 'polynomialSize' parameter}} + %1 = "MidLFHE.neg_glwe"(%arg0): (!MidLFHE.glwe<{1024,12,64}{7}>) -> (!MidLFHE.glwe<{1024,11,64}{7}>) + return %1: !MidLFHE.glwe<{1024,11,64}{7}> +} + diff --git a/compiler/tests/Dialect/MidLFHE/op_neg_glwe.mlir b/compiler/tests/Dialect/MidLFHE/op_neg_glwe.mlir new file mode 100644 index 000000000..a47313c62 --- /dev/null +++ b/compiler/tests/Dialect/MidLFHE/op_neg_glwe.mlir @@ -0,0 +1,10 @@ +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @neg_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> +func @neg_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { + // CHECK-NEXT: %[[V1:.*]] = "MidLFHE.neg_glwe"(%arg0) : (!MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> + // CHECK-NEXT: return %[[V1]] : !MidLFHE.glwe<{1024,12,64}{7}> + + %1 = "MidLFHE.neg_glwe"(%arg0): (!MidLFHE.glwe<{1024,12,64}{7}>) -> (!MidLFHE.glwe<{1024,12,64}{7}>) + return %1: !MidLFHE.glwe<{1024,12,64}{7}> +}