ROCM IFU: Enabled conversion between fp8e4m3b15x4 and fp16. Refactored conversion between fp8e4m3nv and fp16. (#335)

This commit is contained in:
wenchenvincent
2023-09-26 23:23:34 -05:00
committed by Jason Furmanek
parent 634f66a090
commit 42a5bf9c7c
2 changed files with 82 additions and 61 deletions

View File

@@ -922,6 +922,7 @@ def test_load_store_same_ptr():
kernel[(65536,)](x, num_warps=16)
assert torch.all(x == 2)
def convert_float_to_float32(fp: torch.tensor, dtype=None):
if not dtype:
dtype = getattr(tl, torch_dtype_name(fp.dtype))
@@ -941,8 +942,8 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None):
extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width
# special cases, exp is 0b11..1
if dtype in [tl.float8e4, tl.float8e4b15]:
# float8e4m3 does not have infinities
if dtype in [tl.float8e4nv, tl.float8e4b15]:
# float8e4m3nv does not have infinities
output[fp == 0b01111111] = torch.nan
output[fp == 0b11111111] = torch.nan
else:
@@ -969,40 +970,36 @@ def test_convert_float16_to_float32(in_dtype, device):
def serialize_fp8(np_data, in_dtype):
return np_data
# def serialize_fp8(np_data, in_dtype):
# if in_dtype == tl.float8e4b15:
# # triton's f8e4b15 format is optimized for software emulation
# # as a result, each pack of 4xfp8 values:
# # s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively)
# # is actually internally stored as
# # s0s2b0b2s1s3b1b3
# # we apply the conversion here
# f8x4 = np_data.view(np.uint32)
# s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)]
# b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)]
# signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17)
# bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24)
# # tensor of triton fp8 data
# return (signs | bits).view(np.int8)
# else:
# return np_data
if in_dtype == tl.float8e4b15x4:
# triton's f8e4b15 format is optimized for software emulation
# as a result, each pack of 4xfp8 values:
# s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively)
# is actually internally stored as
# s0s2b0b2s1s3b1b3
# we apply the conversion here
f8x4 = np_data.view(np.uint32)
s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)]
b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)]
signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17)
bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24)
# tensor of triton fp8 data
return (signs | bits).view(np.int8)
else:
return np_data
# inverse of `serialize_fp8`
def deserialize_fp8(np_data, in_dtype):
return np_data
# def deserialize_fp8(np_data, in_dtype):
# if in_dtype == tl.float8e4b15:
# f8x4 = np_data.view(np.uint32)
# s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]]
# b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]]
# signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24)
# bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24)
# return (signs | bits).view(np.int8)
# else:
# return np_data
if in_dtype == tl.float8e4b15x4:
f8x4 = np_data.view(np.uint32)
s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]]
b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]]
signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24)
bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24)
return (signs | bits).view(np.int8)
else:
return np_data
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5])
@@ -1014,6 +1011,7 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
- conversion tri_fp8 = convert(input=tri_fp16, out=out_dtype) matches the original
this is only possible if both conversions are correct
"""
check_type_supported(in_dtype, device)
check_type_supported(out_dtype, device)
@triton.jit