mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix small matmul dot (#1463)
https://github.com/openai/triton/issues/1449 In theory, we might be able to support even 8x8 dot if we also wrap around `cOff`.
This commit is contained in:
@@ -106,25 +106,29 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
|
||||
Value kMatArr = kOrder == 1 ? s1 : s0;
|
||||
Value nkMatArr = kOrder == 1 ? s0 : s1;
|
||||
|
||||
// matrix coordinate inside a CTA, the matrix layout is [2x2wpt] for A and
|
||||
// [2wptx2] for B. e.g. Setting wpt=3, The data layout for A(kOrder=1) is
|
||||
// |0 0 1 1 2 2| -> 0,1,2 are the warpids
|
||||
// |0 0 1 1 2 2|
|
||||
//
|
||||
// for B(kOrder=0) is
|
||||
// |0 0| -> 0,1,2 are the warpids
|
||||
// |1 1|
|
||||
// |2 2|
|
||||
// Matrix coordinates inside a CTA,
|
||||
// the matrix layout is [2wpt[0], 2] for A and [2, 2wpt[1]] for B.
|
||||
// e.g., Setting wpt=4, the data layout for A(kOrder=1) is
|
||||
// |0 0| -> 0,1,2,3 are the warpids
|
||||
// |0 0|
|
||||
// |1 1|
|
||||
// |1 1|
|
||||
// |2 2|
|
||||
// |2 2|
|
||||
// |3 3|
|
||||
// |3 3|
|
||||
//
|
||||
// for B(kOrder=0) is
|
||||
// |0 1 2 3 0 1 2 3| -> 0,1,2,3 are the warpids
|
||||
// |0 1 2 3 0 1 2 3|
|
||||
// Note, for each warp, it handles a 2x2 matrices, that is the coordinate
|
||||
// address (s0,s1) annotates.
|
||||
|
||||
Value matOff[2];
|
||||
matOff[kOrder ^ 1] =
|
||||
add(mul(warpId, i32_val(warpOffStride)), // warp offset
|
||||
mul(nkMatArr, i32_val(matArrStride))); // matrix offset inside a warp
|
||||
add(mul(warpId, i32_val(warpOffStride)), // warp offset (kOrder=1)
|
||||
mul(nkMatArr,
|
||||
i32_val(matArrStride))); // matrix offset inside a warp (kOrder=1)
|
||||
matOff[kOrder] = kMatArr;
|
||||
|
||||
// Physical offset (before swizzling)
|
||||
@@ -138,7 +142,13 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
|
||||
|
||||
SmallVector<Value> offs(numPtrs);
|
||||
Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase));
|
||||
Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape)));
|
||||
// To prevent out-of-bound access of B when wpt * 16 > tile_size.
|
||||
// In such a case, we need to wrap around the offset of B.
|
||||
// |0 1 2 3 0 1 2 3| -> | 0(0) 1(1) 2(2) 3(3) |
|
||||
// |0 1 2 3 0 1 2 3| | 0(0) 1(1) 2(2) 3(3) |
|
||||
// ~~~~~~~ out-of-bound access
|
||||
Value sOff = urem(add(sOffInMat, mul(sMatOff, i32_val(sMatShape))),
|
||||
i32_val(tileShape[order[1]]));
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
|
||||
cMatOffI = xor_(cMatOffI, phase);
|
||||
@@ -631,12 +641,6 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
|
||||
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
|
||||
tensorTy.getShape().end());
|
||||
|
||||
// TODO[Superjomn]: transB cannot be accessed in ConvertLayoutOp.
|
||||
bool transB = false;
|
||||
if (transB) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
}
|
||||
|
||||
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
|
||||
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
|
||||
|
||||
|
||||
@@ -1283,6 +1283,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
[64, 128, 128, 4],
|
||||
[32, 128, 64, 2],
|
||||
[64, 64, 32, 4],
|
||||
[32, 32, 128, 16],
|
||||
[128, 128, 64, 2],
|
||||
[64, 128, 128, 2]]
|
||||
for allow_tf32 in [True]
|
||||
@@ -1436,7 +1437,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
if K > 16 or N > 16 or M > 16:
|
||||
if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
|
||||
# XXX: skip small sizes because they are not vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
Reference in New Issue
Block a user