ROCM IFU: Enabled conversion between fp8e4m3b15x4 and fp16. Refactored conversion between fp8e4m3nv and fp16. (#335)

This commit is contained in:
wenchenvincent
2023-09-26 23:23:34 -05:00
committed by Jason Furmanek
parent 634f66a090
commit 42a5bf9c7c
2 changed files with 82 additions and 61 deletions

View File

@@ -417,7 +417,30 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
static SmallVector<Value>
Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
return {};
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v[0], i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v[1], i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v[2], i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v[3], i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
Value a0 = add(i32_ty, fp8x4Vec, fp8x4Vec);
Value a1 = shl(i32_ty, fp8x4Vec, i32_val(7));
Value fp16x2Vec0 = and_(i32_ty, a0, i32_val(0x80008000));
fp16x2Vec0 = or_(i32_ty, fp16x2Vec0, and_(i32_ty, a1, i32_val(0x3f803f80)) );
Value fp16x2Vec1 = and_(i32_ty, fp8x4Vec, i32_val(0xbf80bf80));
auto fp16x2VecTy = vec_ty(f16_ty, 2);
fp16x2Vec0 = bitcast(fp16x2Vec0, fp16x2VecTy);
fp16x2Vec1 = bitcast(fp16x2Vec1, fp16x2VecTy);
return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
};
}
#else
const std::string Fp8E4M3B15x4_to_Fp16 =
@@ -442,7 +465,33 @@ const std::string Fp8E4M3B15x4_to_Fp16 =
static SmallVector<Value>
Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
return {};
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
Value a0 = lshr(i32_ty, fp16x2Vec0, i32_val(1));
Value a1 = lshr(i32_ty, fp16x2Vec0, i32_val(7));
Value fp8x4Vec = and_(i32_ty, a0, i32_val(0x40004000));
fp8x4Vec = or_(i32_ty, fp8x4Vec, and_(i32_ty, a1, i32_val(0x007f007f)) );
fp8x4Vec = or_(i32_ty, fp8x4Vec, and_(i32_ty, fp16x2Vec1, i32_val(0xbf80bf80)) );
auto fp8x4VecTy = vec_ty(i8_ty, 4);
fp8x4Vec = bitcast(fp8x4Vec, fp8x4VecTy);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))
};
}
#else
const std::string Fp16_to_Fp8E4M3B15x4 =
@@ -474,33 +523,19 @@ Fp8E4M3_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
a0 = bitcast(a0, i32_ty);
Value a1 = undef(fp8x4VecTy);
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0));
a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1));
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3));
a1 = bitcast(a1, i32_ty);
Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
Value b1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
b0 = lshr(i32_ty, b0, i32_val(1));
b1 = lshr(i32_ty, b1, i32_val(1));
b0 = add(i32_ty, b0, i32_val(0x20002000));
b1 = add(i32_ty, b1, i32_val(0x20002000));
b0 = or_( i32_ty, b0, and_(i32_ty, a0, i32_val(0x80008000)) );
b1 = or_( i32_ty, b1, and_(i32_ty, a1, i32_val(0x80008000)) );
auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2Vec0 = bitcast(b0, fp16x2VecTy);
auto fp16x2Vec1 = bitcast(b1, fp16x2VecTy);
return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
extract_element(f16_ty, fp16x2Vec0, i32_val(1))
};
}
#else
@@ -528,35 +563,23 @@ Fp16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
fp16x2Vec0 = sub(i32_ty, fp16x2Vec0, i32_val(0x20002000));
fp16x2Vec1 = sub(i32_ty, fp16x2Vec1, i32_val(0x20002000));
Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1));
Value a1 = shl(i32_ty, fp16x2Vec1, i32_val(1));
a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
a1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
a0 = add(i32_ty, a0, i32_val(0x00800080));
a1 = add(i32_ty, a1, i32_val(0x00800080));
Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 );
Value b1 = or_( i32_ty, and_(i32_ty, fp16x2Vec1, i32_val(0x80008000)), a1 );
auto fp8x4VecTy = vec_ty(i8_ty, 4);
b0 = bitcast(b0, fp8x4VecTy);
b1 = bitcast(b1, fp8x4VecTy);
return {extract_element(i8_ty, b0, i32_val(1)),
extract_element(i8_ty, b0, i32_val(3)),
extract_element(i8_ty, b1, i32_val(1)),
extract_element(i8_ty, b1, i32_val(3))
extract_element(i8_ty, b0, i32_val(3))
};
}
#else

View File

@@ -922,6 +922,7 @@ def test_load_store_same_ptr():
kernel[(65536,)](x, num_warps=16)
assert torch.all(x == 2)
def convert_float_to_float32(fp: torch.tensor, dtype=None):
if not dtype:
dtype = getattr(tl, torch_dtype_name(fp.dtype))
@@ -941,8 +942,8 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None):
extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width
# special cases, exp is 0b11..1
if dtype in [tl.float8e4, tl.float8e4b15]:
# float8e4m3 does not have infinities
if dtype in [tl.float8e4nv, tl.float8e4b15]:
# float8e4m3nv does not have infinities
output[fp == 0b01111111] = torch.nan
output[fp == 0b11111111] = torch.nan
else:
@@ -969,40 +970,36 @@ def test_convert_float16_to_float32(in_dtype, device):
def serialize_fp8(np_data, in_dtype):
return np_data
# def serialize_fp8(np_data, in_dtype):
# if in_dtype == tl.float8e4b15:
# # triton's f8e4b15 format is optimized for software emulation
# # as a result, each pack of 4xfp8 values:
# # s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively)
# # is actually internally stored as
# # s0s2b0b2s1s3b1b3
# # we apply the conversion here
# f8x4 = np_data.view(np.uint32)
# s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)]
# b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)]
# signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17)
# bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24)
# # tensor of triton fp8 data
# return (signs | bits).view(np.int8)
# else:
# return np_data
if in_dtype == tl.float8e4b15x4:
# triton's f8e4b15 format is optimized for software emulation
# as a result, each pack of 4xfp8 values:
# s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively)
# is actually internally stored as
# s0s2b0b2s1s3b1b3
# we apply the conversion here
f8x4 = np_data.view(np.uint32)
s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)]
b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)]
signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17)
bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24)
# tensor of triton fp8 data
return (signs | bits).view(np.int8)
else:
return np_data
# inverse of `serialize_fp8`
def deserialize_fp8(np_data, in_dtype):
return np_data
# def deserialize_fp8(np_data, in_dtype):
# if in_dtype == tl.float8e4b15:
# f8x4 = np_data.view(np.uint32)
# s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]]
# b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]]
# signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24)
# bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24)
# return (signs | bits).view(np.int8)
# else:
# return np_data
if in_dtype == tl.float8e4b15x4:
f8x4 = np_data.view(np.uint32)
s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]]
b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]]
signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24)
bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24)
return (signs | bits).view(np.int8)
else:
return np_data
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5])
@@ -1014,6 +1011,7 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
- conversion tri_fp8 = convert(input=tri_fp16, out=out_dtype) matches the original
this is only possible if both conversions are correct
"""
check_type_supported(in_dtype, device)
check_type_supported(out_dtype, device)
@triton.jit