mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simple contiguous_while_contiguous prereqs [pr] (#8038)
* simple contiguous_while_contiguous prereqs [pr] * early realize * fine if it's folding a non-contig buffer
This commit is contained in:
@@ -226,7 +226,7 @@ if getenv("RUN_PROCESS_REPLAY"):
|
||||
|
||||
def uval(u:UOp) -> UOp:
|
||||
assert is_scheduled(u), f"must be a scheduled op {u}"
|
||||
return to_store.src[0] if (to_store:=u.src[1]).is_contiguous_base else to_store
|
||||
return r.src[0] if (r:=u.src[1]).op is Ops.CONTIGUOUS and not (r.src[0].base.op is Ops.VIEW and len(r.src[0].base.src) == 2) else r
|
||||
|
||||
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], realizes:Dict[UOp, UOp],
|
||||
reduce_for_op:Dict[UOp, UOp], group:Dict[UOp, None], cache:Dict[Tuple[UOp, ShapeTracker], None]) -> None:
|
||||
@@ -313,7 +313,7 @@ def group_realizes(ctx:ScheduleContext) -> List[List[UOp]]:
|
||||
# maybe fuse arange with its children
|
||||
for rbuf in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
|
||||
if any(ctx.allbufs[tr].src[1].is_contiguous_base for tr in group): continue
|
||||
if any(ctx.lazybufs[tr].forced_realize for tr in group): continue
|
||||
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
|
||||
if len(kernel_children) == 0: continue
|
||||
for tr in group: del ctx.realizes[tr]
|
||||
@@ -413,18 +413,19 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
ctx = ScheduleContext()
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
buffers: Dict[UOp, Buffer] = {}
|
||||
big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), ops_folding+do_realize, ctx.realizes)
|
||||
for u in big_graph.src: ctx.realizes[u.buf_uop] = u
|
||||
for u in (big_graph:=UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs))).src: ctx.realizes[u.buf_uop] = u
|
||||
big_graph = graph_rewrite(big_graph, ops_folding+do_realize, ctx.realizes)
|
||||
# group realizes into kernels
|
||||
store_groups = group_realizes(ctx)
|
||||
graph_rewrite(big_graph, break_sched, ctx)
|
||||
# preschedule realize groups
|
||||
prescheduled: List[ScheduleItem] = []
|
||||
for store_uops in store_groups:
|
||||
ast, ast_ctx = full_ast_rewrite(UOp.sink(*(ctx.realizes[u] for u in store_uops)), ctx)
|
||||
prescheduled.append(ScheduleItem(ast, tuple(buffers[u] for u in ast_ctx.bufs if u.size != 0), tuple(ast_ctx.metadata),
|
||||
frozenset(ubuf for ubuf,ops in ast_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops))))
|
||||
for u in ast_ctx.sinked: del ast_ctx.lazybufs[u].srcs # can only schedule once
|
||||
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) != 0:
|
||||
ast, ast_ctx = full_ast_rewrite(UOp.sink(*stores), ctx)
|
||||
prescheduled.append(ScheduleItem(ast, tuple(buffers[u] for u in ast_ctx.bufs if u.size != 0), tuple(ast_ctx.metadata),
|
||||
frozenset(ubuf for ubuf,ops in ast_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops))))
|
||||
for u in ast_ctx.sinked: del ast_ctx.lazybufs[u].srcs # can only schedule once
|
||||
# do BFS
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list)
|
||||
|
||||
@@ -330,8 +330,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def r(self, op:Ops, axis:Tuple[int, ...]): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
|
||||
@property
|
||||
def is_contiguous_base(self): return self.op is Ops.CONTIGUOUS and not (self.src[0].base.op is Ops.VIEW and len(self.src[0].base.src) == 2)
|
||||
|
||||
# *** from LazyBuffer ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user