mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix wrong conversion from float8e4m3 <> float16 (#1375)
after offset shifting, exponent compensate should not be forgotten also add back some comments from `legacy_backend`
This commit is contained in:
@@ -885,6 +885,52 @@ def test_load_store_same_ptr():
|
||||
assert torch.all(x == 2)
|
||||
|
||||
|
||||
def convert_float_to_float32(fp: torch.tensor, dtype=None):
|
||||
if not dtype:
|
||||
dtype = getattr(tl, torch_dtype_name(fp.dtype))
|
||||
|
||||
fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}"))
|
||||
exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1
|
||||
exp_bias = 2 ** (exp_width - 1) - 1
|
||||
sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int()
|
||||
exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int()
|
||||
frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int()
|
||||
|
||||
output = torch.where(exp == 0,
|
||||
# subnormal
|
||||
((-1.0) ** sign) * (2.0 ** (1 - exp_bias)) * (frac / (2.0 ** dtype.fp_mantissa_width)),
|
||||
# normal
|
||||
((-1.0) ** sign) * (2.0 ** (exp - exp_bias)) * (1.0 + frac / (2.0 ** dtype.fp_mantissa_width))).float()
|
||||
|
||||
extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width
|
||||
# special cases, exp is 0b11..1
|
||||
if dtype == tl.float8e4:
|
||||
# float8e4m3 does not have infinities
|
||||
output[fp == torch.tensor(0b01111111, dtype=torch.int8)] = torch.nan
|
||||
output[fp == torch.tensor(0b11111111, dtype=torch.int8)] = torch.nan
|
||||
else:
|
||||
output = torch.where(exp == (1 << exp_width) - 1,
|
||||
((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32),
|
||||
output)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16])
|
||||
def test_convert_float16_to_float32(in_dtype):
|
||||
"""Tests that check convert_float_to_float32 function"""
|
||||
check_type_supported(in_dtype)
|
||||
|
||||
f16_input = torch.tensor(range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=torch.int16).view(in_dtype)
|
||||
f32_output = convert_float_to_float32(f16_input)
|
||||
|
||||
nan = f16_input.isnan()
|
||||
assert torch.all(f32_output[nan].isnan())
|
||||
inf = f16_input.isinf()
|
||||
assert torch.all(f32_output[inf].isinf())
|
||||
other = torch.logical_not(torch.logical_or(nan, inf))
|
||||
assert torch.all(f16_input[other] == f32_output[other])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
|
||||
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
@@ -909,6 +955,14 @@ def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
# exponent_mask = 0b01111100 for float8e5
|
||||
# exponent_mask = 0b01111000 for float8e4
|
||||
exponent_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1)
|
||||
normal = torch.logical_and((f8_tensor & exponent_mask) != 0, (f8_tensor & exponent_mask) != exponent_mask)
|
||||
ref16 = convert_float_to_float32(f8_tensor, in_dtype)
|
||||
# WARN: currently only normal float8s are handled
|
||||
assert torch.all(xf16[normal] == ref16[normal])
|
||||
|
||||
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
@@ -965,9 +1019,13 @@ def test_f16_to_f8_rounding(in_dtype):
|
||||
),
|
||||
dim=1,
|
||||
)[0]
|
||||
# 1.9375 is float8 max
|
||||
|
||||
# WARN: only normalized numbers are handled
|
||||
f8_normal_min = 1 << in_dtype.fp_mantissa_width # 0b00001000 for float8e4
|
||||
f8_normal_max = 0b01111110
|
||||
f16_min, f16_max = convert_float_to_float32(torch.tensor([f8_normal_min, f8_normal_max], dtype=torch.int8), in_dtype)
|
||||
mismatch = torch.logical_and(
|
||||
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
|
||||
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.logical_and(torch.abs(f16_input) <= f16_max, torch.abs(f16_input) >= f16_min))
|
||||
)
|
||||
assert torch.all(
|
||||
torch.logical_not(mismatch)
|
||||
|
||||
Reference in New Issue
Block a user