[BACKEND] Prevent double rounding when doing f32 -> fp8 (#2583)

This commit is contained in:
Thomas Raoux
2023-11-01 22:32:16 -07:00
committed by GitHub
parent d0098da7b1
commit 218492cd65

View File

@@ -755,11 +755,11 @@ struct FpToFpOpConversion
return builder.launch(rewriter, loc, i16_ty, false);
}
static Value convertFp32ToFp16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
static Value convertFp32ToFp16NZ(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.f16.f32");
auto &cvt = *builder.create("cvt.rz.f16.f32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(v, "r");
cvt(res, operand);
@@ -858,7 +858,7 @@ struct FpToFpOpConversion
}
if (useFP16IntermediateSrc)
for (Value &v : inVals)
v = convertFp32ToFp16(loc, rewriter, v);
v = convertFp32ToFp16NZ(loc, rewriter, v);
inVals.resize(numElements, undef(typeConverter->convertType(srcType)));
SmallVector<Value> outVals = cvtFunc(loc, rewriter, inVals);
assert(outVals.size() == inVals.size());