early skip const [pr] (#7480)

This commit is contained in:
qazal
2024-11-02 07:18:45 +02:00
committed by GitHub
parent c56364fad0
commit 24d7fde63d

View File

@@ -14,7 +14,7 @@ sys.setrecursionlimit(10000)
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None], \
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer], double_reduces:Dict[LazyBuffer, None], ctx):
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
if buf in allbufs: return None
if buf in allbufs or buf.base.op is MetaOps.CONST: return None
if buf.base.realized is not None: return realizes.setdefault(buf.base)
# check if we need to realize views
if buf is not buf.base:
@@ -31,7 +31,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
# check all other pads for safe fusion
elif any(v.mask is not None for v in buf.st.views): simple_pads[buf.base] = None
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx)
if buf.op is not MetaOps.CONST and ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None
if ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None
if buf.op in ReduceOps and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
allbufs[buf] = None
if buf.op is MetaOps.ASSIGN:
@@ -155,7 +155,7 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff
output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list)
lazybufs_to_realize: Dict[Buffer, LazyBuffer] = {}
for buf in realizes:
if buf.realized is None and buf.op is not MetaOps.CONST:
if buf.realized is None:
if (dup:=lazybufs_to_realize.get(buf.buffer)) is not None:
raise RuntimeError(f"can't double realize in one schedule, Buffer is realizing both {dup} and {buf}")
lazybufs_to_realize[buf.buffer] = buf