prep refactor to UPatLoadStore [pr] (#7472)

* prep refactor to UPatLoadStore [pr]

* [pr]
This commit is contained in:
qazal
2024-11-01 23:50:20 +02:00
committed by GitHub
parent 6febd20fcf
commit e3ea7cc4b4

View File

@@ -228,11 +228,15 @@ if getenv("RUN_PROCESS_REPLAY"):
# **** Schedule creation and BFS toposort
def _add_realize(ctx:Dict[UOp, UOp], b:UOp, store:UOp, load:UOp) -> Optional[UOp]:
if b not in ctx: return None
def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp:
assert b in ctx, f"trying to realize {b} while not realized"
ctx[b] = store
return UOp(UOps.LOAD, load.dtype, (b, load.st_arg.to_uop()))
break_sched = PatternMatcher([(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat(), name="store"), name="load"), _add_realize),])
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
break_sched = PatternMatcher([
(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),
])
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: