mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): Make TFHE.keyswitch_glwe and TFHE.bootstrap_glwe batchable
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Interfaces/BatchableInterface.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h.inc"
|
||||
|
||||
@@ -12,9 +12,10 @@
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHEAttrs.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHEDialect.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHETypes.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHEAttrs.td"
|
||||
include "concretelang/Interfaces/BatchableInterface.td"
|
||||
|
||||
class TFHE_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<TFHE_Dialect, mnemonic, traits>;
|
||||
@@ -121,7 +122,18 @@ def TFHE_MulGLWEIntOp : TFHE_Op<"mul_glwe_int", [Pure]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe", [Pure]> {
|
||||
def TFHE_BatchedKeySwitchGLWEOp : TFHE_Op<"batched_keyswitch_glwe", [Pure]> {
|
||||
let summary = "Batched version of KeySwitchGLWEOp";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
|
||||
TFHE_KeyswitchKeyAttr : $key
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe", [Pure, BatchableOpInterface]> {
|
||||
let summary = "Change the encryption parameters of a glwe ciphertext by "
|
||||
"applying a keyswitch";
|
||||
|
||||
@@ -132,10 +144,44 @@ def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe", [Pure]> {
|
||||
|
||||
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::OpOperand& getBatchableOperand() {
|
||||
return getOperation()->getOpOperand(0);
|
||||
}
|
||||
|
||||
::mlir::OperandRange getNonBatchableOperands() {
|
||||
return getOperation()->getOperands().drop_front();
|
||||
}
|
||||
|
||||
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
|
||||
::mlir::Value batchedOperands,
|
||||
::mlir::ValueRange hoistedNonBatchableOperands) {
|
||||
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
|
||||
batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(),
|
||||
getResult().getType());
|
||||
|
||||
return builder.create<BatchedKeySwitchGLWEOp>(
|
||||
mlir::TypeRange{resType},
|
||||
mlir::ValueRange{batchedOperands},
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TFHE_BatchedBootstrapGLWEOp : TFHE_Op<"batched_bootstrap_glwe", [Pure]> {
|
||||
let summary = "Batched version of KeySwitchGLWEOp";
|
||||
|
||||
def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure]> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
|
||||
1DTensorOf<[I64]> : $lookup_table,
|
||||
TFHE_BootstrapKeyAttr: $key
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure, BatchableOpInterface]> {
|
||||
let summary =
|
||||
"Programmable bootstraping of a GLWE ciphertext with a lookup table";
|
||||
|
||||
@@ -146,6 +192,34 @@ def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure]> {
|
||||
);
|
||||
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::OpOperand& getBatchableOperand() {
|
||||
return getOperation()->getOpOperand(0);
|
||||
}
|
||||
|
||||
::mlir::OperandRange getNonBatchableOperands() {
|
||||
return getOperation()->getOperands().drop_front();
|
||||
}
|
||||
|
||||
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
|
||||
::mlir::Value batchedOperands,
|
||||
::mlir::ValueRange hoistedNonBatchableOperands) {
|
||||
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
|
||||
batchedOperands.getType().cast<::mlir::RankedTensorType>().getShape(),
|
||||
getResult().getType());
|
||||
|
||||
::llvm::SmallVector<::mlir::Value> operands;
|
||||
operands.push_back(batchedOperands);
|
||||
operands.append(hoistedNonBatchableOperands.begin(),
|
||||
hoistedNonBatchableOperands.end());
|
||||
|
||||
return builder.create<BatchedBootstrapGLWEOp>(
|
||||
mlir::TypeRange{resType},
|
||||
operands,
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe", [Pure]> {
|
||||
|
||||
Reference in New Issue
Block a user