mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* schedule_with_vars -> linear_with_vars in tests * tests batch 1 * batch 2 * estimate_uop * simpler * rm
53 lines
2.2 KiB
Python
53 lines
2.2 KiB
Python
import unittest
|
|
from tinygrad import Tensor, dtypes
|
|
from tinygrad.helpers import Context
|
|
from tinygrad.uop.ops import Ops
|
|
|
|
class TestRingAllReduce(unittest.TestCase):
|
|
def test_schedule_ring(self):
|
|
with Context(RING=2):
|
|
N = 4
|
|
ds = tuple(f"CPU:{i}" for i in range(N))
|
|
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
|
|
linear = t.sum(0).linear_with_vars()[0]
|
|
copies = [si for si in linear.src if si.src[0].op is Ops.COPY]
|
|
pairs = [(c.src[1].buffer.device, c.src[2].buffer.device) for c in copies]
|
|
# N*(N-1) scatter reduce, and N*(N-1) allgather
|
|
self.assertEqual(len(pairs), N*(N-1)*2)
|
|
# copy topology forms a ring
|
|
self.assertEqual(len(set(pairs)), N)
|
|
|
|
def test_correct_ring(self):
|
|
with Context(RING=2):
|
|
N = 4
|
|
ds = tuple(f"CPU:{i}" for i in range(N))
|
|
t = Tensor.ones(N, N*100).contiguous().shard(ds, axis=0).realize()
|
|
out = t.sum(0)
|
|
self.assertListEqual(out.tolist(), [4]*N*100)
|
|
|
|
class TestAllreduceCast(unittest.TestCase):
|
|
def _get_copy_dtypes(self, dtype, allreduce_cast):
|
|
ds = tuple(f"CPU:{i}" for i in range(2))
|
|
with Context(ALLREDUCE_CAST=allreduce_cast, RING=0, SCACHE=0):
|
|
t = Tensor.empty(4, 4, dtype=dtype).shard(ds, axis=0)
|
|
linear = t.sum(0).linear_with_vars()[0]
|
|
return {si.src[1].buffer.dtype.scalar() for si in linear.src if si.src[0].op is Ops.COPY}
|
|
|
|
def test_allreduce_cast_bf16(self):
|
|
# with ALLREDUCE_CAST, allreduce copies stay in bfloat16 instead of promoting to float32
|
|
self.assertNotIn(dtypes.float, self._get_copy_dtypes(dtypes.bfloat16, allreduce_cast=1))
|
|
self.assertIn(dtypes.float, self._get_copy_dtypes(dtypes.bfloat16, allreduce_cast=0))
|
|
|
|
def test_allreduce_cast_half(self):
|
|
self.assertNotIn(dtypes.float, self._get_copy_dtypes(dtypes.half, allreduce_cast=1))
|
|
self.assertIn(dtypes.float, self._get_copy_dtypes(dtypes.half, allreduce_cast=0))
|
|
|
|
def test_allreduce_cast_float32_noop(self):
|
|
# float32 should not be affected by ALLREDUCE_CAST (no promotion happens)
|
|
dtypes_on = self._get_copy_dtypes(dtypes.float, allreduce_cast=1)
|
|
dtypes_off = self._get_copy_dtypes(dtypes.float, allreduce_cast=0)
|
|
self.assertEqual(dtypes_on, dtypes_off)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|