mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA] Change kWidth parameter semantics
This PR changes kWidth semantics "from elements per instruction" to "elements per thread per instruction" along k axis.
This commit is contained in:
committed by
Lixun Zhang
parent
10795d8fd3
commit
d80cd2d374
@@ -2398,11 +2398,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
|
||||
%17 = tt.addptr %16, %8 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
|
||||
%18 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked>
|
||||
%19 = triton_gpu.convert_layout %18 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared1>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>>
|
||||
%21 = tt.load %13 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked>
|
||||
%22 = triton_gpu.convert_layout %21 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared2>
|
||||
%23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>>
|
||||
%24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, kDim = 8, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> -> tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>
|
||||
%23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>>
|
||||
%24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, kDim = 8, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> -> tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>
|
||||
%25 = triton_gpu.convert_layout %24 : (tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>) -> tensor<32x32xf32, #blocked>
|
||||
%26 = arith.truncf %25 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked>
|
||||
tt.store %17, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf16, #blocked>
|
||||
|
||||
Reference in New Issue
Block a user