mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
st is a cached_property on UOp [run_process_replay] (#6433)
This commit is contained in:
@@ -15,7 +15,7 @@ from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
|
||||
from tinygrad.ops import graph_rewrite
|
||||
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import create_schedule, get_output_st, reduceop_fusor, st_fixup, ScheduleItem
|
||||
from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup, ScheduleItem
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
from test.helpers import assert_equiv_uops, ast_const, is_dtype_supported, Context, timeit
|
||||
from tinygrad.lazy import LazyBuffer, view_supported_devices
|
||||
@@ -1650,8 +1650,8 @@ class TestScheduleRewrite(unittest.TestCase):
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
for _ in range(24): a = a + a
|
||||
ast = a.schedule()[0].ast
|
||||
new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {}, {})
|
||||
self.assertEqual(get_output_st(new_uop, {}), ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {})
|
||||
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
self.assertLess(et, 1e3)
|
||||
|
||||
def test_no_rewrite_elementwise(self):
|
||||
|
||||
@@ -2,11 +2,11 @@ import sys, pickle, atexit, importlib, contextlib
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast, get_args
|
||||
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
|
||||
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
|
||||
from tinygrad.ops import PatternMatcher, UPat, graph_rewrite
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
|
||||
GlobalCounters, all_same, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
GlobalCounters, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -133,21 +133,13 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
|
||||
|
||||
# ***** helpers for doing movementops on uops *****
|
||||
|
||||
def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> Optional[ShapeTracker]:
|
||||
if (st:=uop_sts.get(uop)): return st
|
||||
if uop.op in BUFFER_UOPS: return uop.st_arg
|
||||
src_sts = [xst for x in uop.src if (xst:=get_output_st(x, uop_sts)) is not None]
|
||||
if len(src_sts) != len(uop.src) or not all_same([x.shape for x in src_sts]): return None
|
||||
uop_sts[uop] = st = ShapeTracker.from_shape(src_sts[0].reduce(uop.arg[1])) if uop.op is UOps.REDUCE_AXIS else src_sts[0]
|
||||
return st
|
||||
|
||||
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], uop_sts:Dict[UOp, ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
|
||||
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
|
||||
if (n:=cache.get(u)): return n
|
||||
if (st:=uop_sts.get(u)) and st == apply_to_st(st): return u
|
||||
if u.op is UOps.SHAPETRACKER:
|
||||
new_st = apply_to_st(u.arg)
|
||||
return u if u.arg == new_st else UOp(UOps.SHAPETRACKER, None, (), new_st)
|
||||
new_srcs = tuple(st_fixup(x, apply_to_st, uop_sts, cache) for x in u.src)
|
||||
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
|
||||
new_srcs = tuple(st_fixup(x, apply_to_st, cache) for x in u.src)
|
||||
cache[u] = ret = u if new_srcs == u.src else UOp(u.op, u.dtype, new_srcs, u.arg)
|
||||
return ret
|
||||
|
||||
@@ -174,10 +166,9 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
|
||||
# ***** reduceop fusor *****
|
||||
|
||||
def push_swizzle_through_reduce(swizzle:UOp, reduceop:UOp) -> UOp:
|
||||
uop_sts: Dict[UOp, ShapeTracker] = {}
|
||||
rsrc = reduceop.src[0]
|
||||
new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, reduceop.arg[1])
|
||||
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), (reduceop.arg[0], new_axis))
|
||||
new_input_st, new_axis = swizzle_reduceop(unwrap(rsrc.st), swizzle.arg, reduceop.arg[1])
|
||||
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, {}),), (reduceop.arg[0], new_axis))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
@@ -188,10 +179,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
|
||||
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
|
||||
if len(reduceops) == 0: return None
|
||||
uop_sts: Dict[UOp, ShapeTracker] = {}
|
||||
rshape = unwrap(get_output_st(reduceops[0], uop_sts)).shape
|
||||
if (root_st:=get_output_st(root, uop_sts)) is not None and rshape == root_st.shape: return None
|
||||
return st_fixup(root, lambda st:st.reshape(rshape), uop_sts, {})
|
||||
rshape = unwrap(reduceops[0].st).shape
|
||||
if root.st is not None and rshape == root.st.shape: return None
|
||||
return st_fixup(root, lambda st:st.reshape(rshape), {})
|
||||
|
||||
reduceop_fusor = PatternMatcher([
|
||||
(UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_through_reduce),
|
||||
|
||||
@@ -339,6 +339,14 @@ class UOp(MathTrait):
|
||||
src: Tuple[UOp, ...] = tuple()
|
||||
arg: Any = None
|
||||
@functools.cached_property
|
||||
def st(self) -> Optional[ShapeTracker]:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
if len(self.src) == 0: return None
|
||||
if self.op in BUFFER_UOPS: return self.st_arg
|
||||
src_sts = [x.st for x in self.src if x.st is not None]
|
||||
if len(src_sts) != len(self.src) or not all_same([x.shape for x in src_sts]): return None
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.arg[1])) if self.op is UOps.REDUCE_AXIS else src_sts[0]
|
||||
@functools.cached_property
|
||||
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
|
||||
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
||||
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0]) if self.op is not UOps.ALU else \
|
||||
|
||||
Reference in New Issue
Block a user