cleanup multi tests (#10635)

This commit is contained in:
chenyu
2025-06-05 00:28:44 -04:00
committed by GitHub
parent 571c0296a9
commit d0969f5a1f

View File

@@ -1,13 +1,12 @@
import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import Ops, UOp
from tinygrad.helpers import CI, getenv, prod, Context, OSX
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
import numpy as np
from hypothesis import given, strategies as strat, settings
from tinygrad.device import is_dtype_supported
from test.helpers import REAL_DEV, not_support_multi_device
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
@@ -575,12 +574,9 @@ class TestMultiTensor(unittest.TestCase):
scheds = [sched for sched in out.schedule() if sched.bufs[0].device in devices_4 and sched.ast.op is not Ops.COPY]
assert set(sched.bufs[0].device for sched in scheds) == set(devices_4), "should have ast on each shard device"
asts = [sched.ast for sched in scheds]
assert len(asts)
# test case to show that ast can be different on devices
# TODO: make ast identical on devices
#assert len(set(asts)) == 4, len(asts)
# for i, ast in enumerate(asts):
# print(f"{i} {ast}")
self.assertEqual(len(asts), 4)
# ast are the same on devices
self.assertEqual(len(set(asts)), 1)
def test_reshape_on_axis(self):
t0 = Tensor.rand((26, 15, 7)).shard(devices_3, axis=1)
@@ -768,32 +764,6 @@ class TestMultiTensor(unittest.TestCase):
assert set(unique) == {0, 2}, unique
assert 100 < counts[0] < 156, counts[0]
@unittest.skip("test depends on UOp order. TODO: fix it")
def test_broadcast_const(self):
for axis in (None, 0, 1):
t = Tensor.zeros(16, 16).contiguous().shard(devices_4, axis).realize()
t = t + 1
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is Ops.ADD
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1
t = 2 * t
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is Ops.MUL
assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2
assert ast.src[2].src[1].op is Ops.LOAD
t = t + t.full_like(3)
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is Ops.ADD
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3
@unittest.skip("TODO: this requires forced_realize to be deleted.")
def test_shard_memory(self):
devices = (d0, d1, d2, d3)
@@ -809,28 +779,15 @@ class TestMultiTensor(unittest.TestCase):
t = Tensor.rand(16, 16).shard(devices_2, axis=0)
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
@unittest.skip("this test looks wrong, times 0 is 0")
def test_multi_const_folding(self):
with Context(TRACK_MATCH_STATS=0):
a = Tensor.arange(3).realize()
zeros = Tensor.zeros(3).realize()
b = a.to(devices_2)*zeros.to(devices_2)
sched = b.schedule()
self.assertEqual(len(sched), 6)
# notably, only two copies (for the arange) - vs 4 copies if we didn't fold the const copy
self.assertEqual(len([x for x in sched if any(u.op is Ops.COPY for u in x.ast.toposort())]), 2)
run_schedule(sched)
self.assertEqual(len(sched), 0)
self.assertListEqual(b.tolist(), [0, 0, 0])
@unittest.skip("not sure what this tests")
def test_dont_realize_intermediate_expand(self):
a = Tensor.empty(16, 1).shard_(devices_2, axis=0)
b = Tensor.empty(16, 16).to_(devices_2)
c = Tensor.empty(16, 16).shard_(devices_2, axis=1)
d = a+b
(d*c).realize()
assert not d.lazydata.is_realized
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestHandleData(unittest.TestCase):
def test_copied_to_device(self):
@@ -875,19 +832,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
a.schedule()
assert a.shape == (2, 8)
# real is no longer used, so these are on None and we can pad them however
"""
with self.assertRaises(AssertionError):
# cannot pad sharded and non-sharded axis at the same time
p = a.pad(((0, 6), (0, 1)))
p.schedule()
with self.assertRaises(AssertionError):
# can only pad to whole axis
p = a.pad(((1, 5), (0, 0)))
p.schedule()
"""
p = a.pad(((0, 6), (0, 0)))
p.schedule()
assert p.shape == (8, 8)
@@ -956,7 +900,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
np.testing.assert_equal((a+a).numpy(), na+na)
np.testing.assert_equal((b+b).numpy(), nb+nb)
@unittest.skip("why didn't this work?")
# @unittest.skip("why didn't this work?")
def test_add_two_partitions(self):
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
@@ -967,16 +911,9 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
nb = t.numpy()[6:8]
np.testing.assert_equal(a.numpy(), na)
np.testing.assert_equal(b.numpy(), nb)
self.assertEqual(a.lazydata.real, (False, True, False, False))
self.assertEqual(b.lazydata.real, (False, False, False, True))
with self.assertRaises(AssertionError):
# cannot add directly
c = a + b
c.schedule()
np.testing.assert_equal((a+b).numpy(), na+nb)
c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
c.realize()
self.assertEqual(c.lazydata.real, (True, True, True, True))
expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
np.testing.assert_equal(c.numpy(), expected)
@@ -988,7 +925,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
for i in range(len(devices)):
to_add.append((Tensor.ones(2, 8) * i).shard(devices))
added:List[Tensor] = []
added:list[Tensor] = []
for bound, a in zip(x.lazydata.bounds, to_add):
added.append(x[bound[0]:bound[1]] + a)
@@ -1032,7 +969,7 @@ class TestBatchNorm(unittest.TestCase):
class BatchNorm:
def __init__(self, num_features):
self.bns:List[nn.BatchNorm2d] = []
self.bns:list[nn.BatchNorm2d] = []
for _ in GPUS:
bn = nn.BatchNorm2d(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
self.bns.append(bn)