[BACKEND] Fix unsupported view op created during optimizations (#2510)

When propagating layout we were generating a view op with mismatching
total number of element per threads. Lowering such op would require
exchanging data across threads.
This change prevents the optimizer from generating such cases. This may
require further optimizations in the future.
This commit is contained in:
Thomas Raoux
2023-10-18 08:37:13 -07:00
committed by GitHub
parent 768fc1fcd9
commit e36d1665ca
5 changed files with 61 additions and 7 deletions

View File

@@ -107,6 +107,9 @@ bool isSharedEncoding(Value value);
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
// Return true if a view between the two types cannot be implemented as a no-op.
bool isExpensiveView(Type srcType, Type dstType);
} // namespace gpu
} // namespace triton
} // namespace mlir

View File

@@ -147,6 +147,8 @@ struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
matchAndRewrite(ViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
assert(!triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType()) &&
"expensive view not supported");
auto resultTy = op.getType().template cast<RankedTensorType>();
auto vals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());

View File

@@ -339,6 +339,10 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
return shape;
}
bool isExpensiveView(Type srcType, Type dstType) {
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
}
namespace {
/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
@@ -1533,13 +1537,15 @@ struct CanonicalizeConvertFromView
Operation *arg = op->getOperand(0).getDefiningOp();
if (!arg)
return mlir::failure();
auto convert = dyn_cast<ConvertLayoutOp>(arg);
if (!convert)
return failure();
if (isExpensiveView(convert.getOperand().getType(), op.getType()))
return failure();
// view(convert) -> view
if (auto convert = dyn_cast<ConvertLayoutOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::ViewOp>(
op, op->getResult(0).getType(), convert.getOperand());
return mlir::success();
}
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::ViewOp>(op, op->getResult(0).getType(),
convert.getOperand());
return mlir::success();
}
};
@@ -1584,6 +1590,8 @@ struct CanonicalizeConvertFromConvert
return mlir::failure();
// cvt(view) -> view
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
if (isExpensiveView(view.getOperand().getType(), op.getType()))
return failure();
rewriter.replaceOpWithNewOp<triton::ViewOp>(
op, op->getResult(0).getType(), view.getResult());
return mlir::success();

View File

@@ -342,8 +342,15 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
}
return true;
}
if (auto view = dyn_cast<triton::ViewOp>(op)) {
auto viewDstType = view.getType().cast<RankedTensorType>();
RankedTensorType newDstType = RankedTensorType::get(
viewDstType.getShape(), viewDstType.getElementType(), targetEncoding);
return !triton::gpu::isExpensiveView(view.getOperand().getType(),
newDstType);
}
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
triton::MakeRangeOp, triton::SplatOp>(op);
}
//

View File

@@ -0,0 +1,34 @@
// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s
// CHECK-LABEL: @test_canonicalize_convert_view
// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: %[[V:.+]] = tt.view %[[ARG]]
// CHECK: tt.return %[[V]]
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
%c = triton_gpu.convert_layout %arg0 : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #blocked2>
%r = tt.view %c : (tensor<64x64xf32, #blocked2>) -> tensor<4096xf32, #blocked1>
tt.return %r : tensor<4096xf32, #blocked1>
}
// -----
// test that the convert doesn't get combined with view if the resulting operations
// is an expensive view which would require moving data across threads.
// CHECK-LABEL: @test_canonicalize_convert_expensive_view
// CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32
// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]]
// CHECK: %[[V:.+]] = tt.view %[[C]]
// CHECK: tt.return %[[V]]
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
%c = triton_gpu.convert_layout %arg0 : (tensor<256x16xf32, #blocked0>) -> tensor<256x16xf32, #blocked2>
%r = tt.view %c : (tensor<256x16xf32, #blocked2>) -> tensor<4096xf32, #blocked1>
tt.return %r : tensor<4096xf32, #blocked1>
}