mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
fixed warp size in lowering reduce op (#471)
This commit is contained in:
@@ -449,7 +449,8 @@ private:
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(32);
|
||||
unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout);
|
||||
Value warpSize = i32_val(wavefront_size);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value zero = i32_val(0);
|
||||
|
||||
@@ -481,8 +482,6 @@ private:
|
||||
icmp_eq(laneIdModSizeInterWarps, zero);
|
||||
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
|
||||
|
||||
auto srcLayout = helper.getSrcLayout();
|
||||
unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout);
|
||||
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
#if USE_ROCM
|
||||
|
||||
Reference in New Issue
Block a user