Files
concrete/compilers/concrete-compiler/compiler/lib/Dialect/RT/Transforms/HoistAwaitFuturePass.cpp
Andi Drebes d620fa9a44 feat(compiler): Add pass hoisting RT.await_future out of scf.forall loops
The new pass hoists `RT.await_future` operations whose results are
yielded by scf.forall operations out of the loops in order to avoid
over-synchronization of data-flow tasks.

E.g., the following IR:

```
scf.forall (%arg) in (16)
  shared_outs(%o1 = %sometensor, %o2 = %someothertensor)
  -> (tensor<...>, tensor<...>)
{
  ...
  %rph = "RT.build_return_ptr_placeholder"() :
    () -> !RT.rtptr<!RT.future<tensor<...>>>
  "RT.create_async_task"(..., %rph, ...) { ... } : ...
  %future = "RT.deref_return_ptr_placeholder"(%rph) :
    (!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
  %res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
  ...
  scf.forall.in_parallel {
    ...
    tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] :
      tensor<...> into tensor<...>
    ...
  }
}
```

is transformed into:

```
%tensoroffutures = tensor.empty() : tensor<16x!RT.future<tensor<...>>>

scf.forall (%arg) in (16)
  shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor)
  -> (tensor<...>, tensor<...>)
{
  ...
  %rph = "RT.build_return_ptr_placeholder"() :
    () -> !RT.rtptr<!RT.future<tensor<...>>>
  "RT.create_async_task"(..., %rph, ...) { ... } : ...
  %future = "RT.deref_return_ptr_placeholder"(%rph) :
    (!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
  %wrappedfuture = tensor.from_elements %future :
    tensor<1x!RT.future<tensor<...>>>
  ...
  scf.forall.in_parallel {
    ...
    tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] :
      tensor<1xRT.future<tensor<...>>> into tensor<16x!RT.future<tensor<...>>>
    ...
  }
}

scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) {
  %future = tensor.extract %tensoroffutures[%arg] :
    tensor<4x!RT.future<tensor<...>>>
  %res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
  scf.forall.in_parallel {
    tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] :
      tensor<...> into tensor<...>
  }
}
```
2024-04-08 16:16:07 +02:00

