mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
rangeify: fix ram usage in multi (#12286)
This commit is contained in:
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@@ -531,7 +531,8 @@ jobs:
|
||||
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding"
|
||||
- name: Test multitensor
|
||||
run: |
|
||||
CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W TestMultiTensor.test_simple_reduce
|
||||
CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W TestMultiTensor.test_simple_reduce \
|
||||
TestMultiTensor.test_elementwise_dtype TestMultiTensor.test_shard_no_recompile
|
||||
CPU=1 RANGEIFY=1 python3 -m pytest test/test_multitensor.py::TestMultiAssign -k 'not (multi_assign_piece_noncontig or multi_assign_var_offset)'
|
||||
- name: Test CPU=1 RANGEIFY=2
|
||||
run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import cast, TypeVar
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve
|
||||
@@ -82,9 +82,10 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
|
||||
# ***** multi rewrite MSELECT/MSTACK *****
|
||||
|
||||
def _replace_dnum(st:ShapeTracker, val:int) -> ShapeTracker:
|
||||
# replace dnum in ShapeTracker with literal const for this mselect
|
||||
if (dnums:=[x for x in st.vars() if x.op is Ops.DEFINE_VAR and x.arg[0] == '_device_num']):
|
||||
T = TypeVar("T", bound=ShapeTracker|sint)
|
||||
def _replace_dnum(st:T, val:int) -> T:
|
||||
# replace dnum in ShapeTracker (or UOp) with literal const for this mselect
|
||||
if not isinstance(st, int) and (dnums:=[x for x in st.vars() if x.op is Ops.DEFINE_VAR and x.arg[0] == '_device_num']):
|
||||
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
|
||||
st = st.substitute({dnums[0]:dnums[0].const_like(val)})
|
||||
return st
|
||||
@@ -94,20 +95,23 @@ def mstack_reorder_view(ms:UOp):
|
||||
if not all_same(args) or len([x for x in args[0].vars() if x.arg[0] == '_device_num']) != 0: return None
|
||||
return UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).view(args[0])
|
||||
|
||||
def mstack_early_shrink(view:UOp, ms:UOp):
|
||||
if resolve(prod(view.shape) >= prod(ms.shape)) or _replace_dnum(unwrap(view.st), 0) == view.st: return None
|
||||
# NOTE: view path is for RANGEIFY=0, there should only be one way of doing this
|
||||
def mstack_early_shrink(ms:UOp, view:UOp|None=None, shrink:UOp|None=None):
|
||||
if view is not None and (resolve(prod(view.shape) >= prod(ms.shape)) or _replace_dnum(unwrap(view.st), 0) == view.st): return None
|
||||
ret = []
|
||||
def apply_shrink(s:UOp, i:int) -> UOp:
|
||||
if view is not None: return s.view(_replace_dnum(unwrap(view.st), i))
|
||||
return s.shrink(tuple(tuple(_replace_dnum(x, i) for x in ss) for ss in unwrap(shrink).arg))
|
||||
for i, x in enumerate(ms.src):
|
||||
new_view = _replace_dnum(unwrap(view.st), i)
|
||||
if x.op is Ops.COPY:
|
||||
# if src device doesn't have a renderer, we have to view after the copy
|
||||
# TODO: a way to understand this
|
||||
if x.src[0].device in {"DISK", "NPY"}:
|
||||
ret.append(x.view(new_view))
|
||||
ret.append(apply_shrink(x, i))
|
||||
else:
|
||||
ret.append(x.src[0].view(new_view).copy_to_device(x.device))
|
||||
ret.append(apply_shrink(x.src[0], i).copy_to_device(x.device))
|
||||
else:
|
||||
ret.append(x.view(new_view).contiguous())
|
||||
ret.append(apply_shrink(x, i).contiguous())
|
||||
return ms.replace(src=tuple(ret))
|
||||
|
||||
replace_allreduce = PatternMatcher([
|
||||
@@ -128,6 +132,9 @@ replace_allreduce = PatternMatcher([
|
||||
(UPat(Ops.MSTACK, src=UPat(Ops.VIEW), name="ms"), mstack_reorder_view),
|
||||
# move shrink before MSTACK
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.MSTACK, name="ms"),), name="view"), mstack_early_shrink),
|
||||
# *** new movement ops reordering
|
||||
# move shrink before MSTACK
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MSTACK, name="ms"),), name="shrink"), mstack_early_shrink),
|
||||
])
|
||||
|
||||
# ***** multi functions *****
|
||||
|
||||
Reference in New Issue
Block a user