From e36d1665ca2f816212fc80ee2633caa66a0066bf Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 18 Oct 2023 08:37:13 -0700 Subject: [PATCH] [BACKEND] Fix unsupported view op created during optimizations (#2510) When propagating layout we were generating a view op with mismatching total number of element per threads. Lowering such op would require exchanging data across threads. This change prevents the optimizer from generating such cases. This may require further optimizations in the future. --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 3 ++ .../TritonGPUToLLVM/ViewOpToLLVM.cpp | 2 ++ lib/Dialect/TritonGPU/IR/Dialect.cpp | 20 +++++++---- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 9 ++++- test/TritonGPU/canonicalize.mlir | 34 +++++++++++++++++++ 5 files changed, 61 insertions(+), 7 deletions(-) create mode 100644 test/TritonGPU/canonicalize.mlir diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 24d971701..998c6ecb9 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -107,6 +107,9 @@ bool isSharedEncoding(Value value); bool isExpensiveCat(CatOp cat, Attribute targetEncoding); +// Return true if a view between the two types cannot be implemented as a no-op. +bool isExpensiveView(Type srcType, Type dstType); + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index fdd47f2de..001c83783 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -147,6 +147,8 @@ struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern { matchAndRewrite(ViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + assert(!triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType()) && + "expensive view not supported"); auto resultTy = op.getType().template cast(); auto vals = this->getTypeConverter()->unpackLLElements( loc, adaptor.getSrc(), rewriter, op.getOperand().getType()); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 7bd34cd09..f9411221d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -339,6 +339,10 @@ SmallVector getShapePerCTATile(Attribute layout, return shape; } +bool isExpensiveView(Type srcType, Type dstType) { + return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); +} + namespace { /* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. @@ -1533,13 +1537,15 @@ struct CanonicalizeConvertFromView Operation *arg = op->getOperand(0).getDefiningOp(); if (!arg) return mlir::failure(); + auto convert = dyn_cast(arg); + if (!convert) + return failure(); + if (isExpensiveView(convert.getOperand().getType(), op.getType())) + return failure(); // view(convert) -> view - if (auto convert = dyn_cast(arg)) { - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType(), convert.getOperand()); - return mlir::success(); - } - return mlir::failure(); + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + convert.getOperand()); + return mlir::success(); } }; @@ -1584,6 +1590,8 @@ struct CanonicalizeConvertFromConvert return mlir::failure(); // cvt(view) -> view if (auto view = dyn_cast(arg)) { + if (isExpensiveView(view.getOperand().getType(), op.getType())) + return failure(); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), view.getResult()); return mlir::success(); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index f315fe5ad..8101ed45a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -342,8 +342,15 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { } return true; } + if (auto view = dyn_cast(op)) { + auto viewDstType = view.getType().cast(); + RankedTensorType newDstType = RankedTensorType::get( + viewDstType.getShape(), viewDstType.getElementType(), targetEncoding); + return !triton::gpu::isExpensiveView(view.getOperand().getType(), + newDstType); + } return isa(op); + triton::MakeRangeOp, triton::SplatOp>(op); } // diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir new file mode 100644 index 000000000..cb2bba970 --- /dev/null +++ b/test/TritonGPU/canonicalize.mlir @@ -0,0 +1,34 @@ +// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s + + +// CHECK-LABEL: @test_canonicalize_convert_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32 +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: %[[V:.+]] = tt.view %[[ARG]] +// CHECK: tt.return %[[V]] +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { + %c = triton_gpu.convert_layout %arg0 : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #blocked2> + %r = tt.view %c : (tensor<64x64xf32, #blocked2>) -> tensor<4096xf32, #blocked1> + tt.return %r : tensor<4096xf32, #blocked1> +} + +// ----- + +// test that the convert doesn't get combined with view if the resulting operations +// is an expensive view which would require moving data across threads. +// CHECK-LABEL: @test_canonicalize_convert_expensive_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32 +// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]] +// CHECK: %[[V:.+]] = tt.view %[[C]] +// CHECK: tt.return %[[V]] +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { + %c = triton_gpu.convert_layout %arg0 : (tensor<256x16xf32, #blocked0>) -> tensor<256x16xf32, #blocked2> + %r = tt.view %c : (tensor<256x16xf32, #blocked2>) -> tensor<4096xf32, #blocked1> + tt.return %r : tensor<4096xf32, #blocked1> +}