mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
multi test case for sharded ring allreduce (#12462)
* multi test case for sharded ring allreduce triggers `children not making progress` with RANGEIFY * expect_rangeify_fails
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import time, struct
|
||||
import time, struct, unittest
|
||||
from typing import Any, Callable
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
@@ -7,7 +7,7 @@ from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.engine.realize import Runner
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.helpers import T, CI
|
||||
from tinygrad.helpers import T, CI, RANGEIFY
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler
|
||||
|
||||
@@ -62,3 +62,6 @@ def not_support_multi_device():
|
||||
|
||||
# NOTE: This will open REMOTE if it's the default device
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device)
|
||||
|
||||
def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn)
|
||||
def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn)
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
from test.helpers import REAL_DEV, not_support_multi_device
|
||||
from test.helpers import REAL_DEV, not_support_multi_device, expect_rangeify_fails
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
@@ -201,6 +201,14 @@ class TestMultiTensor(unittest.TestCase):
|
||||
fn = f(n)
|
||||
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
|
||||
|
||||
@expect_rangeify_fails # TODO: fix
|
||||
def test_allreduce_shard_ring_sum(self):
|
||||
for axis in (0, 1, None):
|
||||
for use_ring in (0, 2):
|
||||
t = Tensor([1, 2, 3, 4]).reshape(2, 2)
|
||||
with Context(RING=use_ring):
|
||||
np.testing.assert_equal(t.shard(devices_2, axis=axis).sum().item(), 10)
|
||||
|
||||
def test_allreduce_naive(self):
|
||||
with Context(RING=0):
|
||||
a,b = _test_allreduce(Tensor.rand(256, 256))
|
||||
|
||||
@@ -18,6 +18,7 @@ from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context,
|
||||
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from test.helpers import expect_rangeify_fails, expect_nonrangeify_fails
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
|
||||
@@ -42,9 +43,6 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
|
||||
raise KernelCountException(f"{kernel_cnt} != {allowed}")
|
||||
return sched
|
||||
|
||||
def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn)
|
||||
def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn)
|
||||
|
||||
def _realize_weights(m):
|
||||
for p in nn.state.get_parameters(m): p.realize()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user