mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Support scalar fp8 conversions by packing (#2379)
Support fp8 scalar conversions by packing fp8 with undef values. Also add simple unittests to cover this change.
This commit is contained in:
@@ -20,6 +20,7 @@ uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtypes_with_bfloat16 = dtypes + ['bfloat16']
|
||||
torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2']
|
||||
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
|
||||
|
||||
# TODO: enable multiple cta cluster testing.
|
||||
@@ -131,7 +132,7 @@ def check_type_supported(dtype, device):
|
||||
cc = torch.cuda.get_device_capability()
|
||||
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4nv"):
|
||||
if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}:
|
||||
pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90")
|
||||
|
||||
|
||||
@@ -1281,23 +1282,33 @@ def test_atomic_cas(sem, num_ctas, device):
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||
(dtype_x, dtype_z, False)
|
||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [
|
||||
(dtype_x, dtype_z, False, 1024)
|
||||
for dtype_x in dtypes
|
||||
for dtype_z in dtypes
|
||||
] + [
|
||||
('float32', 'bfloat16', False),
|
||||
('bfloat16', 'float32', False),
|
||||
('float32', 'int32', True),
|
||||
('float32', 'int1', False),
|
||||
('int8', 'bfloat16', False),
|
||||
('float32', 'bfloat16', False, 1024),
|
||||
('bfloat16', 'float32', False, 1024),
|
||||
('float32', 'int32', True, 1024),
|
||||
('float32', 'int1', False, 1024),
|
||||
('int8', 'bfloat16', False, 1024),
|
||||
] + [
|
||||
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||
(f'uint{x}', f'int{x}', True, 1024) for x in [8, 16, 32, 64]
|
||||
] + [
|
||||
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||
])
|
||||
(f'int{x}', f'uint{x}', True, 1024) for x in [8, 16, 32, 64]
|
||||
] + (([
|
||||
(dtype_x, dtype_z, False, size)
|
||||
for dtype_x in torch_float8_dtypes
|
||||
for dtype_z in ["float16", "float32"]
|
||||
for size in [1024, 32]
|
||||
] + [
|
||||
(dtype_x, dtype_z, False, size)
|
||||
for dtype_z in torch_float8_dtypes
|
||||
for dtype_x in ["float16", "float32"]
|
||||
for size in [1024, 32]
|
||||
]) if torch.__version__ >= "2.1" else []))
|
||||
@pytest.mark.parametrize("num_ctas", num_ctas_list)
|
||||
def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device):
|
||||
def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
|
||||
# bfloat16 on cc < 80 will not be tested
|
||||
check_type_supported(dtype_x, device)
|
||||
check_type_supported(dtype_z, device)
|
||||
@@ -1305,10 +1316,11 @@ def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device):
|
||||
if is_hip() and (dtype_z == "bfloat16"):
|
||||
pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.')
|
||||
|
||||
size = 1024
|
||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
if dtype_x.startswith('bfloat'):
|
||||
x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device)
|
||||
elif dtype_x.startswith('float8'):
|
||||
x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x))
|
||||
else:
|
||||
x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10
|
||||
# Triton clamps negative values to zero, while numpy wraps around
|
||||
@@ -1331,11 +1343,13 @@ def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device):
|
||||
# triton result
|
||||
if dtype_z.startswith('bfloat'):
|
||||
z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device)
|
||||
elif dtype_z.startswith('float8'):
|
||||
z_tri = torch.empty((size,), dtype=torch.float, device=device)
|
||||
else:
|
||||
z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device)
|
||||
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1, num_ctas=num_ctas)
|
||||
# torch result
|
||||
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
||||
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith('float8') or dtype_x.startswith('float8'):
|
||||
assert bitcast is False
|
||||
z_ref = x_tri.to(z_tri.dtype)
|
||||
torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0)
|
||||
@@ -3080,7 +3094,7 @@ def test_call(type, num_ctas, device):
|
||||
err_msg = str(e)
|
||||
|
||||
if type == "noinline":
|
||||
assert err_msg is not ""
|
||||
assert err_msg != ""
|
||||
else:
|
||||
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
||||
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
||||
|
||||
Reference in New Issue
Block a user