realize meta ops from graph_rewrite [pr] (#7474)

This commit is contained in:
qazal
2024-11-02 01:48:57 +02:00
committed by GitHub
parent e149777b52
commit 3819f5cf4d
2 changed files with 5 additions and 4 deletions

View File

@@ -34,7 +34,6 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
if buf.op is not MetaOps.CONST and ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None
if buf.op in ReduceOps and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
allbufs[buf] = None
if buf.op in MetaOps: realizes[buf] = None
if buf.op is MetaOps.ASSIGN:
assign_targets[(target:=buf.srcs[0])] = buf
assert target._base is None, f"assign must be to base {target}"

View File

@@ -231,14 +231,15 @@ if getenv("RUN_PROCESS_REPLAY"):
# **** Schedule creation and BFS toposort
def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp:
assert b in ctx, f"trying to realize {b} while not realized"
ctx[b] = store
return UOp(UOps.LOAD, load.dtype, (b, load.st_arg.to_uop()))
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
break_sched = PatternMatcher([
(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),
do_realize = PatternMatcher([
# always realize meta ops
(UPatLoadStore(UPat((UOps.ASSIGN, UOps.CONTIGUOUS, *METAOPS.values()))), realize),
])
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
@@ -249,6 +250,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
cache: Dict[LazyBuffer, UOp] = {}
big_graph = UOp.sink(*(to_uop(x, ctx, cache) for x in outs))
# get realizes
graph_rewrite(big_graph, do_realize, ctx.realizes)
store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx)
# split realizes into small graphs
graph_rewrite(big_graph, break_sched, ctx.realizes)