feat(compiler): Batching: Favor batching of all operands for binary TFHE ops

This changes the order of batching variants for binary TFHE
operations, such that batching of both operands is favored over
batching of a single operand.
This commit is contained in:
Andi Drebes
2023-06-09 16:57:30 +02:00
committed by Antoniu Pop
parent 81eaaa7560
commit 549d2ded86

View File

@@ -30,21 +30,21 @@ class TFHE_BatchableBinaryOp<
{
let extraClassDeclaration = [{
struct BatchingVariant {
static const unsigned BATCHED_SCALAR = 0;
static const unsigned SCALAR_BATCHED = 1;
static const unsigned ALL_BATCHED = 2;
static const unsigned ALL_BATCHED = 0;
static const unsigned BATCHED_SCALAR = 1;
static const unsigned SCALAR_BATCHED = 2;
};
unsigned getNumBatchingVariants() { return 3; }
::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands(unsigned variant) {
switch(variant) {
case BatchingVariant::ALL_BATCHED:
return getOperation()->getOpOperands();
case BatchingVariant::BATCHED_SCALAR:
return getOperation()->getOpOperands().take_front();
case BatchingVariant::SCALAR_BATCHED:
return getOperation()->getOpOperands().drop_front().take_front();
case BatchingVariant::ALL_BATCHED:
return getOperation()->getOpOperands();
}
llvm_unreachable("Unknown batching variant");
@@ -61,6 +61,12 @@ class TFHE_BatchableBinaryOp<
::llvm::SmallVector<::mlir::Value> operands;
switch(variant) {
case BatchingVariant::ALL_BATCHED:
operands = batchedOperands;
return builder.create<}] # all_batched_opname # [{>(
mlir::TypeRange{resType},
operands,
getOperation()->getAttrs());
case BatchingVariant::BATCHED_SCALAR:
operands.push_back(batchedOperands[0]);
operands.push_back(hoistedNonBatchableOperands[0]);
@@ -75,12 +81,6 @@ class TFHE_BatchableBinaryOp<
mlir::TypeRange{resType},
operands,
getOperation()->getAttrs());
case BatchingVariant::ALL_BATCHED:
operands = batchedOperands;
return builder.create<}] # all_batched_opname # [{>(
mlir::TypeRange{resType},
operands,
getOperation()->getAttrs());
}
llvm_unreachable("Unknown batching variant");