mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[ROCM] Fixed implementation of fp32 to bf16 conversion on ROCm.
This commit is contained in:
@@ -1112,10 +1112,21 @@ struct FpToFpOpConversion
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
#ifdef USE_ROCM
|
||||
auto as_int32 = bitcast(v, i32_ty);
|
||||
auto shifted = lshr(i32_ty, as_int32, i32_val(16));
|
||||
auto as_uint32 = bitcast(v, i32_ty);
|
||||
auto check_exponent = and_(i32_ty, xor_(i32_ty, as_uint32, i32_val(0xffffffff)), i32_val(0x7f800000));
|
||||
auto exponent_not_all1s = icmp_ne(check_exponent, i32_val(0));
|
||||
auto exponent_all1s = icmp_eq(check_exponent, i32_val(0));
|
||||
auto rounded = add(i32_ty, i32_val(0x7fff), and_(i32_ty, lshr(i32_ty, as_uint32, i32_val(16)), i32_val(1)) );
|
||||
rounded = add(i32_ty, rounded, as_uint32);
|
||||
auto res = select(exponent_not_all1s, rounded, as_uint32);
|
||||
|
||||
auto preserve_nan = and_( i1_ty, exponent_all1s, icmp_ne(and_(i32_ty, as_uint32, i32_val(0xffff)), i32_val(0)) );
|
||||
auto nan = or_(i32_ty, as_uint32, i32_val(0x10000));
|
||||
res = select(preserve_nan, nan, res);
|
||||
|
||||
auto shifted = lshr(i32_ty, res, i32_val(16));
|
||||
auto truncated = trunc(i16_ty, shifted);
|
||||
return(bitcast(truncated, i16_ty));
|
||||
return truncated;
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rn.bf16.f32");
|
||||
|
||||
@@ -840,26 +840,29 @@ def test_atomic_cas():
|
||||
] + [
|
||||
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device):
|
||||
# bfloat16 on cc < 80 will not be tested
|
||||
check_type_supported(dtype_x)
|
||||
check_type_supported(dtype_z)
|
||||
check_type_supported(dtype_x, device)
|
||||
check_type_supported(dtype_z, device)
|
||||
|
||||
size = 1024
|
||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||
if dtype_x in float_dtypes and dtype_z == 'int1':
|
||||
x0 = 0.5
|
||||
if dtype_x.startswith('bfloat'):
|
||||
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
|
||||
x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device)
|
||||
else:
|
||||
x = np.array([x0], dtype=getattr(np, dtype_x))
|
||||
x_tri = to_triton(x)
|
||||
x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10
|
||||
# Triton clamps negative values to zero, while numpy wraps around
|
||||
# intmax, so avoid negatives for now.
|
||||
# TODO: figure out which one should actually be happening, and test it
|
||||
if dtype_z in uint_dtypes:
|
||||
x = np.absolute(x)
|
||||
x_tri = to_triton(x, device=device)
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||
x_ptr = X + tl.arange(0, 1)
|
||||
z_ptr = Z + tl.arange(0, 1)
|
||||
def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr):
|
||||
x_ptr = X + tl.arange(0, SIZE)
|
||||
z_ptr = Z + tl.arange(0, SIZE)
|
||||
x = tl.load(x_ptr)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||
tl.store(z_ptr, z)
|
||||
@@ -867,21 +870,21 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
||||
# triton result
|
||||
if dtype_z.startswith('bfloat'):
|
||||
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
|
||||
z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device)
|
||||
else:
|
||||
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device)
|
||||
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
|
||||
z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device)
|
||||
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1)
|
||||
# torch result
|
||||
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
||||
assert bitcast is False
|
||||
z_ref = x_tri.to(z_tri.dtype)
|
||||
assert z_tri == z_ref
|
||||
torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0)
|
||||
else:
|
||||
if bitcast:
|
||||
z_ref = x.view(getattr(np, dtype_z_np))
|
||||
else:
|
||||
z_ref = x.astype(getattr(np, dtype_z_np))
|
||||
assert to_numpy(z_tri) == z_ref
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
|
||||
|
||||
Reference in New Issue
Block a user