[BACKEND] Fix small matmul dot (#1463)

https://github.com/openai/triton/issues/1449

In theory, we might be able to support even 8x8 dot if we also wrap
around `cOff`.
This commit is contained in:
Keren Zhou
2023-04-01 19:05:05 -07:00
committed by GitHub
parent 801bb9d3b5
commit 0855cacdd8
2 changed files with 24 additions and 19 deletions

View File

@@ -106,25 +106,29 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
Value kMatArr = kOrder == 1 ? s1 : s0;
Value nkMatArr = kOrder == 1 ? s0 : s1;
// matrix coordinate inside a CTA, the matrix layout is [2x2wpt] for A and
// [2wptx2] for B. e.g. Setting wpt=3, The data layout for A(kOrder=1) is
// |0 0 1 1 2 2| -> 0,1,2 are the warpids
// |0 0 1 1 2 2|
//
// for B(kOrder=0) is
// |0 0| -> 0,1,2 are the warpids
// |1 1|
// |2 2|
// Matrix coordinates inside a CTA,
// the matrix layout is [2wpt[0], 2] for A and [2, 2wpt[1]] for B.
// e.g., Setting wpt=4, the data layout for A(kOrder=1) is
// |0 0| -> 0,1,2,3 are the warpids
// |0 0|
// |1 1|
// |1 1|
// |2 2|
// |2 2|
// |3 3|
// |3 3|
//
// for B(kOrder=0) is
// |0 1 2 3 0 1 2 3| -> 0,1,2,3 are the warpids
// |0 1 2 3 0 1 2 3|
// Note, for each warp, it handles a 2x2 matrices, that is the coordinate
// address (s0,s1) annotates.
Value matOff[2];
matOff[kOrder ^ 1] =
add(mul(warpId, i32_val(warpOffStride)), // warp offset
mul(nkMatArr, i32_val(matArrStride))); // matrix offset inside a warp
add(mul(warpId, i32_val(warpOffStride)), // warp offset (kOrder=1)
mul(nkMatArr,
i32_val(matArrStride))); // matrix offset inside a warp (kOrder=1)
matOff[kOrder] = kMatArr;
// Physical offset (before swizzling)
@@ -138,7 +142,13 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
SmallVector<Value> offs(numPtrs);
Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase));
Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape)));
// To prevent out-of-bound access of B when wpt * 16 > tile_size.
// In such a case, we need to wrap around the offset of B.
// |0 1 2 3 0 1 2 3| -> | 0(0) 1(1) 2(2) 3(3) |
// |0 1 2 3 0 1 2 3| | 0(0) 1(1) 2(2) 3(3) |
// ~~~~~~~ out-of-bound access
Value sOff = urem(add(sOffInMat, mul(sMatOff, i32_val(sMatShape))),
i32_val(tileShape[order[1]]));
for (int i = 0; i < numPtrs; ++i) {
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
cMatOffI = xor_(cMatOffI, phase);
@@ -631,12 +641,6 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
tensorTy.getShape().end());
// TODO[Superjomn]: transB cannot be accessed in ConvertLayoutOp.
bool transB = false;
if (transB) {
std::swap(shape[0], shape[1]);
}
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;

View File

@@ -1283,6 +1283,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
[64, 128, 128, 4],
[32, 128, 64, 2],
[64, 64, 32, 4],
[32, 32, 128, 16],
[128, 128, 64, 2],
[64, 128, 128, 2]]
for allow_tf32 in [True]
@@ -1436,7 +1437,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
if K > 16 or N > 16 or M > 16:
if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
# XXX: skip small sizes because they are not vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx