update tests

This commit is contained in:
David Hou
2024-02-29 13:48:40 -08:00
parent 78de0ea9ee
commit 7ebf65d89a

View File

@@ -543,13 +543,12 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
def test_unsynced_backprop_sync_weights(self):
from extra.lr_scheduler import OneCycleLR
from examples.hlb_cifar10 import UnsyncedBatchNorm
from tinygrad.features.multi import MultiLazyBuffer
GPUS = (d1, d2)
with Tensor.train():
conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
bn = nn.BatchNorm2d(16, num_devices=len(GPUS))
for p in get_parameters([conv, bn]):
if not isinstance(p.lazydata, MultiLazyBuffer):
@@ -588,14 +587,13 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
bn_ts[0].cat(*bn_ts[1:]).numpy()
def test_synced_vs_unsynced_bn(self):
from examples.hlb_cifar10 import UnsyncedBatchNorm
from tinygrad.nn import BatchNorm2d
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train():
synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
unsynced_bn = BatchNorm2d(8, num_devices=len(devices))
for p in get_parameters([synced_bn, unsynced_bn]):
p.shard_(devices)