[ROCM] Fixed implementation of fp32 to bf16 conversion on ROCm.

This commit is contained in:
Wen Chen
2023-08-30 04:39:39 +00:00
committed by jayfurmanek
parent 2d3e38e182
commit ffc230ebfe
2 changed files with 34 additions and 20 deletions

View File

@@ -840,26 +840,29 @@ def test_atomic_cas():
] + [
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
def test_cast(dtype_x, dtype_z, bitcast, device):
# bfloat16 on cc < 80 will not be tested
check_type_supported(dtype_x)
check_type_supported(dtype_z)
check_type_supported(dtype_x, device)
check_type_supported(dtype_z, device)
size = 1024
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
x0 = 43 if dtype_x in int_dtypes else 43.5
if dtype_x in float_dtypes and dtype_z == 'int1':
x0 = 0.5
if dtype_x.startswith('bfloat'):
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device)
else:
x = np.array([x0], dtype=getattr(np, dtype_x))
x_tri = to_triton(x)
x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10
# Triton clamps negative values to zero, while numpy wraps around
# intmax, so avoid negatives for now.
# TODO: figure out which one should actually be happening, and test it
if dtype_z in uint_dtypes:
x = np.absolute(x)
x_tri = to_triton(x, device=device)
# triton kernel
@triton.jit
def kernel(X, Z, BITCAST: tl.constexpr):
x_ptr = X + tl.arange(0, 1)
z_ptr = Z + tl.arange(0, 1)
def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr):
x_ptr = X + tl.arange(0, SIZE)
z_ptr = Z + tl.arange(0, SIZE)
x = tl.load(x_ptr)
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
tl.store(z_ptr, z)
@@ -867,21 +870,21 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
# triton result
if dtype_z.startswith('bfloat'):
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device)
else:
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device)
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device)
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1)
# torch result
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
assert bitcast is False
z_ref = x_tri.to(z_tri.dtype)
assert z_tri == z_ref
torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0)
else:
if bitcast:
z_ref = x.view(getattr(np, dtype_z_np))
else:
z_ref = x.astype(getattr(np, dtype_z_np))
assert to_numpy(z_tri) == z_ref
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0)
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))