simple contiguous_while_contiguous prereqs [pr] (#8038)

* simple contiguous_while_contiguous prereqs [pr]

* early realize

* fine if it's folding a non-contig buffer
This commit is contained in:
qazal
2024-12-04 10:00:28 -05:00
committed by GitHub
parent c9e7701417
commit ff6def9ffb
2 changed files with 9 additions and 10 deletions

View File

@@ -226,7 +226,7 @@ if getenv("RUN_PROCESS_REPLAY"):
def uval(u:UOp) -> UOp:
assert is_scheduled(u), f"must be a scheduled op {u}"
return to_store.src[0] if (to_store:=u.src[1]).is_contiguous_base else to_store
return r.src[0] if (r:=u.src[1]).op is Ops.CONTIGUOUS and not (r.src[0].base.op is Ops.VIEW and len(r.src[0].base.src) == 2) else r
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], realizes:Dict[UOp, UOp],
reduce_for_op:Dict[UOp, UOp], group:Dict[UOp, None], cache:Dict[Tuple[UOp, ShapeTracker], None]) -> None:
@@ -313,7 +313,7 @@ def group_realizes(ctx:ScheduleContext) -> List[List[UOp]]:
# maybe fuse arange with its children
for rbuf in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
if any(ctx.allbufs[tr].src[1].is_contiguous_base for tr in group): continue
if any(ctx.lazybufs[tr].forced_realize for tr in group): continue
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
if len(kernel_children) == 0: continue
for tr in group: del ctx.realizes[tr]
@@ -413,18 +413,19 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
ctx = ScheduleContext()
cache: Dict[LazyBuffer, UOp] = {}
buffers: Dict[UOp, Buffer] = {}
big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), ops_folding+do_realize, ctx.realizes)
for u in big_graph.src: ctx.realizes[u.buf_uop] = u
for u in (big_graph:=UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs))).src: ctx.realizes[u.buf_uop] = u
big_graph = graph_rewrite(big_graph, ops_folding+do_realize, ctx.realizes)
# group realizes into kernels
store_groups = group_realizes(ctx)
graph_rewrite(big_graph, break_sched, ctx)
# preschedule realize groups
prescheduled: List[ScheduleItem] = []
for store_uops in store_groups:
ast, ast_ctx = full_ast_rewrite(UOp.sink(*(ctx.realizes[u] for u in store_uops)), ctx)
prescheduled.append(ScheduleItem(ast, tuple(buffers[u] for u in ast_ctx.bufs if u.size != 0), tuple(ast_ctx.metadata),
frozenset(ubuf for ubuf,ops in ast_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops))))
for u in ast_ctx.sinked: del ast_ctx.lazybufs[u].srcs # can only schedule once
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) != 0:
ast, ast_ctx = full_ast_rewrite(UOp.sink(*stores), ctx)
prescheduled.append(ScheduleItem(ast, tuple(buffers[u] for u in ast_ctx.bufs if u.size != 0), tuple(ast_ctx.metadata),
frozenset(ubuf for ubuf,ops in ast_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops))))
for u in ast_ctx.sinked: del ast_ctx.lazybufs[u].srcs # can only schedule once
# do BFS
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)

View File

@@ -330,8 +330,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def r(self, op:Ops, axis:Tuple[int, ...]): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
@property
def is_contiguous_base(self): return self.op is Ops.CONTIGUOUS and not (self.src[0].base.op is Ops.VIEW and len(self.src[0].base.src) == 2)
# *** from LazyBuffer ***