feat(compiler): Add reinstantiating rewrite patterns for scf.forall and scf.forall.in_parallel

This commit is contained in:
Andi Drebes
2023-09-19 06:44:27 +02:00
parent ff20f88c44
commit 64eaeb068e
2 changed files with 107 additions and 0 deletions

View File

@@ -23,6 +23,26 @@ TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
scf::ForOp oldOp, mlir::OpConversionPattern<scf::ForOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const;
//
// Specializations for ForallOp
//
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<scf::ForallOp, false>::matchAndRewrite(
scf::ForallOp oldOp,
mlir::OpConversionPattern<scf::ForallOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const;
//
// Specializations for InParallelOp
//
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<scf::InParallelOp, false>::matchAndRewrite(
scf::InParallelOp oldOp,
mlir::OpConversionPattern<scf::InParallelOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const;
} // namespace concretelang
} // namespace mlir

View File

@@ -39,5 +39,92 @@ TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
return mlir::success();
}
//
// Specializations for ForallOp
//
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<scf::ForallOp, false>::matchAndRewrite(
scf::ForallOp oldOp,
mlir::OpConversionPattern<scf::ForallOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
// Create new forall operation with empty body, but converted iter
// args
llvm::SmallVector<mlir::OpFoldResult> lbs = getMixedValues(
adaptor.getStaticLowerBound(), adaptor.getDynamicLowerBound(), rewriter);
llvm::SmallVector<mlir::OpFoldResult> ubs = getMixedValues(
adaptor.getStaticUpperBound(), adaptor.getDynamicUpperBound(), rewriter);
llvm::SmallVector<mlir::OpFoldResult> step = getMixedValues(
adaptor.getStaticStep(), adaptor.getDynamicStep(), rewriter);
rewriter.setInsertionPoint(oldOp);
scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
oldOp.getLoc(), lbs, ubs, step, adaptor.getOutputs(),
adaptor.getMapping());
newForallOp->setAttrs(adaptor.getAttributes());
// Move operations from old for op to new one
auto &newOperations = newForallOp.getBody()->getOperations();
mlir::Block *oldBody = oldOp.getBody();
newOperations.splice(newOperations.begin(), oldBody->getOperations(),
oldBody->begin(), std::prev(oldBody->end()));
// Move operations from `scf.forall.in_parallel` terminator of the
// old op to the terminator of the new op
mlir::scf::InParallelOp oldTerminator =
llvm::dyn_cast<mlir::scf::InParallelOp>(*std::prev(oldBody->end()));
assert(oldTerminator && "Last operation of `scf.forall` op expected be a "
"`scf.forall.in_parallel` op");
mlir::scf::InParallelOp newTerminator = newForallOp.getTerminator();
mlir::Block::OpListType &oldTerminatorOps =
oldTerminator.getRegion().getBlocks().begin()->getOperations();
mlir::Block::OpListType &newTerminatorOps =
newTerminator.getRegion().getBlocks().begin()->getOperations();
newTerminatorOps.splice(newTerminatorOps.begin(), oldTerminatorOps,
oldTerminatorOps.begin(), oldTerminatorOps.end());
// Remap iter args and IV
for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(),
newForallOp.getBody()->getArguments())) {
std::get<0>(argsPair).replaceAllUsesWith(std::get<1>(argsPair));
}
rewriter.replaceOp(oldOp, newForallOp.getResults());
return mlir::success();
}
//
// Specializations for InParallelOp
//
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<scf::InParallelOp, false>::matchAndRewrite(
scf::InParallelOp oldOp,
mlir::OpConversionPattern<scf::InParallelOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
// Create new for loop with empty body, but converted iter args
scf::InParallelOp newInParallelOp =
rewriter.replaceOpWithNewOp<scf::InParallelOp>(oldOp);
// Move operations from old for op to new one
auto &newOperations = newInParallelOp.getBody()->getOperations();
mlir::Block *oldBody = oldOp.getBody();
newOperations.splice(newOperations.begin(), oldBody->getOperations(),
oldBody->begin(), oldBody->end());
return mlir::success();
}
} // namespace concretelang
} // namespace mlir