mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
swizzle parents with graph rewrite (#7625)
* delete st_fixup * refactor * minimal diff
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import sys, atexit, functools, itertools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
|
||||
from tinygrad.helpers import DEBUG, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import DEBUG, Context, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
@@ -91,13 +91,8 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce
|
||||
|
||||
# ** helpers for doing movementops on uops
|
||||
|
||||
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
|
||||
if (n:=cache.get(u)) is not None: return n
|
||||
if u.op is Ops.VIEW: return u.replace(arg=apply_to_st(u.arg))
|
||||
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
|
||||
assert u.op is not Ops.REDUCE_AXIS, "can't push a fixup through a reduce"
|
||||
cache[u] = ret = u.replace(src=tuple(st_fixup(x, apply_to_st, cache) for x in u.src))
|
||||
return ret
|
||||
def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
|
||||
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
|
||||
|
||||
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
||||
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
|
||||
@@ -118,7 +113,7 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
|
||||
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
||||
return st_fixup(src, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
||||
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
||||
|
||||
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
|
||||
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
|
||||
@@ -134,8 +129,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
|
||||
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
|
||||
new_shape, new_input_shape = swizzle_shapes[0]
|
||||
fixup_cache: Dict[UOp, UOp] = {}
|
||||
ret = root.replace(src=tuple(x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src))
|
||||
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
|
||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
|
||||
Reference in New Issue
Block a user