check assign buffers in group [pr] (#7527)

This commit is contained in:
qazal
2024-11-04 14:27:22 +02:00
committed by GitHub
parent 9fe596ce6e
commit bf31585444
2 changed files with 8 additions and 8 deletions

View File

@@ -1,6 +1,6 @@
import sys
from collections import defaultdict, deque
from typing import Tuple, List, Dict, DefaultDict
from typing import Set, Tuple, List, Dict, DefaultDict
from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps, resolve
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts
from tinygrad.dtype import ImageDType
@@ -12,7 +12,7 @@ from tinygrad.device import Buffer
sys.setrecursionlimit(10000)
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None], \
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer], double_reduces:Dict[LazyBuffer, None], ctx):
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Set[Buffer], double_reduces:Dict[LazyBuffer, None], ctx):
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
if buf in allbufs or buf.base.op is MetaOps.CONST: return None
if buf.base.realized is not None: return realizes.setdefault(buf.base)
@@ -34,7 +34,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
if ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None
if buf.op in GroupOp.Reduce and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
allbufs[buf] = None
if buf.op is MetaOps.ASSIGN: assign_targets[buf.srcs[0]] = buf
if buf.op is MetaOps.ASSIGN: assign_targets.add(buf.buffer)
for x in buf.srcs:
if x.base.realized is None: children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx)
@@ -79,13 +79,13 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buffer, LazyBuffer], Dict[LazyBuffer, LazyBuffer]]:
def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buffer, LazyBuffer], Set[Buffer]]:
"""search the graph for all the LazyBuffers that need to realize"""
realizes: Dict[LazyBuffer, None] = {}
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Dict[LazyBuffer, None] = {}
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
assign_targets: Set[Buffer] = set()
double_reduces: Dict[LazyBuffer, None] = {}
for out in outs: _recurse_lb(out, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx)
@@ -112,8 +112,8 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group):
parents = deque((r, *group))
while parents and not forced_realize:
if (p:=parents.pop().base).realized or p in realizes:
if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
if (p:=parents.pop().base).is_realized() or p in realizes:
if p.is_realized() and p.buffer in assign_targets and not any(x.buffer is p.buffer for x in group): forced_realize, can_chase = True, False
continue
parents.extend(p.srcs)
if forced_realize or not group:

View File

@@ -255,7 +255,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx)
# split realizes into small graphs
graph_rewrite(big_graph, break_sched, ctx.realizes)
assigned = {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None}
assigned = {ubuf for b in assigns if (ubuf:=ctx.buf_uops.get(b)) is not None}
small_graphs: List[Tuple[UOp, ScheduleItemContext]] = []
metadata: List[Set[Metadata]] = []
for stores in store_groups: