From efae2a79d76658dd23ba7a6175db495e066c409c Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 28 May 2021 09:52:54 +0100 Subject: [PATCH] feat(compiler): h_add and h_mul --- .../zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td | 25 +++++++++++++++ .../tests/Dialect/MidLFHE/op_add_plain.mlir | 32 +++++++++---------- compiler/tests/Dialect/MidLFHE/op_h_add.mlir | 11 +++++++ compiler/tests/Dialect/MidLFHE/op_h_mul.mlir | 11 +++++++ .../tests/Dialect/MidLFHE/op_mul_plain.mlir | 32 +++++++++---------- 5 files changed, 79 insertions(+), 32 deletions(-) create mode 100644 compiler/tests/Dialect/MidLFHE/op_h_add.mlir create mode 100644 compiler/tests/Dialect/MidLFHE/op_h_mul.mlir diff --git a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td index 5b890c4ee..b759afbbd 100644 --- a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td +++ b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td @@ -62,4 +62,29 @@ def KeySwitchOp : MidLFHE_Op<"keyswitch"> { let results = (outs CipherTextType); } + +def HAddOp : MidLFHE_Op<"h_add"> { + let arguments = (ins GLWECipherTextType:$a, GLWECipherTextType:$b); + let results = (outs GLWECipherTextType); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b), [{ + build($_builder, $_state, a.getType(), a, b); + }]> + ]; +} + + +def HMulOp : MidLFHE_Op<"h_mul"> { + let arguments = (ins GLWECipherTextType:$a, GLWECipherTextType:$b); + let results = (outs GLWECipherTextType); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b), [{ + build($_builder, $_state, a.getType(), a, b); + }]> + ]; +} + + #endif diff --git a/compiler/tests/Dialect/MidLFHE/op_add_plain.mlir b/compiler/tests/Dialect/MidLFHE/op_add_plain.mlir index db4b31710..fc1946264 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_plain.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_plain.mlir @@ -1,45 +1,45 @@ // RUN: zamacompiler %s 2>&1| FileCheck %s -// CHECK-LABEL: func @mul_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext -func @mul_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext { +// CHECK-LABEL: func @add_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext +func @add_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.ciphertext, i32) -> !MidLFHE.ciphertext + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.ciphertext, i32) -> !MidLFHE.ciphertext // CHECK-NEXT: return %[[V2]] : !MidLFHE.ciphertext %0 = constant 1 : i32 - %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.ciphertext, i32) -> (!MidLFHE.ciphertext) + %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.ciphertext, i32) -> (!MidLFHE.ciphertext) return %1: !MidLFHE.ciphertext } -// CHECK-LABEL: func @mul_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> -func @mul_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> { +// CHECK-LABEL: func @add_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> +func @add_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.lwe<1024>, i32) -> !MidLFHE.lwe<1024> + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.lwe<1024>, i32) -> !MidLFHE.lwe<1024> // CHECK-NEXT: return %[[V2]] : !MidLFHE.lwe<1024> %0 = constant 1 : i32 - %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.lwe<1024>, i32) -> (!MidLFHE.lwe<1024>) + %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.lwe<1024>, i32) -> (!MidLFHE.lwe<1024>) return %1: !MidLFHE.lwe<1024> } -// CHECK-LABEL: func @mul_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> -func @mul_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> { +// CHECK-LABEL: func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> +func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.glwe<1024,12>, i32) -> !MidLFHE.glwe<1024,12> + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.glwe<1024,12>, i32) -> !MidLFHE.glwe<1024,12> // CHECK-NEXT: return %[[V2]] : !MidLFHE.glwe<1024,12> %0 = constant 1 : i32 - %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.glwe<1024,12>, i32) -> (!MidLFHE.glwe<1024,12>) + %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.glwe<1024,12>, i32) -> (!MidLFHE.glwe<1024,12>) return %1: !MidLFHE.glwe<1024,12> } -// CHECK-LABEL: func @mul_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> -func @mul_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> { +// CHECK-LABEL: func @add_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> +func @add_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.ggsw<1024,12,3,2>, i32) -> !MidLFHE.ggsw<1024,12,3,2> + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.ggsw<1024,12,3,2>, i32) -> !MidLFHE.ggsw<1024,12,3,2> // CHECK-NEXT: return %[[V2]] : !MidLFHE.ggsw<1024,12,3,2> %0 = constant 1 : i32 - %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.ggsw<1024,12,3,2>, i32) -> (!MidLFHE.ggsw<1024,12,3,2>) + %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.ggsw<1024,12,3,2>, i32) -> (!MidLFHE.ggsw<1024,12,3,2>) return %1: !MidLFHE.ggsw<1024,12,3,2> } \ No newline at end of file diff --git a/compiler/tests/Dialect/MidLFHE/op_h_add.mlir b/compiler/tests/Dialect/MidLFHE/op_h_add.mlir new file mode 100644 index 000000000..3c90c460b --- /dev/null +++ b/compiler/tests/Dialect/MidLFHE/op_h_add.mlir @@ -0,0 +1,11 @@ +// RUN: zamacompiler %s 2>&1| FileCheck %s + + +// CHECK-LABEL: func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>, %arg1: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> +func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>, %arg1: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> { + // CHECK-NEXT: %[[V1:.*]] = "MidLFHE.h_mul"(%arg0, %arg1) : (!MidLFHE.glwe<1024,12>, !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> + // CHECK-NEXT: return %[[V1]] : !MidLFHE.glwe<1024,12> + + %0 = "MidLFHE.h_mul"(%arg0, %arg1): (!MidLFHE.glwe<1024,12>, !MidLFHE.glwe<1024,12>) -> (!MidLFHE.glwe<1024,12>) + return %0: !MidLFHE.glwe<1024,12> +} diff --git a/compiler/tests/Dialect/MidLFHE/op_h_mul.mlir b/compiler/tests/Dialect/MidLFHE/op_h_mul.mlir new file mode 100644 index 000000000..cdc1f716f --- /dev/null +++ b/compiler/tests/Dialect/MidLFHE/op_h_mul.mlir @@ -0,0 +1,11 @@ +// RUN: zamacompiler %s 2>&1| FileCheck %s + + +// CHECK-LABEL: func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>, %arg1: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> +func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>, %arg1: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> { + // CHECK-NEXT: %[[V1:.*]] = "MidLFHE.h_add"(%arg0, %arg1) : (!MidLFHE.glwe<1024,12>, !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> + // CHECK-NEXT: return %[[V1]] : !MidLFHE.glwe<1024,12> + + %0 = "MidLFHE.h_add"(%arg0, %arg1): (!MidLFHE.glwe<1024,12>, !MidLFHE.glwe<1024,12>) -> (!MidLFHE.glwe<1024,12>) + return %0: !MidLFHE.glwe<1024,12> +} diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_plain.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_plain.mlir index fc1946264..db4b31710 100644 --- a/compiler/tests/Dialect/MidLFHE/op_mul_plain.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_mul_plain.mlir @@ -1,45 +1,45 @@ // RUN: zamacompiler %s 2>&1| FileCheck %s -// CHECK-LABEL: func @add_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext -func @add_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext { +// CHECK-LABEL: func @mul_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext +func @mul_plain_ciphertext(%arg0: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.ciphertext, i32) -> !MidLFHE.ciphertext + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.ciphertext, i32) -> !MidLFHE.ciphertext // CHECK-NEXT: return %[[V2]] : !MidLFHE.ciphertext %0 = constant 1 : i32 - %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.ciphertext, i32) -> (!MidLFHE.ciphertext) + %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.ciphertext, i32) -> (!MidLFHE.ciphertext) return %1: !MidLFHE.ciphertext } -// CHECK-LABEL: func @add_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> -func @add_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> { +// CHECK-LABEL: func @mul_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> +func @mul_plain_lwe(%arg0: !MidLFHE.lwe<1024>) -> !MidLFHE.lwe<1024> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.lwe<1024>, i32) -> !MidLFHE.lwe<1024> + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.lwe<1024>, i32) -> !MidLFHE.lwe<1024> // CHECK-NEXT: return %[[V2]] : !MidLFHE.lwe<1024> %0 = constant 1 : i32 - %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.lwe<1024>, i32) -> (!MidLFHE.lwe<1024>) + %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.lwe<1024>, i32) -> (!MidLFHE.lwe<1024>) return %1: !MidLFHE.lwe<1024> } -// CHECK-LABEL: func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> -func @add_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> { +// CHECK-LABEL: func @mul_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> +func @mul_plain_glwe(%arg0: !MidLFHE.glwe<1024,12>) -> !MidLFHE.glwe<1024,12> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.glwe<1024,12>, i32) -> !MidLFHE.glwe<1024,12> + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.glwe<1024,12>, i32) -> !MidLFHE.glwe<1024,12> // CHECK-NEXT: return %[[V2]] : !MidLFHE.glwe<1024,12> %0 = constant 1 : i32 - %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.glwe<1024,12>, i32) -> (!MidLFHE.glwe<1024,12>) + %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.glwe<1024,12>, i32) -> (!MidLFHE.glwe<1024,12>) return %1: !MidLFHE.glwe<1024,12> } -// CHECK-LABEL: func @add_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> -func @add_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> { +// CHECK-LABEL: func @mul_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> +func @mul_plain_ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> { // CHECK-NEXT: %[[V1:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.add_plain"(%arg0, %[[V1]]) : (!MidLFHE.ggsw<1024,12,3,2>, i32) -> !MidLFHE.ggsw<1024,12,3,2> + // CHECK-NEXT: %[[V2:.*]] = "MidLFHE.mul_plain"(%arg0, %[[V1]]) : (!MidLFHE.ggsw<1024,12,3,2>, i32) -> !MidLFHE.ggsw<1024,12,3,2> // CHECK-NEXT: return %[[V2]] : !MidLFHE.ggsw<1024,12,3,2> %0 = constant 1 : i32 - %1 = "MidLFHE.add_plain"(%arg0, %0): (!MidLFHE.ggsw<1024,12,3,2>, i32) -> (!MidLFHE.ggsw<1024,12,3,2>) + %1 = "MidLFHE.mul_plain"(%arg0, %0): (!MidLFHE.ggsw<1024,12,3,2>, i32) -> (!MidLFHE.ggsw<1024,12,3,2>) return %1: !MidLFHE.ggsw<1024,12,3,2> } \ No newline at end of file