split to get_realizes [pr] (#7225)

This commit is contained in:
qazal
2024-10-23 10:22:36 +03:00
committed by GitHub
parent f890d1cbbd
commit 3ce1c69c9c

View File

@@ -282,9 +282,8 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
"""create a graph for realizing the outputs"""
def get_realizes(outs:List[LazyBuffer]) -> Tuple[Dict[LazyBuffer, None], Dict[LazyBuffer, LazyBuffer]]:
"""search the graph for all the LazyBuffers that need to realize"""
# start by just realizing the buffers passed in
realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
allbufs: Dict[LazyBuffer, None] = {}
@@ -356,11 +355,16 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
if len(kernel_children) == 0: continue
if DEBUG_ARANGE: print(colored(f"folding {r}", "green"))
for tr in group: del realizes[tr]
return realizes, reduce_for_op
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
realizes, reduce_for_op = get_realizes(outs)
output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list)
buf_uops: Dict[Buffer, UOp] = {}
uop_bufs: Dict[UOp, Buffer] = {}
var_vals: Dict[Variable, int] = {}
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
lazybufs_to_realize: Dict[Buffer, LazyBuffer] = {}
for buf in realizes:
if buf.realized is None and buf.op is not MetaOps.CONST:
@@ -386,7 +390,9 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
if buf.buffer not in buf_uops:
buf_uops[buf.buffer] = uop
uop_bufs[uop] = buf.buffer
if buf.realized is None and buf.op is not MetaOps.CONST: output_groups[reduce_for_op.get(buf, buf)].append(buf_uops[buf.buffer])
if buf.realized is None:
if buf.op is MetaOps.ASSIGN: assign_targets[buf.srcs[0]] = buf
if buf.op is not MetaOps.CONST:output_groups[reduce_for_op.get(buf, buf)].append(buf_uops[buf.buffer])
# preschedule all buffers in realizes
prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, var_vals) for outs in output_groups.values()]