st is a cached_property on UOp [run_process_replay] (#6433)

This commit is contained in:
qazal
2024-09-10 08:30:35 +08:00
committed by GitHub
parent cf64f8bb40
commit 29e63097a0
3 changed files with 21 additions and 23 deletions

View File

@@ -15,7 +15,7 @@ from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
from tinygrad.ops import graph_rewrite
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import create_schedule, get_output_st, reduceop_fusor, st_fixup, ScheduleItem
from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup, ScheduleItem
from tinygrad.engine.realize import CompiledRunner, run_schedule
from test.helpers import assert_equiv_uops, ast_const, is_dtype_supported, Context, timeit
from tinygrad.lazy import LazyBuffer, view_supported_devices
@@ -1650,8 +1650,8 @@ class TestScheduleRewrite(unittest.TestCase):
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
ast = a.schedule()[0].ast
new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {}, {})
self.assertEqual(get_output_st(new_uop, {}), ShapeTracker.from_shape((4,)).reshape((4, 1)))
new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {})
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertLess(et, 1e3)
def test_no_rewrite_elementwise(self):

View File

@@ -2,11 +2,11 @@ import sys, pickle, atexit, importlib, contextlib
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast, get_args
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
from tinygrad.ops import PatternMatcher, UPat, graph_rewrite
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
GlobalCounters, all_same, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
GlobalCounters, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
from tinygrad.lazy import LazyBuffer
@@ -133,21 +133,13 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
# ***** helpers for doing movementops on uops *****
def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> Optional[ShapeTracker]:
if (st:=uop_sts.get(uop)): return st
if uop.op in BUFFER_UOPS: return uop.st_arg
src_sts = [xst for x in uop.src if (xst:=get_output_st(x, uop_sts)) is not None]
if len(src_sts) != len(uop.src) or not all_same([x.shape for x in src_sts]): return None
uop_sts[uop] = st = ShapeTracker.from_shape(src_sts[0].reduce(uop.arg[1])) if uop.op is UOps.REDUCE_AXIS else src_sts[0]
return st
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], uop_sts:Dict[UOp, ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
if (n:=cache.get(u)): return n
if (st:=uop_sts.get(u)) and st == apply_to_st(st): return u
if u.op is UOps.SHAPETRACKER:
new_st = apply_to_st(u.arg)
return u if u.arg == new_st else UOp(UOps.SHAPETRACKER, None, (), new_st)
new_srcs = tuple(st_fixup(x, apply_to_st, uop_sts, cache) for x in u.src)
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
new_srcs = tuple(st_fixup(x, apply_to_st, cache) for x in u.src)
cache[u] = ret = u if new_srcs == u.src else UOp(u.op, u.dtype, new_srcs, u.arg)
return ret
@@ -174,10 +166,9 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
# ***** reduceop fusor *****
def push_swizzle_through_reduce(swizzle:UOp, reduceop:UOp) -> UOp:
uop_sts: Dict[UOp, ShapeTracker] = {}
rsrc = reduceop.src[0]
new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, reduceop.arg[1])
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), (reduceop.arg[0], new_axis))
new_input_st, new_axis = swizzle_reduceop(unwrap(rsrc.st), swizzle.arg, reduceop.arg[1])
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, {}),), (reduceop.arg[0], new_axis))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
@@ -188,10 +179,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
if len(reduceops) == 0: return None
uop_sts: Dict[UOp, ShapeTracker] = {}
rshape = unwrap(get_output_st(reduceops[0], uop_sts)).shape
if (root_st:=get_output_st(root, uop_sts)) is not None and rshape == root_st.shape: return None
return st_fixup(root, lambda st:st.reshape(rshape), uop_sts, {})
rshape = unwrap(reduceops[0].st).shape
if root.st is not None and rshape == root.st.shape: return None
return st_fixup(root, lambda st:st.reshape(rshape), {})
reduceop_fusor = PatternMatcher([
(UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_through_reduce),

View File

@@ -339,6 +339,14 @@ class UOp(MathTrait):
src: Tuple[UOp, ...] = tuple()
arg: Any = None
@functools.cached_property
def st(self) -> Optional[ShapeTracker]:
from tinygrad.shape.shapetracker import ShapeTracker
if len(self.src) == 0: return None
if self.op in BUFFER_UOPS: return self.st_arg
src_sts = [x.st for x in self.src if x.st is not None]
if len(src_sts) != len(self.src) or not all_same([x.shape for x in src_sts]): return None
return ShapeTracker.from_shape(src_sts[0].reduce(self.arg[1])) if self.op is UOps.REDUCE_AXIS else src_sts[0]
@functools.cached_property
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0]) if self.op is not UOps.ALU else \