mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cleanup render_reduceop (#4807)
* update acc key * refactor return type * remove return type * run all reduces * set acc key [run_process_replay] * local_idxs are copied in render_reduceop [run_process_replay]
This commit is contained in:
@@ -96,8 +96,7 @@ class Linearizer(Kernel):
|
||||
acc_count = 0
|
||||
for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
|
||||
this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
|
||||
# todo: when multiple reduceops are supported, clearly disambiguate and test acc load keys are unique for each reduceop
|
||||
key = f"{acc is not None}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
||||
key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
|
||||
@@ -309,9 +308,9 @@ class Linearizer(Kernel):
|
||||
# end the late reduce loop
|
||||
self.load_cache.clear()
|
||||
|
||||
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
|
||||
# been rewritten with fake end_local_idxs.
|
||||
return (accs, loaded_buffers, fake_reduce_idxs, local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)], upcast_idxs)
|
||||
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
|
||||
# been rewritten with fake end_local_idxs.
|
||||
return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
|
||||
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
def linearize(self):
|
||||
@@ -401,9 +400,9 @@ class Linearizer(Kernel):
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
|
||||
# render reduce op
|
||||
for reduceop in [self.reduceop] if self.reduceop is not None else []:
|
||||
accs,loaded_buffers,fake_reduce_idxs,local_idxs,upcast_idxs = \
|
||||
self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[reduceop])
|
||||
for reduceop in self.reduceops:
|
||||
local_idxs, upcast_idxs = self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs,
|
||||
full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[reduceop])
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
|
||||
|
||||
Reference in New Issue
Block a user