From ce69aab7562f0b0ff4c53cb3971add6e190a1478 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 25 May 2021 10:46:34 +0200 Subject: [PATCH] feat(compiler): Add !MidLFHE.ggsw --- .../Dialect/MidLFHE/IR/MidLFHETypes.td | 44 +++++++++++++++++++ .../lib/Dialect/MidLFHE/IR/MidLFHEDialect.cpp | 12 +++-- .../tests/Dialect/MidLFHE/types_ggsw.mlir | 7 +++ 3 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 compiler/tests/Dialect/MidLFHE/types_ggsw.mlir diff --git a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td index b458d462d..1ab1162df 100644 --- a/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td +++ b/compiler/include/zamalang/Dialect/MidLFHE/IR/MidLFHETypes.td @@ -68,4 +68,48 @@ def GLWECipherTextType : MidLFHE_Type<"GLWECipherText"> { }]; } +def GGSWCipherTextType : MidLFHE_Type<"GGSWCipherText"> { + let mnemonic = "ggsw"; + + let summary = "An GGSW cipher text"; + + let description = [{ + An GGSW cipher text + }]; + + let parameters = (ins "unsigned":$size, "unsigned":$N, "unsigned": $level, "unsigned":$base_log); + + // We define the printer inline. + let printer = [{ + $_printer << "ggsw<" << getImpl()->size << "," << getImpl()->N << "," << getImpl()->level << "," << getImpl()->base_log << ">"; + }]; + + // The parser is defined here also. + let parser = [{ + if ($_parser.parseLess()) + return Type(); + int size; + if ($_parser.parseInteger(size)) + return Type(); + if ($_parser.parseComma()) + return Type(); + int N; + if ($_parser.parseInteger(N)) + return Type(); + if ($_parser.parseComma()) + return Type(); + int level; + if ($_parser.parseInteger(level)) + return Type(); + if ($_parser.parseComma()) + return Type(); + int base_log; + if ($_parser.parseInteger(base_log)) + return Type(); + if ($_parser.parseGreater()) + return Type(); + return get($_ctxt, size, N, level, base_log); + }]; +} + #endif diff --git a/compiler/lib/Dialect/MidLFHE/IR/MidLFHEDialect.cpp b/compiler/lib/Dialect/MidLFHE/IR/MidLFHEDialect.cpp index 1425bd432..036b1ff0a 100644 --- a/compiler/lib/Dialect/MidLFHE/IR/MidLFHEDialect.cpp +++ b/compiler/lib/Dialect/MidLFHE/IR/MidLFHEDialect.cpp @@ -21,11 +21,12 @@ void MidLFHEDialect::initialize() { ::mlir::Type MidLFHEDialect::parseType(::mlir::DialectAsmParser &parser) const { - if(parser.parseOptionalKeyword("glwe").succeeded()) - return GLWECipherTextType::parse(this->getContext(), parser); if(parser.parseOptionalKeyword("lwe").succeeded()) return LWECipherTextType::parse(this->getContext(), parser); - + if(parser.parseOptionalKeyword("glwe").succeeded()) + return GLWECipherTextType::parse(this->getContext(), parser); + if(parser.parseOptionalKeyword("ggsw").succeeded()) + return GGSWCipherTextType::parse(this->getContext(), parser); return ::mlir::Type(); } @@ -42,6 +43,11 @@ void MidLFHEDialect::printType(::mlir::Type type, glwe.print(printer); return; } + mlir::zamalang::MidLFHE::GGSWCipherTextType ggsw = type.dyn_cast_or_null(); + if (ggsw != nullptr) { + ggsw.print(printer); + return; + } // TODO - What should be done here? printer << "unknwontype"; } \ No newline at end of file diff --git a/compiler/tests/Dialect/MidLFHE/types_ggsw.mlir b/compiler/tests/Dialect/MidLFHE/types_ggsw.mlir new file mode 100644 index 000000000..53e615a90 --- /dev/null +++ b/compiler/tests/Dialect/MidLFHE/types_ggsw.mlir @@ -0,0 +1,7 @@ +// RUN: zamacompiler %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> +func @ggsw(%arg0: !MidLFHE.ggsw<1024,12,3,2>) -> !MidLFHE.ggsw<1024,12,3,2> { + // CHECK-LABEL: return %arg0 : !MidLFHE.ggsw<1024,12,3,2> + return %arg0: !MidLFHE.ggsw<1024,12,3,2> +} \ No newline at end of file