mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Support MMA V3 with float16 accumulator (#2049)
Also fixes a bug exposed in convertLayout lowering for float16. We shouldn't be using cvt.pack.sat.u16.s32 to pack 16bits values as this needs to take a 32bits register. Also this prevented optimization at llvm ir level.
This commit is contained in:
@@ -697,6 +697,15 @@ private:
|
||||
return success();
|
||||
}
|
||||
|
||||
// Pack two 16-bit values into a 32-bit register.
|
||||
static Value pack16bitsTo32(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value hb, Value lb) {
|
||||
hb = zext(i32_ty, bitcast(hb, i16_ty));
|
||||
lb = zext(i32_ty, bitcast(lb, i16_ty));
|
||||
Value pack = or_(lb, shl(hb, i32_val(16)));
|
||||
return pack;
|
||||
}
|
||||
|
||||
// blocked -> shared.
|
||||
// Swizzling in shared memory to avoid bank conflict. Normally used for
|
||||
// A/B operands of dots.
|
||||
@@ -768,14 +777,14 @@ private:
|
||||
numElemsPerSwizzlingRow, true);
|
||||
|
||||
Value addr = gep(elemPtrTy, smemBase, offset);
|
||||
Value data0 = rewriter.create<triton::nvgpu::CvtPackOp>(
|
||||
loc, i32_ty, inVals[elemIdx + 1], inVals[elemIdx + 0]);
|
||||
Value data1 = rewriter.create<triton::nvgpu::CvtPackOp>(
|
||||
loc, i32_ty, inVals[elemIdx + 3], inVals[elemIdx + 2]);
|
||||
Value data2 = rewriter.create<triton::nvgpu::CvtPackOp>(
|
||||
loc, i32_ty, inVals[elemIdx + 5], inVals[elemIdx + 4]);
|
||||
Value data3 = rewriter.create<triton::nvgpu::CvtPackOp>(
|
||||
loc, i32_ty, inVals[elemIdx + 7], inVals[elemIdx + 6]);
|
||||
Value data0 = pack16bitsTo32(rewriter, loc, inVals[elemIdx + 1],
|
||||
inVals[elemIdx + 0]);
|
||||
Value data1 = pack16bitsTo32(rewriter, loc, inVals[elemIdx + 3],
|
||||
inVals[elemIdx + 2]);
|
||||
Value data2 = pack16bitsTo32(rewriter, loc, inVals[elemIdx + 5],
|
||||
inVals[elemIdx + 4]);
|
||||
Value data3 = pack16bitsTo32(rewriter, loc, inVals[elemIdx + 7],
|
||||
inVals[elemIdx + 6]);
|
||||
|
||||
rewriter.create<triton::nvgpu::StoreMatrixOp>(
|
||||
loc, bitcast(addr, ptrI8SharedTy),
|
||||
|
||||
Reference in New Issue
Block a user