move assign_targets assignment (#5578)

This commit is contained in:
qazal
2024-07-19 20:29:50 +08:00
committed by GitHub
parent 0ad87021e2
commit ecf88bb775

View File

@@ -73,14 +73,9 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
if buf.op is MetaOps.CONTIGUOUS:
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
assert buf in outputs
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache)
if buf.op is MetaOps.ASSIGN:
assert buf in outputs
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache)
# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:
@@ -137,8 +132,8 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
# *** DAG creation: decide which LazyBuffers should realize ***
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Set[LazyBuffer],\
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer], scheduled=False):
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
if buf in allbufs or buf.base.realized is not None: return
if GRAPH: log_lazybuffer(buf, scheduled)
@@ -156,18 +151,21 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
elif not FUSE_AS_ONE_KERNEL: realizes[buf.base] = None
# check all other pads for safe fusion
elif any(v.mask is not None for v in buf.st.views): simple_pads.add(buf.base)
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children, assign_targets)
# base
allbufs[buf] = None
if buf.forced_realize: realizes[buf] = None
if buf.op in MetaOps: realizes[buf.base] = None
if buf.forced_realize or buf.op in MetaOps: realizes[buf] = None
if buf.op is MetaOps.ASSIGN:
assert buf.srcs[1].base is buf.srcs[1], f"assign must be to base {buf.srcs[1]}"
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
assign_targets[buf.srcs[1]] = buf
if buf.op is MetaOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes[buf.srcs[0].base] = None
if buf.op is MetaOps.VIEW: realizes[buf.srcs[0].base] = None
for x in buf.srcs:
if x.base.realized is None: children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children)
_recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets)
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
if buf in realizes or buf.realized is not None: return True
@@ -199,8 +197,8 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Set[LazyBuffer] = set()
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
assign_targets = {x.srcs[1]:x for x in realizes if x.op is MetaOps.ASSIGN and x not in seen and x.realized is None}
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, assign_targets, scheduled=True)
# check if we have to realize pads
for p in simple_pads: