[BACKEND] Support MMA V3 with float16 accumulator (#2049)

Also fixes a bug exposed in convertLayout lowering for float16. We
shouldn't be using cvt.pack.sat.u16.s32 to pack 16bits values as this
needs to take a 32bits register. Also this prevented optimization at
llvm ir level.
This commit is contained in:
Thomas
2023-08-07 15:55:44 -07:00
committed by GitHub
parent 521cfae44d
commit 98523bcc48
6 changed files with 79 additions and 72 deletions

View File

@@ -697,6 +697,15 @@ private:
return success();
}
// Pack two 16-bit values into a 32-bit register.
static Value pack16bitsTo32(ConversionPatternRewriter &rewriter, Location loc,
Value hb, Value lb) {
hb = zext(i32_ty, bitcast(hb, i16_ty));
lb = zext(i32_ty, bitcast(lb, i16_ty));
Value pack = or_(lb, shl(hb, i32_val(16)));
return pack;
}
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
@@ -768,14 +777,14 @@ private:
numElemsPerSwizzlingRow, true);
Value addr = gep(elemPtrTy, smemBase, offset);
Value data0 = rewriter.create<triton::nvgpu::CvtPackOp>(
loc, i32_ty, inVals[elemIdx + 1], inVals[elemIdx + 0]);
Value data1 = rewriter.create<triton::nvgpu::CvtPackOp>(
loc, i32_ty, inVals[elemIdx + 3], inVals[elemIdx + 2]);
Value data2 = rewriter.create<triton::nvgpu::CvtPackOp>(
loc, i32_ty, inVals[elemIdx + 5], inVals[elemIdx + 4]);
Value data3 = rewriter.create<triton::nvgpu::CvtPackOp>(
loc, i32_ty, inVals[elemIdx + 7], inVals[elemIdx + 6]);
Value data0 = pack16bitsTo32(rewriter, loc, inVals[elemIdx + 1],
inVals[elemIdx + 0]);
Value data1 = pack16bitsTo32(rewriter, loc, inVals[elemIdx + 3],
inVals[elemIdx + 2]);
Value data2 = pack16bitsTo32(rewriter, loc, inVals[elemIdx + 5],
inVals[elemIdx + 4]);
Value data3 = pack16bitsTo32(rewriter, loc, inVals[elemIdx + 7],
inVals[elemIdx + 6]);
rewriter.create<triton::nvgpu::StoreMatrixOp>(
loc, bitcast(addr, ptrI8SharedTy),