rangeify: fix ram usage in multi (#12286)

This commit is contained in:
qazal
2025-09-24 13:48:58 +03:00
committed by GitHub
parent e8945c74de
commit 154c865966
2 changed files with 19 additions and 11 deletions

View File

@@ -531,7 +531,8 @@ jobs:
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding"
- name: Test multitensor
run: |
CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W TestMultiTensor.test_simple_reduce
CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W TestMultiTensor.test_simple_reduce \
TestMultiTensor.test_elementwise_dtype TestMultiTensor.test_shard_no_recompile
CPU=1 RANGEIFY=1 python3 -m pytest test/test_multitensor.py::TestMultiAssign -k 'not (multi_assign_piece_noncontig or multi_assign_var_offset)'
- name: Test CPU=1 RANGEIFY=2
run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20

View File

@@ -1,4 +1,4 @@
from typing import cast
from typing import cast, TypeVar
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve
@@ -82,9 +82,10 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
# ***** multi rewrite MSELECT/MSTACK *****
def _replace_dnum(st:ShapeTracker, val:int) -> ShapeTracker:
# replace dnum in ShapeTracker with literal const for this mselect
if (dnums:=[x for x in st.vars() if x.op is Ops.DEFINE_VAR and x.arg[0] == '_device_num']):
T = TypeVar("T", bound=ShapeTracker|sint)
def _replace_dnum(st:T, val:int) -> T:
# replace dnum in ShapeTracker (or UOp) with literal const for this mselect
if not isinstance(st, int) and (dnums:=[x for x in st.vars() if x.op is Ops.DEFINE_VAR and x.arg[0] == '_device_num']):
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
st = st.substitute({dnums[0]:dnums[0].const_like(val)})
return st
@@ -94,20 +95,23 @@ def mstack_reorder_view(ms:UOp):
if not all_same(args) or len([x for x in args[0].vars() if x.arg[0] == '_device_num']) != 0: return None
return UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).view(args[0])
def mstack_early_shrink(view:UOp, ms:UOp):
if resolve(prod(view.shape) >= prod(ms.shape)) or _replace_dnum(unwrap(view.st), 0) == view.st: return None
# NOTE: view path is for RANGEIFY=0, there should only be one way of doing this
def mstack_early_shrink(ms:UOp, view:UOp|None=None, shrink:UOp|None=None):
if view is not None and (resolve(prod(view.shape) >= prod(ms.shape)) or _replace_dnum(unwrap(view.st), 0) == view.st): return None
ret = []
def apply_shrink(s:UOp, i:int) -> UOp:
if view is not None: return s.view(_replace_dnum(unwrap(view.st), i))
return s.shrink(tuple(tuple(_replace_dnum(x, i) for x in ss) for ss in unwrap(shrink).arg))
for i, x in enumerate(ms.src):
new_view = _replace_dnum(unwrap(view.st), i)
if x.op is Ops.COPY:
# if src device doesn't have a renderer, we have to view after the copy
# TODO: a way to understand this
if x.src[0].device in {"DISK", "NPY"}:
ret.append(x.view(new_view))
ret.append(apply_shrink(x, i))
else:
ret.append(x.src[0].view(new_view).copy_to_device(x.device))
ret.append(apply_shrink(x.src[0], i).copy_to_device(x.device))
else:
ret.append(x.view(new_view).contiguous())
ret.append(apply_shrink(x, i).contiguous())
return ms.replace(src=tuple(ret))
replace_allreduce = PatternMatcher([
@@ -128,6 +132,9 @@ replace_allreduce = PatternMatcher([
(UPat(Ops.MSTACK, src=UPat(Ops.VIEW), name="ms"), mstack_reorder_view),
# move shrink before MSTACK
(UPat(Ops.VIEW, src=(UPat(Ops.MSTACK, name="ms"),), name="view"), mstack_early_shrink),
# *** new movement ops reordering
# move shrink before MSTACK
(UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), name="shrink"), mstack_early_shrink),
])
# ***** multi functions *****