mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tk: don't flatten in clear (#13249)
This commit is contained in:
@@ -30,10 +30,10 @@ class Group:
|
||||
def clear(self, reg:UOp, value:float=0):
|
||||
assert self.warps == 1
|
||||
|
||||
i = UOp.range(reg.size, Group.clear_rid)
|
||||
Group.clear_rid += 1
|
||||
rngs_for_shape = tuple(UOp.range(dim, Group.clear_rid + i) for i, dim in enumerate(reg.shape))
|
||||
Group.clear_rid += len(reg.shape)
|
||||
|
||||
reg_store = reg.reshape((reg.size,))[i].store(value).end(i)
|
||||
reg_store = reg[*rngs_for_shape].store(value).end(*rngs_for_shape)
|
||||
|
||||
self.ker.push_store(reg_store, reg)
|
||||
return reg.after(reg_store).reshape(reg.shape)
|
||||
|
||||
Reference in New Issue
Block a user