refactor(compiler): Use signature conversion for conversion of ops with nested blocks

The current scheme used by reinstantiating conversion patterns in
`lib/Conversion/Utils/Dialects` for operations with blocks is to
create a new operation with empty blocks, to move the operations from
the old blocks and then to replace any references to block
arguments. However, such in-place updates of the types of block
arguments leave conversion patterns for operations nested in the
blocks without the ability to determine the original types of values
from before the update.

This change uses proper signature conversion for block arguments, such
that the original types of block arguments with converted types is
preserved, while the new types are made available through the dialect
conversion infrastructure via the respective adaptors.
This commit is contained in:
Andi Drebes
2024-04-03 12:07:03 +02:00
parent 48d919bd25
commit 3ad3dcb08f
4 changed files with 48 additions and 92 deletions

View File

@@ -22,6 +22,12 @@ mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter,
mlir::Location loc,
mlir::ArrayAttr arrAttr);
mlir::Operation *convertOpWithBlocks(mlir::Operation *op,
mlir::ValueRange newOperands,
mlir::TypeRange newResultTypes,
mlir::TypeConverter &typeConverter,
mlir::ConversionPatternRewriter &rewriter);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -4,7 +4,7 @@
// for license information.
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
#include "mlir/Transforms/RegionUtils.h"
#include "concretelang/Conversion/Utils/Utils.h"
namespace mlir {
namespace concretelang {
@@ -13,29 +13,16 @@ mlir::LogicalResult
TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
scf::ForOp oldOp, mlir::OpConversionPattern<scf::ForOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
// Create new for loop with empty body, but converted iter args
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
oldOp.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(),
adaptor.getStep(), adaptor.getInitArgs(),
[&](OpBuilder &builder, Location loc, Value iv, ValueRange args) {});
mlir::TypeConverter &typeConverter = *getTypeConverter();
llvm::SmallVector<mlir::Type> convertedResultTypes;
newForOp->setAttrs(adaptor.getAttributes());
// Move operations from old for op to new one
auto &newOperations = newForOp.getBody()->getOperations();
mlir::Block *oldBody = oldOp.getBody();
newOperations.splice(newOperations.begin(), oldBody->getOperations(),
oldBody->begin(), oldBody->end());
// Remap iter args and IV
for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(),
newForOp.getBody()->getArguments())) {
replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair),
newForOp.getRegion());
if (typeConverter.convertTypes(oldOp.getResultTypes(), convertedResultTypes)
.failed()) {
return mlir::failure();
}
rewriter.replaceOp(oldOp, newForOp.getResults());
convertOpWithBlocks(oldOp, adaptor.getOperands(), convertedResultTypes,
typeConverter, rewriter);
return mlir::success();
}
@@ -49,56 +36,10 @@ TypeConvertingReinstantiationPattern<scf::ForallOp, false>::matchAndRewrite(
scf::ForallOp oldOp,
mlir::OpConversionPattern<scf::ForallOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
// Create new forall operation with empty body, but converted iter
// args
llvm::SmallVector<mlir::OpFoldResult> lbs = getMixedValues(
adaptor.getStaticLowerBound(), adaptor.getDynamicLowerBound(), rewriter);
llvm::SmallVector<mlir::OpFoldResult> ubs = getMixedValues(
adaptor.getStaticUpperBound(), adaptor.getDynamicUpperBound(), rewriter);
llvm::SmallVector<mlir::OpFoldResult> step = getMixedValues(
adaptor.getStaticStep(), adaptor.getDynamicStep(), rewriter);
rewriter.setInsertionPoint(oldOp);
scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
oldOp.getLoc(), lbs, ubs, step, adaptor.getOutputs(),
adaptor.getMapping());
newForallOp->setAttrs(adaptor.getAttributes());
// Move operations from old for op to new one
auto &newOperations = newForallOp.getBody()->getOperations();
mlir::Block *oldBody = oldOp.getBody();
newOperations.splice(newOperations.begin(), oldBody->getOperations(),
oldBody->begin(), std::prev(oldBody->end()));
// Move operations from `scf.forall.in_parallel` terminator of the
// old op to the terminator of the new op
mlir::scf::InParallelOp oldTerminator =
llvm::dyn_cast<mlir::scf::InParallelOp>(*std::prev(oldBody->end()));
assert(oldTerminator && "Last operation of `scf.forall` op expected be a "
"`scf.forall.in_parallel` op");
mlir::scf::InParallelOp newTerminator = newForallOp.getTerminator();
mlir::Block::OpListType &oldTerminatorOps =
oldTerminator.getRegion().getBlocks().begin()->getOperations();
mlir::Block::OpListType &newTerminatorOps =
newTerminator.getRegion().getBlocks().begin()->getOperations();
newTerminatorOps.splice(newTerminatorOps.begin(), oldTerminatorOps,
oldTerminatorOps.begin(), oldTerminatorOps.end());
// Remap iter args and IV
for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(),
newForallOp.getBody()->getArguments())) {
std::get<0>(argsPair).replaceAllUsesWith(std::get<1>(argsPair));
}
rewriter.replaceOp(oldOp, newForallOp.getResults());
convertOpWithBlocks(oldOp, adaptor.getOperands(),
adaptor.getOutputs().getTypes(), *getTypeConverter(),
rewriter);
return mlir::success();
}

