tk: small changes from fa bwd (#13732)

This commit is contained in:
wozeparrot
2025-12-16 22:44:36 -08:00
committed by GitHub
parent fda73c8180
commit 5151a341b3
2 changed files with 10 additions and 5 deletions

View File

@@ -387,7 +387,7 @@ class Group:
idxs = tuple(idx * rv.length if i == 3 else idx for i, idx in enumerate(idxs))
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
for outer in self.ker.range(dst.shape[-2]):
for outer in self.ker.range(dst.shape[-2], track=False):
src_i += outer * reductions + (laneid % reductions)
src_load = srcf[src_i]

View File

@@ -89,7 +89,7 @@ class Kernel(AbstractContextManager):
# end stores stores
store_uops = []
for _i in range(stores):
for _ in range(stores):
store = self.store_stack.pop()[0]
if hasattr(store, '_uop'): store_uops.append(store._uop)
else: store_uops.append(store)
@@ -97,7 +97,12 @@ class Kernel(AbstractContextManager):
return uop.end(*rngs).sink(arg=KernelInfo(name=self.name, opts_to_apply=())).simplify()
def endrange(self):
def endrange(self, ranges:int=1):
last_store = self.store_stack.pop()
last_range = self.range_stack.pop()
return last_store[1].after(last_store[0].end(last_range._rng)).reshape(last_store[1].shape)
rngs = []
for _ in range(ranges):
last_range = self.range_stack.pop()
rngs.append(last_range._rng)
return last_store[1].after(last_store[0].end(*rngs)).reshape(last_store[1].shape)