From bce73c9a54ec4e9fd45b7cd24ee65ad9c7692d7e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:26:35 +0800 Subject: [PATCH] more scheduler graph_rewrite cleanups [run_process_replay] (#6479) --- tinygrad/engine/schedule.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index b0325d483b..b778bf47a4 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -6,7 +6,7 @@ from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOp from tinygrad.ops import PatternMatcher, UPat, graph_rewrite from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \ - GlobalCounters, all_same, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap + GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap from tinygrad.shape.symbolic import Variable, sint from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes from tinygrad.lazy import LazyBuffer @@ -90,6 +90,9 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops)) return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op)) + +# ** AST graph rewrite: UOp with SWIZZLE (movementops) -> UOp we can index ** + # ***** helpers for doing movementops on uops ***** def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp: @@ -131,12 +134,6 @@ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp: assert prod(swizzle.arg.shape) == prod(unwrap(swizzle.src[0].st).shape), "can't push expands down to STORE" return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, root.arg).swizzle(ShapeTracker.from_shape(unwrap(swizzle.st).reduce(root.arg[1]))) -def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: - assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" - assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time" - new_axis: Tuple[int, ...] = root.arg[1]+first_reduce.arg[1] - return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], new_axis)) - def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: swizzles = [x for x in root.src if x.op is UOps.SWIZZLE] if len(swizzles) == 0: return None @@ -147,6 +144,12 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg) return ret if ret.op is UOps.STORE else ret.swizzle(ShapeTracker.from_shape(sw_shape)) +def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: + assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" + assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time" + new_axis: Tuple[int, ...] = root.arg[1]+first_reduce.arg[1] + return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], new_axis)) + reduceop_fusor = PatternMatcher([ # push a SWIZZLE up to LOAD, through a reduce (eg. expands) (UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce), @@ -157,15 +160,15 @@ reduceop_fusor = PatternMatcher([ (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> List[LBScheduleItem]: +def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem: """describe the computation for a LazyBuffer with UOp + inputs + var_vals""" if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]: st_uop = ShapeTracker.from_shape(out.arg).to_uop() rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), st_uop)) wr = UOp(UOps.STORE, dtypes.void, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), st_uop, rd)) - return [LBScheduleItem(UOp(UOps.SINK, dtypes.void, (wr,)), outs, [x.base for x in out.srcs])] + return LBScheduleItem(UOp(UOps.SINK, dtypes.void, (wr,)), outs, [x.base for x in out.srcs]) if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: - return [LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])] + return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs]) # create the stores var_vals = merge_dicts([out.st.var_vals.copy() for out in outs]) assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN} @@ -173,18 +176,17 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> ast: List[UOp] = [] inputs: Dict[LazyBuffer, int] = {} for i, out in enumerate(outs): - output_st = ShapeTracker.from_shape(out.shape) - src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache) + src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache) if out.op is MetaOps.ASSIGN and out.arg: assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}" - output_st = out.arg[0].reshape(out.shape) + output_st = out.arg[0] output_st, vv = output_st.simplify().unbind() - if vv: var_vals.update(vv) + var_vals.update(vv) ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i) ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src))) sink = UOp(UOps.SINK, dtypes.void, tuple(ast)) if AST_REWRITE: sink = graph_rewrite(sink, reduceop_fusor) - return [LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))] + return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])) # *** DAG creation: decide which LazyBuffers should realize *** @@ -363,7 +365,7 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \ """create a graph for realizing the outputs""" output_groups, realizes, assign_targets = _get_output_groups(outs) # preschedule all buffers in realizes - prescheduled = flatten([_lower_lazybuffer(group, realizes) for group in output_groups.values()]) + prescheduled = [_lower_lazybuffer(group, realizes) for group in output_groups.values()] schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs} graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)