mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[ROCM] Added ROCm support for the conversions of following data types:
[float8e4m3, float8e4m3b15, float8e5m2] <-> [float16, bfloat16]
This commit is contained in:
@@ -45,6 +45,9 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h
|
||||
x = rs.randint(low, high, shape, dtype=dtype)
|
||||
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
||||
return x
|
||||
elif dtype_str and 'float8' in dtype_str:
|
||||
x = rs.randint(20, 40, shape, dtype=np.int8)
|
||||
return x
|
||||
elif dtype_str in float_dtypes:
|
||||
return rs.normal(0, 1, shape).astype(dtype_str)
|
||||
elif dtype_str == 'bfloat16':
|
||||
@@ -68,6 +71,8 @@ def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrappe
|
||||
x_signed = x.astype(getattr(np, signed_type_name))
|
||||
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
|
||||
else:
|
||||
if dst_type and 'float8' in dst_type:
|
||||
return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type))
|
||||
if t == 'float32' and dst_type == 'bfloat16':
|
||||
return torch.tensor(x, device=device).bfloat16()
|
||||
return torch.tensor(x, device=device)
|
||||
@@ -102,14 +107,14 @@ def patch_kernel(template, to_replace):
|
||||
return kernel
|
||||
|
||||
|
||||
def check_type_supported(dtype):
|
||||
def check_type_supported(dtype, device='cuda'):
|
||||
'''
|
||||
skip test if dtype is not supported on the current device
|
||||
'''
|
||||
cc = torch.cuda.get_device_capability()
|
||||
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
if device in ['cuda']:
|
||||
cc = torch.cuda.get_device_capability()
|
||||
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
@@ -913,12 +918,99 @@ def test_load_store_same_ptr():
|
||||
kernel[(65536,)](x, num_warps=16)
|
||||
assert torch.all(x == 2)
|
||||
|
||||
def convert_float_to_float32(fp: torch.tensor, dtype=None):
|
||||
if not dtype:
|
||||
dtype = getattr(tl, torch_dtype_name(fp.dtype))
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
||||
fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}"))
|
||||
exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1
|
||||
exp_bias = dtype.exponent_bias
|
||||
sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int()
|
||||
exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int()
|
||||
frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int()
|
||||
|
||||
output = torch.where(exp == 0,
|
||||
# subnormal
|
||||
((-1.0) ** sign) * (2.0 ** (1 - exp_bias)) * (frac / (2.0 ** dtype.fp_mantissa_width)),
|
||||
# normal
|
||||
((-1.0) ** sign) * (2.0 ** (exp - exp_bias)) * (1.0 + frac / (2.0 ** dtype.fp_mantissa_width))).float()
|
||||
|
||||
extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width
|
||||
# special cases, exp is 0b11..1
|
||||
if dtype in [tl.float8e4, tl.float8e4b15]:
|
||||
# float8e4m3 does not have infinities
|
||||
output[fp == 0b01111111] = torch.nan
|
||||
output[fp == 0b11111111] = torch.nan
|
||||
else:
|
||||
output = torch.where(exp == (1 << exp_width) - 1,
|
||||
((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32),
|
||||
output)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16])
|
||||
def test_convert_float16_to_float32(in_dtype, device):
|
||||
"""Tests that check convert_float_to_float32 function"""
|
||||
check_type_supported(in_dtype, device)
|
||||
|
||||
f16_input = torch.tensor(range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=torch.int16).view(in_dtype)
|
||||
f32_output = convert_float_to_float32(f16_input)
|
||||
|
||||
nan = f16_input.isnan()
|
||||
assert torch.all(f32_output[nan].isnan())
|
||||
inf = f16_input.isinf()
|
||||
assert torch.all(f32_output[inf].isinf())
|
||||
other = torch.logical_not(torch.logical_or(nan, inf))
|
||||
assert torch.all(f16_input[other] == f32_output[other])
|
||||
|
||||
|
||||
def serialize_fp8(np_data, in_dtype):
|
||||
return np_data
|
||||
# def serialize_fp8(np_data, in_dtype):
|
||||
# if in_dtype == tl.float8e4b15:
|
||||
# # triton's f8e4b15 format is optimized for software emulation
|
||||
# # as a result, each pack of 4xfp8 values:
|
||||
# # s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively)
|
||||
# # is actually internally stored as
|
||||
# # s0s2b0b2s1s3b1b3
|
||||
# # we apply the conversion here
|
||||
# f8x4 = np_data.view(np.uint32)
|
||||
# s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)]
|
||||
# b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)]
|
||||
# signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17)
|
||||
# bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24)
|
||||
# # tensor of triton fp8 data
|
||||
# return (signs | bits).view(np.int8)
|
||||
# else:
|
||||
# return np_data
|
||||
|
||||
# inverse of `serialize_fp8`
|
||||
|
||||
|
||||
def deserialize_fp8(np_data, in_dtype):
|
||||
return np_data
|
||||
# def deserialize_fp8(np_data, in_dtype):
|
||||
# if in_dtype == tl.float8e4b15:
|
||||
# f8x4 = np_data.view(np.uint32)
|
||||
# s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]]
|
||||
# b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]]
|
||||
# signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24)
|
||||
# bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24)
|
||||
# return (signs | bits).view(np.int8)
|
||||
# else:
|
||||
# return np_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
|
||||
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(out_dtype)
|
||||
def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
"""
|
||||
For all possible float8 values (ref_fp8 = range(0, 256)), test that:
|
||||
- conversion tri_fp16 = convert(input=ref_fp8, out=out_dtype) matches the reference
|
||||
- conversion tri_fp8 = convert(input=tri_fp16, out=out_dtype) matches the original
|
||||
this is only possible if both conversions are correct
|
||||
"""
|
||||
check_type_supported(out_dtype, device)
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
@@ -928,84 +1020,24 @@ def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip("test_masked_load_shared_memory[bfloat16] is only supported on AMDGPU")
|
||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
|
||||
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
f8_tensor[all_exp_ones] = 0
|
||||
f8 = triton.reinterpret(f8_tensor, in_dtype)
|
||||
n_elements = f8_tensor.numel()
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
# initialize array containing all possible f8 values except NaN
|
||||
ref_fp8 = np.array(range(-128, 128), dtype=np.int8)
|
||||
is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
exp_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1)
|
||||
is_subnormal = np.logical_or((ref_fp8 & exp_mask) == 0, (ref_fp8 & exp_mask) == exp_mask)
|
||||
ref_fp8[is_nan] = 0
|
||||
ref_fp8[is_subnormal] = 0
|
||||
tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).cuda()
|
||||
tri_fp16 = torch.empty(256, dtype=out_dtype, device="cuda")
|
||||
copy_kernel[(1,)](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024)
|
||||
|
||||
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
|
||||
def test_f16_to_f8_rounding(in_dtype):
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
Or the same explanation a bit mathier:
|
||||
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
if torch.version.hip is not None:
|
||||
pytest.skip("test_masked_load_shared_memory[bfloat16] is only supported on AMDGPU")
|
||||
# torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
|
||||
f16_input_np = (
|
||||
np.array(
|
||||
range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
|
||||
)
|
||||
.view(np.float16)
|
||||
)
|
||||
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
n_elements = f16_input.numel()
|
||||
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
f16_output = torch.empty_like(f16_input, dtype=torch.float16)
|
||||
copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
abs_error = torch.abs(f16_input - f16_output)
|
||||
|
||||
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
|
||||
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
|
||||
all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
|
||||
torch.isfinite(all_f8_vals_in_f16)
|
||||
]
|
||||
|
||||
min_error = torch.min(
|
||||
torch.abs(
|
||||
f16_input.reshape((-1, 1))
|
||||
- all_finite_f8_vals_in_f16.reshape((1, -1))
|
||||
),
|
||||
dim=1,
|
||||
)[0]
|
||||
# 1.9375 is float8 max
|
||||
mismatch = torch.logical_and(
|
||||
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
|
||||
)
|
||||
assert torch.all(
|
||||
torch.logical_not(mismatch)
|
||||
), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
|
||||
ref_fp8 = torch.from_numpy(ref_fp8).cuda()
|
||||
ref_fp16 = convert_float_to_float32(ref_fp8, in_dtype)
|
||||
assert torch.all(tri_fp16[~is_subnormal] == ref_fp16[~is_subnormal])
|
||||
|
||||
ref_fp8 = torch.empty_like(tri_fp16, dtype=torch.int8)
|
||||
copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024)
|
||||
assert torch.all(tri_fp8 == ref_fp8)
|
||||
|
||||
# ---------------
|
||||
# test reduce
|
||||
@@ -1160,7 +1192,7 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
[(dtype, shape, perm)
|
||||
# TODO: bfloat16
|
||||
for dtype in ['float16', 'float32']
|
||||
for dtype in ['float8e4b15', 'float16', 'float32']
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
@@ -1192,7 +1224,13 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
# numpy result
|
||||
z_ref = x.transpose(*perm)
|
||||
if dtype_str == 'float8e4b15':
|
||||
ty = tl.float8e4b15
|
||||
z_ref = serialize_fp8(deserialize_fp8(x, ty).T.copy(), ty)
|
||||
z_tri = z_tri.base
|
||||
z_tri_contiguous = z_tri_contiguous.base
|
||||
else:
|
||||
z_ref = x.transpose(*perm)
|
||||
# compare
|
||||
np.testing.assert_allclose(to_numpy(z_tri), z_ref)
|
||||
np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref)
|
||||
|
||||
Reference in New Issue
Block a user