mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tk: small changes from fa bwd (#13732)
This commit is contained in:
@@ -387,7 +387,7 @@ class Group:
|
||||
idxs = tuple(idx * rv.length if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
|
||||
|
||||
for outer in self.ker.range(dst.shape[-2]):
|
||||
for outer in self.ker.range(dst.shape[-2], track=False):
|
||||
src_i += outer * reductions + (laneid % reductions)
|
||||
|
||||
src_load = srcf[src_i]
|
||||
|
||||
@@ -89,7 +89,7 @@ class Kernel(AbstractContextManager):
|
||||
|
||||
# end stores stores
|
||||
store_uops = []
|
||||
for _i in range(stores):
|
||||
for _ in range(stores):
|
||||
store = self.store_stack.pop()[0]
|
||||
if hasattr(store, '_uop'): store_uops.append(store._uop)
|
||||
else: store_uops.append(store)
|
||||
@@ -97,7 +97,12 @@ class Kernel(AbstractContextManager):
|
||||
|
||||
return uop.end(*rngs).sink(arg=KernelInfo(name=self.name, opts_to_apply=())).simplify()
|
||||
|
||||
def endrange(self):
|
||||
def endrange(self, ranges:int=1):
|
||||
last_store = self.store_stack.pop()
|
||||
last_range = self.range_stack.pop()
|
||||
return last_store[1].after(last_store[0].end(last_range._rng)).reshape(last_store[1].shape)
|
||||
|
||||
rngs = []
|
||||
for _ in range(ranges):
|
||||
last_range = self.range_stack.pop()
|
||||
rngs.append(last_range._rng)
|
||||
|
||||
return last_store[1].after(last_store[0].end(*rngs)).reshape(last_store[1].shape)
|
||||
|
||||
Reference in New Issue
Block a user