swizzle parents with graph rewrite (#7625)

* delete st_fixup

* refactor

* minimal diff
This commit is contained in:
qazal
2024-11-11 10:50:38 +02:00
committed by GitHub
parent fec977b966
commit 766a680588

View File

@@ -1,9 +1,9 @@
import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
from tinygrad.helpers import DEBUG, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.helpers import DEBUG, Context, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
@@ -91,13 +91,8 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce
# ** helpers for doing movementops on uops
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
if (n:=cache.get(u)) is not None: return n
if u.op is Ops.VIEW: return u.replace(arg=apply_to_st(u.arg))
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
assert u.op is not Ops.REDUCE_AXIS, "can't push a fixup through a reduce"
cache[u] = ret = u.replace(src=tuple(st_fixup(x, apply_to_st, cache) for x in u.src))
return ret
def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
@@ -118,7 +113,7 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
new_input_st = tmp + ShapeTracker(tuple(nv))
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
return st_fixup(src, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
@@ -134,8 +129,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
new_shape, new_input_shape = swizzle_shapes[0]
fixup_cache: Dict[UOp, UOp] = {}
ret = root.replace(src=tuple(x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src))
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: