mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
support type conversion between fp8 formats and bf16/fp32 with HW instructions on MI300 (#414)
* add type conversion between fp8 and bf16/fp32..
This commit is contained in:
@@ -56,9 +56,10 @@ static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
|
||||
// ROCM utility functions for data type conversion
|
||||
#ifdef USE_ROCM
|
||||
// convert fp16 to fp32
|
||||
static Value cvtFp16ToFp32(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create("v_cvt_f32_f16");
|
||||
auto res = builder.newOperand("=v");
|
||||
@@ -67,9 +68,10 @@ static Value cvtFp16ToFp32(Location loc,
|
||||
return builder.launch(rewriter, loc, f32_ty, false);
|
||||
}
|
||||
|
||||
// convert fp32 to f16
|
||||
static Value cvtFp32ToFp16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create("v_cvt_f16_f32");
|
||||
auto res = builder.newOperand("=v");
|
||||
@@ -78,34 +80,8 @@ static Value cvtFp32ToFp16(Location loc,
|
||||
return builder.launch(rewriter, loc, f16_ty, false);
|
||||
}
|
||||
|
||||
static SmallVector<Value> convert_val_Fp16_to_Fp8(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value v0, Value v1, const std::string& fp8_format) {
|
||||
assert(fp8_format == "fp8" or fp8_format == "bf8");
|
||||
std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32";
|
||||
|
||||
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0);
|
||||
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1);
|
||||
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create(ins_str);
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand0 = builder.newOperand(f32_0, "v");
|
||||
auto operand1 = builder.newOperand(f32_1, "v");
|
||||
cvt(res, operand0, operand1);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto a1 = bitcast(fp8x4Vec, fp8x4VecTy);
|
||||
|
||||
SmallVector<Value> ret(2);
|
||||
ret[0] = extract_element(i8_ty, a1, i32_val(0));
|
||||
ret[1] = extract_element(i8_ty, a1, i32_val(1));
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static SmallVector<Value> convert_val_Fp8_to_Fp16(
|
||||
// convert fp8 to fp32
|
||||
static SmallVector<Value> cvtFp8ToFp32(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value v0, Value v1, const std::string& fp8_format) {
|
||||
assert(fp8_format == "fp8" or fp8_format == "bf8");
|
||||
@@ -126,15 +102,96 @@ static SmallVector<Value> convert_val_Fp8_to_Fp16(
|
||||
auto fp32x2VecTy = vec_ty(f32_ty, 2);
|
||||
auto fp32x2Vec = bitcast(i64v, fp32x2VecTy);
|
||||
|
||||
auto f32_0 = extract_element(f32_ty, fp32x2Vec, i32_val(0));
|
||||
auto f32_1 = extract_element(f32_ty, fp32x2Vec, i32_val(1));
|
||||
|
||||
SmallVector<Value> ret(2);
|
||||
ret[0] = cvtFp32ToFp16(loc, rewriter, f32_0);
|
||||
ret[1] = cvtFp32ToFp16(loc, rewriter, f32_1);
|
||||
ret[0] = extract_element(f32_ty, fp32x2Vec, i32_val(0));
|
||||
ret[1] = extract_element(f32_ty, fp32x2Vec, i32_val(1));
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
// convert fp32 to fp8
|
||||
static SmallVector<Value> cvtFp32ToFp8(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value v0, Value v1, const std::string& fp8_format) {
|
||||
assert(fp8_format == "fp8" or fp8_format == "bf8");
|
||||
std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32";
|
||||
|
||||
GCNBuilder builder;
|
||||
auto &cvt = *builder.create(ins_str);
|
||||
auto res = builder.newOperand("=v");
|
||||
auto operand0 = builder.newOperand(v0, "v");
|
||||
auto operand1 = builder.newOperand(v1, "v");
|
||||
cvt(res, operand0, operand1);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto a1 = bitcast(fp8x4Vec, fp8x4VecTy);
|
||||
|
||||
SmallVector<Value> ret(2);
|
||||
ret[0] = extract_element(i8_ty, a1, i32_val(0));
|
||||
ret[1] = extract_element(i8_ty, a1, i32_val(1));
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
// convert fp16 to fp8 for MI300 format
|
||||
static SmallVector<Value> convert_val_Fp16_to_Fp8(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value v0, Value v1, const std::string& fp8_format) {
|
||||
assert(fp8_format == "fp8" or fp8_format == "bf8");
|
||||
|
||||
// Convert fp16 to fp32
|
||||
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0);
|
||||
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1);
|
||||
|
||||
// Convert fp32 to fp8
|
||||
return cvtFp32ToFp8(loc, rewriter, f32_0, f32_1, fp8_format);
|
||||
}
|
||||
|
||||
// convert fp8 to fp16 for mi300 formats
|
||||
static SmallVector<Value> convert_val_Fp8_to_Fp16(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value v0, Value v1, const std::string& fp8_format) {
|
||||
|
||||
// Convert fp8 to fp32
|
||||
SmallVector<Value> ret = cvtFp8ToFp32(loc, rewriter, v0, v1, fp8_format);
|
||||
|
||||
// Convert fp32 to fp16
|
||||
ret[0] = cvtFp32ToFp16(loc, rewriter, ret[0]);
|
||||
ret[1] = cvtFp32ToFp16(loc, rewriter, ret[1]);
|
||||
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value> Fp32_to_Fp8E5M2FNUZ(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value>& v) {
|
||||
assert (v.size() == 2);
|
||||
return cvtFp32ToFp8(loc, rewriter, v[0], v[1], "bf8");
|
||||
}
|
||||
|
||||
static SmallVector<Value> Fp32_to_Fp8E4M3FNUZ(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value>& v) {
|
||||
assert (v.size() == 2);
|
||||
return cvtFp32ToFp8(loc, rewriter, v[0], v[1], "fp8");
|
||||
}
|
||||
|
||||
static SmallVector<Value> Fp8E5M2FNUZ_to_Fp32(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value>& v) {
|
||||
assert (v.size() == 2);
|
||||
return cvtFp8ToFp32(loc, rewriter, v[0], v[1], "bf8");
|
||||
}
|
||||
|
||||
static SmallVector<Value> Fp8E4M3FNUZ_to_Fp32(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value>& v) {
|
||||
assert (v.size() == 2);
|
||||
return cvtFp8ToFp32(loc, rewriter, v[0], v[1], "fp8");
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -187,7 +244,6 @@ ConverterT Fp16_to_Fp8E5M2FNUZ(int computeCapability) {
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
@@ -231,6 +287,55 @@ static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) {
|
||||
}
|
||||
#endif
|
||||
|
||||
static Value convertBf16ToFp32(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
#ifdef USE_ROCM
|
||||
auto as_int16 = bitcast(v, i16_ty);
|
||||
auto as_int32 = zext(i32_ty, as_int16);
|
||||
auto shifted = shl(i32_ty, as_int32, i32_val(16));
|
||||
return(bitcast(shifted, f32_ty));
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.f32.bf16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(v, "h");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f32_ty, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
static Value convertFp32ToBf16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
#ifdef USE_ROCM
|
||||
auto as_uint32 = bitcast(v, i32_ty);
|
||||
auto check_exponent = and_(i32_ty, xor_(i32_ty, as_uint32, i32_val(0xffffffff)), i32_val(0x7f800000));
|
||||
auto exponent_not_all1s = icmp_ne(check_exponent, i32_val(0));
|
||||
auto exponent_all1s = icmp_eq(check_exponent, i32_val(0));
|
||||
auto rounded = add(i32_ty, i32_val(0x7fff), and_(i32_ty, lshr(i32_ty, as_uint32, i32_val(16)), i32_val(1)) );
|
||||
rounded = add(i32_ty, rounded, as_uint32);
|
||||
auto res = select(exponent_not_all1s, rounded, as_uint32);
|
||||
|
||||
auto preserve_nan = and_( i1_ty, exponent_all1s, icmp_ne(and_(i32_ty, as_uint32, i32_val(0xffff)), i32_val(0)) );
|
||||
auto nan = or_(i32_ty, as_uint32, i32_val(0x10000));
|
||||
res = select(preserve_nan, nan, res);
|
||||
|
||||
auto shifted = lshr(i32_ty, res, i32_val(16));
|
||||
auto truncated = trunc(i16_ty, shifted);
|
||||
return truncated;
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rn.bf16.f32");
|
||||
auto res = builder.newOperand("=h");
|
||||
auto operand = builder.newOperand(v, "r");
|
||||
cvt(res, operand);
|
||||
// TODO: This is a hack to get the right type. We should be able to invoke
|
||||
// the type converter
|
||||
return builder.launch(rewriter, loc, i16_ty, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static Value Fp8E5M2FNUZ_to_Fp16_oneValue(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
@@ -269,10 +374,10 @@ static Value Fp8E5M2FNUZ_to_Fp16_oneValue(
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
SmallVector<Value> result(2);
|
||||
result[0] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
|
||||
result[1] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
|
||||
return result;
|
||||
SmallVector<Value> res(2);
|
||||
res[0] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]);
|
||||
res[1] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]);
|
||||
return res;
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
@@ -505,6 +610,53 @@ static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) {
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
// ROCM type conversion between fp8 and bf16
|
||||
#ifdef USE_ROCM
|
||||
|
||||
// fp8e4m3fnuz to bf16
|
||||
static SmallVector<Value>
|
||||
Fp8E4M3FNUZ_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
assert(v.size() == 2);
|
||||
auto ret = cvtFp8ToFp32(loc, rewriter, v[0], v[1], "fp8");
|
||||
ret[0] = convertFp32ToBf16(loc, rewriter, ret[0]);
|
||||
ret[1] = convertFp32ToBf16(loc, rewriter, ret[1]);
|
||||
return ret;
|
||||
}
|
||||
|
||||
// bf16 to fp8e4m3fnuz
|
||||
static SmallVector<Value>
|
||||
Bf16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
assert(v.size() == 2);
|
||||
auto v0 = convertBf16ToFp32(loc, rewriter, v[0]);
|
||||
auto v1 = convertBf16ToFp32(loc, rewriter, v[1]);
|
||||
return cvtFp32ToFp8(loc, rewriter, v0, v1, "fp8");
|
||||
}
|
||||
|
||||
// fp8e5m2fnuz to bf16
|
||||
static SmallVector<Value>
|
||||
Fp8E5M2FNUZ_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
assert(v.size() == 2);
|
||||
auto ret = cvtFp8ToFp32(loc, rewriter, v[0], v[1], "bf8");
|
||||
ret[0] = convertFp32ToBf16(loc, rewriter, ret[0]);
|
||||
ret[1] = convertFp32ToBf16(loc, rewriter, ret[1]);
|
||||
return ret;
|
||||
}
|
||||
|
||||
// bf16 to fp8e5m2fnuz
|
||||
static SmallVector<Value>
|
||||
Bf16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SmallVector<Value> &v) {
|
||||
assert(v.size() == 2);
|
||||
auto v0 = convertBf16ToFp32(loc, rewriter, v[0]);
|
||||
auto v1 = convertBf16ToFp32(loc, rewriter, v[1]);
|
||||
return cvtFp32ToFp8(loc, rewriter, v0, v1, "bf8");
|
||||
}
|
||||
#endif
|
||||
|
||||
/* ----- FP8E4M3B15 ------ */
|
||||
// This data-type is a variant of the standard FP8E4M3 format.
|
||||
// It was designed for fast software conversion to FP16 on
|
||||
@@ -752,7 +904,6 @@ static const std::string Fp16_to_Fp8E4M3B15x4 =
|
||||
// Note: when handled by software, this format
|
||||
// does not handle denormals and has
|
||||
// more than a single NaN values.
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static Value Fp8E4M3FNUZ_to_Fp16_oneValue(
|
||||
Location loc, ConversionPatternRewriter &rewriter, Value v) {
|
||||
@@ -1524,24 +1675,6 @@ struct FpToFpOpConversion
|
||||
: ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static Value convertBf16ToFp32(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
#ifdef USE_ROCM
|
||||
auto as_int16 = bitcast(v, i16_ty);
|
||||
auto as_int32 = zext(i32_ty, as_int16);
|
||||
auto shifted = shl(i32_ty, as_int32, i32_val(16));
|
||||
return(bitcast(shifted, f32_ty));
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.f32.bf16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(v, "h");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f32_ty, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
static Value convertFp16ToFp32(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
@@ -1557,37 +1690,6 @@ struct FpToFpOpConversion
|
||||
#endif
|
||||
}
|
||||
|
||||
static Value convertFp32ToBf16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
#ifdef USE_ROCM
|
||||
auto as_uint32 = bitcast(v, i32_ty);
|
||||
auto check_exponent = and_(i32_ty, xor_(i32_ty, as_uint32, i32_val(0xffffffff)), i32_val(0x7f800000));
|
||||
auto exponent_not_all1s = icmp_ne(check_exponent, i32_val(0));
|
||||
auto exponent_all1s = icmp_eq(check_exponent, i32_val(0));
|
||||
auto rounded = add(i32_ty, i32_val(0x7fff), and_(i32_ty, lshr(i32_ty, as_uint32, i32_val(16)), i32_val(1)) );
|
||||
rounded = add(i32_ty, rounded, as_uint32);
|
||||
auto res = select(exponent_not_all1s, rounded, as_uint32);
|
||||
|
||||
auto preserve_nan = and_( i1_ty, exponent_all1s, icmp_ne(and_(i32_ty, as_uint32, i32_val(0xffff)), i32_val(0)) );
|
||||
auto nan = or_(i32_ty, as_uint32, i32_val(0x10000));
|
||||
res = select(preserve_nan, nan, res);
|
||||
|
||||
auto shifted = lshr(i32_ty, res, i32_val(16));
|
||||
auto truncated = trunc(i16_ty, shifted);
|
||||
return truncated;
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rn.bf16.f32");
|
||||
auto res = builder.newOperand("=h");
|
||||
auto operand = builder.newOperand(v, "r");
|
||||
cvt(res, operand);
|
||||
// TODO: This is a hack to get the right type. We should be able to invoke
|
||||
// the type converter
|
||||
return builder.launch(rewriter, loc, i16_ty, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
static Value convertFp32ToFp16NZ(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
@@ -1650,6 +1752,8 @@ struct FpToFpOpConversion
|
||||
// F8 -> BF16
|
||||
#ifdef USE_ROCM
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
|
||||
{{F8E5M2FNUZTyID, BF16TyID}, Fp8E5M2FNUZ_to_Bf16},
|
||||
{{F8E4M3FNUZTyID, BF16TyID}, Fp8E4M3FNUZ_to_Bf16},
|
||||
#else
|
||||
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)},
|
||||
{{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16},
|
||||
@@ -1658,10 +1762,20 @@ struct FpToFpOpConversion
|
||||
// BF16 -> F8
|
||||
#ifdef USE_ROCM
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
|
||||
{{BF16TyID, F8E5M2FNUZTyID}, Bf16_to_Fp8E5M2FNUZ},
|
||||
{{BF16TyID, F8E4M3FNUZTyID}, Bf16_to_Fp8E4M3FNUZ},
|
||||
#else
|
||||
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2(computeCapability >= 90)},
|
||||
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv},
|
||||
// F32 -> F8
|
||||
#endif
|
||||
|
||||
// F32 <-> F8
|
||||
#ifdef USE_ROCM
|
||||
{{F32TyID, F8E4M3FNUZTyID}, Fp32_to_Fp8E4M3FNUZ},
|
||||
{{F32TyID, F8E5M2FNUZTyID}, Fp32_to_Fp8E5M2FNUZ},
|
||||
{{F8E4M3FNUZTyID, F32TyID}, Fp8E4M3FNUZ_to_Fp32},
|
||||
{{F8E5M2FNUZTyID, F32TyID}, Fp8E5M2FNUZ_to_Fp32},
|
||||
#else
|
||||
{{F32TyID, F8E4M3TyID}, Fp32_to_Fp8E4M3Nv},
|
||||
{{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2},
|
||||
#endif
|
||||
@@ -1727,15 +1841,26 @@ struct FpToFpOpConversion
|
||||
}
|
||||
bool useFP16IntermediateSrc =
|
||||
#ifdef USE_ROCM
|
||||
srcElementType.isF32();
|
||||
srcElementType.isF32() &&
|
||||
!(computeCapability >= 300 &&
|
||||
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2FNUZ()));
|
||||
#else
|
||||
srcElementType.isF32() &&
|
||||
!(computeCapability >= 90 &&
|
||||
(dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2()));
|
||||
#endif
|
||||
bool isDstFP32 = dstElementType.isF32();
|
||||
|
||||
bool useFP16IntermediateDst =
|
||||
#ifdef USE_ROCM
|
||||
dstElementType.isF32() &&
|
||||
!(computeCapability >= 300 &&
|
||||
(srcElementType.isFloat8E4M3FNUZ() || srcElementType.isFloat8E5M2FNUZ()));
|
||||
#else
|
||||
dstElementType.isF32();
|
||||
#endif
|
||||
|
||||
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
|
||||
Type dstType = isDstFP32 ? f16_ty : dstElementType;
|
||||
Type dstType = useFP16IntermediateDst ? f16_ty : dstElementType;
|
||||
auto cvtFunc = getConversionFunc(srcType, dstType);
|
||||
SmallVector<Value> inVals;
|
||||
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
|
||||
@@ -1748,7 +1873,7 @@ struct FpToFpOpConversion
|
||||
SmallVector<Value> outVals = cvtFunc(loc, rewriter, inVals);
|
||||
assert(outVals.size() == inVals.size());
|
||||
outVals.resize(std::min(numElements, operands.size()));
|
||||
if (isDstFP32)
|
||||
if (useFP16IntermediateDst)
|
||||
for (Value &v : outVals)
|
||||
v = convertFp16ToFp32(loc, rewriter, v);
|
||||
// Pack values
|
||||
@@ -1763,10 +1888,10 @@ template <typename OP>
|
||||
Value EmitDualBF16ElementwiseOp(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
MultipleOperandsRange operands) {
|
||||
auto v0 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]);
|
||||
auto v1 = FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][1]);
|
||||
auto v0 = convertBf16ToFp32(loc, rewriter, operands[0][0]);
|
||||
auto v1 = convertBf16ToFp32(loc, rewriter, operands[0][1]);
|
||||
auto result = rewriter.create<OP>(loc, f32_ty, v0, v1);
|
||||
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, result);
|
||||
return convertFp32ToBf16(loc, rewriter, result);
|
||||
}
|
||||
|
||||
struct CmpIOpConversion
|
||||
@@ -2187,7 +2312,7 @@ struct SIToFPOpConversion
|
||||
return outVals;
|
||||
} else if (outElemTy.isBF16()) {
|
||||
auto value = rewriter.create<LLVM::SIToFPOp>(loc, f32_ty, operands[0][0]);
|
||||
return {FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, value)};
|
||||
return {convertFp32ToBf16(loc, rewriter, value)};
|
||||
} else {
|
||||
return {rewriter.create<LLVM::SIToFPOp>(loc, elemTy, operands[0][0])};
|
||||
}
|
||||
@@ -2208,7 +2333,7 @@ struct FPToSIOpConversion
|
||||
auto inElemTy = getElementType(op.getIn());
|
||||
if (inElemTy.isBF16()) {
|
||||
auto value =
|
||||
FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]);
|
||||
convertBf16ToFp32(loc, rewriter, operands[0][0]);
|
||||
return {rewriter.create<LLVM::FPToSIOp>(loc, elemTy, value)};
|
||||
} else {
|
||||
return {rewriter.create<LLVM::FPToSIOp>(loc, elemTy, operands[0][0])};
|
||||
@@ -2231,8 +2356,7 @@ struct ExtFOpConversion
|
||||
if (inElemTy.isBF16()) {
|
||||
auto outElemTy = getElementType(op.getOut());
|
||||
assert(outElemTy.isF32() && "unsupported conversion");
|
||||
return {
|
||||
FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0])};
|
||||
return {convertBf16ToFp32(loc, rewriter, operands[0][0])};
|
||||
} else {
|
||||
return {rewriter.create<LLVM::FPExtOp>(loc, elemTy, operands[0][0])};
|
||||
}
|
||||
@@ -2254,8 +2378,7 @@ struct TruncFOpConversion
|
||||
if (outElemTy.isBF16()) {
|
||||
auto inElemTy = getElementType(op.getIn());
|
||||
assert(inElemTy.isF32() && "unsupported conversion");
|
||||
return {
|
||||
FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, operands[0][0])};
|
||||
return {convertFp32ToBf16(loc, rewriter, operands[0][0])};
|
||||
} else {
|
||||
return {rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0][0])};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user