mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix wrong conversion from float8e5m2 <> bfloat16 (#1391)
exponent compensate should be 0x3800(112) instead of 0x3000(96) also add a mantissa bit for float16 conversion to round to nearest float8e5m2 Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -144,19 +144,21 @@ struct FpToFpOpConversion
|
||||
convertFp8E5M2x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"shr.b32 b0, b0, 3; \n"
|
||||
"shr.b32 b1, b1, 3; \n"
|
||||
"add.u32 b0, b0, 0x30003000; \n"
|
||||
"add.u32 b1, b1, 0x30003000; \n"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
|
||||
"}";
|
||||
auto *ptxAsm = // WARN: subnormal (0bs00000xx) are not handled
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
|
||||
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
|
||||
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
|
||||
// exponent compensate = 112
|
||||
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
return convertFp8x4ToBf16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
};
|
||||
|
||||
@@ -222,9 +224,17 @@ struct FpToFpOpConversion
|
||||
convertFp16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = "{ \n"
|
||||
"prmt.b32 $0, $1, $2, 0x7531; \n\t"
|
||||
"}";
|
||||
auto *ptxAsm =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0x7fff7fff; \n" // a0 &= 0x7fff7fff
|
||||
"and.b32 a1, $2, 0x7fff7fff; \n" // (strip sign)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"lop3.b32 a0, $1, 0x80008000, a0, 0xea; \n" // a0 = a0|(0x80008000&in0)
|
||||
"lop3.b32 a1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
}
|
||||
|
||||
@@ -336,12 +346,53 @@ struct FpToFpOpConversion
|
||||
return convertBf16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
};
|
||||
|
||||
// TODO:
|
||||
// static SmallVector<Value>
|
||||
// convertBf16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
// const Value &v0, const Value &v1, const Value &v2,
|
||||
// const Value &v3) {
|
||||
// }
|
||||
static SmallVector<Value>
|
||||
convertBf16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = // bf16 is clamped firstly to fp8 min/max
|
||||
"{ \n" // bf16=fp8>>3 + 112<<7
|
||||
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000
|
||||
".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111
|
||||
"mov.u32 fp8_min, 0x38003800; \n" // so bf16_min = 0x3800
|
||||
"mov.u32 fp8_max, 0x57e057e0; \n" // so bf16_max = 0x57e0
|
||||
"mov.u32 rn_, 0x00100010; \n" // round to nearest
|
||||
"and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000
|
||||
"and.b32 sign1, $2, 0x80008000; \n" // (store sign)
|
||||
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
|
||||
"and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff
|
||||
"and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign)
|
||||
|
||||
// nosign = clamp(nosign, min, max)
|
||||
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
|
||||
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
|
||||
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
|
||||
"min.u32 nosign_0_0, nosign_0_0, 0x57e00000; \n"
|
||||
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
|
||||
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
|
||||
"min.u32 nosign_0_1, nosign_0_1, 0x57e0; \n"
|
||||
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
|
||||
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
|
||||
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
|
||||
"min.u32 nosign_1_0, nosign_1_0, 0x57e00000; \n"
|
||||
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
|
||||
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
|
||||
"min.u32 nosign_1_1, nosign_1_1, 0x57e0; \n"
|
||||
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
|
||||
|
||||
"add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_
|
||||
"add.u32 nosign1, nosign1, rn_; \n" // (round to nearest)
|
||||
"sub.u32 nosign0, nosign0, 0x38003800; \n" // nosign0-=0x38003800
|
||||
"sub.u32 nosign1, nosign1, 0x38003800; \n" // (compensate offset)
|
||||
"shl.b32 nosign0, nosign0, 3; \n" // nosign0 <<= 3
|
||||
"shl.b32 nosign1, nosign1, 3; \n" // shift into to fp8e4
|
||||
"prmt.b32 nosign, nosign0, nosign1, 0x7531; \n" // nosign0 = 0xf100f200
|
||||
// nosign1 = 0xf300f400
|
||||
// nosign = 0xf3f4f1f2
|
||||
"or.b32 $0, nosign, sign; \n" // restore sign
|
||||
"}";
|
||||
return convertBf16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
}
|
||||
|
||||
/* ------------------ */
|
||||
// FP8 -> FP32
|
||||
@@ -479,8 +530,7 @@ struct FpToFpOpConversion
|
||||
{{F8E5M2TyID, BF16TyID}, convertFp8E5M2x4ToBf16x4},
|
||||
// BF16 -> F8
|
||||
{{BF16TyID, F8E4M3TyID}, convertBf16x4ToFp8E4M3x4},
|
||||
// TODO:
|
||||
// {{BF16TyID, F8E5M2TyID}, convertBf16x4ToFp8E5M2x4},
|
||||
{{BF16TyID, F8E5M2TyID}, convertBf16x4ToFp8E5M2x4},
|
||||
// F8 -> F32
|
||||
{{F8E4M3TyID, F32TyID}, convertFp8E4M3x4ToFp32x4},
|
||||
{{F8E5M2TyID, F32TyID}, convertFp8E5M2x4ToFp32x4},
|
||||
|
||||
@@ -932,7 +932,7 @@ def test_convert_float16_to_float32(in_dtype):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(out_dtype)
|
||||
@@ -970,7 +970,7 @@ def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f16_to_f8_rounding(in_dtype, out_dtype):
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
@@ -1017,8 +1017,10 @@ def test_f16_to_f8_rounding(in_dtype, out_dtype):
|
||||
|
||||
# WARN: only normalized numbers are handled
|
||||
f8_normal_min = 1 << in_dtype.fp_mantissa_width # 0b00001000 for float8e4
|
||||
f8_normal_max = 0b01111110
|
||||
f8_normal_max = 0b01111110 if in_dtype == tl.float8e4 else 0b01111011
|
||||
f16_min, f16_max, f16_max_minus_1 = convert_float_to_float32(torch.tensor([f8_normal_min, f8_normal_max, f8_normal_max - 1], dtype=torch.int8), in_dtype)
|
||||
assert torch.all(torch.isfinite(f16_min))
|
||||
assert torch.all(torch.isfinite(f16_max))
|
||||
thres_error = f16_max - f16_max_minus_1
|
||||
mismatch = torch.logical_and(
|
||||
torch.logical_or(abs_error != min_error, abs_error > thres_error), torch.logical_and(torch.isfinite(f16_input), torch.logical_and(torch.abs(f16_input) <= f16_max, torch.abs(f16_input) >= f16_min))
|
||||
|
||||
Reference in New Issue
Block a user