mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Add support for float8e5m2 type (#1314)
This commit is contained in:
@@ -837,10 +837,11 @@ def test_load_store_same_ptr():
|
||||
assert torch.all(x == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
|
||||
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(dtype)
|
||||
check_type_supported(out_dtype)
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
@@ -851,20 +852,24 @@ def test_f8_xf16_roundtrip(dtype):
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
|
||||
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
f8_tensor[all_exp_ones] = 0
|
||||
f8 = triton.reinterpret(f8_tensor, in_dtype)
|
||||
n_elements = f8_tensor.numel()
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
def test_f16_to_f8_rounding():
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
|
||||
def test_f16_to_f8_rounding(in_dtype):
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
Or the same explanation a bit mathier:
|
||||
@@ -887,7 +892,7 @@ def test_f16_to_f8_rounding():
|
||||
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
n_elements = f16_input.numel()
|
||||
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
@@ -897,7 +902,7 @@ def test_f16_to_f8_rounding():
|
||||
abs_error = torch.abs(f16_input - f16_output)
|
||||
|
||||
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
|
||||
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user