mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Added missing #ifdef and fixed code style
This commit is contained in:
@@ -253,9 +253,12 @@ for
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
#ifdef USE_ROCM
|
||||
threadsPerWarp[order[rank-1]] = 64 / prevLanes;
|
||||
#else
|
||||
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
|
||||
#endif
|
||||
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
|
||||
}]>
|
||||
|
||||
@@ -426,9 +426,9 @@ struct MMA16816ConversionHelper {
|
||||
loc(loc), ctx(mmaLayout.getContext()) {
|
||||
helper.deduceMmaType(dotOperand);
|
||||
#ifdef USE_ROCM
|
||||
Value warpSize = i32_val(64);
|
||||
Value warpSize = i32_val(64);
|
||||
#else
|
||||
Value warpSize = i32_val(32);
|
||||
Value warpSize = i32_val(32);
|
||||
#endif
|
||||
lane = urem(thread, warpSize);
|
||||
warp = udiv(thread, warpSize);
|
||||
|
||||
@@ -709,8 +709,11 @@ private:
|
||||
Value _2 = i32_val(2);
|
||||
Value _4 = i32_val(4);
|
||||
Value _16 = i32_val(16);
|
||||
Value _32 = i32_val(32);
|
||||
Value _64 = i32_val(64);
|
||||
#ifdef USE_ROCM
|
||||
Value warpSize = i32_val(64);
|
||||
#else
|
||||
Value warpSize = i32_val(32);
|
||||
#endif
|
||||
Value _fpw0 = i32_val(fpw[0]);
|
||||
Value _fpw1 = i32_val(fpw[1]);
|
||||
|
||||
@@ -721,8 +724,8 @@ private:
|
||||
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
|
||||
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
|
||||
|
||||
Value lane = urem(thread, _64);
|
||||
Value warp = udiv(thread, _64);
|
||||
Value lane = urem(thread, warpSize);
|
||||
Value warp = udiv(thread, warpSize);
|
||||
|
||||
Value warp0 = urem(warp, i32_val(wpt[0]));
|
||||
Value warp12 = udiv(warp, i32_val(wpt[0]));
|
||||
@@ -807,7 +810,11 @@ private:
|
||||
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
|
||||
idx_val(_warpsPerCTA[1])};
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
#ifdef USE_ROCM
|
||||
Value warpSize = idx_val(64);
|
||||
#else
|
||||
Value warpSize = idx_val(32);
|
||||
#endif
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), idx_val(shape[0] / 16));
|
||||
|
||||
@@ -1856,18 +1856,18 @@ if torch.version.hip is not None:
|
||||
]
|
||||
else:
|
||||
layouts = [
|
||||
# MmaLayout(version=1, warps_per_cta=[1, 4]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
|
||||
# MmaLayout(version=1, warps_per_cta=[4, 1]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
|
||||
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
|
||||
# BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
|
||||
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
|
||||
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
||||
]
|
||||
# MmaLayout(version=1, warps_per_cta=[1, 4]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
|
||||
# MmaLayout(version=1, warps_per_cta=[4, 1]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
|
||||
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
|
||||
# BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
|
||||
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
|
||||
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("shape", [(128, 128)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
|
||||
Reference in New Issue
Block a user