[BACKEND] Fix wrong conversion from float8e5m2 <> bfloat16 (#1391)

exponent compensate should be 0x3800(112) instead of 0x3000(96)
also add a mantissa bit for float16 conversion to round to nearest
float8e5m2

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
xndcn
2023-03-24 12:42:08 +08:00
committed by GitHub
parent c9f47d9094
commit ff1d0377e0
2 changed files with 79 additions and 27 deletions

View File

@@ -144,19 +144,21 @@ struct FpToFpOpConversion
convertFp8E5M2x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5140; \n"
"prmt.b32 a1, 0, $2, 0x7362; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"shr.b32 b0, b0, 3; \n"
"shr.b32 b1, b1, 3; \n"
"add.u32 b0, b0, 0x30003000; \n"
"add.u32 b1, b1, 0x30003000; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"}";
auto *ptxAsm = // WARN: subnormal (0bs00000xx) are not handled
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
// exponent compensate = 112
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
return convertFp8x4ToBf16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
};
@@ -222,9 +224,17 @@ struct FpToFpOpConversion
convertFp16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = "{ \n"
"prmt.b32 $0, $1, $2, 0x7531; \n\t"
"}";
auto *ptxAsm =
"{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0x7fff7fff; \n" // a0 &= 0x7fff7fff
"and.b32 a1, $2, 0x7fff7fff; \n" // (strip sign)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"lop3.b32 a0, $1, 0x80008000, a0, 0xea; \n" // a0 = a0|(0x80008000&in0)
"lop3.b32 a1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
}
@@ -336,12 +346,53 @@ struct FpToFpOpConversion
return convertBf16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
};
// TODO:
// static SmallVector<Value>
// convertBf16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
// const Value &v0, const Value &v1, const Value &v2,
// const Value &v3) {
// }
static SmallVector<Value>
convertBf16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto *ptxAsm = // bf16 is clamped firstly to fp8 min/max
"{ \n" // bf16=fp8>>3 + 112<<7
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000
".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111
"mov.u32 fp8_min, 0x38003800; \n" // so bf16_min = 0x3800
"mov.u32 fp8_max, 0x57e057e0; \n" // so bf16_max = 0x57e0
"mov.u32 rn_, 0x00100010; \n" // round to nearest
"and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000
"and.b32 sign1, $2, 0x80008000; \n" // (store sign)
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
"and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff
"and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign)
// nosign = clamp(nosign, min, max)
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
"min.u32 nosign_0_0, nosign_0_0, 0x57e00000; \n"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
"min.u32 nosign_0_1, nosign_0_1, 0x57e0; \n"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
"min.u32 nosign_1_0, nosign_1_0, 0x57e00000; \n"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
"min.u32 nosign_1_1, nosign_1_1, 0x57e0; \n"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
"add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_
"add.u32 nosign1, nosign1, rn_; \n" // (round to nearest)
"sub.u32 nosign0, nosign0, 0x38003800; \n" // nosign0-=0x38003800
"sub.u32 nosign1, nosign1, 0x38003800; \n" // (compensate offset)
"shl.b32 nosign0, nosign0, 3; \n" // nosign0 <<= 3
"shl.b32 nosign1, nosign1, 3; \n" // shift into to fp8e4
"prmt.b32 nosign, nosign0, nosign1, 0x7531; \n" // nosign0 = 0xf100f200
// nosign1 = 0xf300f400
// nosign = 0xf3f4f1f2
"or.b32 $0, nosign, sign; \n" // restore sign
"}";
return convertBf16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
}
/* ------------------ */
// FP8 -> FP32
@@ -479,8 +530,7 @@ struct FpToFpOpConversion
{{F8E5M2TyID, BF16TyID}, convertFp8E5M2x4ToBf16x4},
// BF16 -> F8
{{BF16TyID, F8E4M3TyID}, convertBf16x4ToFp8E4M3x4},
// TODO:
// {{BF16TyID, F8E5M2TyID}, convertBf16x4ToFp8E5M2x4},
{{BF16TyID, F8E5M2TyID}, convertBf16x4ToFp8E5M2x4},
// F8 -> F32
{{F8E4M3TyID, F32TyID}, convertFp8E4M3x4ToFp32x4},
{{F8E5M2TyID, F32TyID}, convertFp8E5M2x4ToFp32x4},

View File

@@ -932,7 +932,7 @@ def test_convert_float16_to_float32(in_dtype):
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(out_dtype)
@@ -970,7 +970,7 @@ def test_f8_xf16_roundtrip(in_dtype, out_dtype):
assert torch.all(f8_tensor == f8_output_tensor)
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
def test_f16_to_f8_rounding(in_dtype, out_dtype):
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
@@ -1017,8 +1017,10 @@ def test_f16_to_f8_rounding(in_dtype, out_dtype):
# WARN: only normalized numbers are handled
f8_normal_min = 1 << in_dtype.fp_mantissa_width # 0b00001000 for float8e4
f8_normal_max = 0b01111110
f8_normal_max = 0b01111110 if in_dtype == tl.float8e4 else 0b01111011
f16_min, f16_max, f16_max_minus_1 = convert_float_to_float32(torch.tensor([f8_normal_min, f8_normal_max, f8_normal_max - 1], dtype=torch.int8), in_dtype)
assert torch.all(torch.isfinite(f16_min))
assert torch.all(torch.isfinite(f16_max))
thres_error = f16_max - f16_max_minus_1
mismatch = torch.logical_and(
torch.logical_or(abs_error != min_error, abs_error > thres_error), torch.logical_and(torch.isfinite(f16_input), torch.logical_and(torch.abs(f16_input) <= f16_max, torch.abs(f16_input) >= f16_min))