mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Fix the issue when CTA coverage is larger than the tile
This commit is contained in:
@@ -271,8 +271,9 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= type.getShape()[dim] /
|
||||
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]);
|
||||
stride *= std::max<unsigned int>(
|
||||
1, type.getShape()[dim] /
|
||||
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]));
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
@@ -390,7 +391,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getKWidth() == 8 &&
|
||||
dotOperandLayout.getParent() == mfmaLayout &&
|
||||
mfmaLayout.getIsTransposed() && (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
|
||||
mfmaLayout.getIsTransposed() &&
|
||||
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -235,11 +235,11 @@ static Value shflUp_amd(Location loc, ConversionPatternRewriter &rewriter,
|
||||
GCNBuilder builder;
|
||||
Value threadId =
|
||||
rewriter
|
||||
.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{i32_ty},
|
||||
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||
.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{i32_ty},
|
||||
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)})
|
||||
.getResult(0);
|
||||
.getResult(0);
|
||||
Value warpSize = i32_val(64);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value mask = icmp_slt(laneId, i32_val(i));
|
||||
|
||||
@@ -1188,6 +1188,62 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
# ---------------
|
||||
# test scan
|
||||
# ---------------
|
||||
|
||||
|
||||
scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)]
|
||||
|
||||
scan_configs = [
|
||||
(op, type, shape, axis, num_warps)
|
||||
for num_warps in [4, 16]
|
||||
for type in ['int32', 'float32']
|
||||
for axis in [1, 0]
|
||||
for shape in scan2d_shapes
|
||||
for op in ['cumsum', 'cumprod']
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs)
|
||||
def test_scan2d(op, dtype_str, shape, axis, num_warps, device):
|
||||
check_type_supported(dtype_str, device)
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
range_m = tl.arange(0, BLOCK_M)
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis={axis})'})
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
z = np.empty_like(x)
|
||||
x_tri = to_triton(x, device=device)
|
||||
numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op]
|
||||
z_dtype_str = dtype_str
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(z, device=device)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
if dtype_str == 'float32':
|
||||
if op == 'cumprod':
|
||||
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01, atol=1e-3)
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# ---------------
|
||||
# test permute
|
||||
# ---------------
|
||||
@@ -2555,6 +2611,75 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
||||
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
scan_layouts = [
|
||||
BlockedLayout([1, 4], [4, 16], [4, 1], [0, 1]),
|
||||
BlockedLayout([1, 4], [8, 8], [4, 1], [0, 1]),
|
||||
BlockedLayout([4, 1], [4, 16], [1, 4], [0, 1]),
|
||||
BlockedLayout([2, 2], [4, 16], [2, 2], [0, 1]),
|
||||
BlockedLayout([2, 2], [8, 8], [2, 2], [0, 1]),
|
||||
|
||||
BlockedLayout([1, 4], [4, 16], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [8, 8], [4, 1], [1, 0]),
|
||||
BlockedLayout([4, 1], [4, 16], [1, 4], [1, 0]),
|
||||
BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0]),
|
||||
BlockedLayout([2, 2], [8, 8], [2, 2], [1, 0]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N", [[32, 32], [32, 64], [64, 32]])
|
||||
@pytest.mark.parametrize("src_layout", scan_layouts)
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
def test_scan_layouts(M, N, src_layout, axis, device):
|
||||
ir = f"""
|
||||
#blocked = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
|
||||
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked>
|
||||
%3 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
|
||||
%4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
|
||||
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
|
||||
%7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
|
||||
%8 = tt.broadcast %6 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
|
||||
%9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked>
|
||||
%11 = "tt.scan"(%10) <{{axis = {axis} : i32}}> ({{
|
||||
^bb0(%arg2: i32, %arg3: i32):
|
||||
%16 = arith.addi %arg2, %arg3 : i32
|
||||
tt.scan.return %16 : i32
|
||||
}}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
|
||||
%13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
|
||||
%15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
tt.store %15, %11 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{M}x{N}xi32, #blocked>
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
kernel = triton.compile(f.name)
|
||||
rs = RandomState(17)
|
||||
x = rs.randint(-100, 100, (M, N)).astype('int32')
|
||||
|
||||
z = np.zeros((M, N)).astype('int32')
|
||||
x_tri = torch.tensor(x, device=device)
|
||||
z_tri = torch.tensor(z, device=device)
|
||||
|
||||
kernel[(1, 1, 1)](x_tri, z_tri)
|
||||
|
||||
z_ref = np.cumsum(x, axis=axis)
|
||||
|
||||
np.testing.assert_equal(z_ref, z_tri.cpu().numpy())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(64, 64)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], isTransposed=False), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed=True)])
|
||||
|
||||
Reference in New Issue
Block a user