swizzle parents with graph rewrite (#7625)

* delete st_fixup

* refactor

* minimal diff
This commit is contained in:
qazal
2024-11-11 10:50:38 +02:00
committed by GitHub
parent fec977b966
commit 766a680588

View File

@@ -1,9 +1,9 @@
import sys, atexit, functools, itertools import sys, atexit, functools, itertools
from collections import defaultdict, deque from collections import defaultdict, deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
from tinygrad.helpers import DEBUG, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap from tinygrad.helpers import DEBUG, Context, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape from tinygrad.shape.view import View, strides_for_shape
@@ -91,13 +91,8 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce
# ** helpers for doing movementops on uops # ** helpers for doing movementops on uops
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp: def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
if (n:=cache.get(u)) is not None: return n with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
if u.op is Ops.VIEW: return u.replace(arg=apply_to_st(u.arg))
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
assert u.op is not Ops.REDUCE_AXIS, "can't push a fixup through a reduce"
cache[u] = ret = u.replace(src=tuple(st_fixup(x, apply_to_st, cache) for x in u.src))
return ret
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]: def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
@@ -118,7 +113,7 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
new_input_st = tmp + ShapeTracker(tuple(nv)) new_input_st = tmp + ShapeTracker(tuple(nv))
_, new_rshape = permute_reduce(new_input_st, r.axis_arg) _, new_rshape = permute_reduce(new_input_st, r.axis_arg)
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape))) new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
return st_fixup(src, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp: def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st) swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
@@ -134,8 +129,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles] swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}" assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
new_shape, new_input_shape = swizzle_shapes[0] new_shape, new_input_shape = swizzle_shapes[0]
fixup_cache: Dict[UOp, UOp] = {} ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
ret = root.replace(src=tuple(x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src))
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape)) return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: