refactor linearize to render_block, P1 (#4839)

* refactor to render_block

* move rendering the reduce to its own thing

* add todo and cleanups [run_process_replay]

* inplace update of idxs [run_process_replay]
This commit is contained in:
qazal
2024-06-06 20:31:43 +08:00
committed by GitHub
parent b932ce0f1d
commit eeb5a7af39

View File

@@ -4,7 +4,7 @@ import itertools, math, functools
from collections import defaultdict
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
from tinygrad.helpers import colored, DEBUG, dedup, prod, getenv, to_function_name
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
@@ -373,6 +373,7 @@ class Linearizer(Kernel):
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
# set global/local size
self.global_size: Optional[List[int]] = None
@@ -389,29 +390,20 @@ class Linearizer(Kernel):
if self.global_size is not None: self.global_size += [1]*(3-len(self.global_size))
if self.local_size is not None: self.local_size += [1]*(3-len(self.local_size))
# define idxs for aliased buffers TODO: this doesn't belong in Kernel, but it can't exist in Block either (because of multireduce tensor cores)
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
# parse AST
self.load_cache: Dict[str, UOp] = {}
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
accs: Dict[LazyOp, List[UOp]] = {}
self.load_cache: Dict[str, UOp] = {}
# define indexs
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
fake_reduce_idxs = [x*0 for x in reduce_idxs]
alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
# render reduce op
# render reduceops by depth
for reduceop in self.reduceops:
local_idxs, upcast_idxs = self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs,
full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[reduceop])
self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
# load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
# run late AST (without the store)
for op in self.ast:
val = self.ast_parse(op.src[0], accs, None, loaded_buffers)
self.global_store(op.arg.idx, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
# maybe graph the uops
if DEBUG >= 5: self.uops.print()
@@ -424,6 +416,25 @@ class Linearizer(Kernel):
self.applied_opts_cache = self.applied_opts[:]
return self
def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs,
alias_buf_idxs, loaded_buffers, accs) -> List[List[UOp]]:
assert len(reduceops:=dedup(x for x in outputs if x.op in ReduceOps)) <= 1, "max one reduceop per block"
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
fake_reduce_idxs = [x*0 for x in reduce_idxs]
if len(reduceops) != 0:
# TODO: delete render_reduceop and move the logic for group_for_reduces to Block
local_idxs[:], upcast_idxs[:] = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs, full_upcast_idxs,
reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r])
return accs[r]
# load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
# run late AST (without the store)
store_vals = {op.arg.idx:self.ast_parse(op.src[0], accs, None, loaded_buffers) for op in self.ast}
return [self.global_store(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) for i, val in store_vals.items()]
def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
if cache is None: cache = {}
if x in cache: return cache[x]