feat(compiler): keyswitching types and ops

This commit is contained in:
youben11
2021-05-27 16:32:34 +01:00
committed by Quentin Bourgerie
parent 7104e2600c
commit 8edf4a358e
4 changed files with 43 additions and 0 deletions

View File

@@ -56,4 +56,10 @@ def ReturnOp : MidLFHE_Op<"pbs_return", [NoSideEffect, ReturnLike, Terminator]>
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
}
def KeySwitchOp : MidLFHE_Op<"keyswitch"> {
let arguments = (ins KeySwitchingKeyType:$ks, CipherTextType:$ct, I32Attr:$base_log, I32Attr:$level);
let results = (outs CipherTextType);
}
#endif

View File

@@ -134,6 +134,27 @@ def AnyCipherTextType : MidLFHE_Type<"AnyCipherText"> {
}];
}
def KeySwitchingKeyType : MidLFHE_Type<"KeySwitchingKey"> {
let mnemonic = "ksk";
let summary = "A KeySwitching key";
let description = [{
A KeySwitching key
}];
// We define the printer inline.
let printer = [{
$_printer << "ksk";
}];
// The parser is defined here also.
let parser = [{
return get($_ctxt);
}];
}
def CipherTextType: TypeConstraint<Or<[
LWECipherTextType.predicate,
GLWECipherTextType.predicate,

View File

@@ -29,6 +29,8 @@ void MidLFHEDialect::initialize() {
return GGSWCipherTextType::parse(this->getContext(), parser);
if(parser.parseOptionalKeyword("ciphertext").succeeded())
return AnyCipherTextType::parse(this->getContext(), parser);
if(parser.parseOptionalKeyword("ksk").succeeded())
return KeySwitchingKeyType::parse(this->getContext(), parser);
parser.emitError(parser.getCurrentLocation(), "Unknown MidLFHE type");
return ::mlir::Type();
}
@@ -51,6 +53,11 @@ void MidLFHEDialect::printType(::mlir::Type type,
ggsw.print(printer);
return;
}
mlir::zamalang::MidLFHE::KeySwitchingKeyType ksk = type.dyn_cast_or_null<mlir::zamalang::MidLFHE::KeySwitchingKeyType>();
if (ksk != nullptr) {
ksk.print(printer);
return;
}
mlir::zamalang::MidLFHE::AnyCipherTextType any = type.dyn_cast_or_null<mlir::zamalang::MidLFHE::AnyCipherTextType>();
if (any != nullptr) {
any.print(printer);

View File

@@ -0,0 +1,9 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: func @keyswitch(%arg0: !MidLFHE.ksk, %arg1: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext {
func @keyswitch(%arg0: !MidLFHE.ksk, %arg1: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext {
// CHECK-NEXT: %[[V1:.*]] = "MidLFHE.keyswitch"(%arg0, %arg1) {base_log = 8 : i32, level = 2 : i32} : (!MidLFHE.ksk, !MidLFHE.ciphertext) -> !MidLFHE.ciphertext
%0 = "MidLFHE.keyswitch"(%arg0, %arg1) {base_log = 8 : i32, level = 2 : i32} : (!MidLFHE.ksk, !MidLFHE.ciphertext) -> !MidLFHE.ciphertext
// CHECK-NEXT: return %[[V1]] : !MidLFHE.ciphertext
return %0 : !MidLFHE.ciphertext
}