From e3ea7cc4b49fe4e089583315b4f962f4aebd6c57 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 1 Nov 2024 23:50:20 +0200 Subject: [PATCH] prep refactor to UPatLoadStore [pr] (#7472) * prep refactor to UPatLoadStore [pr] * [pr] --- tinygrad/engine/schedule.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index b37bfa0550..a01ddba0c9 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -228,11 +228,15 @@ if getenv("RUN_PROCESS_REPLAY"): # **** Schedule creation and BFS toposort -def _add_realize(ctx:Dict[UOp, UOp], b:UOp, store:UOp, load:UOp) -> Optional[UOp]: - if b not in ctx: return None +def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp: + assert b in ctx, f"trying to realize {b} while not realized" ctx[b] = store return UOp(UOps.LOAD, load.dtype, (b, load.st_arg.to_uop())) -break_sched = PatternMatcher([(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat(), name="store"), name="load"), _add_realize),]) + +def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load") +break_sched = PatternMatcher([ + (UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None), +]) @track_rewrites(named=True) def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: