mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move groups to uop [pr] (#7640)
* override group post chase [pr] * key reduceop on ubuf * fix type
This commit is contained in:
@@ -7,7 +7,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
|
||||
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
|
||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None],
|
||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, UOp], group:Dict[LazyBuffer, None],
|
||||
cache:Dict[Tuple[LazyBuffer, ShapeTracker], None]) -> None:
|
||||
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
|
||||
if (tr, st) in cache: return
|
||||
@@ -24,7 +24,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
|
||||
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
|
||||
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
|
||||
|
||||
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
|
||||
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, UOp], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
|
||||
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]:
|
||||
rc_parents, cache = deque(group), set()
|
||||
while rc_parents:
|
||||
@@ -48,8 +48,8 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
|
||||
if (ubuf:=buf_uops[r.buffer]) in ubuf_realizes: realizes[r] = None
|
||||
if r.op is Ops.ASSIGN: assigns.add(ubuf)
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
reduce_of_const: List[LazyBuffer] = []
|
||||
reduce_for_op: Dict[LazyBuffer, UOp] = {}
|
||||
reduce_of_const: List[UOp] = []
|
||||
for r in allbufs:
|
||||
if r in realizes or r.op not in GroupOp.Reduce: continue
|
||||
group: Dict[LazyBuffer, None] = {}
|
||||
@@ -84,10 +84,11 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
|
||||
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
||||
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
|
||||
tr = tr.srcs[0].base
|
||||
reduce_for_op[tr] = r
|
||||
group = {tr: None}
|
||||
realizes[tr] = None
|
||||
else: reduce_for_op.update((tr, r) for tr in group)
|
||||
if FUSE_ARANGE and r.op is ReduceOps.SUM and r.srcs[0].base.op is MetaOps.CONST: reduce_of_const.append(r)
|
||||
rbuf = buf_uops[r.buffer]
|
||||
reduce_for_op.update((tr, rbuf) for tr in group)
|
||||
if FUSE_ARANGE and r.op is ReduceOps.SUM and r.srcs[0].base.op is MetaOps.CONST: reduce_of_const.append(rbuf)
|
||||
|
||||
# fuse double reduces with no other child
|
||||
if FUSE_CONV_BW:
|
||||
@@ -97,8 +98,8 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
|
||||
del realizes[top_reduce]
|
||||
if (ubuf:=buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
|
||||
|
||||
for r in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
|
||||
for rbuf in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
|
||||
if any(tr.forced_realize for tr in group): continue
|
||||
kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.BUFFER_VIEW}}
|
||||
if len(kernel_children) == 0: continue
|
||||
@@ -106,8 +107,8 @@ def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbu
|
||||
del realizes[tr]
|
||||
if (ubuf:=buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
|
||||
|
||||
output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list)
|
||||
output_groups: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
||||
for buf in realizes:
|
||||
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=buf_uops[buf.buffer])
|
||||
output_groups[reduce_for_op.get(buf, ubuf:=buf_uops[buf.buffer])].append(ubuf)
|
||||
ubuf_realizes[ubuf] = ubuf
|
||||
return list(output_groups.values())
|
||||
|
||||
Reference in New Issue
Block a user