mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: Enabled conversion between fp8e4m3b15x4 and fp16. Refactored conversion between fp8e4m3nv and fp16. (#335)
This commit is contained in:
committed by
Jason Furmanek
parent
634f66a090
commit
42a5bf9c7c
@@ -922,6 +922,7 @@ def test_load_store_same_ptr():
|
||||
kernel[(65536,)](x, num_warps=16)
|
||||
assert torch.all(x == 2)
|
||||
|
||||
|
||||
def convert_float_to_float32(fp: torch.tensor, dtype=None):
|
||||
if not dtype:
|
||||
dtype = getattr(tl, torch_dtype_name(fp.dtype))
|
||||
@@ -941,8 +942,8 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None):
|
||||
|
||||
extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width
|
||||
# special cases, exp is 0b11..1
|
||||
if dtype in [tl.float8e4, tl.float8e4b15]:
|
||||
# float8e4m3 does not have infinities
|
||||
if dtype in [tl.float8e4nv, tl.float8e4b15]:
|
||||
# float8e4m3nv does not have infinities
|
||||
output[fp == 0b01111111] = torch.nan
|
||||
output[fp == 0b11111111] = torch.nan
|
||||
else:
|
||||
@@ -969,40 +970,36 @@ def test_convert_float16_to_float32(in_dtype, device):
|
||||
|
||||
|
||||
def serialize_fp8(np_data, in_dtype):
|
||||
return np_data
|
||||
# def serialize_fp8(np_data, in_dtype):
|
||||
# if in_dtype == tl.float8e4b15:
|
||||
# # triton's f8e4b15 format is optimized for software emulation
|
||||
# # as a result, each pack of 4xfp8 values:
|
||||
# # s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively)
|
||||
# # is actually internally stored as
|
||||
# # s0s2b0b2s1s3b1b3
|
||||
# # we apply the conversion here
|
||||
# f8x4 = np_data.view(np.uint32)
|
||||
# s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)]
|
||||
# b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)]
|
||||
# signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17)
|
||||
# bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24)
|
||||
# # tensor of triton fp8 data
|
||||
# return (signs | bits).view(np.int8)
|
||||
# else:
|
||||
# return np_data
|
||||
if in_dtype == tl.float8e4b15x4:
|
||||
# triton's f8e4b15 format is optimized for software emulation
|
||||
# as a result, each pack of 4xfp8 values:
|
||||
# s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively)
|
||||
# is actually internally stored as
|
||||
# s0s2b0b2s1s3b1b3
|
||||
# we apply the conversion here
|
||||
f8x4 = np_data.view(np.uint32)
|
||||
s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)]
|
||||
b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)]
|
||||
signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17)
|
||||
bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24)
|
||||
# tensor of triton fp8 data
|
||||
return (signs | bits).view(np.int8)
|
||||
else:
|
||||
return np_data
|
||||
|
||||
# inverse of `serialize_fp8`
|
||||
|
||||
|
||||
def deserialize_fp8(np_data, in_dtype):
|
||||
return np_data
|
||||
# def deserialize_fp8(np_data, in_dtype):
|
||||
# if in_dtype == tl.float8e4b15:
|
||||
# f8x4 = np_data.view(np.uint32)
|
||||
# s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]]
|
||||
# b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]]
|
||||
# signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24)
|
||||
# bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24)
|
||||
# return (signs | bits).view(np.int8)
|
||||
# else:
|
||||
# return np_data
|
||||
if in_dtype == tl.float8e4b15x4:
|
||||
f8x4 = np_data.view(np.uint32)
|
||||
s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]]
|
||||
b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]]
|
||||
signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24)
|
||||
bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24)
|
||||
return (signs | bits).view(np.int8)
|
||||
else:
|
||||
return np_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5])
|
||||
@@ -1014,6 +1011,7 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
- conversion tri_fp8 = convert(input=tri_fp16, out=out_dtype) matches the original
|
||||
this is only possible if both conversions are correct
|
||||
"""
|
||||
check_type_supported(in_dtype, device)
|
||||
check_type_supported(out_dtype, device)
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user