mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Tests] Update test_core_amd.py from test_core.py.
1. Added some tests/configs are passing at the moment. 2. Changed switching off parts from commenting to `if torch.version.hip`.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# flake8: noqa: F821,F841
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -1026,6 +1027,58 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
# ---------------
|
||||
# test permute
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
[(dtype, shape, perm)
|
||||
# TODO: bfloat16
|
||||
for dtype in ['float16', 'float32']
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
if torch.version.hip is not None:
|
||||
if (dtype_str, shape) == ('float32', (128, 128)):
|
||||
pytest.skip("Not supported: memory out of resource.")
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn,
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
# input
|
||||
x = numpy_random(shape, dtype_str=dtype_str)
|
||||
# triton result
|
||||
z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||
z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_str)
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
z_tri, z_tri.stride(1), z_tri.stride(0),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
|
||||
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
# numpy result
|
||||
z_ref = x.transpose(*perm)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
||||
if torch.version.hip is None:
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
ptx = pgm_contiguous.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
# ---------------
|
||||
# test dot
|
||||
@@ -1034,7 +1087,7 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
|
||||
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
|
||||
for shape in [(64, 64, 64)]
|
||||
for shape in [(64, 64, 64), (16, 16, 16)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16', 'float32']
|
||||
@@ -1045,11 +1098,10 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
[128, 16, 32, 4],
|
||||
[32, 128, 64, 4],
|
||||
[128, 128, 64, 4],
|
||||
#[64, 128, 128, 4],
|
||||
[64, 128, 128, 4],
|
||||
[32, 128, 64, 2],
|
||||
[128, 128, 64, 2],
|
||||
#[64, 128, 128, 2]
|
||||
]
|
||||
[64, 128, 128, 2]]
|
||||
for allow_tf32 in [True]
|
||||
for col_a in [True, False]
|
||||
for col_b in [True, False]
|
||||
@@ -1063,6 +1115,13 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
if capability[0] == 7:
|
||||
if (M, N, K, num_warps) == (128, 256, 32, 8):
|
||||
pytest.skip("shared memory out of resource")
|
||||
|
||||
if torch.version.hip is not None:
|
||||
if (M, N, K) == (64, 128, 128):
|
||||
pytest.skip("Not supported: memory out of resource.")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
@@ -1175,16 +1234,19 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
#ptx = pgm.asm['ptx']
|
||||
#assert 'ld.global.v4' in ptx
|
||||
#assert 'st.global.v4' in ptx
|
||||
#if dtype == 'float32' and allow_tf32:
|
||||
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
#elif dtype == 'float32' and allow_tf32:
|
||||
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
#elif dtype == 'int8':
|
||||
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
if torch.version.hip is None:
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
if K > 16 or N > 16 or M > 16:
|
||||
# XXX: skip small sizes because they are not vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -1243,6 +1305,57 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
triton.testing.allclose(output, reference_out)
|
||||
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
M = 32
|
||||
N = 32
|
||||
K = 16
|
||||
|
||||
in1 = torch.rand((M, K), dtype=dtype, device=device)
|
||||
in2 = torch.rand((K, N), dtype=dtype, device=device)
|
||||
out = torch.zeros((M, N), dtype=dtype, device=device)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(in1_ptr, in2_ptr, output_ptr,
|
||||
in_stride, in2_stride, out_stride,
|
||||
in_numel, in2_numel, out_numel,
|
||||
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
|
||||
|
||||
M_offsets = tl.arange(0, M)
|
||||
N_offsets = tl.arange(0, N)
|
||||
K_offsets = tl.arange(0, K)
|
||||
|
||||
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :]
|
||||
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
|
||||
|
||||
# Load inputs.
|
||||
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
|
||||
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
|
||||
|
||||
# Without a dot product the memory doesn't get promoted to shared.
|
||||
o = tl.dot(x, w)
|
||||
|
||||
# Store output
|
||||
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
|
||||
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
|
||||
|
||||
pgm = _kernel[(1,)](in1, in2, out,
|
||||
in1.stride()[0],
|
||||
in2.stride()[0],
|
||||
out.stride()[0],
|
||||
in1.numel(),
|
||||
in2.numel(),
|
||||
out.numel(),
|
||||
M=M, N=N, K=K)
|
||||
|
||||
reference_out = torch.matmul(in1, in2)
|
||||
triton.testing.allclose(out, reference_out)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
@@ -1298,18 +1411,18 @@ def test_noop(device='cuda'):
|
||||
kernel[(1, )](x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned'])
|
||||
def test_pointer_arguments(device):
|
||||
@triton.jit
|
||||
def kernel(x):
|
||||
pass
|
||||
x = torch.empty(1024, device=device)
|
||||
result = True
|
||||
try:
|
||||
kernel[(1,)](x)
|
||||
except ValueError:
|
||||
result = True if device == 'cpu' else False
|
||||
assert result
|
||||
pin_memory = 'pinned' in device
|
||||
x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory)
|
||||
if device == "cpu":
|
||||
with pytest.raises(ValueError):
|
||||
kernel[(1,)](x)
|
||||
else:
|
||||
kernel[(1, )](x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
@@ -1801,4 +1914,3 @@ def test_load_scalar_with_mask():
|
||||
Out = torch.empty_like(Index, device='cuda')
|
||||
kernel[(1,)](Input, Index, Out, Index.numel())
|
||||
assert Out.data[0] == 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user