From 4ada51618f2b2e96baa6e71a3eaf4941edd53382 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 13 Nov 2025 13:38:01 -0800 Subject: [PATCH] tk: don't flatten in clear (#13249) --- extra/thunder/tiny/tk/group.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index 5ab1e4a01b..e86e599edd 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -30,10 +30,10 @@ class Group: def clear(self, reg:UOp, value:float=0): assert self.warps == 1 - i = UOp.range(reg.size, Group.clear_rid) - Group.clear_rid += 1 + rngs_for_shape = tuple(UOp.range(dim, Group.clear_rid + i) for i, dim in enumerate(reg.shape)) + Group.clear_rid += len(reg.shape) - reg_store = reg.reshape((reg.size,))[i].store(value).end(i) + reg_store = reg[*rngs_for_shape].store(value).end(*rngs_for_shape) self.ker.push_store(reg_store, reg) return reg.after(reg_store).reshape(reg.shape)