mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
swizzle parents with graph rewrite (#7625)
* delete st_fixup * refactor * minimal diff
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
import sys, atexit, functools, itertools
|
import sys, atexit, functools, itertools
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
|
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
|
||||||
from tinygrad.helpers import DEBUG, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
from tinygrad.helpers import DEBUG, Context, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||||
from tinygrad.dtype import ImageDType, dtypes
|
from tinygrad.dtype import ImageDType, dtypes
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import View, strides_for_shape
|
from tinygrad.shape.view import View, strides_for_shape
|
||||||
@@ -91,13 +91,8 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce
|
|||||||
|
|
||||||
# ** helpers for doing movementops on uops
|
# ** helpers for doing movementops on uops
|
||||||
|
|
||||||
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
|
def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
|
||||||
if (n:=cache.get(u)) is not None: return n
|
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
|
||||||
if u.op is Ops.VIEW: return u.replace(arg=apply_to_st(u.arg))
|
|
||||||
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
|
|
||||||
assert u.op is not Ops.REDUCE_AXIS, "can't push a fixup through a reduce"
|
|
||||||
cache[u] = ret = u.replace(src=tuple(st_fixup(x, apply_to_st, cache) for x in u.src))
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
||||||
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
|
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
|
||||||
@@ -118,7 +113,7 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
|||||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||||
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
|
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
|
||||||
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
||||||
return st_fixup(src, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
||||||
|
|
||||||
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
|
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
|
||||||
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
|
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
|
||||||
@@ -134,8 +129,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
|||||||
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
|
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
|
||||||
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
|
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
|
||||||
new_shape, new_input_shape = swizzle_shapes[0]
|
new_shape, new_input_shape = swizzle_shapes[0]
|
||||||
fixup_cache: Dict[UOp, UOp] = {}
|
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
|
||||||
ret = root.replace(src=tuple(x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src))
|
|
||||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
||||||
|
|
||||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||||
|
|||||||
Reference in New Issue
Block a user