From 8edf4a358e954171df57f8225f306cd96e502528 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 27 May 2021 16:32:34 +0100 Subject: [PATCH] feat(compiler): keyswitching types and ops --- .../zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td | 6 ++++++ .../Dialect/MidLFHE/IR/MidLFHETypes.td | 21 +++++++++++++++++++ .../lib/Dialect/MidLFHE/IR/MidLFHEDialect.cpp | 7 +++++++ compiler/tests/Dialect/MidLFHE/op_ks.mlir | 9 ++++++++ 4 files changed, 43 insertions(+) create mode 100644 compiler/tests/Dialect/MidLFHE/op_ks.mlir diff --git a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td index 1b866aabd..5b890c4ee 100644 --- a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td +++ b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHEOps.td @@ -56,4 +56,10 @@ def ReturnOp : MidLFHE_Op<"pbs_return", [NoSideEffect, ReturnLike, Terminator]> let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; } + +def KeySwitchOp : MidLFHE_Op<"keyswitch"> { + let arguments = (ins KeySwitchingKeyType:$ks, CipherTextType:$ct, I32Attr:$base_log, I32Attr:$level); + let results = (outs CipherTextType); +} + #endif diff --git a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td index 49f7fe6a6..3bce1fc53 100644 --- a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td +++ b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td @@ -134,6 +134,27 @@ def AnyCipherTextType : MidLFHE_Type<"AnyCipherText"> { }]; } +def KeySwitchingKeyType : MidLFHE_Type<"KeySwitchingKey"> { + let mnemonic = "ksk"; + + let summary = "A KeySwitching key"; + + let description = [{ + A KeySwitching key + }]; + + // We define the printer inline. + let printer = [{ + $_printer << "ksk"; + }]; + + // The parser is defined here also. + let parser = [{ + return get($_ctxt); + }]; +} + + def CipherTextType: TypeConstraintgetContext(), parser); if(parser.parseOptionalKeyword("ciphertext").succeeded()) return AnyCipherTextType::parse(this->getContext(), parser); + if(parser.parseOptionalKeyword("ksk").succeeded()) + return KeySwitchingKeyType::parse(this->getContext(), parser); parser.emitError(parser.getCurrentLocation(), "Unknown MidLFHE type"); return ::mlir::Type(); } @@ -51,6 +53,11 @@ void MidLFHEDialect::printType(::mlir::Type type, ggsw.print(printer); return; } + mlir::zamalang::MidLFHE::KeySwitchingKeyType ksk = type.dyn_cast_or_null(); + if (ksk != nullptr) { + ksk.print(printer); + return; + } mlir::zamalang::MidLFHE::AnyCipherTextType any = type.dyn_cast_or_null(); if (any != nullptr) { any.print(printer); diff --git a/compiler/tests/Dialect/MidLFHE/op_ks.mlir b/compiler/tests/Dialect/MidLFHE/op_ks.mlir new file mode 100644 index 000000000..7f1651884 --- /dev/null +++ b/compiler/tests/Dialect/MidLFHE/op_ks.mlir @@ -0,0 +1,9 @@ +// RUN: zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @keyswitch(%arg0: !MidLFHE.ksk, %arg1: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext { +func @keyswitch(%arg0: !MidLFHE.ksk, %arg1: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext { + // CHECK-NEXT: %[[V1:.*]] = "MidLFHE.keyswitch"(%arg0, %arg1) {base_log = 8 : i32, level = 2 : i32} : (!MidLFHE.ksk, !MidLFHE.ciphertext) -> !MidLFHE.ciphertext + %0 = "MidLFHE.keyswitch"(%arg0, %arg1) {base_log = 8 : i32, level = 2 : i32} : (!MidLFHE.ksk, !MidLFHE.ciphertext) -> !MidLFHE.ciphertext + // CHECK-NEXT: return %[[V1]] : !MidLFHE.ciphertext + return %0 : !MidLFHE.ciphertext +} \ No newline at end of file