fixed warp size in lowering reduce op (#471)

This commit is contained in:
Shucai Xiao
2024-01-18 09:38:41 -06:00
committed by GitHub
parent e7033218d6
commit 2c7d850c2d

View File

@@ -449,7 +449,8 @@ private:
Location loc = op.getLoc();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout);
Value warpSize = i32_val(wavefront_size);
Value laneId = urem(threadId, warpSize);
Value zero = i32_val(0);
@@ -481,8 +482,6 @@ private:
icmp_eq(laneIdModSizeInterWarps, zero);
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
auto srcLayout = helper.getSrcLayout();
unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
#if USE_ROCM