[BACKEND] Calculate correct warp ids for small matrices (#1180)

Fixing https://github.com/openai/triton/issues/1162

Add tests 16x16x16
This commit is contained in:
Keren Zhou
2023-02-14 00:28:03 -05:00
committed by GitHub
parent 30db959dae
commit 6413c7b9de
4 changed files with 12 additions and 21 deletions

View File

@@ -434,15 +434,15 @@ struct MMA16816ConversionHelper {
// Get a warpId for M axis.
Value getWarpM(int M) const {
auto matShape = helper.getMmaMatShape();
return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matShape[0]));
auto matInstrShape = helper.getMmaInstrShape();
return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matInstrShape[0]));
}
// Get a warpId for N axis.
Value getWarpN(int N) const {
auto matShape = helper.getMmaMatShape();
auto matInstrShape = helper.getMmaInstrShape();
Value warpMN = udiv(warp, i32_val(wpt[0]));
return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matShape[1]));
return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matInstrShape[1]));
}
// Get the mmaInstrShape deducing either from $a or $b.

View File

@@ -391,8 +391,6 @@ public:
getSwizzledSharedPtrs(loc, inVec, srcTy, dstSharedLayout, dstElemTy,
smemObj, rewriter, offsetVals, srcStrides);
std::map<unsigned, Value> cache0;
std::map<unsigned, Value> cache1;
for (unsigned i = 0; i < numElems; ++i) {
if (i % minVec == 0)
word = undef(wordTy);
@@ -733,8 +731,9 @@ private:
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
Value warpId0 = urem(warpId, warpsPerCTA[0]);
Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), idx_val(shape[0] / 16));
Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]),
idx_val(shape[1] / 8));
Value offWarp0 = mul(warpId0, idx_val(16));
Value offWarp1 = mul(warpId1, idx_val(8));

View File

@@ -219,16 +219,6 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout,
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
} else if (mmaLayout.isVolta()) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
} else {
llvm_unreachable("Unexpected mma version");
}
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}

View File

@@ -1083,7 +1083,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
for shape in [(64, 64, 64)]
for shape in [(64, 64, 64), (16, 16, 16)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float16', 'float32']
@@ -1228,8 +1228,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if K > 16 or N > 16 or M > 16:
# XXX: skip small sizes because they are not vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32' and allow_tf32: