mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
This adds a new pass that is able to hoist operations implementing the
`BatchableOpInterface` out of a loop nest that applies the operation
to the elements of a tensor indexed by the loop induction variables.
Example:
scf.for %i = c0 to %cN step %c1 {
scf.for %j = c0 to %cM step %c1 {
scf.for %k = c0 to %cK step %c1 {
%s = tensor.extract %T[%i, %j, %k]
%res = batchable_op %s
...
}
}
}
is replaced with:
%batchedSlice = tensor.extract_slice
%T[%c0, %c0, %c0] [%cN, %cM, %cK] [%c1, %c1, %c1]
%flatSlice = tensor.collapse_shape %batchedSlice
%resTFlat = batchedOp %flatSlice
%resT = tensor.expand_shape %resTFlat
scf.for %i = c0 to %cN step %c1 {
scf.for %j = c0 to %cM step %c1 {
scf.for %k = c0 to %cK step %c1 {
%res = tensor.extract %resT[%i, %j, %k]
...
}
}
}
Every index of the tensor with the input values may be a quasi-affine
expression on a single loop induction variable, as long as the
difference between the results of the expression for any two
consecutive values of the referenced loop induction variable is
constant.
26 lines
760 B
C++
26 lines
760 B
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#ifndef CONCRETELANG_TRANSFORMS_PASS_H
|
|
#define CONCRETELANG_TRANSFORMS_PASS_H
|
|
|
|
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
|
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
|
#include <mlir/Dialect/SCF/IR/SCF.h>
|
|
#include <mlir/Pass/Pass.h>
|
|
|
|
#define GEN_PASS_CLASSES
|
|
#include <concretelang/Transforms/Passes.h.inc>
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
|
|
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createForLoopToParallel();
|
|
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createBatchingPass();
|
|
} // namespace concretelang
|
|
} // namespace mlir
|
|
|
|
#endif
|