diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 24d971701..998c6ecb9 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -107,6 +107,9 @@ bool isSharedEncoding(Value value); bool isExpensiveCat(CatOp cat, Attribute targetEncoding); +// Return true if a view between the two types cannot be implemented as a no-op. +bool isExpensiveView(Type srcType, Type dstType); + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index fdd47f2de..001c83783 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -147,6 +147,8 @@ struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern { matchAndRewrite(ViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + assert(!triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType()) && + "expensive view not supported"); auto resultTy = op.getType().template cast(); auto vals = this->getTypeConverter()->unpackLLElements( loc, adaptor.getSrc(), rewriter, op.getOperand().getType()); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 7bd34cd09..f9411221d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -339,6 +339,10 @@ SmallVector getShapePerCTATile(Attribute layout, return shape; } +bool isExpensiveView(Type srcType, Type dstType) { + return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); +} + namespace { /* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. @@ -1533,13 +1537,15 @@ struct CanonicalizeConvertFromView Operation *arg = op->getOperand(0).getDefiningOp(); if (!arg) return mlir::failure(); + auto convert = dyn_cast(arg); + if (!convert) + return failure(); + if (isExpensiveView(convert.getOperand().getType(), op.getType())) + return failure(); // view(convert) -> view - if (auto convert = dyn_cast(arg)) { - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType(), convert.getOperand()); - return mlir::success(); - } - return mlir::failure(); + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + convert.getOperand()); + return mlir::success(); } }; @@ -1584,6 +1590,8 @@ struct CanonicalizeConvertFromConvert return mlir::failure(); // cvt(view) -> view if (auto view = dyn_cast(arg)) { + if (isExpensiveView(view.getOperand().getType(), op.getType())) + return failure(); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), view.getResult()); return mlir::success(); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index f315fe5ad..8101ed45a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -342,8 +342,15 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { } return true; } + if (auto view = dyn_cast(op)) { + auto viewDstType = view.getType().cast(); + RankedTensorType newDstType = RankedTensorType::get( + viewDstType.getShape(), viewDstType.getElementType(), targetEncoding); + return !triton::gpu::isExpensiveView(view.getOperand().getType(), + newDstType); + } return isa(op); + triton::MakeRangeOp, triton::SplatOp>(op); } // diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir new file mode 100644 index 000000000..cb2bba970 --- /dev/null +++ b/test/TritonGPU/canonicalize.mlir @@ -0,0 +1,34 @@ +// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s + + +// CHECK-LABEL: @test_canonicalize_convert_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32 +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: %[[V:.+]] = tt.view %[[ARG]] +// CHECK: tt.return %[[V]] +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { + %c = triton_gpu.convert_layout %arg0 : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #blocked2> + %r = tt.view %c : (tensor<64x64xf32, #blocked2>) -> tensor<4096xf32, #blocked1> + tt.return %r : tensor<4096xf32, #blocked1> +} + +// ----- + +// test that the convert doesn't get combined with view if the resulting operations +// is an expensive view which would require moving data across threads. +// CHECK-LABEL: @test_canonicalize_convert_expensive_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32 +// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]] +// CHECK: %[[V:.+]] = tt.view %[[C]] +// CHECK: tt.return %[[V]] +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { + %c = triton_gpu.convert_layout %arg0 : (tensor<256x16xf32, #blocked0>) -> tensor<256x16xf32, #blocked2> + %r = tt.view %c : (tensor<256x16xf32, #blocked2>) -> tensor<4096xf32, #blocked1> + tt.return %r : tensor<4096xf32, #blocked1> +}