uop output_st is Optional [run_process_replay] (#6282)

This commit is contained in:
qazal
2024-08-26 17:58:55 +08:00
committed by GitHub
parent 1c0456af89
commit b4381e9777

View File

@@ -5,7 +5,7 @@ from typing import Callable, Tuple, List, Dict, Optional, Set, DefaultDict, cast
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, PatternMatcher, ReduceOps, UNSAFE_PAD_OPS, UPat, UnaryOps, UOp, UOps, graph_rewrite
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
from tinygrad.lazy import LazyBuffer
@@ -133,11 +133,11 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
# ***** helpers for doing movementops on uops *****
def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> ShapeTracker:
def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> Optional[ShapeTracker]:
if (st:=uop_sts.get(uop)): return st
if uop.op in BUFFER_UOPS: return uop.st_arg
src_sts = [get_output_st(x, uop_sts) for x in uop.src]
assert all_same([x.shape for x in src_sts]), f"inhomogeneous shape from\n{uop}\n{[x.shape for x in src_sts]}"
src_sts = [xst for x in uop.src if (xst:=get_output_st(x, uop_sts)) is not None]
if len(src_sts) != len(uop.src) or not all_same([x.shape for x in src_sts]): return None
uop_sts[uop] = st = ShapeTracker.from_shape(src_sts[0].reduce(uop.arg[1])) if uop.op is UOps.REDUCE_AXIS else src_sts[0]
return st
@@ -175,15 +175,15 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
def apply_swizzle(root:UOp, rsrc:UOp, swizzle:UOp) -> UOp:
uop_sts: Dict[UOp, ShapeTracker] = {}
new_input_st, new_axis = swizzle_reduceop(get_output_st(rsrc, uop_sts), swizzle.arg, root.arg[1])
new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, root.arg[1])
return replace(root, src=(st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), arg=(root.arg[0], new_axis))
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
if len(reduceops) == 0: return None
uop_sts: Dict[UOp, ShapeTracker] = {}
rshape = get_output_st(reduceops[0], uop_sts).shape
if rshape == root.st_arg.shape: return None
rshape = unwrap(get_output_st(reduceops[0], uop_sts)).shape
if (root_st:=get_output_st(root, uop_sts)) is not None and rshape == root_st.shape: return None
return st_fixup(root, lambda st:st.reshape(rshape), uop_sts, {})
reduceop_fusor = PatternMatcher([