mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-16 17:45:38 -05:00
refactor schedule edges to tuple[LazyBuffer, ...] [run_process_replay] (#6797)
This commit is contained in:
@@ -128,7 +128,7 @@ def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
|
||||
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer],
|
||||
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||
bufs:Tuple[LazyBuffer, ...], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
|
||||
"""recursively create a UOp"""
|
||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||
@@ -137,7 +137,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
|
||||
|
||||
# buffer ops define ShapeTracker
|
||||
if buf in realizes and buf not in outputs:
|
||||
if buf in bufs and buf not in outputs:
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
# if it's a const, we generate it
|
||||
@@ -158,11 +158,11 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
|
||||
# reduce ops change ShapeTracker
|
||||
if buf.op in ReduceOps:
|
||||
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, realizes, assign_targets, cache)
|
||||
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, bufs, assign_targets, cache)
|
||||
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).swizzle(st))
|
||||
|
||||
# elementwise ops pass shapetracker
|
||||
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs)
|
||||
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, bufs, assign_targets, cache) for x in buf.srcs)
|
||||
if buf.op is MetaOps.CONTIGUOUS:
|
||||
assert buf in outputs, f"{buf.op} must be writable"
|
||||
return in_uops[0]
|
||||
@@ -171,7 +171,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
|
||||
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], bufs:Tuple[LazyBuffer, ...]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
|
||||
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs)), {}
|
||||
@@ -182,7 +182,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||
ast: List[UOp] = []
|
||||
inputs: List[LazyBuffer] = []
|
||||
for out in outs:
|
||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
|
||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, bufs, assign_targets, cache=cache)
|
||||
if out.op is MetaOps.ASSIGN and out.arg:
|
||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
||||
output_st = out.arg[0]
|
||||
@@ -273,7 +273,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
|
||||
def _get_output_groups(outs:List[LazyBuffer]) -> \
|
||||
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups
|
||||
Dict[LazyBuffer, None], # these are all the realizes in the graph
|
||||
Tuple[LazyBuffer, ...], # these are all the realizes in the graph
|
||||
Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule
|
||||
"""find all the realizes in the graph, group the output LazyBuffers into kernels."""
|
||||
# start by just realizing the buffers passed in
|
||||
@@ -363,7 +363,7 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \
|
||||
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
|
||||
buf.buffer.dtype = dtypes.float32
|
||||
buf.buffer.options = None
|
||||
return output_groups, realizes, assign_targets
|
||||
return output_groups, tuple(realizes), assign_targets
|
||||
|
||||
SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
|
||||
def _graph_schedule(outs:List[LazyBuffer]) -> \
|
||||
@@ -371,12 +371,12 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \
|
||||
DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph
|
||||
Dict[Variable, int]]: # this has all the var values of the schedule
|
||||
"""create a graph for realizing the outputs"""
|
||||
output_groups, realizes, assign_targets = _get_output_groups(outs)
|
||||
output_groups, bufs, assign_targets = _get_output_groups(outs)
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled: List[LBScheduleItem] = []
|
||||
var_vals: Dict[Variable, int] = {}
|
||||
for group in output_groups.values():
|
||||
prescheduled.append((ret:=_lower_lazybuffer(group, realizes))[0])
|
||||
prescheduled.append((ret:=_lower_lazybuffer(group, bufs))[0])
|
||||
var_vals = merge_dicts([var_vals, ret[1]])
|
||||
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user