refactor schedule edges to tuple[LazyBuffer, ...] [run_process_replay] (#6797)

This commit is contained in:
qazal
2024-09-29 11:34:39 +08:00
committed by GitHub
parent 68e59eb3f5
commit 5e1221845f

View File

@@ -128,7 +128,7 @@ def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
# *** List[LazyBuffer] lowering to ScheduleItem ***
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer],
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
bufs:Tuple[LazyBuffer, ...], assign_targets:Dict[LazyBuffer, LazyBuffer],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
"""recursively create a UOp"""
if buf is not buf.base: st, buf = buf.st+st, buf.base
@@ -137,7 +137,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
# buffer ops define ShapeTracker
if buf in realizes and buf not in outputs:
if buf in bufs and buf not in outputs:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
# if it's a const, we generate it
@@ -158,11 +158,11 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
# reduce ops change ShapeTracker
if buf.op in ReduceOps:
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, realizes, assign_targets, cache)
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, bufs, assign_targets, cache)
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).swizzle(st))
# elementwise ops pass shapetracker
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs)
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, bufs, assign_targets, cache) for x in buf.srcs)
if buf.op is MetaOps.CONTIGUOUS:
assert buf in outputs, f"{buf.op} must be writable"
return in_uops[0]
@@ -171,7 +171,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
def _lower_lazybuffer(outs:List[LazyBuffer], bufs:Tuple[LazyBuffer, ...]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs)), {}
@@ -182,7 +182,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
ast: List[UOp] = []
inputs: List[LazyBuffer] = []
for out in outs:
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, bufs, assign_targets, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
output_st = out.arg[0]
@@ -273,7 +273,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
def _get_output_groups(outs:List[LazyBuffer]) -> \
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups
Dict[LazyBuffer, None], # these are all the realizes in the graph
Tuple[LazyBuffer, ...], # these are all the realizes in the graph
Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule
"""find all the realizes in the graph, group the output LazyBuffers into kernels."""
# start by just realizing the buffers passed in
@@ -363,7 +363,7 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
buf.buffer.dtype = dtypes.float32
buf.buffer.options = None
return output_groups, realizes, assign_targets
return output_groups, tuple(realizes), assign_targets
SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
def _graph_schedule(outs:List[LazyBuffer]) -> \
@@ -371,12 +371,12 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \
DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph
Dict[Variable, int]]: # this has all the var values of the schedule
"""create a graph for realizing the outputs"""
output_groups, realizes, assign_targets = _get_output_groups(outs)
output_groups, bufs, assign_targets = _get_output_groups(outs)
# preschedule all buffers in realizes
prescheduled: List[LBScheduleItem] = []
var_vals: Dict[Variable, int] = {}
for group in output_groups.values():
prescheduled.append((ret:=_lower_lazybuffer(group, realizes))[0])
prescheduled.append((ret:=_lower_lazybuffer(group, bufs))[0])
var_vals = merge_dicts([var_vals, ret[1]])
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}