From 231ed2c6560f08a5019edea835f4f4e8c1735e35 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 2 Jun 2024 04:46:52 +0800 Subject: [PATCH] compute aliased buffer idxs pre reduce (#4788) --- tinygrad/codegen/linearizer.py | 55 ++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 451173beb7..b75e2bf121 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -177,8 +177,7 @@ class Linearizer(Kernel): self.loop_uops.update(new_loops) return tuple(new_loops.values()) - def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], - global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs): + def index_local_aliases(self, global_idxs, local_idxs, reduce_idxs, upcast_idxs, full_upcast_idxs): def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]): replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1) for s in local_sizes: @@ -194,33 +193,38 @@ class Linearizer(Kernel): replace_idxs.append(full_var) return replace_idxs - # compute local aliases - modify idxs if necessary for TC - alias_buf_idxs = [] - for i in (local_alias:=self.local_alias[reduceop]): - localbuf_idx = self.bufs.index(local_alias[i]) - buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())] - if (tc:=self.tensor_core): - min_alias_idx = min(local_alias.keys()) - replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx]) - for n in range(len(tc.threads)): - buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals - for n in range(tc.num_upcasts()): - buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts - if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}") - alias_buf_idxs.append((i, localbuf_idx, buf_idxs,)) - - # reduce loop - loop_ctx = self.render_loop(reduce_idxs, 2) - - # define accumulator - modify idxs if necessary for TC - out_buf = -1 if self.group_for_reduces else 0 + # compute local aliases + alias_buf_idxs: DefaultDict[LazyOp, List[Tuple[int, int, List]]] = defaultdict(list) + for op, local_alias in self.local_alias.items(): + for i in local_alias: + localbuf_idx = self.bufs.index(local_alias[i]) + buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())] + if (tc:=self.tensor_core): + min_alias_idx = min(local_alias.keys()) + replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx]) + for n in range(len(tc.threads)): + buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals + for n in range(tc.num_upcasts()): + buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts + if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}") + alias_buf_idxs[op].append((i, localbuf_idx, buf_idxs)) + # modify idxs if necessary for TC if (tc:=self.tensor_core): replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2]) for n in range(len(tc.threads)): local_idxs[n] = replace_acc_idxs[n] # replace locals for n in range(len(replace_acc_idxs)-len(tc.threads)): upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts - if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}") + if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+upcast_idxs}") + return alias_buf_idxs + + def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], + global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs, alias_buf_idxs): + # reduce loop + loop_ctx = self.render_loop(reduce_idxs, 2) + + # define accumulator - modify idxs if necessary for TC + out_buf = -1 if self.group_for_reduces else 0 accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx) # store local aliases @@ -253,7 +257,7 @@ class Linearizer(Kernel): assert not locals_to_store, "storing locals isn't supported here" # load earlybufs - loaded_buffers.update({b:self.global_load(self.bufs.index(local_alias[i]) if i in self.local_alias else i, + loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[reduceop][i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs}) # run early AST (with reduce) @@ -395,10 +399,11 @@ class Linearizer(Kernel): full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])] reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501 fake_reduce_idxs = [x*0 for x in reduce_idxs] + alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs) # render reduce op for reduceop in [self.reduceop] if self.reduceop is not None else []: accs,loaded_buffers,fake_reduce_idxs,local_idxs,upcast_idxs = \ - self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs) + self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[reduceop]) # load latebufs loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \