[BACKEND] Fix wrong conversion from float8e4m3 <> float16 (#1375)

after offset shifting, exponent compensate should not be forgotten
also add back some comments from `legacy_backend`
This commit is contained in:
xndcn
2023-03-21 12:45:25 +08:00
committed by GitHub
parent e281bd9fe9
commit 84ffefc368
2 changed files with 93 additions and 25 deletions

View File

@@ -51,17 +51,21 @@ struct FpToFpOpConversion
convertFp8E4M3x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"shr.b32 b0, b0, 1; \n"
"shr.b32 b1, b1, 1; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"}";
auto *ptxAsm = // WARN: subnormal (0bs0000xxx) are not handled
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7060; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 1; \n" // b0 >>= 1
"shr.b32 b1, b1, 1; \n" // shift into fp16 position
"add.u32 b0, b0, 0x20002000; \n" // b0.exp += 2**4-2**3
// exponent compensate = 8
"add.u32 b1, b1, 0x20002000; \n" // b1 += 8<<10 | 8<<10<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
return convertFp8x4ToFp16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
}
@@ -69,6 +73,7 @@ struct FpToFpOpConversion
convertFp8E5M2x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
// exponent bias of Fp8E5M2 and Fp16 are the same
auto *ptxAsm = "{ \n"
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
@@ -193,18 +198,23 @@ struct FpToFpOpConversion
convertFp16x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"shl.b32 a0, $1, 1; \n"
"shl.b32 a1, $2, 1; \n"
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
"add.u32 a0, a0, 0x00800080; \n"
"add.u32 a1, a1, 0x00800080; \n"
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
"prmt.b32 $0, b0, b1, 0x7531; \n"
"}";
auto *ptxAsm = // WARN: subnormal Fp8s are not handled
"{ \n"
".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4
"sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000
// (compensate offset)
"sub.u32 a1, $2, 0x20002000; \n" // a1 = input1 - 0x20002000
// (8 << 10 | 8 << 10 << 16)
"shl.b32 a0, a0, 1; \n" // a0 <<= 1
"shl.b32 a1, a1, 1; \n" // shift into fp8e4 position
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" // a0 &= 0x7fff7fff
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" // b0 = a0|(0x80008000&in0)
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
"prmt.b32 $0, b0, b1, 0x7531; \n" // output = b1b0
"}";
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
}

View File

@@ -885,6 +885,52 @@ def test_load_store_same_ptr():
assert torch.all(x == 2)
def convert_float_to_float32(fp: torch.tensor, dtype=None):
if not dtype:
dtype = getattr(tl, torch_dtype_name(fp.dtype))
fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}"))
exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1
exp_bias = 2 ** (exp_width - 1) - 1
sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int()
exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int()
frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int()
output = torch.where(exp == 0,
# subnormal
((-1.0) ** sign) * (2.0 ** (1 - exp_bias)) * (frac / (2.0 ** dtype.fp_mantissa_width)),
# normal
((-1.0) ** sign) * (2.0 ** (exp - exp_bias)) * (1.0 + frac / (2.0 ** dtype.fp_mantissa_width))).float()
extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width
# special cases, exp is 0b11..1
if dtype == tl.float8e4:
# float8e4m3 does not have infinities
output[fp == torch.tensor(0b01111111, dtype=torch.int8)] = torch.nan
output[fp == torch.tensor(0b11111111, dtype=torch.int8)] = torch.nan
else:
output = torch.where(exp == (1 << exp_width) - 1,
((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32),
output)
return output
@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16])
def test_convert_float16_to_float32(in_dtype):
"""Tests that check convert_float_to_float32 function"""
check_type_supported(in_dtype)
f16_input = torch.tensor(range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=torch.int16).view(in_dtype)
f32_output = convert_float_to_float32(f16_input)
nan = f16_input.isnan()
assert torch.all(f32_output[nan].isnan())
inf = f16_input.isinf()
assert torch.all(f32_output[inf].isinf())
other = torch.logical_not(torch.logical_or(nan, inf))
assert torch.all(f16_input[other] == f32_output[other])
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
@@ -909,6 +955,14 @@ def test_f8_xf16_roundtrip(in_dtype, out_dtype):
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
# exponent_mask = 0b01111100 for float8e5
# exponent_mask = 0b01111000 for float8e4
exponent_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1)
normal = torch.logical_and((f8_tensor & exponent_mask) != 0, (f8_tensor & exponent_mask) != exponent_mask)
ref16 = convert_float_to_float32(f8_tensor, in_dtype)
# WARN: currently only normal float8s are handled
assert torch.all(xf16[normal] == ref16[normal])
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
@@ -965,9 +1019,13 @@ def test_f16_to_f8_rounding(in_dtype):
),
dim=1,
)[0]
# 1.9375 is float8 max
# WARN: only normalized numbers are handled
f8_normal_min = 1 << in_dtype.fp_mantissa_width # 0b00001000 for float8e4
f8_normal_max = 0b01111110
f16_min, f16_max = convert_float_to_float32(torch.tensor([f8_normal_min, f8_normal_max], dtype=torch.int8), in_dtype)
mismatch = torch.logical_and(
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.logical_and(torch.abs(f16_input) <= f16_max, torch.abs(f16_input) >= f16_min))
)
assert torch.all(
torch.logical_not(mismatch)