tk: small fixes (#13345)

* fix: handle case where final uop isn't a tk wrapped one

* clean: remove after from mma
This commit is contained in:
wozeparrot
2025-11-19 00:58:50 -08:00
committed by GitHub
parent e4fbde5b3b
commit be72b78dcb
2 changed files with 9 additions and 5 deletions

View File

@@ -56,7 +56,7 @@ class Group:
self.ker.push_store(dst_store, dst)
return dst.after(dst_store).reshape(dst.shape)
def mma_AB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT, after=True):
def mma_AB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT):
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
@@ -77,9 +77,9 @@ class Group:
c_store = UOp.group(*c_i).end(height, width, inner)
self.ker.push_store(c_store, c)
return c.after(c_store).reshape(c.shape) if after else c_store
return c.after(c_store).reshape(c.shape)
def mma_ABt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT, after=True):
def mma_ABt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT):
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
@@ -100,7 +100,7 @@ class Group:
c_store = UOp.group(*c_i).end(height, width, inner)
self.ker.push_store(c_store, c)
return c.after(c_store).reshape(c.shape) if after else c_store
return c.after(c_store).reshape(c.shape)
map_rid = 400
def map(self, a:ALL_TILES, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]):

View File

@@ -80,7 +80,11 @@ class Kernel(AbstractContextManager):
rngs = []
while self.range_stack: rngs.append(self.range_stack.pop(0)._rng)
return self.store_stack.pop()[0]._uop.end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify()
last_store = self.store_stack.pop()[0]
if hasattr(last_store, '_uop'): uop = last_store._uop
else: uop = last_store
return uop.end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify()
def endrange(self):
last_store = self.store_stack.pop()