diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 6262f28fa2..82941bebff 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -543,13 +543,12 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): def test_unsynced_backprop_sync_weights(self): from extra.lr_scheduler import OneCycleLR - from examples.hlb_cifar10 import UnsyncedBatchNorm from tinygrad.features.multi import MultiLazyBuffer GPUS = (d1, d2) with Tensor.train(): conv = nn.Conv2d(3, 16, 3) - bn = UnsyncedBatchNorm(16, num_devices=len(GPUS)) + bn = nn.BatchNorm2d(16, num_devices=len(GPUS)) for p in get_parameters([conv, bn]): if not isinstance(p.lazydata, MultiLazyBuffer): @@ -588,14 +587,13 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): bn_ts[0].cat(*bn_ts[1:]).numpy() def test_synced_vs_unsynced_bn(self): - from examples.hlb_cifar10 import UnsyncedBatchNorm from tinygrad.nn import BatchNorm2d devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0) with Tensor.train(): synced_bn = BatchNorm2d(8) - unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices)) + unsynced_bn = BatchNorm2d(8, num_devices=len(devices)) for p in get_parameters([synced_bn, unsynced_bn]): p.shard_(devices)