diff --git a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td index 0273e0380..d7d5c47e4 100644 --- a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td +++ b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td @@ -29,6 +29,17 @@ def AddPlainOp : MidLFHE_Op<"add_plain"> { ]; } +def MulPlainOp : MidLFHE_Op<"mul_plain"> { + let arguments = (ins CipherTextType:$a, AnyInteger:$b); + let results = (outs CipherTextType); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b), [{ + build($_builder, $_state, a.getType(), a, b); + }]> + ]; +} + def PBSRegion : Region< CPred<"::mlir::zamalang::predPBSRegion($_self)">, "pbs region needs one block with one any integer argument">; diff --git a/compiler/tests/Dialect/MidLFHE/op_add_plain.mlir b/compiler/tests/Dialect/MidLFHE/op_add_plain.mlir index fc1946264..db4b31710 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_plain.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_plain.mlir @@ -1,45 +1,45 @@ // RUN: zamacompiler %s 2>&1| FileCheck %s -// CHECK-LABEL: func @add_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext -func @add_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext { +// CHECK-LABEL: func @mul_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext +func @mul_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.ciphertext, i32) -> !MidLFHE.ciphertext + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.ciphertext, i32) -> !MidLFHE.ciphertext // CHECK-NEXT: return %[[V2]] : !MidLFHE.ciphertext %0 = constant 1 : i32 - %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.ciphertext, i32) -> (!MidLFHE.ciphertext) + %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.ciphertext, i32) -> (!MidLFHE.ciphertext) return %1: !MidLFHE.ciphertext } -// CHECK-LABEL: func @add_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> -func @add_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> { +// CHECK-LABEL: func @mul_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> +func @mul_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.lwe<1024>, i32) -> !MidLFHE.lwe<1024> + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.lwe<1024>, i32) -> !MidLFHE.lwe<1024> // CHECK-NEXT: return %[[V2]] : !MidLFHE.lwe<1024> %0 = constant 1 : i32 - %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.lwe<1024>, i32) -> (!MidLFHE.lwe<1024>) + %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.lwe<1024>, i32) -> (!MidLFHE.lwe<1024>) return %1: !MidLFHE.lwe<1024> } -// CHECK-LABEL: func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> -func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> { +// CHECK-LABEL: func @mul_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> +func @mul_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.glwe<1024,12>, i32) -> !MidLFHE.glwe<1024,12> + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.glwe<1024,12>, i32) -> !MidLFHE.glwe<1024,12> // CHECK-NEXT: return %[[V2]] : !MidLFHE.glwe<1024,12> %0 = constant 1 : i32 - %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.glwe<1024,12>, i32) -> (!MidLFHE.glwe<1024,12>) + %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.glwe<1024,12>, i32) -> (!MidLFHE.glwe<1024,12>) return %1: !MidLFHE.glwe<1024,12> } -// CHECK-LABEL: func @add_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> -func @add_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> { +// CHECK-LABEL: func @mul_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> +func @mul_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.ggsw<1024,12,3,2>, i32) -> !MidLFHE.ggsw<1024,12,3,2> + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.ggsw<1024,12,3,2>, i32) -> !MidLFHE.ggsw<1024,12,3,2> // CHECK-NEXT: return %[[V2]] : !MidLFHE.ggsw<1024,12,3,2> %0 = constant 1 : i32 - %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.ggsw<1024,12,3,2>, i32) -> (!MidLFHE.ggsw<1024,12,3,2>) + %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.ggsw<1024,12,3,2>, i32) -> (!MidLFHE.ggsw<1024,12,3,2>) return %1: !MidLFHE.ggsw<1024,12,3,2> } \ No newline at end of file diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_plain.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_plain.mlir new file mode 100644 index 000000000..fc1946264 --- /dev/null +++ b/compiler/tests/Dialect/MidLFHE/op_mul_plain.mlir @@ -0,0 +1,45 @@ +// RUN: zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @add_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext +func @add_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext { + // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.ciphertext, i32) -> !MidLFHE.ciphertext + // CHECK-NEXT: return %[[V2]] : !MidLFHE.ciphertext + + %0 = constant 1 : i32 + %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.ciphertext, i32) -> (!MidLFHE.ciphertext) + return %1: !MidLFHE.ciphertext +} + +// CHECK-LABEL: func @add_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> +func @add_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> { + // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.lwe<1024>, i32) -> !MidLFHE.lwe<1024> + // CHECK-NEXT: return %[[V2]] : !MidLFHE.lwe<1024> + + %0 = constant 1 : i32 + %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.lwe<1024>, i32) -> (!MidLFHE.lwe<1024>) + return %1: !MidLFHE.lwe<1024> +} + +// CHECK-LABEL: func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> +func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> { + // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.glwe<1024,12>, i32) -> !MidLFHE.glwe<1024,12> + // CHECK-NEXT: return %[[V2]] : !MidLFHE.glwe<1024,12> + + %0 = constant 1 : i32 + %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.glwe<1024,12>, i32) -> (!MidLFHE.glwe<1024,12>) + return %1: !MidLFHE.glwe<1024,12> +} + +// CHECK-LABEL: func @add_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> +func @add_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> { + // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.ggsw<1024,12,3,2>, i32) -> !MidLFHE.ggsw<1024,12,3,2> + // CHECK-NEXT: return %[[V2]] : !MidLFHE.ggsw<1024,12,3,2> + + %0 = constant 1 : i32 + %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.ggsw<1024,12,3,2>, i32) -> (!MidLFHE.ggsw<1024,12,3,2>) + return %1: !MidLFHE.ggsw<1024,12,3,2> +} \ No newline at end of file