use hw for fp8 type conversion (#386)

* use hardware instruction for type conversion between fp8 and fp32

* move gpu_matrix_core_version from semantics.py to hip_backend.py

---------

Co-authored-by: Aleksandr Efimov <efimov.alexander@gmail.com>
This commit is contained in:
Shucai Xiao
2023-11-24 10:26:40 -06:00
committed by GitHub
parent e1513b34e1
commit d9219e0eba
5 changed files with 266 additions and 169 deletions

View File

@@ -4,6 +4,10 @@ using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::getTotalElemsPerThread;
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const SmallVector<Value> &)>
ConverterT;
/* ----- FP8E5M2 ------ */
// This data-type is the standard FP8E5M2 format
#ifdef USE_ROCM
@@ -50,8 +54,93 @@ static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
}
#endif
// ROCM utility functions for data type conversion
#ifdef USE_ROCM
static Value convert_val_Fp16_to_Fp8E5M2FNUZ(
static Value cvtFp16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
GCNBuilder builder;
auto &cvt = *builder.create("v_cvt_f32_f16");
auto res = builder.newOperand("=v");
auto operand = builder.newOperand(v, "v");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
}
static Value cvtFp32ToFp16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
GCNBuilder builder;
auto &cvt = *builder.create("v_cvt_f16_f32");
auto res = builder.newOperand("=v");
auto operand = builder.newOperand(v, "v");
cvt(res, operand);
return builder.launch(rewriter, loc, f16_ty, false);
}
static SmallVector<Value> convert_val_Fp16_to_Fp8(
Location loc, ConversionPatternRewriter &rewriter,
Value v0, Value v1, const std::string& fp8_format) {
assert(fp8_format == "fp8" or fp8_format == "bf8");
std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32";
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0);
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1);
GCNBuilder builder;
auto &cvt = *builder.create(ins_str);
auto res = builder.newOperand("=v");
auto operand0 = builder.newOperand(f32_0, "v");
auto operand1 = builder.newOperand(f32_1, "v");
cvt(res, operand0, operand1);
auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto a1 = bitcast(fp8x4Vec, fp8x4VecTy);
SmallVector<Value> ret(2);
ret[0] = extract_element(i8_ty, a1, i32_val(0));
ret[1] = extract_element(i8_ty, a1, i32_val(1));
return ret;
}
static SmallVector<Value> convert_val_Fp8_to_Fp16(
Location loc, ConversionPatternRewriter &rewriter,
Value v0, Value v1, const std::string& fp8_format) {
assert(fp8_format == "fp8" or fp8_format == "bf8");
std::string ins_str = "v_cvt_pk_f32_" + fp8_format;
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
auto i32v = bitcast(fp8x4Vec, i32_ty);
GCNBuilder builder1;
auto &cvt = *builder1.create(ins_str);
auto res = builder1.newOperand("=v");
auto operand = builder1.newOperand(i32v, "v");
cvt(res, operand);
auto i64v = builder1.launch(rewriter, loc, i64_ty, false);
auto fp32x2VecTy = vec_ty(f32_ty, 2);
auto fp32x2Vec = bitcast(i64v, fp32x2VecTy);
auto f32_0 = extract_element(f32_ty, fp32x2Vec, i32_val(0));
auto f32_1 = extract_element(f32_ty, fp32x2Vec, i32_val(1));
SmallVector<Value> ret(2);
ret[0] = cvtFp32ToFp16(loc, rewriter, f32_0);
ret[1] = cvtFp32ToFp16(loc, rewriter, f32_1);
return ret;
}
#endif
#ifdef USE_ROCM
// Depend on whether we focus more on performance, we may skip
// the processing of submornal values
static Value Fp16_to_Fp8E5M2FNUZ_oneValue(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto vi16 = bitcast(v, i16_ty);
auto e = and_(i16_ty, vi16, int_val(16, 0x7C00));
@@ -79,28 +168,26 @@ static Value convert_val_Fp16_to_Fp8E5M2FNUZ(
}
static SmallVector<Value>
Fp16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
Fp16_to_Fp8E5M2FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(4);
result[0] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[0]);
result[1] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[1]);
result[2] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[2]);
result[3] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[3]);
SmallVector<Value> result(2);
result[0] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[0]);
result[1] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[1]);
return result;
}
#else
const std::string Fp16_to_Fp8E5M2FNUZ =
"{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
static SmallVector<Value> Fp16_to_Fp8E5M2FNUZ_HW(
Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value>& v) {
return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "bf8");
}
ConverterT Fp16_to_Fp8E5M2FNUZ(int computeCapability) {
return computeCapability >= 300 ? Fp16_to_Fp8E5M2FNUZ_HW : Fp16_to_Fp8E5M2FNUZ_SW;
}
#endif
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
@@ -145,8 +232,7 @@ static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
#endif
#ifdef USE_ROCM
static Value convert_val_Fp8E5M2FNUZ_to_Fp16(
static Value Fp8E5M2FNUZ_to_Fp16_oneValue(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto fp8x2VecTy = vec_ty(i8_ty, 2);
Value a = undef(fp8x2VecTy);
@@ -181,17 +267,23 @@ static Value convert_val_Fp8E5M2FNUZ_to_Fp16(
}
static SmallVector<Value>
Fp8E5M2FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
Fp8E5M2FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(4);
result[0] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[0]);
result[1] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[1]);
result[2] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[2]);
result[3] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[3]);
SmallVector<Value> result(2);
result[0] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
result[1] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
return result;
}
static SmallVector<Value>
Fp8E5M2FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "bf8");
}
ConverterT Fp8E5M2FNUZ_to_Fp16(int computeCapability) {
return (computeCapability >= 300) ? Fp8E5M2FNUZ_to_Fp16_HW : Fp8E5M2FNUZ_to_Fp16_SW;
}
#endif
#ifdef USE_ROCM
@@ -655,7 +747,7 @@ static const std::string Fp16_to_Fp8E4M3B15x4 =
// more than a single NaN values.
#ifdef USE_ROCM
static Value convert_val_Fp8E4M3FNUZ_to_Fp16(
static Value Fp8E4M3FNUZ_to_Fp16_oneValue(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto fp8x2VecTy = vec_ty(i8_ty, 2);
Value a = undef(fp8x2VecTy);
@@ -686,37 +778,30 @@ static Value convert_val_Fp8E4M3FNUZ_to_Fp16(
return bitcast(io, f16_ty);
}
// Fp8E4M3FNUZ -> Fp16 (packed)
static SmallVector<Value>
Fp8E4M3FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
Fp8E4M3FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(2);
result[0] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[0]);
result[1] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[1]);
result[0] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
result[1] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
return result;
}
#else
const std::string Fp8E4M3FNUZ_to_Fp16 =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7060; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 1; \n" // b0 >>= 1
"shr.b32 b1, b1, 1; \n" // shift into fp16 position
"add.u32 b0, b0, 0x20002000; \n" // b0.exp += 2**4-2**3
// exponent compensate = 8
"add.u32 b1, b1, 0x20002000; \n" // b1 += 8<<10 | 8<<10<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
static SmallVector<Value>
Fp8E4M3FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "fp8");
}
static ConverterT
Fp8E4M3FNUZ_to_Fp16(int computeCapability) {
return computeCapability >= 300 ? Fp8E4M3FNUZ_to_Fp16_HW : Fp8E4M3FNUZ_to_Fp16_SW;
}
#endif
// Fp16 -> Fp8E4M3 (packed)
#ifdef USE_ROCM
static Value convert_val_Fp16_to_Fp8E4M3FNUZ(
static Value Fp16_to_Fp8E4M3FNUZ_oneValue(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto vi16 = bitcast(v, i16_ty);
auto e10 = and_(vi16, int_val(16, 0x7C00));
@@ -749,33 +834,25 @@ static Value convert_val_Fp16_to_Fp8E4M3FNUZ(
}
static SmallVector<Value>
Fp16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
Fp16_to_Fp8E4M3FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(2);
result[0] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[0]);
result[1] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[1]);
result[0] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[0]);
result[1] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[1]);
return result;
}
#else
const std::string Fp16_to_Fp8E4M3FNUZ =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4
"sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000
// (compensate offset)
"sub.u32 a1, $2, 0x20002000; \n" // a1 = input1 - 0x20002000
// (8 << 10 | 8 << 10 << 16)
"shl.b32 a0, a0, 1; \n" // a0 <<= 1
"shl.b32 a1, a1, 1; \n" // shift into fp8e4 position
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" // a0 &= 0x7fff7fff
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" // b0 = a0|(0x80008000&in0)
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
"prmt.b32 $0, b0, b1, 0x7531; \n" // output = b1b0
"}";
static SmallVector<Value>
Fp16_to_Fp8E4M3FNUZ_HW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "fp8");
}
static ConverterT
Fp16_to_Fp8E4M3FNUZ(int computeCapability) {
return computeCapability >= 300 ? Fp16_to_Fp8E4M3FNUZ_HW : Fp16_to_Fp8E4M3FNUZ_SW;
}
#endif
// WARN: subnormal (0bs0000xxx) are not handled
@@ -1144,10 +1221,6 @@ inline SmallVector<Value> packI32(const SmallVector<Value> &inValues,
return outValues;
}
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const SmallVector<Value> &)>
ConverterT;
static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
Type outType,
const int inVecWidthBits = 32,
@@ -1351,12 +1424,7 @@ struct FpToFpOpConversion
ConversionPatternRewriter &rewriter,
const Value &v) {
#ifdef USE_ROCM
GCNBuilder builder;
auto &cvt = *builder.create("v_cvt_f32_f16");
auto res = builder.newOperand("=v");
auto operand = builder.newOperand(v, "v");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
return cvtFp16ToFp32(loc, rewriter, v);
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.f32.f16");
@@ -1402,12 +1470,7 @@ struct FpToFpOpConversion
ConversionPatternRewriter &rewriter,
const Value &v) {
#ifdef USE_ROCM
GCNBuilder builder;
auto &cvt = *builder.create("v_cvt_f16_f32");
auto res = builder.newOperand("=v");
auto operand = builder.newOperand(v, "v");
cvt(res, operand);
return builder.launch(rewriter, loc, f16_ty, false);
return cvtFp32ToFp16(loc, rewriter, v);
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.f16.f32");
@@ -1420,7 +1483,11 @@ struct FpToFpOpConversion
ConverterT getConversionFunc(Type srcTy, Type dstTy) const {
auto F8E4M3B15TyID = TypeID::get<mlir::Float8E4M3B11FNUZType>();
#ifdef USE_ROCM
auto F8E4M3FNUZTyID = TypeID::get<mlir::Float8E4M3FNUZType>();
#else
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNUZType>();
#endif
auto F8E4M3FNTyID = TypeID::get<mlir::Float8E4M3FNType>();
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
auto F8E5M2FNUZTyID = TypeID::get<mlir::Float8E5M2FNUZType>();
@@ -1436,37 +1503,37 @@ struct FpToFpOpConversion
// F8 -> F16
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
#ifdef USE_ROCM
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16(computeCapability)},
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16(computeCapability)},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
#else
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)},
#endif
// F16 -> F8
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
#ifdef USE_ROCM
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ(computeCapability)},
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ(computeCapability)},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
#else
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
#endif
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
#ifdef USE_ROCM
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
#else
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)},
#endif
// F8 -> BF16
// F8 -> BF16
#ifdef USE_ROCM
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
#else
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
#endif
// BF16 -> F8
// BF16 -> F8
#ifdef USE_ROCM
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
#else
@@ -1477,6 +1544,16 @@ struct FpToFpOpConversion
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
#endif
};
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
if (srcMap.count(key) == 0) {
llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy
<< "\n";
llvm_unreachable("");
}
#ifdef USE_ROCM
return srcMap.lookup(key);
#else
int inVecWidthBits = 32;
int outVecWidthBits = 32;
if (srcTy.isFloat8E4M3FNUZ() ||
@@ -1490,15 +1567,6 @@ struct FpToFpOpConversion
outVecWidthBits = 16;
}
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
if (srcMap.count(key) == 0) {
llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy
<< "\n";
llvm_unreachable("");
}
#ifdef USE_ROCM
return srcMap.lookup(key);
#else
if (computeCapability < 90 &&
(srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
@@ -1523,14 +1591,24 @@ struct FpToFpOpConversion
size_t numElements = 4;
if (srcElementType.isFloat8E4M3FNUZ() ||
dstElementType.isFloat8E4M3FNUZ() ||
#ifdef USE_ROCM
srcElementType.isFloat8E5M2FNUZ() ||
dstElementType.isFloat8E5M2FNUZ())
#else
(computeCapability >= 90 &&
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) {
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2())))
#endif
{
numElements = 2;
}
bool useFP16IntermediateSrc =
#ifdef USE_ROCM
srcElementType.isF32();
#else
srcElementType.isF32() &&
!(computeCapability >= 90 &&
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
#endif
bool isDstFP32 = dstElementType.isF32();
auto cvtFunc =
getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType,