tk: don't flatten in clear (#13249)

This commit is contained in:
wozeparrot
2025-11-13 13:38:01 -08:00
committed by GitHub
parent 6b1bae6614
commit 4ada51618f

View File

@@ -30,10 +30,10 @@ class Group:
def clear(self, reg:UOp, value:float=0):
assert self.warps == 1
i = UOp.range(reg.size, Group.clear_rid)
Group.clear_rid += 1
rngs_for_shape = tuple(UOp.range(dim, Group.clear_rid + i) for i, dim in enumerate(reg.shape))
Group.clear_rid += len(reg.shape)
reg_store = reg.reshape((reg.size,))[i].store(value).end(i)
reg_store = reg[*rngs_for_shape].store(value).end(*rngs_for_shape)
self.ker.push_store(reg_store, reg)
return reg.after(reg_store).reshape(reg.shape)