diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index b37bfa0550..a01ddba0c9 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -228,11 +228,15 @@ if getenv("RUN_PROCESS_REPLAY"): # **** Schedule creation and BFS toposort -def _add_realize(ctx:Dict[UOp, UOp], b:UOp, store:UOp, load:UOp) -> Optional[UOp]: - if b not in ctx: return None +def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp: + assert b in ctx, f"trying to realize {b} while not realized" ctx[b] = store return UOp(UOps.LOAD, load.dtype, (b, load.st_arg.to_uop())) -break_sched = PatternMatcher([(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat(), name="store"), name="load"), _add_realize),]) + +def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load") +break_sched = PatternMatcher([ + (UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None), +]) @track_rewrites(named=True) def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: