This adds a new pattern to the batching pass that folds operations on
tensors of constants into new tensors of constants. E.g.,
%cst = arith.constant dense<...> : tensor<Nxi9>
%res = scf.for %i = %c0 to %cN {
%cst_i9 = tensor.extract %cst[%i]
%cst_i64 = arith.extui %cst_i9 : i64
...
}
becomes:
%cst = arith.constant dense<...> : tensor<Nxi64>
%res = scf.for %i = %c0 to %cN {
%cst_i64 = tensor.extract %cst[%i]
...
}
The pattern only works for static loops, indexes that are quasi-affine
expressions on single loop induction variables with a constant step
size across iterations and foldable operations that have a single
result.
The current batching pass only supports batching of operations that
have a single batchable operand, that can only be batched in one way
and that operate on scalar values. However, this does not allow for
efficient batching of all arithmetic operations in TFHE, since these
are often applied to pairs of scalar values from tensors, to tensors
and scalars or to tensors that can be grouped in higher-order tensors.
This commit introduces three new features for batching:
1. Support of multiple batchable operands
The operation interface for batching now allows for the
specification of multiple batchable operands. This set can be
composed of any subset of an operation's operands, i.e., it is
not limited to sets of operands with contiguous operand indexes.
2. Support for multiple batching variants
To account for multiple kinds of batching, the batching operation
interface `BatchableOpInterface` now supports variants. The
batching pass attempts to batch an operation by trying the
batching variants expressed via the interface in order until it
succeeds.
3. Support for batching of tensor values
Some operations that could be batched already operate on tensor
values. The new batching pass detects those patterns and groups
the batchable tensors' values into higher-dimensional tensors.
The batching pass erroneously assumes that any expression solely
composed of an induction variable has static bounds. This commit adds
a test for the lower bound, upper bound and step checking that they
are indeed static before attempting to determine their static values.
The current scheme with in-place updates of the types of values may
result in operations recognized as legal and thus preventing them from
being converted when the operations producing their operands have been
converted earlier, as their types have been updated and legality is
solely based on types.
For example, the conversion pattern for an `tensor.insert_slice`
operation working on tensors of encrypted values may not trigger if
the operations producing its operands have been converted, leaving the
operation with updated operand types with the extra dimension added by
the type conversion from TFHE to Concrete, but with unmodified sizes,
strides and offsets, not taking into account the extra dimension. This
causes the verifier of the affected operation to fail and the
compilation to abort.
By using op conversion patterns, the original types of each operation
are preserved during the actual rewrite, correctly triggering all
conversion patterns based on the legality of data types.
The reinstantianting rewrite pattern for `scf.for` operations,
`TypeConvertingReinstantiationPattern<scf::ForOp, false>`, calls
`mlir::ConversionPatternRewriter::replaceOpWithNewOp()` before moving
the operations of the original loop to the newly created loop. Since
`replaceOpWithNewOp()` indirectly marks all operations of the old loop
as ignored for dialect conversion, the dialect converter never
descends recursively into the newly created loop.
This causes operations that are illegal to be preserved, which results
in illegal IR after dialect conversion.
This commit splits the replacement into three steps:
1. Creation of the new loop via
mlir::ConversionPatternRewriter::create()`
2. Moving operations from the old loop to the newly created one
3. Replacement of the original loop with the results of the new one
via `mlir::ConversionPatternRewriter::replaceOp()`
This causes the operations of the loops not to be ignored and fixes
dialect conversion.