feat(compiler): Add batching pass

This adds a new pass that is able to hoist operations implementing the
`BatchableOpInterface` out of a loop nest that applies the operation
to the elements of a tensor indexed by the loop induction variables.

Example:

  scf.for %i = c0 to %cN step %c1 {
    scf.for %j = c0 to %cM step %c1 {
      scf.for %k = c0 to %cK step %c1 {
        %s = tensor.extract %T[%i, %j, %k]
        %res = batchable_op %s
        ...
      }
    }
  }

is replaced with:

  %batchedSlice = tensor.extract_slice
       %T[%c0, %c0, %c0] [%cN, %cM, %cK] [%c1, %c1, %c1]
  %flatSlice = tensor.collapse_shape %batchedSlice
  %resTFlat = batchedOp %flatSlice
  %resT = tensor.expand_shape %resTFlat

  scf.for %i = c0 to %cN step %c1 {
    scf.for %j = c0 to %cM step %c1 {
      scf.for %k = c0 to %cK step %c1 {
        %res = tensor.extract %resT[%i, %j, %k]
        ...
      }
    }
  }

Every index of the tensor with the input values may be a quasi-affine
expression on a single loop induction variable, as long as the
difference between the results of the expression for any two
consecutive values of the referenced loop induction variable is
constant.
This commit is contained in:
Andi Drebes
2022-11-10 16:44:42 +01:00
parent 3ce7c96f3f
commit c367a4b6fd
4 changed files with 931 additions and 0 deletions

View File

@@ -18,6 +18,7 @@ namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createForLoopToParallel();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createBatchingPass();
} // namespace concretelang
} // namespace mlir

View File

@@ -11,4 +11,11 @@ def ForLoopToParallel : Pass<"for-loop-to-parallel", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::scf::SCFDialect"];
}
def Batching : Pass<"concrete", "mlir::ModuleOp"> {
let summary =
"Hoists operation for which a batched version exists out of loops applying "
"the operation to values stored in a tensor.";
let constructor = "mlir::concretelang::createBatchingPass()";
}
#endif