mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move assign_targets assignment (#5578)
This commit is contained in:
@@ -73,14 +73,9 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
|
||||
|
||||
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
||||
if buf.op is MetaOps.CONTIGUOUS:
|
||||
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
|
||||
assert buf in outputs
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache)
|
||||
if buf.op is MetaOps.ASSIGN:
|
||||
assert buf in outputs
|
||||
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
||||
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache)
|
||||
|
||||
# if it's a reduce, we have to change the shapetracker
|
||||
if buf.op in ReduceOps:
|
||||
@@ -137,8 +132,8 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
|
||||
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Set[LazyBuffer],\
|
||||
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer], scheduled=False):
|
||||
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
|
||||
if buf in allbufs or buf.base.realized is not None: return
|
||||
if GRAPH: log_lazybuffer(buf, scheduled)
|
||||
@@ -156,18 +151,21 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
||||
elif not FUSE_AS_ONE_KERNEL: realizes[buf.base] = None
|
||||
# check all other pads for safe fusion
|
||||
elif any(v.mask is not None for v in buf.st.views): simple_pads.add(buf.base)
|
||||
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
|
||||
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children, assign_targets)
|
||||
# base
|
||||
allbufs[buf] = None
|
||||
if buf.forced_realize: realizes[buf] = None
|
||||
if buf.op in MetaOps: realizes[buf.base] = None
|
||||
if buf.forced_realize or buf.op in MetaOps: realizes[buf] = None
|
||||
if buf.op is MetaOps.ASSIGN:
|
||||
assert buf.srcs[1].base is buf.srcs[1], f"assign must be to base {buf.srcs[1]}"
|
||||
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
||||
assign_targets[buf.srcs[1]] = buf
|
||||
if buf.op is MetaOps.COPY:
|
||||
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
||||
realizes[buf.srcs[0].base] = None
|
||||
if buf.op is MetaOps.VIEW: realizes[buf.srcs[0].base] = None
|
||||
for x in buf.srcs:
|
||||
if x.base.realized is None: children[x.base][buf] = None
|
||||
_recurse_lb(x, realizes, allbufs, simple_pads, children)
|
||||
_recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets)
|
||||
|
||||
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
||||
if buf in realizes or buf.realized is not None: return True
|
||||
@@ -199,8 +197,8 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
||||
allbufs: Dict[LazyBuffer, None] = {}
|
||||
simple_pads: Set[LazyBuffer] = set()
|
||||
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
||||
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
|
||||
assign_targets = {x.srcs[1]:x for x in realizes if x.op is MetaOps.ASSIGN and x not in seen and x.realized is None}
|
||||
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, assign_targets, scheduled=True)
|
||||
|
||||
# check if we have to realize pads
|
||||
for p in simple_pads:
|
||||
|
||||
Reference in New Issue
Block a user