mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into rtoposort
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
from typing import cast, TypeVar
|
||||
from typing import cast
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve, track_rewrites, graph_rewrite_map
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, track_rewrites, graph_rewrite_map
|
||||
from tinygrad.device import Device
|
||||
|
||||
# *** allreduce implementation ***
|
||||
@@ -82,26 +81,13 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
|
||||
# ***** multi rewrite MSELECT/MSTACK *****
|
||||
|
||||
T = TypeVar("T", bound=ShapeTracker|sint)
|
||||
def _replace_dnum(st:T, val:int) -> T:
|
||||
# replace dnum in ShapeTracker (or UOp) with literal const for this mselect
|
||||
if not isinstance(st, int) and (dnums:=[x for x in st.vars() if x.op is Ops.DEFINE_VAR and x.arg[0] == '_device_num']):
|
||||
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
|
||||
st = st.substitute({dnums[0]:dnums[0].const_like(val)})
|
||||
return st
|
||||
|
||||
def mstack_reorder_view(ms:UOp):
|
||||
args = [x.arg for x in ms.src]
|
||||
if not all_same(args) or len([x for x in args[0].vars() if x.arg[0] == '_device_num']) != 0: return None
|
||||
return UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).view(args[0])
|
||||
|
||||
# NOTE: view path is for RANGEIFY=0, there should only be one way of doing this
|
||||
def mstack_early_shrink(ms:UOp, view:UOp|None=None, shrink:UOp|None=None):
|
||||
if view is not None and (resolve(prod(view.shape) >= prod(ms.shape)) or _replace_dnum(unwrap(view.st), 0) == view.st): return None
|
||||
ret = []
|
||||
def mstack_early_shrink(ms:UOp, shrink:UOp):
|
||||
ret:list[UOp] = []
|
||||
def apply_shrink(s:UOp, i:int) -> UOp:
|
||||
if view is not None: return s.view(_replace_dnum(unwrap(view.st), i))
|
||||
return s.shrink(tuple(tuple(_replace_dnum(x, i) for x in ss) for ss in unwrap(shrink).arg))
|
||||
new_arg = [tuple([x.substitute({dvar[0]:dvar[0].const_like(i)}) if isinstance(x, UOp) and
|
||||
(dvar:=[v for v in x.vars() if v.op is Ops.DEFINE_VAR and v.arg[0]=='_device_num']) else x for x in ss]) for ss in shrink.arg]
|
||||
return s.shrink(tuple(new_arg))
|
||||
for i, x in enumerate(ms.src):
|
||||
if x.op is Ops.COPY:
|
||||
# if src device doesn't have a renderer, we have to view after the copy
|
||||
@@ -125,14 +111,6 @@ replace_allreduce = PatternMatcher([
|
||||
x.mselect(0).copy_to_device(c.device) if isinstance(c.device, str) and isinstance(x.device, tuple) else None),
|
||||
# MSELECT on MSTACK is replaced with nothing
|
||||
(UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]),
|
||||
# MSELECT must select a base, if there are views apply them after selecting the base
|
||||
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), lambda ms, view, base:
|
||||
base.mselect(ms.arg).view(_replace_dnum(unwrap(view.st), ms.arg))),
|
||||
# move view through MSTACK
|
||||
(UPat(Ops.MSTACK, src=UPat(Ops.VIEW), name="ms"), mstack_reorder_view),
|
||||
# move shrink before MSTACK
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.MSTACK, name="ms"),), name="view"), mstack_early_shrink),
|
||||
# *** new movement ops reordering
|
||||
# move shrink before MSTACK
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), name="shrink"), mstack_early_shrink),
|
||||
# move MSELECT before movement ops
|
||||
|
||||
Reference in New Issue
Block a user