mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Prevent double rounding when doing f32 -> fp8 (#2583)
This commit is contained in:
@@ -755,11 +755,11 @@ struct FpToFpOpConversion
|
||||
return builder.launch(rewriter, loc, i16_ty, false);
|
||||
}
|
||||
|
||||
static Value convertFp32ToFp16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
static Value convertFp32ToFp16NZ(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rn.f16.f32");
|
||||
auto &cvt = *builder.create("cvt.rz.f16.f32");
|
||||
auto res = builder.newOperand("=h");
|
||||
auto operand = builder.newOperand(v, "r");
|
||||
cvt(res, operand);
|
||||
@@ -858,7 +858,7 @@ struct FpToFpOpConversion
|
||||
}
|
||||
if (useFP16IntermediateSrc)
|
||||
for (Value &v : inVals)
|
||||
v = convertFp32ToFp16(loc, rewriter, v);
|
||||
v = convertFp32ToFp16NZ(loc, rewriter, v);
|
||||
inVals.resize(numElements, undef(typeConverter->convertType(srcType)));
|
||||
SmallVector<Value> outVals = cvtFunc(loc, rewriter, inVals);
|
||||
assert(outVals.size() == inVals.size());
|
||||
|
||||
Reference in New Issue
Block a user