support type conversion between fp8 formats and bf16/fp32 with HW instructions on MI300 (#414)

* add type conversion between fp8 and bf16/fp32..
This commit is contained in:
Shucai Xiao
2024-01-15 17:14:49 -06:00
committed by GitHub
parent e231c41467
commit 1223f6077a
2 changed files with 241 additions and 109 deletions

View File

@@ -56,9 +56,10 @@ static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
// ROCM utility functions for data type conversion
#ifdef USE_ROCM
// convert fp16 to fp32
static Value cvtFp16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
ConversionPatternRewriter &rewriter,
const Value &v) {
GCNBuilder builder;
auto &cvt = *builder.create("v_cvt_f32_f16");
auto res = builder.newOperand("=v");
@@ -67,9 +68,10 @@ static Value cvtFp16ToFp32(Location loc,
return builder.launch(rewriter, loc, f32_ty, false);
}
// convert fp32 to f16
static Value cvtFp32ToFp16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
ConversionPatternRewriter &rewriter,
const Value &v) {
GCNBuilder builder;
auto &cvt = *builder.create("v_cvt_f16_f32");
auto res = builder.newOperand("=v");
@@ -78,34 +80,8 @@ static Value cvtFp32ToFp16(Location loc,
return builder.launch(rewriter, loc, f16_ty, false);
}
static SmallVector<Value> convert_val_Fp16_to_Fp8(
Location loc, ConversionPatternRewriter &rewriter,
Value v0, Value v1, const std::string& fp8_format) {
assert(fp8_format == "fp8" or fp8_format == "bf8");
std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32";
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0);
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1);
GCNBuilder builder;
auto &cvt = *builder.create(ins_str);
auto res = builder.newOperand("=v");
auto operand0 = builder.newOperand(f32_0, "v");
auto operand1 = builder.newOperand(f32_1, "v");
cvt(res, operand0, operand1);
auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto a1 = bitcast(fp8x4Vec, fp8x4VecTy);
SmallVector<Value> ret(2);
ret[0] = extract_element(i8_ty, a1, i32_val(0));
ret[1] = extract_element(i8_ty, a1, i32_val(1));
return ret;
}
static SmallVector<Value> convert_val_Fp8_to_Fp16(
// convert fp8 to fp32
static SmallVector<Value> cvtFp8ToFp32(
Location loc, ConversionPatternRewriter &rewriter,
Value v0, Value v1, const std::string& fp8_format) {
assert(fp8_format == "fp8" or fp8_format == "bf8");
@@ -126,15 +102,96 @@ static SmallVector<Value> convert_val_Fp8_to_Fp16(
auto fp32x2VecTy = vec_ty(f32_ty, 2);
auto fp32x2Vec = bitcast(i64v, fp32x2VecTy);
auto f32_0 = extract_element(f32_ty, fp32x2Vec, i32_val(0));
auto f32_1 = extract_element(f32_ty, fp32x2Vec, i32_val(1));
SmallVector<Value> ret(2);
ret[0] = cvtFp32ToFp16(loc, rewriter, f32_0);
ret[1] = cvtFp32ToFp16(loc, rewriter, f32_1);
ret[0] = extract_element(f32_ty, fp32x2Vec, i32_val(0));
ret[1] = extract_element(f32_ty, fp32x2Vec, i32_val(1));
return ret;
}
// convert fp32 to fp8
static SmallVector<Value> cvtFp32ToFp8(
Location loc, ConversionPatternRewriter &rewriter,
Value v0, Value v1, const std::string& fp8_format) {
assert(fp8_format == "fp8" or fp8_format == "bf8");
std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32";
GCNBuilder builder;
auto &cvt = *builder.create(ins_str);
auto res = builder.newOperand("=v");
auto operand0 = builder.newOperand(v0, "v");
auto operand1 = builder.newOperand(v1, "v");
cvt(res, operand0, operand1);
auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto a1 = bitcast(fp8x4Vec, fp8x4VecTy);
SmallVector<Value> ret(2);
ret[0] = extract_element(i8_ty, a1, i32_val(0));
ret[1] = extract_element(i8_ty, a1, i32_val(1));
return ret;
}
// convert fp16 to fp8 for MI300 format
static SmallVector<Value> convert_val_Fp16_to_Fp8(
Location loc, ConversionPatternRewriter &rewriter,
Value v0, Value v1, const std::string& fp8_format) {
assert(fp8_format == "fp8" or fp8_format == "bf8");
// Convert fp16 to fp32
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0);
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1);
// Convert fp32 to fp8
return cvtFp32ToFp8(loc, rewriter, f32_0, f32_1, fp8_format);
}
// convert fp8 to fp16 for mi300 formats
static SmallVector<Value> convert_val_Fp8_to_Fp16(
Location loc, ConversionPatternRewriter &rewriter,
Value v0, Value v1, const std::string& fp8_format) {
// Convert fp8 to fp32
SmallVector<Value> ret = cvtFp8ToFp32(loc, rewriter, v0, v1, fp8_format);
// Convert fp32 to fp16
ret[0] = cvtFp32ToFp16(loc, rewriter, ret[0]);
ret[1] = cvtFp32ToFp16(loc, rewriter, ret[1]);
return ret;
}
#endif
#ifdef USE_ROCM
static SmallVector<Value> Fp32_to_Fp8E5M2FNUZ(
Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value>& v) {
assert (v.size() == 2);
return cvtFp32ToFp8(loc, rewriter, v[0], v[1], "bf8");
}
static SmallVector<Value> Fp32_to_Fp8E4M3FNUZ(
Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value>& v) {
assert (v.size() == 2);
return cvtFp32ToFp8(loc, rewriter, v[0], v[1], "fp8");
}
static SmallVector<Value> Fp8E5M2FNUZ_to_Fp32(
Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value>& v) {
assert (v.size() == 2);
return cvtFp8ToFp32(loc, rewriter, v[0], v[1], "bf8");
}
static SmallVector<Value> Fp8E4M3FNUZ_to_Fp32(
Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value>& v) {
assert (v.size() == 2);
return cvtFp8ToFp32(loc, rewriter, v[0], v[1], "fp8");
}
#endif
#ifdef USE_ROCM
@@ -187,7 +244,6 @@ ConverterT Fp16_to_Fp8E5M2FNUZ(int computeCapability) {
}
#endif
#ifdef USE_ROCM
static SmallVector<Value>
Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
@@ -231,6 +287,55 @@ static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
}
#endif
static Value convertBf16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
#ifdef USE_ROCM
auto as_int16 = bitcast(v, i16_ty);
auto as_int32 = zext(i32_ty, as_int16);
auto shifted = shl(i32_ty, as_int32, i32_val(16));
return(bitcast(shifted, f32_ty));
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.f32.bf16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(v, "h");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
#endif
}
static Value convertFp32ToBf16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
#ifdef USE_ROCM
auto as_uint32 = bitcast(v, i32_ty);
auto check_exponent = and_(i32_ty, xor_(i32_ty, as_uint32, i32_val(0xffffffff)), i32_val(0x7f800000));
auto exponent_not_all1s = icmp_ne(check_exponent, i32_val(0));
auto exponent_all1s = icmp_eq(check_exponent, i32_val(0));
auto rounded = add(i32_ty, i32_val(0x7fff), and_(i32_ty, lshr(i32_ty, as_uint32, i32_val(16)), i32_val(1)) );
rounded = add(i32_ty, rounded, as_uint32);
auto res = select(exponent_not_all1s, rounded, as_uint32);
auto preserve_nan = and_( i1_ty, exponent_all1s, icmp_ne(and_(i32_ty, as_uint32, i32_val(0xffff)), i32_val(0)) );
auto nan = or_(i32_ty, as_uint32, i32_val(0x10000));
res = select(preserve_nan, nan, res);
auto shifted = lshr(i32_ty, res, i32_val(16));
auto truncated = trunc(i16_ty, shifted);
return truncated;
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.bf16.f32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(v, "r");
cvt(res, operand);
// TODO: This is a hack to get the right type. We should be able to invoke
// the type converter
return builder.launch(rewriter, loc, i16_ty, false);
#endif
}
#ifdef USE_ROCM
static Value Fp8E5M2FNUZ_to_Fp16_oneValue(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
@@ -269,10 +374,10 @@ static Value Fp8E5M2FNUZ_to_Fp16_oneValue(
static SmallVector<Value>
Fp8E5M2FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(2);
result[0] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
result[1] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
return result;
SmallVector<Value> res(2);
res[0] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
res[1] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
return res;
}
static SmallVector<Value>
@@ -505,6 +610,53 @@ static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
return ret;
}
#endif
// ROCM type conversion between fp8 and bf16
#ifdef USE_ROCM
// fp8e4m3fnuz to bf16
static SmallVector<Value>
Fp8E4M3FNUZ_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 2);
auto ret = cvtFp8ToFp32(loc, rewriter, v[0], v[1], "fp8");
ret[0] = convertFp32ToBf16(loc, rewriter, ret[0]);
ret[1] = convertFp32ToBf16(loc, rewriter, ret[1]);
return ret;
}
// bf16 to fp8e4m3fnuz
static SmallVector<Value>
Bf16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 2);
auto v0 = convertBf16ToFp32(loc, rewriter, v[0]);
auto v1 = convertBf16ToFp32(loc, rewriter, v[1]);
return cvtFp32ToFp8(loc, rewriter, v0, v1, "fp8");
}
// fp8e5m2fnuz to bf16
static SmallVector<Value>
Fp8E5M2FNUZ_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 2);
auto ret = cvtFp8ToFp32(loc, rewriter, v[0], v[1], "bf8");
ret[0] = convertFp32ToBf16(loc, rewriter, ret[0]);
ret[1] = convertFp32ToBf16(loc, rewriter, ret[1]);
return ret;
}
// bf16 to fp8e5m2fnuz
static SmallVector<Value>
Bf16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
assert(v.size() == 2);
auto v0 = convertBf16ToFp32(loc, rewriter, v[0]);
auto v1 = convertBf16ToFp32(loc, rewriter, v[1]);
return cvtFp32ToFp8(loc, rewriter, v0, v1, "bf8");
}
#endif
/* ----- FP8E4M3B15 ------ */
// This data-type is a variant of the standard FP8E4M3 format.
// It was designed for fast software conversion to FP16 on
@@ -752,7 +904,6 @@ static const std::string Fp16_to_Fp8E4M3B15x4 =
// Note: when handled by software, this format
// does not handle denormals and has
// more than a single NaN values.
#ifdef USE_ROCM
static Value Fp8E4M3FNUZ_to_Fp16_oneValue(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
@@ -1524,24 +1675,6 @@ struct FpToFpOpConversion
: ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit),
computeCapability(computeCapability) {}
static Value convertBf16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
#ifdef USE_ROCM
auto as_int16 = bitcast(v, i16_ty);
auto as_int32 = zext(i32_ty, as_int16);
auto shifted = shl(i32_ty, as_int32, i32_val(16));
return(bitcast(shifted, f32_ty));
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.f32.bf16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(v, "h");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
#endif
}
static Value convertFp16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
@@ -1557,37 +1690,6 @@ struct FpToFpOpConversion
#endif
}
static Value convertFp32ToBf16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
#ifdef USE_ROCM
auto as_uint32 = bitcast(v, i32_ty);
auto check_exponent = and_(i32_ty, xor_(i32_ty, as_uint32, i32_val(0xffffffff)), i32_val(0x7f800000));
auto exponent_not_all1s = icmp_ne(check_exponent, i32_val(0));
auto exponent_all1s = icmp_eq(check_exponent, i32_val(0));
auto rounded = add(i32_ty, i32_val(0x7fff), and_(i32_ty, lshr(i32_ty, as_uint32, i32_val(16)), i32_val(1)) );
rounded = add(i32_ty, rounded, as_uint32);
auto res = select(exponent_not_all1s, rounded, as_uint32);
auto preserve_nan = and_( i1_ty, exponent_all1s, icmp_ne(and_(i32_ty, as_uint32, i32_val(0xffff)), i32_val(0)) );
auto nan = or_(i32_ty, as_uint32, i32_val(0x10000));
res = select(preserve_nan, nan, res);
auto shifted = lshr(i32_ty, res, i32_val(16));
auto truncated = trunc(i16_ty, shifted);
return truncated;
#else
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.bf16.f32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(v, "r");
cvt(res, operand);
// TODO: This is a hack to get the right type. We should be able to invoke
// the type converter
return builder.launch(rewriter, loc, i16_ty, false);
#endif
}
static Value convertFp32ToFp16NZ(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
@@ -1650,6 +1752,8 @@ struct FpToFpOpConversion
// F8 -> BF16
#ifdef USE_ROCM
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
{{F8E5M2FNUZTyID, BF16TyID}, Fp8E5M2FNUZ_to_Bf16},
{{F8E4M3FNUZTyID, BF16TyID}, Fp8E4M3FNUZ_to_Bf16},
#else
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
@@ -1658,10 +1762,20 @@ struct FpToFpOpConversion
// BF16 -> F8
#ifdef USE_ROCM
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
{{BF16TyID, F8E5M2FNUZTyID}, Bf16_to_Fp8E5M2FNUZ},
{{BF16TyID, F8E4M3FNUZTyID}, Bf16_to_Fp8E4M3FNUZ},
#else
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2(computeCapability >= 90)},
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
// F32 -> F8
#endif
// F32 <-> F8
#ifdef USE_ROCM
{{F32TyID, F8E4M3FNUZTyID}, Fp32_to_Fp8E4M3FNUZ},
{{F32TyID, F8E5M2FNUZTyID}, Fp32_to_Fp8E5M2FNUZ},
{{F8E4M3FNUZTyID, F32TyID}, Fp8E4M3FNUZ_to_Fp32},
{{F8E5M2FNUZTyID, F32TyID}, Fp8E5M2FNUZ_to_Fp32},
#else
{{F32TyID, F8E4M3TyID}, Fp32_to_Fp8E4M3Nv},
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
#endif
@@ -1727,15 +1841,26 @@ struct FpToFpOpConversion
}
bool useFP16IntermediateSrc =
#ifdef USE_ROCM
srcElementType.isF32();
srcElementType.isF32() &&
!(computeCapability >= 300 &&
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2FNUZ()));
#else
srcElementType.isF32() &&
!(computeCapability >= 90 &&
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
#endif
bool isDstFP32 = dstElementType.isF32();
bool useFP16IntermediateDst =
#ifdef USE_ROCM
dstElementType.isF32() &&
!(computeCapability >= 300 &&
(srcElementType.isFloat8E4M3FNUZ() || srcElementType.isFloat8E5M2FNUZ()));
#else
dstElementType.isF32();
#endif
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
Type dstType = isDstFP32 ? f16_ty : dstElementType;
Type dstType = useFP16IntermediateDst ? f16_ty : dstElementType;
auto cvtFunc = getConversionFunc(srcType, dstType);
SmallVector<Value> inVals;
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
@@ -1748,7 +1873,7 @@ struct FpToFpOpConversion
SmallVector<Value> outVals = cvtFunc(loc, rewriter, inVals);
assert(outVals.size() == inVals.size());
outVals.resize(std::min(numElements, operands.size()));
if (isDstFP32)
if (useFP16IntermediateDst)
for (Value &v : outVals)
v = convertFp16ToFp32(loc, rewriter, v);
// Pack values
@@ -1763,10 +1888,10 @@ template <typename OP>
Value EmitDualBF16ElementwiseOp(Location loc,
ConversionPatternRewriter &rewriter,
MultipleOperandsRange operands) {
auto v0 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]);
auto v1 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][1]);
auto v0 = convertBf16ToFp32(loc, rewriter, operands[0][0]);
auto v1 = convertBf16ToFp32(loc, rewriter, operands[0][1]);
auto result = rewriter.create<OP>(loc, f32_ty, v0, v1);
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, result);
return convertFp32ToBf16(loc, rewriter, result);
}
struct CmpIOpConversion
@@ -2187,7 +2312,7 @@ struct SIToFPOpConversion
return outVals;
} else if (outElemTy.isBF16()) {
auto value = rewriter.create<LLVM::SIToFPOp>(loc, f32_ty, operands[0][0]);
return {FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, value)};
return {convertFp32ToBf16(loc, rewriter, value)};
} else {
return {rewriter.create<LLVM::SIToFPOp>(loc, elemTy, operands[0][0])};
}
@@ -2208,7 +2333,7 @@ struct FPToSIOpConversion
auto inElemTy = getElementType(op.getIn());
if (inElemTy.isBF16()) {
auto value =
FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]);
convertBf16ToFp32(loc, rewriter, operands[0][0]);
return {rewriter.create<LLVM::FPToSIOp>(loc, elemTy, value)};
} else {
return {rewriter.create<LLVM::FPToSIOp>(loc, elemTy, operands[0][0])};
@@ -2231,8 +2356,7 @@ struct ExtFOpConversion
if (inElemTy.isBF16()) {
auto outElemTy = getElementType(op.getOut());
assert(outElemTy.isF32() && "unsupported conversion");
return {
FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0])};
return {convertBf16ToFp32(loc, rewriter, operands[0][0])};
} else {
return {rewriter.create<LLVM::FPExtOp>(loc, elemTy, operands[0][0])};
}
@@ -2254,8 +2378,7 @@ struct TruncFOpConversion
if (outElemTy.isBF16()) {
auto inElemTy = getElementType(op.getIn());
assert(inElemTy.isF32() && "unsupported conversion");
return {
FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, operands[0][0])};
return {convertFp32ToBf16(loc, rewriter, operands[0][0])};
} else {
return {rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0][0])};
}