[BACKEND] Fix wrong conversion from float8e5m2 <> bfloat16 (#1391)

exponent compensate should be 0x3800(112) instead of 0x3000(96)
also add a mantissa bit for float16 conversion to round to nearest
float8e5m2

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
xndcn
2023-03-24 12:42:08 +08:00
committed by GitHub
parent c9f47d9094
commit ff1d0377e0
2 changed files with 79 additions and 27 deletions

View File

@@ -932,7 +932,7 @@ def test_convert_float16_to_float32(in_dtype):
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(out_dtype)
@@ -970,7 +970,7 @@ def test_f8_xf16_roundtrip(in_dtype, out_dtype):
assert torch.all(f8_tensor == f8_output_tensor)
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
def test_f16_to_f8_rounding(in_dtype, out_dtype):
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
@@ -1017,8 +1017,10 @@ def test_f16_to_f8_rounding(in_dtype, out_dtype):
# WARN: only normalized numbers are handled
f8_normal_min = 1 << in_dtype.fp_mantissa_width # 0b00001000 for float8e4
f8_normal_max = 0b01111110
f8_normal_max = 0b01111110 if in_dtype == tl.float8e4 else 0b01111011
f16_min, f16_max, f16_max_minus_1 = convert_float_to_float32(torch.tensor([f8_normal_min, f8_normal_max, f8_normal_max - 1], dtype=torch.int8), in_dtype)
assert torch.all(torch.isfinite(f16_min))
assert torch.all(torch.isfinite(f16_max))
thres_error = f16_max - f16_max_minus_1
mismatch = torch.logical_and(
torch.logical_or(abs_error != min_error, abs_error > thres_error), torch.logical_and(torch.isfinite(f16_input), torch.logical_and(torch.abs(f16_input) <= f16_max, torch.abs(f16_input) >= f16_min))