From 29e63097a04c31e8a425e99571a38d156eb85cfd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 10 Sep 2024 08:30:35 +0800 Subject: [PATCH] st is a cached_property on UOp [run_process_replay] (#6433) --- test/test_schedule.py | 6 +++--- tinygrad/engine/schedule.py | 30 ++++++++++-------------------- tinygrad/ops.py | 8 ++++++++ 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index abab8899cf..f88f23e234 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -15,7 +15,7 @@ from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps from tinygrad.ops import graph_rewrite from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP from tinygrad.codegen.kernel import Kernel, verify_ast -from tinygrad.engine.schedule import create_schedule, get_output_st, reduceop_fusor, st_fixup, ScheduleItem +from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup, ScheduleItem from tinygrad.engine.realize import CompiledRunner, run_schedule from test.helpers import assert_equiv_uops, ast_const, is_dtype_supported, Context, timeit from tinygrad.lazy import LazyBuffer, view_supported_devices @@ -1650,8 +1650,8 @@ class TestScheduleRewrite(unittest.TestCase): a = Tensor([1,2,3,4]).realize() for _ in range(24): a = a + a ast = a.schedule()[0].ast - new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {}, {}) - self.assertEqual(get_output_st(new_uop, {}), ShapeTracker.from_shape((4,)).reshape((4, 1))) + new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {}) + self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1))) self.assertLess(et, 1e3) def test_no_rewrite_elementwise(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index c6efe3af4a..f7dc737ea1 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,11 +2,11 @@ import sys, pickle, atexit, importlib, contextlib from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast, get_args -from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps +from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps from tinygrad.ops import PatternMatcher, UPat, graph_rewrite from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \ - GlobalCounters, all_same, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap + GlobalCounters, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap from tinygrad.shape.symbolic import Variable, sint from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes from tinygrad.lazy import LazyBuffer @@ -133,21 +133,13 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer # ***** helpers for doing movementops on uops ***** -def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> Optional[ShapeTracker]: - if (st:=uop_sts.get(uop)): return st - if uop.op in BUFFER_UOPS: return uop.st_arg - src_sts = [xst for x in uop.src if (xst:=get_output_st(x, uop_sts)) is not None] - if len(src_sts) != len(uop.src) or not all_same([x.shape for x in src_sts]): return None - uop_sts[uop] = st = ShapeTracker.from_shape(src_sts[0].reduce(uop.arg[1])) if uop.op is UOps.REDUCE_AXIS else src_sts[0] - return st - -def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], uop_sts:Dict[UOp, ShapeTracker], cache:Dict[UOp, UOp]) -> UOp: +def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp: if (n:=cache.get(u)): return n - if (st:=uop_sts.get(u)) and st == apply_to_st(st): return u if u.op is UOps.SHAPETRACKER: new_st = apply_to_st(u.arg) return u if u.arg == new_st else UOp(UOps.SHAPETRACKER, None, (), new_st) - new_srcs = tuple(st_fixup(x, apply_to_st, uop_sts, cache) for x in u.src) + if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u + new_srcs = tuple(st_fixup(x, apply_to_st, cache) for x in u.src) cache[u] = ret = u if new_srcs == u.src else UOp(u.op, u.dtype, new_srcs, u.arg) return ret @@ -174,10 +166,9 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int # ***** reduceop fusor ***** def push_swizzle_through_reduce(swizzle:UOp, reduceop:UOp) -> UOp: - uop_sts: Dict[UOp, ShapeTracker] = {} rsrc = reduceop.src[0] - new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, reduceop.arg[1]) - return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), (reduceop.arg[0], new_axis)) + new_input_st, new_axis = swizzle_reduceop(unwrap(rsrc.st), swizzle.arg, reduceop.arg[1]) + return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, {}),), (reduceop.arg[0], new_axis)) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" @@ -188,10 +179,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: def push_reduceop_shape(root:UOp) -> Optional[UOp]: reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS] if len(reduceops) == 0: return None - uop_sts: Dict[UOp, ShapeTracker] = {} - rshape = unwrap(get_output_st(reduceops[0], uop_sts)).shape - if (root_st:=get_output_st(root, uop_sts)) is not None and rshape == root_st.shape: return None - return st_fixup(root, lambda st:st.reshape(rshape), uop_sts, {}) + rshape = unwrap(reduceops[0].st).shape + if root.st is not None and rshape == root.st.shape: return None + return st_fixup(root, lambda st:st.reshape(rshape), {}) reduceop_fusor = PatternMatcher([ (UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_through_reduce), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 19e2ff8787..956a766712 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -339,6 +339,14 @@ class UOp(MathTrait): src: Tuple[UOp, ...] = tuple() arg: Any = None @functools.cached_property + def st(self) -> Optional[ShapeTracker]: + from tinygrad.shape.shapetracker import ShapeTracker + if len(self.src) == 0: return None + if self.op in BUFFER_UOPS: return self.st_arg + src_sts = [x.st for x in self.src if x.st is not None] + if len(src_sts) != len(self.src) or not all_same([x.shape for x in src_sts]): return None + return ShapeTracker.from_shape(src_sts[0].reduce(self.arg[1])) if self.op is UOps.REDUCE_AXIS else src_sts[0] + @functools.cached_property def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]: # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0]) if self.op is not UOps.ALU else \