mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
check assign buffers in group [pr] (#7527)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from typing import Tuple, List, Dict, DefaultDict
|
||||
from typing import Set, Tuple, List, Dict, DefaultDict
|
||||
from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps, resolve
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts
|
||||
from tinygrad.dtype import ImageDType
|
||||
@@ -12,7 +12,7 @@ from tinygrad.device import Buffer
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None], \
|
||||
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer], double_reduces:Dict[LazyBuffer, None], ctx):
|
||||
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Set[Buffer], double_reduces:Dict[LazyBuffer, None], ctx):
|
||||
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
|
||||
if buf in allbufs or buf.base.op is MetaOps.CONST: return None
|
||||
if buf.base.realized is not None: return realizes.setdefault(buf.base)
|
||||
@@ -34,7 +34,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
||||
if ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None
|
||||
if buf.op in GroupOp.Reduce and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
|
||||
allbufs[buf] = None
|
||||
if buf.op is MetaOps.ASSIGN: assign_targets[buf.srcs[0]] = buf
|
||||
if buf.op is MetaOps.ASSIGN: assign_targets.add(buf.buffer)
|
||||
for x in buf.srcs:
|
||||
if x.base.realized is None: children[x.base][buf] = None
|
||||
_recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx)
|
||||
@@ -79,13 +79,13 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
|
||||
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
||||
|
||||
def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buffer, LazyBuffer], Dict[LazyBuffer, LazyBuffer]]:
|
||||
def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buffer, LazyBuffer], Set[Buffer]]:
|
||||
"""search the graph for all the LazyBuffers that need to realize"""
|
||||
realizes: Dict[LazyBuffer, None] = {}
|
||||
allbufs: Dict[LazyBuffer, None] = {}
|
||||
simple_pads: Dict[LazyBuffer, None] = {}
|
||||
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
||||
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
assign_targets: Set[Buffer] = set()
|
||||
double_reduces: Dict[LazyBuffer, None] = {}
|
||||
for out in outs: _recurse_lb(out, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx)
|
||||
|
||||
@@ -112,8 +112,8 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
|
||||
if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group):
|
||||
parents = deque((r, *group))
|
||||
while parents and not forced_realize:
|
||||
if (p:=parents.pop().base).realized or p in realizes:
|
||||
if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
|
||||
if (p:=parents.pop().base).is_realized() or p in realizes:
|
||||
if p.is_realized() and p.buffer in assign_targets and not any(x.buffer is p.buffer for x in group): forced_realize, can_chase = True, False
|
||||
continue
|
||||
parents.extend(p.srcs)
|
||||
if forced_realize or not group:
|
||||
|
||||
@@ -255,7 +255,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx)
|
||||
# split realizes into small graphs
|
||||
graph_rewrite(big_graph, break_sched, ctx.realizes)
|
||||
assigned = {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}
|
||||
assigned = {ubuf for b in assigns if (ubuf:=ctx.buf_uops.get(b)) is not None}
|
||||
small_graphs: List[Tuple[UOp, ScheduleItemContext]] = []
|
||||
metadata: List[Set[Metadata]] = []
|
||||
for stores in store_groups:
|
||||
|
||||
Reference in New Issue
Block a user