diff --git a/test/helpers.py b/test/helpers.py index cee64595f3..98b5978f98 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,4 +1,4 @@ -import time, struct +import time, struct, unittest from typing import Any, Callable import numpy as np from tinygrad import Tensor, dtypes, Device @@ -7,7 +7,7 @@ from tinygrad.tensor import _to_np_dtype from tinygrad.engine.realize import Runner from tinygrad.dtype import DType from tinygrad.nn.state import get_parameters -from tinygrad.helpers import T, CI +from tinygrad.helpers import T, CI, RANGEIFY from tinygrad.codegen import full_rewrite from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler @@ -62,3 +62,6 @@ def not_support_multi_device(): # NOTE: This will open REMOTE if it's the default device REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device) + +def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn) +def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index fa6086da11..d1c24b0d65 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -7,7 +7,7 @@ from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule import numpy as np from hypothesis import given, strategies as strat, settings -from test.helpers import REAL_DEV, not_support_multi_device +from test.helpers import REAL_DEV, not_support_multi_device, expect_rangeify_fails settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) settings.load_profile("my_profile") @@ -201,6 +201,14 @@ class TestMultiTensor(unittest.TestCase): fn = f(n) np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6) + @expect_rangeify_fails # TODO: fix + def test_allreduce_shard_ring_sum(self): + for axis in (0, 1, None): + for use_ring in (0, 2): + t = Tensor([1, 2, 3, 4]).reshape(2, 2) + with Context(RING=use_ring): + np.testing.assert_equal(t.shard(devices_2, axis=axis).sum().item(), 10) + def test_allreduce_naive(self): with Context(RING=0): a,b = _test_allreduce(Tensor.rand(256, 256)) diff --git a/test/test_schedule.py b/test/test_schedule.py index 34a982fb99..7c358fe016 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -18,6 +18,7 @@ from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule +from test.helpers import expect_rangeify_fails, expect_nonrangeify_fails class KernelCountException(Exception): pass def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True): @@ -42,9 +43,6 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te raise KernelCountException(f"{kernel_cnt} != {allowed}") return sched -def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn) -def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn) - def _realize_weights(m): for p in nn.state.get_parameters(m): p.realize()