Files
tinygrad/test/unit/test_allreduce.py
nimlgen d3378010ee schedule() -> schedule_linear() in tests (batch 1) (#15915)
* schedule_with_vars -> linear_with_vars in tests

* tests batch 1

* batch 2

* estimate_uop

* simpler

* rm
2026-04-24 23:40:53 +03:00

53 lines
2.2 KiB
Python

import unittest
from tinygrad import Tensor, dtypes
from tinygrad.helpers import Context
from tinygrad.uop.ops import Ops
class TestRingAllReduce(unittest.TestCase):
def test_schedule_ring(self):
with Context(RING=2):
N = 4
ds = tuple(f"CPU:{i}" for i in range(N))
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
linear = t.sum(0).linear_with_vars()[0]
copies = [si for si in linear.src if si.src[0].op is Ops.COPY]
pairs = [(c.src[1].buffer.device, c.src[2].buffer.device) for c in copies]
# N*(N-1) scatter reduce, and N*(N-1) allgather
self.assertEqual(len(pairs), N*(N-1)*2)
# copy topology forms a ring
self.assertEqual(len(set(pairs)), N)
def test_correct_ring(self):
with Context(RING=2):
N = 4
ds = tuple(f"CPU:{i}" for i in range(N))
t = Tensor.ones(N, N*100).contiguous().shard(ds, axis=0).realize()
out = t.sum(0)
self.assertListEqual(out.tolist(), [4]*N*100)
class TestAllreduceCast(unittest.TestCase):
def _get_copy_dtypes(self, dtype, allreduce_cast):
ds = tuple(f"CPU:{i}" for i in range(2))
with Context(ALLREDUCE_CAST=allreduce_cast, RING=0, SCACHE=0):
t = Tensor.empty(4, 4, dtype=dtype).shard(ds, axis=0)
linear = t.sum(0).linear_with_vars()[0]
return {si.src[1].buffer.dtype.scalar() for si in linear.src if si.src[0].op is Ops.COPY}
def test_allreduce_cast_bf16(self):
# with ALLREDUCE_CAST, allreduce copies stay in bfloat16 instead of promoting to float32
self.assertNotIn(dtypes.float, self._get_copy_dtypes(dtypes.bfloat16, allreduce_cast=1))
self.assertIn(dtypes.float, self._get_copy_dtypes(dtypes.bfloat16, allreduce_cast=0))
def test_allreduce_cast_half(self):
self.assertNotIn(dtypes.float, self._get_copy_dtypes(dtypes.half, allreduce_cast=1))
self.assertIn(dtypes.float, self._get_copy_dtypes(dtypes.half, allreduce_cast=0))
def test_allreduce_cast_float32_noop(self):
# float32 should not be affected by ALLREDUCE_CAST (no promotion happens)
dtypes_on = self._get_copy_dtypes(dtypes.float, allreduce_cast=1)
dtypes_off = self._get_copy_dtypes(dtypes.float, allreduce_cast=0)
self.assertEqual(dtypes_on, dtypes_off)
if __name__ == '__main__':
unittest.main()