feat: don't use output reg as local reduce reg (#13203)

This commit is contained in:
wozeparrot
2025-11-11 14:35:16 -08:00
committed by GitHub
parent ece1415def
commit 787f0070ed

View File

@@ -32,7 +32,11 @@ class Group:
i = UOp.range(reg.size, Group.clear_rid)
Group.clear_rid += 1
return reg.reshape((reg.size,))[i].set(value, end=i).after(reg).reshape(reg.shape)
reg_store = reg.reshape((reg.size,))[i].store(value).end(i)
self.ker.push_store(reg_store, reg)
return reg.after(reg_store).reshape(reg.shape)
def zero(self, reg:UOp): return self.clear(reg, 0)
def neg_inf(self, reg:UOp): return self.clear(reg, -math.inf)
@@ -125,24 +129,38 @@ class Group:
red_local = UOp.placeholder((self.group_threads, 2), src.dtype.base, addrspace=AddrSpace.LOCAL, slot=slots.shared_slot)
slots.shared_slot += 1
red_reg = UOp.placeholder((2,), src.dtype.base, addrspace=AddrSpace.REG, slot=slots.register_slot)
slots.register_slot += 1
for height in self.ker.range(src.shape[-3], track=False):
i = UOp.range(red_reg.size, Group.clear_rid)
Group.clear_rid += 1
red_reg = red_reg.after(height, *[tkr._rng for tkr in self.ker.range_stack])
reg_store = red_reg.flatten()[i].store(0.).end(i)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
for i_outer in self.ker.range(2, track=False):
for width in self.ker.range(src.shape[-2], AxisType.REDUCE, track=False):
for i_inner in self.ker.range(4, AxisType.REDUCE, track=False):
elem_index = i_inner + 2 * (i_inner // 2) + i_outer * 2
vec_store = vec[height, 0, i_outer].store(op(vec[height, 0, i_outer], src[height, width, elem_index])).end(width, i_inner, i_outer)
vec = vec.after(vec_store).reshape(vec.shape)
reg_store = red_reg[i_outer].store(op(red_reg[i_outer], src[height, width, elem_index])).end(i_inner, width, i_outer)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# store to shared memory
for i_outer in self.ker.range(2, track=False):
red_local_store = red_local[self.laneid, i_outer].store(vec[height, 0, i_outer]).end(i_outer)
red_local = red_local.after(red_local_store).reshape(red_local.shape)
red_local_store = red_local[self.laneid, i_outer].store(red_reg[i_outer]).end(i_outer)
red_local = red_local.after(red_local_store.barrier()).reshape(red_local.shape)
# reduce from shared memory
for i_outer in self.ker.range(2, track=False):
for i_inner in self.ker.range(3, AxisType.REDUCE, track=False):
offset = (self.laneid // 4) * 4 + ((self.laneid + 1 + i_inner) % 4)
vec_store = vec[height, 0, i_outer].store(op(vec[height, 0, i_outer], red_local[offset, i_outer])).end(i_inner, i_outer)
offset = (self.laneid // 4) * 4 + ((self.laneid + i_inner + 1) % 4)
reg_store = red_reg[i_outer].store(op(red_reg[i_outer], red_local[offset, i_outer])).end(i_inner, i_outer)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# reduce with vec
for i_outer in self.ker.range(2, track=False):
vec_store = vec[height, 0, i_outer].store(op(vec[height, 0, i_outer], red_reg[i_outer])).end(i_outer, height)
self.ker.push_store(vec_store, vec)
return vec.after(vec_store).reshape(vec.shape)