mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix wrong conversion from float8e5m2 <> bfloat16 (#1391)
exponent compensate should be 0x3800(112) instead of 0x3000(96) also add a mantissa bit for float16 conversion to round to nearest float8e5m2 Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -932,7 +932,7 @@ def test_convert_float16_to_float32(in_dtype):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(out_dtype)
|
||||
@@ -970,7 +970,7 @@ def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f16_to_f8_rounding(in_dtype, out_dtype):
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
@@ -1017,8 +1017,10 @@ def test_f16_to_f8_rounding(in_dtype, out_dtype):
|
||||
|
||||
# WARN: only normalized numbers are handled
|
||||
f8_normal_min = 1 << in_dtype.fp_mantissa_width # 0b00001000 for float8e4
|
||||
f8_normal_max = 0b01111110
|
||||
f8_normal_max = 0b01111110 if in_dtype == tl.float8e4 else 0b01111011
|
||||
f16_min, f16_max, f16_max_minus_1 = convert_float_to_float32(torch.tensor([f8_normal_min, f8_normal_max, f8_normal_max - 1], dtype=torch.int8), in_dtype)
|
||||
assert torch.all(torch.isfinite(f16_min))
|
||||
assert torch.all(torch.isfinite(f16_max))
|
||||
thres_error = f16_max - f16_max_minus_1
|
||||
mismatch = torch.logical_and(
|
||||
torch.logical_or(abs_error != min_error, abs_error > thres_error), torch.logical_and(torch.isfinite(f16_input), torch.logical_and(torch.abs(f16_input) <= f16_max, torch.abs(f16_input) >= f16_min))
|
||||
|
||||
Reference in New Issue
Block a user