Support scalar fp8 conversions by packing (#2379)

Support fp8 scalar conversions by packing fp8 with undef values.

Also add simple unittests to cover this change.
This commit is contained in:
Ying Zhang
2023-09-27 08:29:53 -07:00
committed by GitHub
parent bf3171f5c7
commit 78c28bf5f6
2 changed files with 35 additions and 19 deletions

View File

@@ -20,6 +20,7 @@ uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes_with_bfloat16 = dtypes + ['bfloat16']
torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2']
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
# TODO: enable multiple cta cluster testing.
@@ -131,7 +132,7 @@ def check_type_supported(dtype, device):
cc = torch.cuda.get_device_capability()
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4nv"):
if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}:
pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90")
@@ -1281,23 +1282,33 @@ def test_atomic_cas(sem, num_ctas, device):
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False)
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [
(dtype_x, dtype_z, False, 1024)
for dtype_x in dtypes
for dtype_z in dtypes
] + [
('float32', 'bfloat16', False),
('bfloat16', 'float32', False),
('float32', 'int32', True),
('float32', 'int1', False),
('int8', 'bfloat16', False),
('float32', 'bfloat16', False, 1024),
('bfloat16', 'float32', False, 1024),
('float32', 'int32', True, 1024),
('float32', 'int1', False, 1024),
('int8', 'bfloat16', False, 1024),
] + [
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
(f'uint{x}', f'int{x}', True, 1024) for x in [8, 16, 32, 64]
] + [
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
])
(f'int{x}', f'uint{x}', True, 1024) for x in [8, 16, 32, 64]
] + (([
(dtype_x, dtype_z, False, size)
for dtype_x in torch_float8_dtypes
for dtype_z in ["float16", "float32"]
for size in [1024, 32]
] + [
(dtype_x, dtype_z, False, size)
for dtype_z in torch_float8_dtypes
for dtype_x in ["float16", "float32"]
for size in [1024, 32]
]) if torch.__version__ >= "2.1" else []))
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device):
def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
# bfloat16 on cc < 80 will not be tested
check_type_supported(dtype_x, device)
check_type_supported(dtype_z, device)
@@ -1305,10 +1316,11 @@ def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device):
if is_hip() and (dtype_z == "bfloat16"):
pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.')
size = 1024
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
if dtype_x.startswith('bfloat'):
x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device)
elif dtype_x.startswith('float8'):
x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x))
else:
x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10
# Triton clamps negative values to zero, while numpy wraps around
@@ -1331,11 +1343,13 @@ def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device):
# triton result
if dtype_z.startswith('bfloat'):
z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device)
elif dtype_z.startswith('float8'):
z_tri = torch.empty((size,), dtype=torch.float, device=device)
else:
z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device)
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1, num_ctas=num_ctas)
# torch result
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith('float8') or dtype_x.startswith('float8'):
assert bitcast is False
z_ref = x_tri.to(z_tri.dtype)
torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0)
@@ -3080,7 +3094,7 @@ def test_call(type, num_ctas, device):
err_msg = str(e)
if type == "noinline":
assert err_msg is not ""
assert err_msg != ""
else:
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
np.testing.assert_equal(to_numpy(rand_val_tri), ans)