mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-19 08:54:26 -05:00
The current batching pass only supports batching of operations that
have a single batchable operand, that can only be batched in one way
and that operate on scalar values. However, this does not allow for
efficient batching of all arithmetic operations in TFHE, since these
are often applied to pairs of scalar values from tensors, to tensors
and scalars or to tensors that can be grouped in higher-order tensors.
This commit introduces three new features for batching:
1. Support of multiple batchable operands
The operation interface for batching now allows for the
specification of multiple batchable operands. This set can be
composed of any subset of an operation's operands, i.e., it is
not limited to sets of operands with contiguous operand indexes.
2. Support for multiple batching variants
To account for multiple kinds of batching, the batching operation
interface `BatchableOpInterface` now supports variants. The
batching pass attempts to batch an operation by trying the
batching variants expressed via the interface in order until it
succeeds.
3. Support for batching of tensor values
Some operations that could be batched already operate on tensor
values. The new batching pass detects those patterns and groups
the batchable tensors' values into higher-dimensional tensors.
55 lines
1.7 KiB
TableGen
55 lines
1.7 KiB
TableGen
#ifndef CONCRETELANG_INTERFACES_BATCHABLEINTERFACE
|
|
#define CONCRETELANG_INTERFACES_BATCHABLEINTERFACE
|
|
|
|
include "mlir/IR/OpBase.td"
|
|
|
|
def BatchableOpInterface : OpInterface<"BatchableOpInterface"> {
|
|
let description = [{
|
|
Interface for operations processing a scalar that can be batched
|
|
if invoked multiple times with different, independent operands.
|
|
}];
|
|
let cppNamespace = "::mlir::concretelang";
|
|
|
|
let methods = [
|
|
InterfaceMethod<[{
|
|
Return the number of batching schemes for the operation.
|
|
}],
|
|
/*retTy=*/"unsigned",
|
|
/*methodName=*/"getNumBatchingVariants",
|
|
/*args=*/(ins),
|
|
/*methodBody=*/"",
|
|
/*defaultImplementation=*/[{
|
|
return 1;
|
|
}]
|
|
>,
|
|
InterfaceMethod<[{
|
|
Return the scalar operands that can be batched in a tensor to
|
|
be passed to the corresponding batched operation.
|
|
}],
|
|
/*retTy=*/"::llvm::MutableArrayRef<::mlir::OpOperand>",
|
|
/*methodName=*/"getBatchableOperands",
|
|
/*args=*/(ins "unsigned":$variant),
|
|
/*methodBody=*/"",
|
|
/*defaultImplementation=*/[{
|
|
llvm_unreachable("getBatchableOperands not implemented");
|
|
}]
|
|
>,
|
|
InterfaceMethod<[{
|
|
Create the batched operation and return it as a value.
|
|
}],
|
|
/*retTy=*/"::mlir::Value",
|
|
/*methodName=*/"createBatchedOperation",
|
|
/*args=*/(ins "unsigned":$variant,
|
|
"::mlir::ImplicitLocOpBuilder&":$builder,
|
|
"::mlir::ValueRange":$batchedOperands,
|
|
"::mlir::ValueRange":$hoistedNonBatchableOperands),
|
|
/*methodBody=*/"",
|
|
/*defaultImplementation=*/[{
|
|
llvm_unreachable("createBatchedOperation not implemented");
|
|
}]
|
|
>
|
|
];
|
|
}
|
|
|
|
#endif // CONCRETELANG_INTERFACES_BATCHABLEINTERFACE
|