mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
use hw for fp8 type conversion (#386)
* use hardware instruction for type conversion between fp8 and fp32 * move gpu_matrix_core_version from semantics.py to hip_backend.py --------- Co-authored-by: Aleksandr Efimov <efimov.alexander@gmail.com>
This commit is contained in:
@@ -4,6 +4,10 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
|
||||
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
|
||||
const SmallVector<Value> &)>
|
||||
ConverterT;
|
||||
|
||||
/* ----- FP8E5M2 ------ */
|
||||
// This data-type is the standard FP8E5M2 format
|
||||
#ifdef USE_ROCM
|
||||
@@ -50,8 +54,93 @@ static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
}
|
||||
#endif
|
||||
|
||||
// ROCM utility functions for data type conversion
|
||||
#ifdef USE_ROCM
|
||||
static Value convert_val_Fp16_to_Fp8E5M2FNUZ(
|
||||
static Value cvtFp16ToFp32(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create("v_cvt_f32_f16");
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand = builder.newOperand(v, "v");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f32_ty, false);
|
||||
}
|
||||
|
||||
static Value cvtFp32ToFp16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create("v_cvt_f16_f32");
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand = builder.newOperand(v, "v");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f16_ty, false);
|
||||
}
|
||||
|
||||
static SmallVector<Value> convert_val_Fp16_to_Fp8(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value v0, Value v1, const std::string& fp8_format) {
|
||||
assert(fp8_format == "fp8" or fp8_format == "bf8");
|
||||
std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32";
|
||||
|
||||
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0);
|
||||
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1);
|
||||
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create(ins_str);
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand0 = builder.newOperand(f32_0, "v");
|
||||
auto operand1 = builder.newOperand(f32_1, "v");
|
||||
cvt(res, operand0, operand1);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto a1 = bitcast(fp8x4Vec, fp8x4VecTy);
|
||||
|
||||
SmallVector<Value> ret(2);
|
||||
ret[0] = extract_element(i8_ty, a1, i32_val(0));
|
||||
ret[1] = extract_element(i8_ty, a1, i32_val(1));
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static SmallVector<Value> convert_val_Fp8_to_Fp16(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value v0, Value v1, const std::string& fp8_format) {
|
||||
assert(fp8_format == "fp8" or fp8_format == "bf8");
|
||||
std::string ins_str = "v_cvt_pk_f32_" + fp8_format;
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value fp8x4Vec = undef(fp8x4VecTy);
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
|
||||
auto i32v = bitcast(fp8x4Vec, i32_ty);
|
||||
|
||||
GCNBuilder builder1;
|
||||
auto &cvt = *builder1.create(ins_str);
|
||||
auto res = builder1.newOperand("=v");
|
||||
auto operand = builder1.newOperand(i32v, "v");
|
||||
cvt(res, operand);
|
||||
auto i64v = builder1.launch(rewriter, loc, i64_ty, false);
|
||||
auto fp32x2VecTy = vec_ty(f32_ty, 2);
|
||||
auto fp32x2Vec = bitcast(i64v, fp32x2VecTy);
|
||||
|
||||
auto f32_0 = extract_element(f32_ty, fp32x2Vec, i32_val(0));
|
||||
auto f32_1 = extract_element(f32_ty, fp32x2Vec, i32_val(1));
|
||||
|
||||
SmallVector<Value> ret(2);
|
||||
ret[0] = cvtFp32ToFp16(loc, rewriter, f32_0);
|
||||
ret[1] = cvtFp32ToFp16(loc, rewriter, f32_1);
|
||||
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Depend on whether we focus more on performance, we may skip
|
||||
// the processing of submornal values
|
||||
static Value Fp16_to_Fp8E5M2FNUZ_oneValue(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto vi16 = bitcast(v, i16_ty);
|
||||
auto e = and_(i16_ty, vi16, int_val(16, 0x7C00));
|
||||
@@ -79,28 +168,26 @@ static Value convert_val_Fp16_to_Fp8E5M2FNUZ(
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Fp16_to_Fp8E5M2FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
SmallVector<Value> result(4);
|
||||
result[0] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[1]);
|
||||
result[2] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[2]);
|
||||
result[3] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[3]);
|
||||
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[0]);
|
||||
result[1] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[1]);
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E5M2FNUZ =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>; \n"
|
||||
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
|
||||
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
|
||||
"}";
|
||||
|
||||
static SmallVector<Value> Fp16_to_Fp8E5M2FNUZ_HW(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value>& v) {
|
||||
return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "bf8");
|
||||
}
|
||||
|
||||
ConverterT Fp16_to_Fp8E5M2FNUZ(int computeCapability) {
|
||||
return computeCapability >= 300 ? Fp16_to_Fp8E5M2FNUZ_HW : Fp16_to_Fp8E5M2FNUZ_SW;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -145,8 +232,7 @@ static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
static Value convert_val_Fp8E5M2FNUZ_to_Fp16(
|
||||
static Value Fp8E5M2FNUZ_to_Fp16_oneValue(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto fp8x2VecTy = vec_ty(i8_ty, 2);
|
||||
Value a = undef(fp8x2VecTy);
|
||||
@@ -181,17 +267,23 @@ static Value convert_val_Fp8E5M2FNUZ_to_Fp16(
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Fp8E5M2FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
|
||||
SmallVector<Value> result(4);
|
||||
result[0] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[1]);
|
||||
result[2] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[2]);
|
||||
result[3] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[3]);
|
||||
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
|
||||
result[1] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
|
||||
return result;
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "bf8");
|
||||
}
|
||||
|
||||
ConverterT Fp8E5M2FNUZ_to_Fp16(int computeCapability) {
|
||||
return (computeCapability >= 300) ? Fp8E5M2FNUZ_to_Fp16_HW : Fp8E5M2FNUZ_to_Fp16_SW;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -655,7 +747,7 @@ static const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
// more than a single NaN values.
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static Value convert_val_Fp8E4M3FNUZ_to_Fp16(
|
||||
static Value Fp8E4M3FNUZ_to_Fp16_oneValue(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto fp8x2VecTy = vec_ty(i8_ty, 2);
|
||||
Value a = undef(fp8x2VecTy);
|
||||
@@ -686,37 +778,30 @@ static Value convert_val_Fp8E4M3FNUZ_to_Fp16(
|
||||
return bitcast(io, f16_ty);
|
||||
}
|
||||
|
||||
// Fp8E4M3FNUZ -> Fp16 (packed)
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Fp8E4M3FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[1]);
|
||||
|
||||
result[0] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
|
||||
result[1] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp8E4M3FNUZ_to_Fp16 =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 1; \n" // b0 >>= 1
|
||||
"shr.b32 b1, b1, 1; \n" // shift into fp16 position
|
||||
"add.u32 b0, b0, 0x20002000; \n" // b0.exp += 2**4-2**3
|
||||
// exponent compensate = 8
|
||||
"add.u32 b1, b1, 0x20002000; \n" // b1 += 8<<10 | 8<<10<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "fp8");
|
||||
}
|
||||
|
||||
static ConverterT
|
||||
Fp8E4M3FNUZ_to_Fp16(int computeCapability) {
|
||||
return computeCapability >= 300 ? Fp8E4M3FNUZ_to_Fp16_HW : Fp8E4M3FNUZ_to_Fp16_SW;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Fp16 -> Fp8E4M3 (packed)
|
||||
#ifdef USE_ROCM
|
||||
static Value convert_val_Fp16_to_Fp8E4M3FNUZ(
|
||||
static Value Fp16_to_Fp8E4M3FNUZ_oneValue(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
auto vi16 = bitcast(v, i16_ty);
|
||||
auto e10 = and_(vi16, int_val(16, 0x7C00));
|
||||
@@ -749,33 +834,25 @@ static Value convert_val_Fp16_to_Fp8E4M3FNUZ(
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Fp16_to_Fp8E4M3FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[0]);
|
||||
result[1] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[1]);
|
||||
result[0] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[0]);
|
||||
result[1] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[1]);
|
||||
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
const std::string Fp16_to_Fp8E4M3FNUZ =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4
|
||||
"sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000
|
||||
// (compensate offset)
|
||||
"sub.u32 a1, $2, 0x20002000; \n" // a1 = input1 - 0x20002000
|
||||
// (8 << 10 | 8 << 10 << 16)
|
||||
"shl.b32 a0, a0, 1; \n" // a0 <<= 1
|
||||
"shl.b32 a1, a1, 1; \n" // shift into fp8e4 position
|
||||
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" // a0 &= 0x7fff7fff
|
||||
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" // b0 = a0|(0x80008000&in0)
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n" // output = b1b0
|
||||
"}";
|
||||
|
||||
static SmallVector<Value>
|
||||
Fp16_to_Fp8E4M3FNUZ_HW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "fp8");
|
||||
}
|
||||
|
||||
static ConverterT
|
||||
Fp16_to_Fp8E4M3FNUZ(int computeCapability) {
|
||||
return computeCapability >= 300 ? Fp16_to_Fp8E4M3FNUZ_HW : Fp16_to_Fp8E4M3FNUZ_SW;
|
||||
}
|
||||
#endif
|
||||
|
||||
// WARN: subnormal (0bs0000xxx) are not handled
|
||||
@@ -1144,10 +1221,6 @@ inline SmallVector<Value> packI32(const SmallVector<Value> &inValues,
|
||||
return outValues;
|
||||
}
|
||||
|
||||
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
|
||||
const SmallVector<Value> &)>
|
||||
ConverterT;
|
||||
|
||||
static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
|
||||
Type outType,
|
||||
const int inVecWidthBits = 32,
|
||||
@@ -1351,12 +1424,7 @@ struct FpToFpOpConversion
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
#ifdef USE_ROCM
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create("v_cvt_f32_f16");
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand = builder.newOperand(v, "v");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f32_ty, false);
|
||||
return cvtFp16ToFp32(loc, rewriter, v);
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.f32.f16");
|
||||
@@ -1402,12 +1470,7 @@ struct FpToFpOpConversion
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
#ifdef USE_ROCM
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create("v_cvt_f16_f32");
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand = builder.newOperand(v, "v");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f16_ty, false);
|
||||
return cvtFp32ToFp16(loc, rewriter, v);
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rn.f16.f32");
|
||||
@@ -1420,7 +1483,11 @@ struct FpToFpOpConversion
|
||||
|
||||
ConverterT getConversionFunc(Type srcTy, Type dstTy) const {
|
||||
auto F8E4M3B15TyID = TypeID::get<mlir::Float8E4M3B11FNUZType>();
|
||||
#ifdef USE_ROCM
|
||||
auto F8E4M3FNUZTyID = TypeID::get<mlir::Float8E4M3FNUZType>();
|
||||
#else
|
||||
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNUZType>();
|
||||
#endif
|
||||
auto F8E4M3FNTyID = TypeID::get<mlir::Float8E4M3FNType>();
|
||||
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
|
||||
auto F8E5M2FNUZTyID = TypeID::get<mlir::Float8E5M2FNUZType>();
|
||||
@@ -1436,37 +1503,37 @@ struct FpToFpOpConversion
|
||||
// F8 -> F16
|
||||
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
|
||||
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
|
||||
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
|
||||
#ifdef USE_ROCM
|
||||
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16(computeCapability)},
|
||||
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16(computeCapability)},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
|
||||
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
|
||||
#else
|
||||
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
|
||||
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)},
|
||||
#endif
|
||||
|
||||
// F16 -> F8
|
||||
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
|
||||
#ifdef USE_ROCM
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
|
||||
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ(computeCapability)},
|
||||
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ(computeCapability)},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
|
||||
#else
|
||||
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
|
||||
#endif
|
||||
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
|
||||
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
|
||||
#ifdef USE_ROCM
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
|
||||
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
|
||||
#else
|
||||
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
|
||||
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)},
|
||||
#endif
|
||||
// F8 -> BF16
|
||||
|
||||
// F8 -> BF16
|
||||
#ifdef USE_ROCM
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
#else
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
|
||||
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
|
||||
#endif
|
||||
// BF16 -> F8
|
||||
|
||||
// BF16 -> F8
|
||||
#ifdef USE_ROCM
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
|
||||
#else
|
||||
@@ -1477,6 +1544,16 @@ struct FpToFpOpConversion
|
||||
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
|
||||
#endif
|
||||
};
|
||||
|
||||
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
|
||||
if (srcMap.count(key) == 0) {
|
||||
llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy
|
||||
<< "\n";
|
||||
llvm_unreachable("");
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
return srcMap.lookup(key);
|
||||
#else
|
||||
int inVecWidthBits = 32;
|
||||
int outVecWidthBits = 32;
|
||||
if (srcTy.isFloat8E4M3FNUZ() ||
|
||||
@@ -1490,15 +1567,6 @@ struct FpToFpOpConversion
|
||||
outVecWidthBits = 16;
|
||||
}
|
||||
|
||||
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
|
||||
if (srcMap.count(key) == 0) {
|
||||
llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy
|
||||
<< "\n";
|
||||
llvm_unreachable("");
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
return srcMap.lookup(key);
|
||||
#else
|
||||
if (computeCapability < 90 &&
|
||||
(srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
|
||||
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
|
||||
@@ -1523,14 +1591,24 @@ struct FpToFpOpConversion
|
||||
size_t numElements = 4;
|
||||
if (srcElementType.isFloat8E4M3FNUZ() ||
|
||||
dstElementType.isFloat8E4M3FNUZ() ||
|
||||
#ifdef USE_ROCM
|
||||
srcElementType.isFloat8E5M2FNUZ() ||
|
||||
dstElementType.isFloat8E5M2FNUZ())
|
||||
#else
|
||||
(computeCapability >= 90 &&
|
||||
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) {
|
||||
(srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2())))
|
||||
#endif
|
||||
{
|
||||
numElements = 2;
|
||||
}
|
||||
bool useFP16IntermediateSrc =
|
||||
#ifdef USE_ROCM
|
||||
srcElementType.isF32();
|
||||
#else
|
||||
srcElementType.isF32() &&
|
||||
!(computeCapability >= 90 &&
|
||||
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
|
||||
#endif
|
||||
bool isDstFP32 = dstElementType.isF32();
|
||||
auto cvtFunc =
|
||||
getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType,
|
||||
|
||||
Reference in New Issue
Block a user