mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] changed float8e4b15 clipping semantics from +-1.875 to +-1.75 (#2422)
clipping float8e4b15 to +-1.875 is a bad idea because these are represented as 0x7f and 0xff, which are +- nan on H100 for float8e4nv. We lose two values but this will make compatibility with float8e4nv way less painful. (it will just be a matter of adjusting the bias)
This commit is contained in:
@@ -1526,22 +1526,26 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
|
||||
# initialize array containing all possible f8 values except NaN
|
||||
ref_fp8 = np.array(range(-128, 128), dtype=np.int8)
|
||||
is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
exp_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1)
|
||||
is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
is_subnormal = np.logical_or((ref_fp8 & exp_mask) == 0, (ref_fp8 & exp_mask) == exp_mask)
|
||||
ref_fp8[is_nan] = 0
|
||||
ref_fp8[is_subnormal] = 0
|
||||
tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).cuda()
|
||||
# check that non-subnormal fp8 are correctly converted to fp16
|
||||
tri_fp16 = torch.empty(256, dtype=out_dtype, device="cuda")
|
||||
copy_kernel[(1,)](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024)
|
||||
|
||||
ref_fp8 = torch.from_numpy(ref_fp8).cuda()
|
||||
ref_fp16 = convert_float_to_float32(ref_fp8, in_dtype)
|
||||
assert torch.all(tri_fp16[~is_subnormal] == ref_fp16[~is_subnormal])
|
||||
|
||||
# check that values are properly converted back to float8
|
||||
ref_fp8 = torch.empty_like(tri_fp16, dtype=torch.int8)
|
||||
copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024)
|
||||
assert torch.all(tri_fp8 == ref_fp8)
|
||||
if in_dtype == tl.float8e4b15:
|
||||
assert torch.all(tri_fp8[:127] == ref_fp8[:127])
|
||||
assert torch.all(tri_fp8[128:255] == ref_fp8[128:255])
|
||||
assert ref_fp8[126] == ref_fp8[127] # -1.875 saturates to -1.75
|
||||
assert ref_fp8[254] == ref_fp8[255] # 1.875 saturates to 1.75
|
||||
else:
|
||||
assert torch.all(tri_fp8[~is_subnormal] == ref_fp8[~is_subnormal])
|
||||
|
||||
# ---------------
|
||||
# test reduce
|
||||
|
||||
Reference in New Issue
Block a user