mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
refactor linearize to render_block, P1 (#4839)
* refactor to render_block * move rendering the reduce to its own thing * add todo and cleanups [run_process_replay] * inplace update of idxs [run_process_replay]
This commit is contained in:
@@ -4,7 +4,7 @@ import itertools, math, functools
|
||||
from collections import defaultdict
|
||||
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
|
||||
from tinygrad.helpers import colored, DEBUG, dedup, prod, getenv, to_function_name
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
|
||||
@@ -373,6 +373,7 @@ class Linearizer(Kernel):
|
||||
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
|
||||
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501
|
||||
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
||||
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
||||
|
||||
# set global/local size
|
||||
self.global_size: Optional[List[int]] = None
|
||||
@@ -389,29 +390,20 @@ class Linearizer(Kernel):
|
||||
if self.global_size is not None: self.global_size += [1]*(3-len(self.global_size))
|
||||
if self.local_size is not None: self.local_size += [1]*(3-len(self.local_size))
|
||||
|
||||
# define idxs for aliased buffers TODO: this doesn't belong in Kernel, but it can't exist in Block either (because of multireduce tensor cores)
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
||||
alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
|
||||
|
||||
# parse AST
|
||||
self.load_cache: Dict[str, UOp] = {}
|
||||
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
|
||||
accs: Dict[LazyOp, List[UOp]] = {}
|
||||
self.load_cache: Dict[str, UOp] = {}
|
||||
|
||||
# define indexs
|
||||
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs)
|
||||
# render reduce op
|
||||
# render reduceops by depth
|
||||
for reduceop in self.reduceops:
|
||||
local_idxs, upcast_idxs = self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs,
|
||||
full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[reduceop])
|
||||
self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
|
||||
for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
|
||||
|
||||
# run late AST (without the store)
|
||||
for op in self.ast:
|
||||
val = self.ast_parse(op.src[0], accs, None, loaded_buffers)
|
||||
self.global_store(op.arg.idx, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
||||
self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
|
||||
|
||||
# maybe graph the uops
|
||||
if DEBUG >= 5: self.uops.print()
|
||||
@@ -424,6 +416,25 @@ class Linearizer(Kernel):
|
||||
self.applied_opts_cache = self.applied_opts[:]
|
||||
return self
|
||||
|
||||
def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs,
|
||||
alias_buf_idxs, loaded_buffers, accs) -> List[List[UOp]]:
|
||||
assert len(reduceops:=dedup(x for x in outputs if x.op in ReduceOps)) <= 1, "max one reduceop per block"
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
if len(reduceops) != 0:
|
||||
# TODO: delete render_reduceop and move the logic for group_for_reduces to Block
|
||||
local_idxs[:], upcast_idxs[:] = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs, full_upcast_idxs,
|
||||
reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r])
|
||||
return accs[r]
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
|
||||
for i,b in enumerate(self.bufs) if b not in self.earlybufs and b.__class__ is not LocalBuffer})
|
||||
# run late AST (without the store)
|
||||
store_vals = {op.arg.idx:self.ast_parse(op.src[0], accs, None, loaded_buffers) for op in self.ast}
|
||||
return [self.global_store(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) for i, val in store_vals.items()]
|
||||
|
||||
def ast_parse(self, x:LazyOp, accs:Dict[LazyOp, List[UOp]], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], reduce_acc:Optional[List[UOp]]=None, cache=None) -> List[UOp]: # noqa: E501
|
||||
if cache is None: cache = {}
|
||||
if x in cache: return cache[x]
|
||||
|
||||
Reference in New Issue
Block a user