From eeb5a7af39d3b1da94eca729aa4d9449f1cde409 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 6 Jun 2024 20:31:43 +0800 Subject: [PATCH] refactor `linearize` to render_block, P1 (#4839) * refactor to render_block * move rendering the reduce to its own thing * add todo and cleanups [run_process_replay] * inplace update of idxs [run_process_replay] --- tinygrad/codegen/linearizer.py | 47 +++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 0698f618bc..607c7c4d73 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -4,7 +4,7 @@ import itertools, math, functools from collections import defaultdict from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType -from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name +from tinygrad.helpers import colored, DEBUG, dedup, prod, getenv, to_function_name from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node @@ -373,6 +373,7 @@ class Linearizer(Kernel): global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0) local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501 upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])] + full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])] # set global/local size self.global_size: Optional[List[int]] = None @@ -389,29 +390,20 @@ class Linearizer(Kernel): if self.global_size is not None: self.global_size += [1]*(3-len(self.global_size)) if self.local_size is not None: self.local_size += [1]*(3-len(self.local_size)) + # define idxs for aliased buffers TODO: this doesn't belong in Kernel, but it can't exist in Block either (because of multireduce tensor cores) + reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501 + alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs) + # parse AST + self.load_cache: Dict[str, UOp] = {} loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {} accs: Dict[LazyOp, List[UOp]] = {} - self.load_cache: Dict[str, UOp] = {} - # define indexs - full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])] - reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501 - fake_reduce_idxs = [x*0 for x in reduce_idxs] - alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs) - # render reduce op + # render reduceops by depth for reduceop in self.reduceops: - local_idxs, upcast_idxs = self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs, - full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[reduceop]) + self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs) - # load latebufs - loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \ - for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer}) - - # run late AST (without the store) - for op in self.ast: - val = self.ast_parse(op.src[0], accs, None, loaded_buffers) - self.global_store(op.arg.idx, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) + self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs) # maybe graph the uops if DEBUG >= 5: self.uops.print() @@ -424,6 +416,25 @@ class Linearizer(Kernel): self.applied_opts_cache = self.applied_opts[:] return self + def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, + alias_buf_idxs, loaded_buffers, accs) -> List[List[UOp]]: + assert len(reduceops:=dedup(x for x in outputs if x.op in ReduceOps)) <= 1, "max one reduceop per block" + reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501 + fake_reduce_idxs = [x*0 for x in reduce_idxs] + + if len(reduceops) != 0: + # TODO: delete render_reduceop and move the logic for group_for_reduces to Block + local_idxs[:], upcast_idxs[:] = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs, full_upcast_idxs, + reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r]) + return accs[r] + + # load latebufs + loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \ + for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer}) + # run late AST (without the store) + store_vals = {op.arg.idx:self.ast_parse(op.src[0], accs, None, loaded_buffers) for op in self.ast} + return [self.global_store(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) for i, val in store_vals.items()] + def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501 if cache is None: cache = {} if x in cache: return cache[x]