feat(compiler): Make TFHE.keyswitch_glwe and TFHE.bootstrap_glwe batchable

This commit is contained in:
Andi Drebes
2023-03-14 17:16:38 +01:00
parent b24709a1ec
commit b495f9dd5c
2 changed files with 78 additions and 3 deletions

View File

@@ -14,6 +14,7 @@
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Interfaces/BatchableInterface.h"
#define GET_OP_CLASSES
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h.inc"

View File

@@ -12,9 +12,10 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "concretelang/Dialect/TFHE/IR/TFHEAttrs.td"
include "concretelang/Dialect/TFHE/IR/TFHEDialect.td"
include "concretelang/Dialect/TFHE/IR/TFHETypes.td"
include "concretelang/Dialect/TFHE/IR/TFHEAttrs.td"
include "concretelang/Interfaces/BatchableInterface.td"
class TFHE_Op<string mnemonic, list<Trait> traits = []>
: Op<TFHE_Dialect, mnemonic, traits>;
@@ -121,7 +122,18 @@ def TFHE_MulGLWEIntOp : TFHE_Op<"mul_glwe_int", [Pure]> {
let hasVerifier = 1;
}
def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe", [Pure]> {
def TFHE_BatchedKeySwitchGLWEOp : TFHE_Op<"batched_keyswitch_glwe", [Pure]> {
let summary = "Batched version of KeySwitchGLWEOp";
let arguments = (ins
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
TFHE_KeyswitchKeyAttr : $key
);
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
}
def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe", [Pure, BatchableOpInterface]> {
let summary = "Change the encryption parameters of a glwe ciphertext by "
"applying a keyswitch";
@@ -132,10 +144,44 @@ def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe", [Pure]> {
let results = (outs TFHE_GLWECipherTextType : $result);
let extraClassDeclaration = [{
::mlir::OpOperand& getBatchableOperand() {
return getOperation()->getOpOperand(0);
}
::mlir::OperandRange getNonBatchableOperands() {
return getOperation()->getOperands().drop_front();
}
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
::mlir::Value batchedOperands,
::mlir::ValueRange hoistedNonBatchableOperands) {
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(),
getResult().getType());
return builder.create<BatchedKeySwitchGLWEOp>(
mlir::TypeRange{resType},
mlir::ValueRange{batchedOperands},
getOperation()->getAttrs());
}
}];
}
def TFHE_BatchedBootstrapGLWEOp : TFHE_Op<"batched_bootstrap_glwe", [Pure]> {
let summary = "Batched version of KeySwitchGLWEOp";
def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure]> {
let arguments = (ins
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
1DTensorOf<[I64]> : $lookup_table,
TFHE_BootstrapKeyAttr: $key
);
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
}
def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure, BatchableOpInterface]> {
let summary =
"Programmable bootstraping of a GLWE ciphertext with a lookup table";
@@ -146,6 +192,34 @@ def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure]> {
);
let results = (outs TFHE_GLWECipherTextType : $result);
let extraClassDeclaration = [{
::mlir::OpOperand& getBatchableOperand() {
return getOperation()->getOpOperand(0);
}
::mlir::OperandRange getNonBatchableOperands() {
return getOperation()->getOperands().drop_front();
}
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
::mlir::Value batchedOperands,
::mlir::ValueRange hoistedNonBatchableOperands) {
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(),
getResult().getType());
::llvm::SmallVector<::mlir::Value> operands;
operands.push_back(batchedOperands);
operands.append(hoistedNonBatchableOperands.begin(),
hoistedNonBatchableOperands.end());
return builder.create<BatchedBootstrapGLWEOp>(
mlir::TypeRange{resType},
operands,
getOperation()->getAttrs());
}
}];
}
def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe", [Pure]> {