mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix unsupported view op created during optimizations (#2510)
When propagating layout we were generating a view op with mismatching total number of element per threads. Lowering such op would require exchanging data across threads. This change prevents the optimizer from generating such cases. This may require further optimizations in the future.
This commit is contained in:
34
test/TritonGPU/canonicalize.mlir
Normal file
34
test/TritonGPU/canonicalize.mlir
Normal file
@@ -0,0 +1,34 @@
|
||||
// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s
|
||||
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_convert_view
|
||||
// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: %[[V:.+]] = tt.view %[[ARG]]
|
||||
// CHECK: tt.return %[[V]]
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
|
||||
%c = triton_gpu.convert_layout %arg0 : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #blocked2>
|
||||
%r = tt.view %c : (tensor<64x64xf32, #blocked2>) -> tensor<4096xf32, #blocked1>
|
||||
tt.return %r : tensor<4096xf32, #blocked1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test that the convert doesn't get combined with view if the resulting operations
|
||||
// is an expensive view which would require moving data across threads.
|
||||
// CHECK-LABEL: @test_canonicalize_convert_expensive_view
|
||||
// CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32
|
||||
// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]]
|
||||
// CHECK: %[[V:.+]] = tt.view %[[C]]
|
||||
// CHECK: tt.return %[[V]]
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
|
||||
%c = triton_gpu.convert_layout %arg0 : (tensor<256x16xf32, #blocked0>) -> tensor<256x16xf32, #blocked2>
|
||||
%r = tt.view %c : (tensor<256x16xf32, #blocked2>) -> tensor<4096xf32, #blocked1>
|
||||
tt.return %r : tensor<4096xf32, #blocked1>
|
||||
}
|
||||
Reference in New Issue
Block a user