mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 08:48:15 -05:00
realize meta ops from graph_rewrite [pr] (#7474)
This commit is contained in:
@@ -34,7 +34,6 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
||||
if buf.op is not MetaOps.CONST and ctx.buf_uops[buf.buffer] in ctx.realizes: realizes[buf] = None
|
||||
if buf.op in ReduceOps and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
|
||||
allbufs[buf] = None
|
||||
if buf.op in MetaOps: realizes[buf] = None
|
||||
if buf.op is MetaOps.ASSIGN:
|
||||
assign_targets[(target:=buf.srcs[0])] = buf
|
||||
assert target._base is None, f"assign must be to base {target}"
|
||||
|
||||
@@ -231,14 +231,15 @@ if getenv("RUN_PROCESS_REPLAY"):
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp:
|
||||
assert b in ctx, f"trying to realize {b} while not realized"
|
||||
ctx[b] = store
|
||||
return UOp(UOps.LOAD, load.dtype, (b, load.st_arg.to_uop()))
|
||||
|
||||
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
|
||||
break_sched = PatternMatcher([
|
||||
(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),
|
||||
do_realize = PatternMatcher([
|
||||
# always realize meta ops
|
||||
(UPatLoadStore(UPat((UOps.ASSIGN, UOps.CONTIGUOUS, *METAOPS.values()))), realize),
|
||||
])
|
||||
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
@@ -249,6 +250,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
big_graph = UOp.sink(*(to_uop(x, ctx, cache) for x in outs))
|
||||
# get realizes
|
||||
graph_rewrite(big_graph, do_realize, ctx.realizes)
|
||||
store_groups, lazybufs_to_realize, assigns = get_realizes(outs, ctx)
|
||||
# split realizes into small graphs
|
||||
graph_rewrite(big_graph, break_sched, ctx.realizes)
|
||||
|
||||
Reference in New Issue
Block a user