diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index a1c201d104..d8e0730e69 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -14,7 +14,7 @@ sys.setrecursionlimit(10000) def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None], \ children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer], double_reduces:Dict[LazyBuffer, None], ctx): """recursively search the entire graph for all LazyBuffers, insert realizes after expands""" - if buf in allbufs: return None + if buf in allbufs or buf.base.op is MetaOps.CONST: return None if buf.base.realized is not None: return realizes.setdefault(buf.base) # check if we need to realize views if buf is not buf.base: @@ -31,7 +31,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La # check all other pads for safe fusion elif any(v.mask is not None for v in buf.st.views): simple_pads[buf.base] = None return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, ctx) - if buf.op is not MetaOps.CONST and ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None + if ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None if buf.op in ReduceOps and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None allbufs[buf] = None if buf.op is MetaOps.ASSIGN: @@ -155,7 +155,7 @@ def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buff output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list) lazybufs_to_realize: Dict[Buffer, LazyBuffer] = {} for buf in realizes: - if buf.realized is None and buf.op is not MetaOps.CONST: + if buf.realized is None: if (dup:=lazybufs_to_realize.get(buf.buffer)) is not None: raise RuntimeError(f"can't double realize in one schedule, Buffer is realizing both {dup} and {buf}") lazybufs_to_realize[buf.buffer] = buf