View File

@@ -4,8 +4,7 @@
// for license information.
#include "concretelang/Conversion/Utils/Dialects/Tensor.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "concretelang/Conversion/Utils/Utils.h"
namespace mlir {
namespace concretelang {
@@ -79,26 +78,8 @@ TypeConvertingReinstantiationPattern<tensor::GenerateOp, true>::matchAndRewrite(
mlir::ConversionPatternRewriter &rewriter) const {
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
rewriter.setInsertionPointAfter(oldOp);
tensor::GenerateOp newGenerateOp = rewriter.create<tensor::GenerateOp>(
oldOp.getLoc(), resultTypes, adaptor.getOperands(), oldOp->getAttrs());
mlir::Block &oldBlock = oldOp.getBody().getBlocks().front();
mlir::Block &newBlock = newGenerateOp.getBody().getBlocks().front();
auto begin = oldBlock.begin();
auto nOps = oldBlock.getOperations().size();
newBlock.getOperations().splice(newBlock.getOperations().begin(),
oldBlock.getOperations(), begin,
std::next(begin, nOps - 1));
for (auto argsPair : llvm::zip(oldOp.getRegion().getArguments(),
newGenerateOp.getRegion().getArguments())) {
replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair),
newGenerateOp.getRegion());
}
rewriter.replaceOp(oldOp, newGenerateOp.getResult());
convertOpWithBlocks(oldOp, adaptor.getOperands(), resultTypes,
*getTypeConverter(), rewriter);
return mlir::success();
}

View File

@@ -6,6 +6,7 @@
#include "concretelang/Conversion/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
#include "mlir/Transforms/RegionUtils.h"
namespace mlir {
namespace concretelang {
@@ -56,5 +57,32 @@ mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter,
return mlir::concretelang::getCastedMemRef(rewriter, globalRef);
}
// Converts an operation `op` with nested blocks using a type
// converter and a conversion pattern rewriter, such that the newly
// created operation uses the operands specified in `newOperands` and
// returns a value of the types `newResultTypes`.
mlir::Operation *
convertOpWithBlocks(mlir::Operation *op, mlir::ValueRange newOperands,
mlir::TypeRange newResultTypes,
mlir::TypeConverter &typeConverter,
mlir::ConversionPatternRewriter &rewriter) {
mlir::OperationState state(op->getLoc(), op->getName().getStringRef(),
newOperands, newResultTypes, op->getAttrs(),
op->getSuccessors());
for (Region &region : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)typeConverter.convertSignatureArgs(newRegion->getArgumentTypes(),
result);
rewriter.applySignatureConversion(newRegion, result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return newOp;
}
} // namespace concretelang
} // namespace mlir