mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Calculate correct warp ids for small matrices (#1180)
Fixing https://github.com/openai/triton/issues/1162 Add tests 16x16x16
This commit is contained in:
@@ -434,15 +434,15 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
// Get a warpId for M axis.
|
||||
Value getWarpM(int M) const {
|
||||
auto matShape = helper.getMmaMatShape();
|
||||
return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matShape[0]));
|
||||
auto matInstrShape = helper.getMmaInstrShape();
|
||||
return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matInstrShape[0]));
|
||||
}
|
||||
|
||||
// Get a warpId for N axis.
|
||||
Value getWarpN(int N) const {
|
||||
auto matShape = helper.getMmaMatShape();
|
||||
auto matInstrShape = helper.getMmaInstrShape();
|
||||
Value warpMN = udiv(warp, i32_val(wpt[0]));
|
||||
return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matShape[1]));
|
||||
return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matInstrShape[1]));
|
||||
}
|
||||
|
||||
// Get the mmaInstrShape deducing either from $a or $b.
|
||||
|
||||
@@ -391,8 +391,6 @@ public:
|
||||
getSwizzledSharedPtrs(loc, inVec, srcTy, dstSharedLayout, dstElemTy,
|
||||
smemObj, rewriter, offsetVals, srcStrides);
|
||||
|
||||
std::map<unsigned, Value> cache0;
|
||||
std::map<unsigned, Value> cache1;
|
||||
for (unsigned i = 0; i < numElems; ++i) {
|
||||
if (i % minVec == 0)
|
||||
word = undef(wordTy);
|
||||
@@ -733,8 +731,9 @@ private:
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value warpId0 = urem(warpId, warpsPerCTA[0]);
|
||||
Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
|
||||
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), idx_val(shape[0] / 16));
|
||||
Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]),
|
||||
idx_val(shape[1] / 8));
|
||||
Value offWarp0 = mul(warpId0, idx_val(16));
|
||||
Value offWarp1 = mul(warpId1, idx_val(8));
|
||||
|
||||
|
||||
@@ -219,16 +219,6 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout,
|
||||
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
}
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isAmpere()) {
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
8 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
16 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
} else {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
}
|
||||
|
||||
@@ -1083,7 +1083,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
|
||||
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
|
||||
for shape in [(64, 64, 64)]
|
||||
for shape in [(64, 64, 64), (16, 16, 16)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16', 'float32']
|
||||
@@ -1228,8 +1228,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if K > 16 or N > 16 or M > 16:
|
||||
# XXX: skip small sizes because they are not vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
|
||||
Reference in New Issue
Block a user