mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
split to get_realizes [pr] (#7225)
This commit is contained in:
@@ -282,9 +282,8 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
|
||||
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
"""create a graph for realizing the outputs"""
|
||||
def get_realizes(outs:List[LazyBuffer]) -> Tuple[Dict[LazyBuffer, None], Dict[LazyBuffer, LazyBuffer]]:
|
||||
"""search the graph for all the LazyBuffers that need to realize"""
|
||||
# start by just realizing the buffers passed in
|
||||
realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
|
||||
allbufs: Dict[LazyBuffer, None] = {}
|
||||
@@ -356,11 +355,16 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
if len(kernel_children) == 0: continue
|
||||
if DEBUG_ARANGE: print(colored(f"folding {r}", "green"))
|
||||
for tr in group: del realizes[tr]
|
||||
return realizes, reduce_for_op
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
realizes, reduce_for_op = get_realizes(outs)
|
||||
output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list)
|
||||
buf_uops: Dict[Buffer, UOp] = {}
|
||||
uop_bufs: Dict[UOp, Buffer] = {}
|
||||
var_vals: Dict[Variable, int] = {}
|
||||
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
lazybufs_to_realize: Dict[Buffer, LazyBuffer] = {}
|
||||
for buf in realizes:
|
||||
if buf.realized is None and buf.op is not MetaOps.CONST:
|
||||
@@ -386,7 +390,9 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
if buf.buffer not in buf_uops:
|
||||
buf_uops[buf.buffer] = uop
|
||||
uop_bufs[uop] = buf.buffer
|
||||
if buf.realized is None and buf.op is not MetaOps.CONST: output_groups[reduce_for_op.get(buf, buf)].append(buf_uops[buf.buffer])
|
||||
if buf.realized is None:
|
||||
if buf.op is MetaOps.ASSIGN: assign_targets[buf.srcs[0]] = buf
|
||||
if buf.op is not MetaOps.CONST:output_groups[reduce_for_op.get(buf, buf)].append(buf_uops[buf.buffer])
|
||||
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, var_vals) for outs in output_groups.values()]
|
||||
|
||||
Reference in New Issue
Block a user