From be72b78dcbd82f1ef7f5201cca802776d576323a Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 19 Nov 2025 00:58:50 -0800 Subject: [PATCH] tk: small fixes (#13345) * fix: handle case where final uop isn't a tk wrapped one * clean: remove after from mma --- extra/thunder/tiny/tk/group.py | 8 ++++---- extra/thunder/tiny/tk/kernel.py | 6 +++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index 71a71a6ea5..c2e4cf009d 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -56,7 +56,7 @@ class Group: self.ker.push_store(dst_store, dst) return dst.after(dst_store).reshape(dst.shape) - def mma_AB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT, after=True): + def mma_AB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT): c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b) assert self.warps == 1 @@ -77,9 +77,9 @@ class Group: c_store = UOp.group(*c_i).end(height, width, inner) self.ker.push_store(c_store, c) - return c.after(c_store).reshape(c.shape) if after else c_store + return c.after(c_store).reshape(c.shape) - def mma_ABt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT, after=True): + def mma_ABt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT): c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b) assert self.warps == 1 @@ -100,7 +100,7 @@ class Group: c_store = UOp.group(*c_i).end(height, width, inner) self.ker.push_store(c_store, c) - return c.after(c_store).reshape(c.shape) if after else c_store + return c.after(c_store).reshape(c.shape) map_rid = 400 def map(self, a:ALL_TILES, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]): diff --git a/extra/thunder/tiny/tk/kernel.py b/extra/thunder/tiny/tk/kernel.py index 29df120e12..7a3270967f 100644 --- a/extra/thunder/tiny/tk/kernel.py +++ b/extra/thunder/tiny/tk/kernel.py @@ -80,7 +80,11 @@ class Kernel(AbstractContextManager): rngs = [] while self.range_stack: rngs.append(self.range_stack.pop(0)._rng) - return self.store_stack.pop()[0]._uop.end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify() + last_store = self.store_stack.pop()[0] + if hasattr(last_store, '_uop'): uop = last_store._uop + else: uop = last_store + + return uop.end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify() def endrange(self): last_store = self.store_stack.pop()