[FRONTEND][BACKEND] changed float8e4b15 clipping semantics from +-1.875 to +-1.75 (#2422)

clipping float8e4b15 to +-1.875 is a bad idea because these are
represented as 0x7f and 0xff, which are +- nan on H100 for float8e4nv.
We lose two values but this will make compatibility with float8e4nv way
less painful. (it will just be a matter of adjusting the bias)
This commit is contained in:
Philippe Tillet
2023-09-29 23:33:28 -07:00
committed by GitHub
parent ee013d8978
commit 533efd0cac
2 changed files with 12 additions and 8 deletions

View File

@@ -109,8 +109,8 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
".reg .b16 c<4>; \n"
".reg .b16 max_val_f16; \n"
".reg .b32 max_val_f16x2; \n"
"mov.b16 max_val_f16, 0x3F80; \n"
"mov.b32 max_val_f16x2, 0x3F803F80; \n"
"mov.b16 max_val_f16, 0x3F00; \n"
"mov.b32 max_val_f16x2, 0x3F003F00; \n"
"and.b32 a0, $1, 0x7fff7fff; \n"
"and.b32 a1, $2, 0x7fff7fff; \n";
if (has_minx2)

View File

@@ -1526,22 +1526,26 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
# initialize array containing all possible f8 values except NaN
ref_fp8 = np.array(range(-128, 128), dtype=np.int8)
is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
exp_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1)
is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
is_subnormal = np.logical_or((ref_fp8 & exp_mask) == 0, (ref_fp8 & exp_mask) == exp_mask)
ref_fp8[is_nan] = 0
ref_fp8[is_subnormal] = 0
tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).cuda()
# check that non-subnormal fp8 are correctly converted to fp16
tri_fp16 = torch.empty(256, dtype=out_dtype, device="cuda")
copy_kernel[(1,)](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024)
ref_fp8 = torch.from_numpy(ref_fp8).cuda()
ref_fp16 = convert_float_to_float32(ref_fp8, in_dtype)
assert torch.all(tri_fp16[~is_subnormal] == ref_fp16[~is_subnormal])
# check that values are properly converted back to float8
ref_fp8 = torch.empty_like(tri_fp16, dtype=torch.int8)
copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024)
assert torch.all(tri_fp8 == ref_fp8)
if in_dtype == tl.float8e4b15:
assert torch.all(tri_fp8[:127] == ref_fp8[:127])
assert torch.all(tri_fp8[128:255] == ref_fp8[128:255])
assert ref_fp8[126] == ref_fp8[127] # -1.875 saturates to -1.75
assert ref_fp8[254] == ref_fp8[255] # 1.875 saturates to 1.75
else:
assert torch.all(tri_fp8[~is_subnormal] == ref_fp8[~is_subnormal])
# ---------------
# test reduce