multi test case for sharded ring allreduce (#12462)

* multi test case for sharded ring allreduce

triggers `children not making progress` with RANGEIFY

* expect_rangeify_fails
This commit is contained in:
chenyu
2025-10-06 11:18:24 +08:00
committed by GitHub
parent 1823a5043f
commit c1e85f699c
3 changed files with 15 additions and 6 deletions

View File

@@ -1,4 +1,4 @@
import time, struct
import time, struct, unittest
from typing import Any, Callable
import numpy as np
from tinygrad import Tensor, dtypes, Device
@@ -7,7 +7,7 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner
from tinygrad.dtype import DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import T, CI
from tinygrad.helpers import T, CI, RANGEIFY
from tinygrad.codegen import full_rewrite
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler
@@ -62,3 +62,6 @@ def not_support_multi_device():
# NOTE: This will open REMOTE if it's the default device
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device)
def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn)
def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn)

View File

@@ -7,7 +7,7 @@ from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
import numpy as np
from hypothesis import given, strategies as strat, settings
from test.helpers import REAL_DEV, not_support_multi_device
from test.helpers import REAL_DEV, not_support_multi_device, expect_rangeify_fails
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
@@ -201,6 +201,14 @@ class TestMultiTensor(unittest.TestCase):
fn = f(n)
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
@expect_rangeify_fails # TODO: fix
def test_allreduce_shard_ring_sum(self):
for axis in (0, 1, None):
for use_ring in (0, 2):
t = Tensor([1, 2, 3, 4]).reshape(2, 2)
with Context(RING=use_ring):
np.testing.assert_equal(t.shard(devices_2, axis=axis).sum().item(), 10)
def test_allreduce_naive(self):
with Context(RING=0):
a,b = _test_allreduce(Tensor.rand(256, 256))

View File

@@ -18,6 +18,7 @@ from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context,
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from test.helpers import expect_rangeify_fails, expect_nonrangeify_fails
class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
@@ -42,9 +43,6 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
raise KernelCountException(f"{kernel_cnt} != {allowed}")
return sched
def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn)
def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn)
def _realize_weights(m):
for p in nn.state.get_parameters(m): p.realize()