more scheduler graph_rewrite cleanups [run_process_replay] (#6479)

This commit is contained in:
qazal
2024-09-11 18:26:35 +08:00
committed by GitHub
parent bdd0c06f29
commit bce73c9a54

View File

@@ -6,7 +6,7 @@ from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOp
from tinygrad.ops import PatternMatcher, UPat, graph_rewrite
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
GlobalCounters, all_same, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
from tinygrad.lazy import LazyBuffer
@@ -90,6 +90,9 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
# ** AST graph rewrite: UOp with SWIZZLE (movementops) -> UOp we can index **
# ***** helpers for doing movementops on uops *****
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
@@ -131,12 +134,6 @@ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
assert prod(swizzle.arg.shape) == prod(unwrap(swizzle.src[0].st).shape), "can't push expands down to STORE"
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, root.arg).swizzle(ShapeTracker.from_shape(unwrap(swizzle.st).reduce(root.arg[1])))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
new_axis: Tuple[int, ...] = root.arg[1]+first_reduce.arg[1]
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], new_axis))
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
swizzles = [x for x in root.src if x.op is UOps.SWIZZLE]
if len(swizzles) == 0: return None
@@ -147,6 +144,12 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg)
return ret if ret.op is UOps.STORE else ret.swizzle(ShapeTracker.from_shape(sw_shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
new_axis: Tuple[int, ...] = root.arg[1]+first_reduce.arg[1]
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], new_axis))
reduceop_fusor = PatternMatcher([
# push a SWIZZLE up to LOAD, through a reduce (eg. expands)
(UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
@@ -157,15 +160,15 @@ reduceop_fusor = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> List[LBScheduleItem]:
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
st_uop = ShapeTracker.from_shape(out.arg).to_uop()
rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), st_uop))
wr = UOp(UOps.STORE, dtypes.void, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), st_uop, rd))
return [LBScheduleItem(UOp(UOps.SINK, dtypes.void, (wr,)), outs, [x.base for x in out.srcs])]
return LBScheduleItem(UOp(UOps.SINK, dtypes.void, (wr,)), outs, [x.base for x in out.srcs])
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return [LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])]
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])
# create the stores
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
@@ -173,18 +176,17 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
ast: List[UOp] = []
inputs: Dict[LazyBuffer, int] = {}
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(out.shape)
src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
output_st = out.arg[0].reshape(out.shape)
output_st = out.arg[0]
output_st, vv = output_st.simplify().unbind()
if vv: var_vals.update(vv)
var_vals.update(vv)
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src)))
sink = UOp(UOps.SINK, dtypes.void, tuple(ast))
if AST_REWRITE: sink = graph_rewrite(sink, reduceop_fusor)
return [LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))]
return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
# *** DAG creation: decide which LazyBuffers should realize ***
@@ -363,7 +365,7 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \
"""create a graph for realizing the outputs"""
output_groups, realizes, assign_targets = _get_output_groups(outs)
# preschedule all buffers in realizes
prescheduled = flatten([_lower_lazybuffer(group, realizes) for group in output_groups.values()])
prescheduled = [_lower_lazybuffer(group, realizes) for group in output_groups.values()]
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)