mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[BACKEND] Fix unsupported view op created during optimizations (#2510)
When propagating layout we were generating a view op with mismatching total number of element per threads. Lowering such op would require exchanging data across threads. This change prevents the optimizer from generating such cases. This may require further optimizations in the future.
This commit is contained in:
@@ -107,6 +107,9 @@ bool isSharedEncoding(Value value);
|
||||
|
||||
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
|
||||
|
||||
// Return true if a view between the two types cannot be implemented as a no-op.
|
||||
bool isExpensiveView(Type srcType, Type dstType);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -147,6 +147,8 @@ struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
|
||||
matchAndRewrite(ViewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
assert(!triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType()) &&
|
||||
"expensive view not supported");
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
auto vals = this->getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
|
||||
|
||||
@@ -339,6 +339,10 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
return shape;
|
||||
}
|
||||
|
||||
bool isExpensiveView(Type srcType, Type dstType) {
|
||||
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
|
||||
@@ -1533,13 +1537,15 @@ struct CanonicalizeConvertFromView
|
||||
Operation *arg = op->getOperand(0).getDefiningOp();
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
auto convert = dyn_cast<ConvertLayoutOp>(arg);
|
||||
if (!convert)
|
||||
return failure();
|
||||
if (isExpensiveView(convert.getOperand().getType(), op.getType()))
|
||||
return failure();
|
||||
// view(convert) -> view
|
||||
if (auto convert = dyn_cast<ConvertLayoutOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(
|
||||
op, op->getResult(0).getType(), convert.getOperand());
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(op, op->getResult(0).getType(),
|
||||
convert.getOperand());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1584,6 +1590,8 @@ struct CanonicalizeConvertFromConvert
|
||||
return mlir::failure();
|
||||
// cvt(view) -> view
|
||||
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
|
||||
if (isExpensiveView(view.getOperand().getType(), op.getType()))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(
|
||||
op, op->getResult(0).getType(), view.getResult());
|
||||
return mlir::success();
|
||||
|
||||
@@ -342,8 +342,15 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (auto view = dyn_cast<triton::ViewOp>(op)) {
|
||||
auto viewDstType = view.getType().cast<RankedTensorType>();
|
||||
RankedTensorType newDstType = RankedTensorType::get(
|
||||
viewDstType.getShape(), viewDstType.getElementType(), targetEncoding);
|
||||
return !triton::gpu::isExpensiveView(view.getOperand().getType(),
|
||||
newDstType);
|
||||
}
|
||||
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
|
||||
triton::MakeRangeOp, triton::SplatOp>(op);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
34
test/TritonGPU/canonicalize.mlir
Normal file
34
test/TritonGPU/canonicalize.mlir
Normal file
@@ -0,0 +1,34 @@
|
||||
// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s
|
||||
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_convert_view
|
||||
// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: %[[V:.+]] = tt.view %[[ARG]]
|
||||
// CHECK: tt.return %[[V]]
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
|
||||
%c = triton_gpu.convert_layout %arg0 : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #blocked2>
|
||||
%r = tt.view %c : (tensor<64x64xf32, #blocked2>) -> tensor<4096xf32, #blocked1>
|
||||
tt.return %r : tensor<4096xf32, #blocked1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test that the convert doesn't get combined with view if the resulting operations
|
||||
// is an expensive view which would require moving data across threads.
|
||||
// CHECK-LABEL: @test_canonicalize_convert_expensive_view
|
||||
// CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32
|
||||
// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]]
|
||||
// CHECK: %[[V:.+]] = tt.view %[[C]]
|
||||
// CHECK: tt.return %[[V]]
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
|
||||
%c = triton_gpu.convert_layout %arg0 : (tensor<256x16xf32, #blocked0>) -> tensor<256x16xf32, #blocked2>
|
||||
%r = tt.view %c : (tensor<256x16xf32, #blocked2>) -> tensor<4096xf32, #blocked1>
|
||||
tt.return %r : tensor<4096xf32, #blocked1>
|
||||
}
|
||||
Reference in New Issue
Block a user