[Tests] Update test_core_amd.py from test_core.py.

1. Added some tests/configs are passing at the moment.
2. Changed switching off parts from commenting to `if torch.version.hip`.
This commit is contained in:
Daniil Fukalov
2023-02-23 23:09:58 +01:00
parent af1ed3a01f
commit 7133d23512

View File

@@ -1,5 +1,6 @@
# flake8: noqa: F821,F841
import itertools
import os
import re
from typing import Optional, Union
@@ -1026,6 +1027,58 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
else:
np.testing.assert_equal(z_ref, z_tri)
# ---------------
# test permute
# ---------------
@pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm)
# TODO: bfloat16
for dtype in ['float16', 'float32']
for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
if torch.version.hip is not None:
if (dtype_str, shape) == ('float32', (128, 128)):
pytest.skip("Not supported: memory out of resource.")
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
# input
x = numpy_random(shape, dtype_str=dtype_str)
# triton result
z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
x_tri = to_triton(x, device=device, dst_type=dtype_str)
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
z_tri, z_tri.stride(1), z_tri.stride(0),
BLOCK_M=shape[0], BLOCK_N=shape[1])
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
BLOCK_M=shape[0], BLOCK_N=shape[1])
# numpy result
z_ref = x.transpose(*perm)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
if torch.version.hip is None:
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
ptx = pgm_contiguous.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# ---------------
# test dot
@@ -1034,7 +1087,7 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
for shape in [(64, 64, 64)]
for shape in [(64, 64, 64), (16, 16, 16)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float16', 'float32']
@@ -1045,11 +1098,10 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
[128, 16, 32, 4],
[32, 128, 64, 4],
[128, 128, 64, 4],
#[64, 128, 128, 4],
[64, 128, 128, 4],
[32, 128, 64, 2],
[128, 128, 64, 2],
#[64, 128, 128, 2]
]
[64, 128, 128, 2]]
for allow_tf32 in [True]
for col_a in [True, False]
for col_b in [True, False]
@@ -1063,6 +1115,13 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
if capability[0] == 7:
if (M, N, K, num_warps) == (128, 256, 32, 8):
pytest.skip("shared memory out of resource")
if torch.version.hip is not None:
if (M, N, K) == (64, 128, 128):
pytest.skip("Not supported: memory out of resource.")
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
@@ -1175,16 +1234,19 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
#ptx = pgm.asm['ptx']
#assert 'ld.global.v4' in ptx
#assert 'st.global.v4' in ptx
#if dtype == 'float32' and allow_tf32:
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
#elif dtype == 'float32' and allow_tf32:
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
#elif dtype == 'int8':
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
if torch.version.hip is None:
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
if K > 16 or N > 16 or M > 16:
# XXX: skip small sizes because they are not vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# ---------------
@@ -1243,6 +1305,57 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
M = 32
N = 32
K = 16
in1 = torch.rand((M, K), dtype=dtype, device=device)
in2 = torch.rand((K, N), dtype=dtype, device=device)
out = torch.zeros((M, N), dtype=dtype, device=device)
@triton.jit
def _kernel(in1_ptr, in2_ptr, output_ptr,
in_stride, in2_stride, out_stride,
in_numel, in2_numel, out_numel,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
M_offsets = tl.arange(0, M)
N_offsets = tl.arange(0, N)
K_offsets = tl.arange(0, K)
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
# Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
# Without a dot product the memory doesn't get promoted to shared.
o = tl.dot(x, w)
# Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
pgm = _kernel[(1,)](in1, in2, out,
in1.stride()[0],
in2.stride()[0],
out.stride()[0],
in1.numel(),
in2.numel(),
out.numel(),
M=M, N=N, K=K)
reference_out = torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out)
# ---------------
# test store
@@ -1298,18 +1411,18 @@ def test_noop(device='cuda'):
kernel[(1, )](x)
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned'])
def test_pointer_arguments(device):
@triton.jit
def kernel(x):
pass
x = torch.empty(1024, device=device)
result = True
try:
kernel[(1,)](x)
except ValueError:
result = True if device == 'cpu' else False
assert result
pin_memory = 'pinned' in device
x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory)
if device == "cpu":
with pytest.raises(ValueError):
kernel[(1,)](x)
else:
kernel[(1, )](x)
@pytest.mark.parametrize("value, value_type", [
@@ -1801,4 +1914,3 @@ def test_load_scalar_with_mask():
Out = torch.empty_like(Index, device='cuda')
kernel[(1,)](Input, Index, Out, Index.numel())
assert Out.data[0] == 0