260 lines
10 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <concretelang/Analysis/StaticLoops.h>
#include <concretelang/Dialect/RT/IR/RTDialect.h>
#include <concretelang/Dialect/RT/IR/RTOps.h>
#include <concretelang/Dialect/RT/Transforms/Passes.h>
#include <mlir/Dialect/Utils/StaticValueUtils.h>
#include <iterator>
#include <optional>
namespace {
struct HoistAwaitFuturePass
: public HoistAwaitFuturePassBase<HoistAwaitFuturePass> {
// Checks if all values of `a` are sizes of a non-dynamic dimensions
bool allStatic(llvm::ArrayRef<int64_t> a) {
return llvm::all_of(
a, [](int64_t r) { return !mlir::ShapedType::isDynamic(r); });
}
void runOnOperation() override {
mlir::func::FuncOp func = getOperation();
llvm::SmallVector<mlir::Operation *> opsToErase;
func.walk([&](mlir::concretelang::RT::AwaitFutureOp awaitFutureOp) {
// Make sure there are no other consumers that rely on the
// synchronization
if (!awaitFutureOp.getResult().hasOneUse())
return;
mlir::scf::ForallOp forallOp =
llvm::dyn_cast<mlir::scf::ForallOp>(awaitFutureOp->getParentOp());
if (!forallOp)
return;
mlir::tensor::ParallelInsertSliceOp parallelInsertSliceOp =
llvm::dyn_cast<mlir::tensor::ParallelInsertSliceOp>(
awaitFutureOp.getResult().getUses().begin()->getOwner());
if (!parallelInsertSliceOp)
return;
// Make sure that the original tensor into which the
// synchronized values are inserted is a region out argument of
// the forall op and thus being written to concurrently
mlir::Value dst = parallelInsertSliceOp.getDest();
if (!llvm::any_of(forallOp.getRegionOutArgs(),
[=](mlir::Value output) { return output == dst; }))
return;
// Currently, the tensor storing the futures must have a static
// shape, so only loops with static trip counts are supported
if (!(allStatic(forallOp.getStaticLowerBound()) &&
allStatic(forallOp.getStaticUpperBound()) &&
allStatic(forallOp.getStaticStep())))
return;
llvm::SmallVector<int64_t> tripCounts;
for (auto [lb, ub, step] : llvm::zip_equal(forallOp.getStaticLowerBound(),
forallOp.getStaticUpperBound(),
forallOp.getStaticStep())) {
tripCounts.push_back(
mlir::concretelang::getStaticTripCount(lb, ub, step));
}
mlir::IRRewriter rewriter(&getContext());
rewriter.setInsertionPoint(forallOp);
mlir::Value tensorOfFutures = rewriter.create<mlir::tensor::EmptyOp>(
forallOp.getLoc(), tripCounts, awaitFutureOp.getInput().getType());
// Assemble the list of shared outputs that are to be preserved
// after the output storing the results of the `RT.await_future`
// has been removed
llvm::SmallVector<mlir::Value> newOutputs;
mlir::Value tensorOfValues;
size_t i = 0;
size_t oldResultIdx;
for (auto [output, regionOutArg] : llvm::zip_equal(
forallOp.getOutputs(), forallOp.getRegionOutArgs())) {
if (regionOutArg != dst) {
newOutputs.push_back(output);
} else {
tensorOfValues = output;
oldResultIdx = i;
}
i++;
}
newOutputs.push_back(tensorOfFutures);
// Create a new forall loop with the same shared outputs except
// for the one previously storing the contents of the
// `RT.await_future` ops is replaced with a tensor of futures
rewriter.setInsertionPointAfter(forallOp);
mlir::scf::ForallOp newForallOp = rewriter.create<mlir::scf::ForallOp>(
forallOp.getLoc(), forallOp.getMixedLowerBound(),
forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOutputs,
std::nullopt);
// Move all operations from the old forall op to the new one
auto &newOperations = newForallOp.getBody()->getOperations();
mlir::Block *oldBody = forallOp.getBody();
newOperations.splice(newOperations.begin(), oldBody->getOperations(),
oldBody->begin(), std::prev(oldBody->end()));
// Wrap future in a tensor of one element, so that it can be
// stored in the new shared output tensor of futures using
// `tensor.parallel_insert_slice`
rewriter.setInsertionPointAfter(awaitFutureOp);
mlir::Value futureAsTensor =
rewriter.create<mlir::tensor::FromElementsOp>(
awaitFutureOp.getLoc(),
mlir::ValueRange{awaitFutureOp.getInput()});
// Move all operations from the old `scf.forall.in_parallel`
// terminator to the new one
mlir::scf::InParallelOp oldTerminator = forallOp.getTerminator();
mlir::scf::InParallelOp newTerminator = newForallOp.getTerminator();
mlir::Block::OpListType &oldTerminatorOps =
oldTerminator.getRegion().getBlocks().begin()->getOperations();
mlir::Block::OpListType &newTerminatorOps =
newTerminator.getRegion().getBlocks().begin()->getOperations();
newTerminatorOps.splice(newTerminatorOps.begin(), oldTerminatorOps,
oldTerminatorOps.begin(), oldTerminatorOps.end());
// Remap IVs and out args
for (auto [oldIV, newIV] : llvm::zip(forallOp.getInductionVars(),
newForallOp.getInductionVars())) {
oldIV.replaceAllUsesWith(newIV);
}
{
size_t offs = 0;
for (auto it : llvm::enumerate(forallOp.getRegionOutArgs())) {
mlir::Value oldRegionOutArg = it.value();
if (oldRegionOutArg != dst) {
oldRegionOutArg.replaceAllUsesWith(
newForallOp.getRegionOutArgs()[it.index() - offs]);
} else {
offs++;
}
}
}
// Create new `tensor.parallel_inset_slice` operation inserting
// the future into the tensor of futures
llvm::SmallVector<mlir::OpFoldResult> ones(tripCounts.size(),
rewriter.getI64IntegerAttr(1));
mlir::Value tensorOfFuturesRegionOutArg =
newForallOp.getRegionOutArgs().back();
mlir::ImplicitLocOpBuilder ilob(parallelInsertSliceOp.getLoc(), rewriter);
rewriter.setInsertionPointAfter(parallelInsertSliceOp);
rewriter.create<mlir::tensor::ParallelInsertSliceOp>(
parallelInsertSliceOp.getLoc(), futureAsTensor,
tensorOfFuturesRegionOutArg,
mlir::getAsOpFoldResult(mlir::concretelang::normalizeInductionVars(
ilob, newForallOp.getInductionVars(),
newForallOp.getMixedLowerBound(), newForallOp.getMixedStep())),
ones, ones);
// Create a new forall loop, that invokes `RT.await_future` on
// all futures stored in the tensor of futures and writes the
// contents into the otiginal tensor with the results
rewriter.setInsertionPointAfter(newForallOp);
mlir::scf::ForallOp syncForallOp = rewriter.create<mlir::scf::ForallOp>(
forallOp.getLoc(), forallOp.getMixedLowerBound(),
forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
mlir::ValueRange{tensorOfValues}, std::nullopt);
mlir::Value resultTensorOfFutures = newForallOp.getResults().back();
rewriter.setInsertionPointToStart(syncForallOp.getBody());
mlir::Value extractedFuture = rewriter.create<mlir::tensor::ExtractOp>(
awaitFutureOp.getLoc(), resultTensorOfFutures,
syncForallOp.getInductionVars());
mlir::concretelang::RT::AwaitFutureOp newAwaitFutureOp =
rewriter.create<mlir::concretelang::RT::AwaitFutureOp>(
awaitFutureOp.getLoc(), awaitFutureOp.getResult().getType(),
extractedFuture);
mlir::IRMapping syncMapping;
for (auto [oldIV, newIV] :
llvm::zip_equal(newForallOp.getInductionVars(),
syncForallOp.getInductionVars())) {
syncMapping.map(oldIV, newIV);
}
syncMapping.map(dst, syncForallOp.getOutputBlockArguments().back());
syncMapping.map(parallelInsertSliceOp.getSource(),
newAwaitFutureOp.getResult());
mlir::scf::InParallelOp syncTerminator = syncForallOp.getTerminator();
rewriter.setInsertionPointToStart(syncTerminator.getBody());
rewriter.clone(*parallelInsertSliceOp.getOperation(), syncMapping);
// Replace uses of the results of the original forall loop with:
// either the corresponding result from the new forall loop if
// this is a result unrelated to the futures or with the result
// of the forall loop synchronizing the futures
{
size_t offs = 0;
for (size_t i = 0; i < forallOp.getNumResults(); i++) {
if (i == oldResultIdx) {
forallOp.getResult(i).replaceAllUsesWith(syncForallOp.getResult(0));
offs = 1;
} else {
forallOp.getResult(i).replaceAllUsesWith(
newForallOp.getResult(i - offs));
}
}
}
// Replace the use of the shared output with the results of the
// original forall loop with the tensor outside of the loop so
// that there are no more references to values that were local
// to the original forall loop, enabling safe erasing of the old
// operations within the original forall loop
dst.replaceAllUsesWith(
forallOp.getOutputs().drop_front(oldResultIdx).front());
parallelInsertSliceOp->erase();
awaitFutureOp.erase();
// Defer erasing the original parallel loop that contained the
// `RT.await_future` operation until later in order to not
// confuse the walk relying on the parent operation
opsToErase.push_back(forallOp);
});
for (mlir::Operation *op : opsToErase)
op->erase();
}
};
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<func::FuncOp>> createHoistAwaitFuturePass() {
return std::make_unique<HoistAwaitFuturePass>();
}
} // namespace concretelang
} // namespace mlir