mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
feat: don't use output reg as local reduce reg (#13203)
This commit is contained in:
@@ -32,7 +32,11 @@ class Group:
|
||||
|
||||
i = UOp.range(reg.size, Group.clear_rid)
|
||||
Group.clear_rid += 1
|
||||
return reg.reshape((reg.size,))[i].set(value, end=i).after(reg).reshape(reg.shape)
|
||||
|
||||
reg_store = reg.reshape((reg.size,))[i].store(value).end(i)
|
||||
|
||||
self.ker.push_store(reg_store, reg)
|
||||
return reg.after(reg_store).reshape(reg.shape)
|
||||
|
||||
def zero(self, reg:UOp): return self.clear(reg, 0)
|
||||
def neg_inf(self, reg:UOp): return self.clear(reg, -math.inf)
|
||||
@@ -125,24 +129,38 @@ class Group:
|
||||
red_local = UOp.placeholder((self.group_threads, 2), src.dtype.base, addrspace=AddrSpace.LOCAL, slot=slots.shared_slot)
|
||||
slots.shared_slot += 1
|
||||
|
||||
red_reg = UOp.placeholder((2,), src.dtype.base, addrspace=AddrSpace.REG, slot=slots.register_slot)
|
||||
slots.register_slot += 1
|
||||
|
||||
for height in self.ker.range(src.shape[-3], track=False):
|
||||
i = UOp.range(red_reg.size, Group.clear_rid)
|
||||
Group.clear_rid += 1
|
||||
red_reg = red_reg.after(height, *[tkr._rng for tkr in self.ker.range_stack])
|
||||
reg_store = red_reg.flatten()[i].store(0.).end(i)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
for i_outer in self.ker.range(2, track=False):
|
||||
for width in self.ker.range(src.shape[-2], AxisType.REDUCE, track=False):
|
||||
for i_inner in self.ker.range(4, AxisType.REDUCE, track=False):
|
||||
elem_index = i_inner + 2 * (i_inner // 2) + i_outer * 2
|
||||
vec_store = vec[height, 0, i_outer].store(op(vec[height, 0, i_outer], src[height, width, elem_index])).end(width, i_inner, i_outer)
|
||||
vec = vec.after(vec_store).reshape(vec.shape)
|
||||
reg_store = red_reg[i_outer].store(op(red_reg[i_outer], src[height, width, elem_index])).end(i_inner, width, i_outer)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
# store to shared memory
|
||||
for i_outer in self.ker.range(2, track=False):
|
||||
red_local_store = red_local[self.laneid, i_outer].store(vec[height, 0, i_outer]).end(i_outer)
|
||||
red_local = red_local.after(red_local_store).reshape(red_local.shape)
|
||||
red_local_store = red_local[self.laneid, i_outer].store(red_reg[i_outer]).end(i_outer)
|
||||
red_local = red_local.after(red_local_store.barrier()).reshape(red_local.shape)
|
||||
|
||||
# reduce from shared memory
|
||||
for i_outer in self.ker.range(2, track=False):
|
||||
for i_inner in self.ker.range(3, AxisType.REDUCE, track=False):
|
||||
offset = (self.laneid // 4) * 4 + ((self.laneid + 1 + i_inner) % 4)
|
||||
vec_store = vec[height, 0, i_outer].store(op(vec[height, 0, i_outer], red_local[offset, i_outer])).end(i_inner, i_outer)
|
||||
offset = (self.laneid // 4) * 4 + ((self.laneid + i_inner + 1) % 4)
|
||||
reg_store = red_reg[i_outer].store(op(red_reg[i_outer], red_local[offset, i_outer])).end(i_inner, i_outer)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
# reduce with vec
|
||||
for i_outer in self.ker.range(2, track=False):
|
||||
vec_store = vec[height, 0, i_outer].store(op(vec[height, 0, i_outer], red_reg[i_outer])).end(i_outer, height)
|
||||
|
||||
self.ker.push_store(vec_store, vec)
|
||||
return vec.after(vec_store).reshape(vec.shape)
|
||||
|
||||
Reference in New Issue
Block a user