diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index a3b4228a3e..75157ede56 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -34,7 +34,6 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La if buf.op is not MetaOps.CONST and ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None if buf.op in ReduceOps and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None allbufs[buf] = None - if buf.op in MetaOps: realizes[buf] = None if buf.op is MetaOps.ASSIGN: assign_targets[(target:=buf.srcs[0])] = buf assert target._base is None, f"assign must be to base {target}" diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index eaac950455..9df5e311ba 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -231,14 +231,15 @@ if getenv("RUN_PROCESS_REPLAY"): # **** Schedule creation and BFS toposort def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp: - assert b in ctx, f"trying to realize {b} while not realized" ctx[b] = store return UOp(UOps.LOAD, load.dtype, (b, load.st_arg.to_uop())) def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load") -break_sched = PatternMatcher([ - (UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None), +do_realize = PatternMatcher([ + # always realize meta ops + (UPatLoadStore(UPat((UOps.ASSIGN, UOps.CONTIGUOUS, *METAOPS.values()))), realize), ]) +break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),]) @track_rewrites(named=True) def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: @@ -249,6 +250,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] cache: Dict[LazyBuffer, UOp] = {} big_graph = UOp.sink(*(to_uop(x, ctx, cache) for x in outs)) # get realizes + graph_rewrite(big_graph, do_realize, ctx.realizes) store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx) # split realizes into small graphs graph_rewrite(big_graph, break_sched, ctx.realizes)