cleanup render_reduceop (#4807)

* update acc key

* refactor return type

* remove return type

* run all reduces

* set acc key [run_process_replay]

* local_idxs are copied in render_reduceop [run_process_replay]
This commit is contained in:
qazal
2024-06-03 19:39:02 +08:00
committed by GitHub
parent dd84f7d35e
commit 6e0c16dfb0

View File

@@ -96,8 +96,7 @@ class Linearizer(Kernel):
acc_count = 0
for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
# todo: when multiple reduceops are supported, clearly disambiguate and test acc load keys are unique for each reduceop
key = f"{acc is not None}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
if key not in self.load_cache:
if acc is not None:
self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
@@ -309,9 +308,9 @@ class Linearizer(Kernel):
# end the late reduce loop
self.load_cache.clear()
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
# been rewritten with fake end_local_idxs.
return (accs, loaded_buffers, fake_reduce_idxs, local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)], upcast_idxs)
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
# been rewritten with fake end_local_idxs.
return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
def linearize(self):
@@ -401,9 +400,9 @@ class Linearizer(Kernel):
fake_reduce_idxs = [x*0 for x in reduce_idxs]
alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
# render reduce op
for reduceop in [self.reduceop] if self.reduceop is not None else []:
accs,loaded_buffers,fake_reduce_idxs,local_idxs,upcast_idxs = \
self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[reduceop])
for reduceop in self.reduceops:
local_idxs, upcast_idxs = self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs,
full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[reduceop])
# load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \