[BACKEND] Fix wrong conversion from float8e4m3 <> bfloat16 (#1384)

exponent compensate should be 0x3c00(120) instead of 0x3800(112)
This commit is contained in:
xndcn
2023-03-22 09:58:13 +08:00
committed by GitHub
parent 08f705d193
commit 65d8d802d5
2 changed files with 65 additions and 64 deletions

View File

@@ -971,7 +971,8 @@ def test_f8_xf16_roundtrip(in_dtype, out_dtype):
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
def test_f16_to_f8_rounding(in_dtype):
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
def test_f16_to_f8_rounding(in_dtype, out_dtype):
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
error is the minimum over all float8.
Or the same explanation a bit mathier:
@@ -984,28 +985,22 @@ def test_f16_to_f8_rounding(in_dtype):
output = input
tl.store(output_ptr + offsets, output, mask=mask)
# torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
f16_input_np = (
np.array(
range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
)
.view(np.float16)
)
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
i16_input = torch.tensor(range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=torch.int16, device='cuda')
f16_input = i16_input.view(out_dtype)
n_elements = f16_input.numel()
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
f16_output = torch.empty_like(f16_input, dtype=torch.float16)
f16_output = torch.empty_like(f16_input, dtype=out_dtype)
copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
abs_error = torch.abs(f16_input - f16_output)
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=out_dtype)
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
@@ -1023,9 +1018,10 @@ def test_f16_to_f8_rounding(in_dtype):
# WARN: only normalized numbers are handled
f8_normal_min = 1 << in_dtype.fp_mantissa_width # 0b00001000 for float8e4
f8_normal_max = 0b01111110
f16_min, f16_max = convert_float_to_float32(torch.tensor([f8_normal_min, f8_normal_max], dtype=torch.int8), in_dtype)
f16_min, f16_max, f16_max_minus_1 = convert_float_to_float32(torch.tensor([f8_normal_min, f8_normal_max, f8_normal_max - 1], dtype=torch.int8), in_dtype)
thres_error = f16_max - f16_max_minus_1
mismatch = torch.logical_and(
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.logical_and(torch.abs(f16_input) <= f16_max, torch.abs(f16_input) >= f16_min))
torch.logical_or(abs_error != min_error, abs_error > thres_error), torch.logical_and(torch.isfinite(f16_input), torch.logical_and(torch.abs(f16_input) <= f16_max, torch.abs(f16_input) >= f16_min))
)
assert torch.all(
torch.logical_not(mismatch)