Merge branch 'master' into rtoposort

This commit is contained in:
George Hotz
2025-10-08 19:15:38 +08:00
committed by GitHub

View File

@@ -1,8 +1,7 @@
from typing import cast, TypeVar
from typing import cast
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve, track_rewrites, graph_rewrite_map
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, track_rewrites, graph_rewrite_map
from tinygrad.device import Device
# *** allreduce implementation ***
@@ -82,26 +81,13 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
# ***** multi rewrite MSELECT/MSTACK *****
T = TypeVar("T", bound=ShapeTracker|sint)
def _replace_dnum(st:T, val:int) -> T:
# replace dnum in ShapeTracker (or UOp) with literal const for this mselect
if not isinstance(st, int) and (dnums:=[x for x in st.vars() if x.op is Ops.DEFINE_VAR and x.arg[0] == '_device_num']):
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
st = st.substitute({dnums[0]:dnums[0].const_like(val)})
return st
def mstack_reorder_view(ms:UOp):
args = [x.arg for x in ms.src]
if not all_same(args) or len([x for x in args[0].vars() if x.arg[0] == '_device_num']) != 0: return None
return UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).view(args[0])
# NOTE: view path is for RANGEIFY=0, there should only be one way of doing this
def mstack_early_shrink(ms:UOp, view:UOp|None=None, shrink:UOp|None=None):
if view is not None and (resolve(prod(view.shape) >= prod(ms.shape)) or _replace_dnum(unwrap(view.st), 0) == view.st): return None
ret = []
def mstack_early_shrink(ms:UOp, shrink:UOp):
ret:list[UOp] = []
def apply_shrink(s:UOp, i:int) -> UOp:
if view is not None: return s.view(_replace_dnum(unwrap(view.st), i))
return s.shrink(tuple(tuple(_replace_dnum(x, i) for x in ss) for ss in unwrap(shrink).arg))
new_arg = [tuple([x.substitute({dvar[0]:dvar[0].const_like(i)}) if isinstance(x, UOp) and
(dvar:=[v for v in x.vars() if v.op is Ops.DEFINE_VAR and v.arg[0]=='_device_num']) else x for x in ss]) for ss in shrink.arg]
return s.shrink(tuple(new_arg))
for i, x in enumerate(ms.src):
if x.op is Ops.COPY:
# if src device doesn't have a renderer, we have to view after the copy
@@ -125,14 +111,6 @@ replace_allreduce = PatternMatcher([
x.mselect(0).copy_to_device(c.device) if isinstance(c.device, str) and isinstance(x.device, tuple) else None),
# MSELECT on MSTACK is replaced with nothing
(UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]),
# MSELECT must select a base, if there are views apply them after selecting the base
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), lambda ms, view, base:
base.mselect(ms.arg).view(_replace_dnum(unwrap(view.st), ms.arg))),
# move view through MSTACK
(UPat(Ops.MSTACK, src=UPat(Ops.VIEW), name="ms"), mstack_reorder_view),
# move shrink before MSTACK
(UPat(Ops.VIEW, src=(UPat(Ops.MSTACK, name="ms"),), name="view"), mstack_early_shrink),
# *** new movement ops reordering
# move shrink before MSTACK
(UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), name="shrink"), mstack_early_shrink),
# move MSELECT before movement ops