mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): keyswitching types and ops
This commit is contained in:
committed by
Quentin Bourgerie
parent
7104e2600c
commit
8edf4a358e
@@ -56,4 +56,10 @@ def ReturnOp : MidLFHE_Op<"pbs_return", [NoSideEffect, ReturnLike, Terminator]>
|
||||
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
|
||||
}
|
||||
|
||||
|
||||
def KeySwitchOp : MidLFHE_Op<"keyswitch"> {
|
||||
let arguments = (ins KeySwitchingKeyType:$ks, CipherTextType:$ct, I32Attr:$base_log, I32Attr:$level);
|
||||
let results = (outs CipherTextType);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -134,6 +134,27 @@ def AnyCipherTextType : MidLFHE_Type<"AnyCipherText"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def KeySwitchingKeyType : MidLFHE_Type<"KeySwitchingKey"> {
|
||||
let mnemonic = "ksk";
|
||||
|
||||
let summary = "A KeySwitching key";
|
||||
|
||||
let description = [{
|
||||
A KeySwitching key
|
||||
}];
|
||||
|
||||
// We define the printer inline.
|
||||
let printer = [{
|
||||
$_printer << "ksk";
|
||||
}];
|
||||
|
||||
// The parser is defined here also.
|
||||
let parser = [{
|
||||
return get($_ctxt);
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def CipherTextType: TypeConstraint<Or<[
|
||||
LWECipherTextType.predicate,
|
||||
GLWECipherTextType.predicate,
|
||||
|
||||
@@ -29,6 +29,8 @@ void MidLFHEDialect::initialize() {
|
||||
return GGSWCipherTextType::parse(this->getContext(), parser);
|
||||
if(parser.parseOptionalKeyword("ciphertext").succeeded())
|
||||
return AnyCipherTextType::parse(this->getContext(), parser);
|
||||
if(parser.parseOptionalKeyword("ksk").succeeded())
|
||||
return KeySwitchingKeyType::parse(this->getContext(), parser);
|
||||
parser.emitError(parser.getCurrentLocation(), "Unknown MidLFHE type");
|
||||
return ::mlir::Type();
|
||||
}
|
||||
@@ -51,6 +53,11 @@ void MidLFHEDialect::printType(::mlir::Type type,
|
||||
ggsw.print(printer);
|
||||
return;
|
||||
}
|
||||
mlir::zamalang::MidLFHE::KeySwitchingKeyType ksk = type.dyn_cast_or_null<mlir::zamalang::MidLFHE::KeySwitchingKeyType>();
|
||||
if (ksk != nullptr) {
|
||||
ksk.print(printer);
|
||||
return;
|
||||
}
|
||||
mlir::zamalang::MidLFHE::AnyCipherTextType any = type.dyn_cast_or_null<mlir::zamalang::MidLFHE::AnyCipherTextType>();
|
||||
if (any != nullptr) {
|
||||
any.print(printer);
|
||||
|
||||
9
compiler/tests/Dialect/MidLFHE/op_ks.mlir
Normal file
9
compiler/tests/Dialect/MidLFHE/op_ks.mlir
Normal file
@@ -0,0 +1,9 @@
|
||||
// RUN: zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @keyswitch(%arg0: !MidLFHE.ksk, %arg1: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext {
|
||||
func @keyswitch(%arg0: !MidLFHE.ksk, %arg1: !MidLFHE.ciphertext) -> !MidLFHE.ciphertext {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "MidLFHE.keyswitch"(%arg0, %arg1) {base_log = 8 : i32, level = 2 : i32} : (!MidLFHE.ksk, !MidLFHE.ciphertext) -> !MidLFHE.ciphertext
|
||||
%0 = "MidLFHE.keyswitch"(%arg0, %arg1) {base_log = 8 : i32, level = 2 : i32} : (!MidLFHE.ksk, !MidLFHE.ciphertext) -> !MidLFHE.ciphertext
|
||||
// CHECK-NEXT: return %[[V1]] : !MidLFHE.ciphertext
|
||||
return %0 : !MidLFHE.ciphertext
|
||||
}
|
||||
Reference in New Issue
Block a user