mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
uop output_st is Optional [run_process_replay] (#6282)
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Callable, Tuple, List, Dict, Optional, Set, DefaultDict, cast
|
||||
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, PatternMatcher, ReduceOps, UNSAFE_PAD_OPS, UPat, UnaryOps, UOp, UOps, graph_rewrite
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
|
||||
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
|
||||
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -133,11 +133,11 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
|
||||
|
||||
# ***** helpers for doing movementops on uops *****
|
||||
|
||||
def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> ShapeTracker:
|
||||
def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> Optional[ShapeTracker]:
|
||||
if (st:=uop_sts.get(uop)): return st
|
||||
if uop.op in BUFFER_UOPS: return uop.st_arg
|
||||
src_sts = [get_output_st(x, uop_sts) for x in uop.src]
|
||||
assert all_same([x.shape for x in src_sts]), f"inhomogeneous shape from\n{uop}\n{[x.shape for x in src_sts]}"
|
||||
src_sts = [xst for x in uop.src if (xst:=get_output_st(x, uop_sts)) is not None]
|
||||
if len(src_sts) != len(uop.src) or not all_same([x.shape for x in src_sts]): return None
|
||||
uop_sts[uop] = st = ShapeTracker.from_shape(src_sts[0].reduce(uop.arg[1])) if uop.op is UOps.REDUCE_AXIS else src_sts[0]
|
||||
return st
|
||||
|
||||
@@ -175,15 +175,15 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
|
||||
|
||||
def apply_swizzle(root:UOp, rsrc:UOp, swizzle:UOp) -> UOp:
|
||||
uop_sts: Dict[UOp, ShapeTracker] = {}
|
||||
new_input_st, new_axis = swizzle_reduceop(get_output_st(rsrc, uop_sts), swizzle.arg, root.arg[1])
|
||||
new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, root.arg[1])
|
||||
return replace(root, src=(st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), arg=(root.arg[0], new_axis))
|
||||
|
||||
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
|
||||
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
|
||||
if len(reduceops) == 0: return None
|
||||
uop_sts: Dict[UOp, ShapeTracker] = {}
|
||||
rshape = get_output_st(reduceops[0], uop_sts).shape
|
||||
if rshape == root.st_arg.shape: return None
|
||||
rshape = unwrap(get_output_st(reduceops[0], uop_sts)).shape
|
||||
if (root_st:=get_output_st(root, uop_sts)) is not None and rshape == root_st.shape: return None
|
||||
return st_fixup(root, lambda st:st.reshape(rshape), uop_sts, {})
|
||||
|
||||
reduceop_fusor = PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user