Added missing #ifdef and fixed code style

This commit is contained in:
B1tway
2023-03-07 11:32:52 +00:00
parent a931a50719
commit b5dc18d7c9
4 changed files with 29 additions and 19 deletions

View File

@@ -253,9 +253,12 @@ for
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
#ifdef USE_ROCM
threadsPerWarp[order[rank-1]] = 64 / prevLanes;
#else
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
#endif
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
}]>

View File

@@ -426,9 +426,9 @@ struct MMA16816ConversionHelper {
loc(loc), ctx(mmaLayout.getContext()) {
helper.deduceMmaType(dotOperand);
#ifdef USE_ROCM
Value warpSize = i32_val(64);
Value warpSize = i32_val(64);
#else
Value warpSize = i32_val(32);
Value warpSize = i32_val(32);
#endif
lane = urem(thread, warpSize);
warp = udiv(thread, warpSize);

View File

@@ -709,8 +709,11 @@ private:
Value _2 = i32_val(2);
Value _4 = i32_val(4);
Value _16 = i32_val(16);
Value _32 = i32_val(32);
Value _64 = i32_val(64);
#ifdef USE_ROCM
Value warpSize = i32_val(64);
#else
Value warpSize = i32_val(32);
#endif
Value _fpw0 = i32_val(fpw[0]);
Value _fpw1 = i32_val(fpw[1]);
@@ -721,8 +724,8 @@ private:
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
Value lane = urem(thread, _64);
Value warp = udiv(thread, _64);
Value lane = urem(thread, warpSize);
Value warp = udiv(thread, warpSize);
Value warp0 = urem(warp, i32_val(wpt[0]));
Value warp12 = udiv(warp, i32_val(wpt[0]));
@@ -807,7 +810,11 @@ private:
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
idx_val(_warpsPerCTA[1])};
Value threadId = getThreadId(rewriter, loc);
#ifdef USE_ROCM
Value warpSize = idx_val(64);
#else
Value warpSize = idx_val(32);
#endif
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), idx_val(shape[0] / 16));

View File

@@ -1856,18 +1856,18 @@ if torch.version.hip is not None:
]
else:
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
# BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]
# MmaLayout(version=1, warps_per_cta=[1, 4]),
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
# BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]
@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])