mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
feat(compiler): Batching: Favor batching of all operands for binary TFHE ops
This changes the order of batching variants for binary TFHE operations, such that batching of both operands is favored over batching of a single operand.
This commit is contained in:
@@ -30,21 +30,21 @@ class TFHE_BatchableBinaryOp<
|
||||
{
|
||||
let extraClassDeclaration = [{
|
||||
struct BatchingVariant {
|
||||
static const unsigned BATCHED_SCALAR = 0;
|
||||
static const unsigned SCALAR_BATCHED = 1;
|
||||
static const unsigned ALL_BATCHED = 2;
|
||||
static const unsigned ALL_BATCHED = 0;
|
||||
static const unsigned BATCHED_SCALAR = 1;
|
||||
static const unsigned SCALAR_BATCHED = 2;
|
||||
};
|
||||
|
||||
unsigned getNumBatchingVariants() { return 3; }
|
||||
|
||||
::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands(unsigned variant) {
|
||||
switch(variant) {
|
||||
case BatchingVariant::ALL_BATCHED:
|
||||
return getOperation()->getOpOperands();
|
||||
case BatchingVariant::BATCHED_SCALAR:
|
||||
return getOperation()->getOpOperands().take_front();
|
||||
case BatchingVariant::SCALAR_BATCHED:
|
||||
return getOperation()->getOpOperands().drop_front().take_front();
|
||||
case BatchingVariant::ALL_BATCHED:
|
||||
return getOperation()->getOpOperands();
|
||||
}
|
||||
|
||||
llvm_unreachable("Unknown batching variant");
|
||||
@@ -61,6 +61,12 @@ class TFHE_BatchableBinaryOp<
|
||||
::llvm::SmallVector<::mlir::Value> operands;
|
||||
|
||||
switch(variant) {
|
||||
case BatchingVariant::ALL_BATCHED:
|
||||
operands = batchedOperands;
|
||||
return builder.create<}] # all_batched_opname # [{>(
|
||||
mlir::TypeRange{resType},
|
||||
operands,
|
||||
getOperation()->getAttrs());
|
||||
case BatchingVariant::BATCHED_SCALAR:
|
||||
operands.push_back(batchedOperands[0]);
|
||||
operands.push_back(hoistedNonBatchableOperands[0]);
|
||||
@@ -75,12 +81,6 @@ class TFHE_BatchableBinaryOp<
|
||||
mlir::TypeRange{resType},
|
||||
operands,
|
||||
getOperation()->getAttrs());
|
||||
case BatchingVariant::ALL_BATCHED:
|
||||
operands = batchedOperands;
|
||||
return builder.create<}] # all_batched_opname # [{>(
|
||||
mlir::TypeRange{resType},
|
||||
operands,
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
|
||||
llvm_unreachable("Unknown batching variant");
|
||||
|
||||
Reference in New Issue
Block a user