mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
update tests
This commit is contained in:
@@ -543,13 +543,12 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
|
||||
def test_unsynced_backprop_sync_weights(self):
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
GPUS = (d1, d2)
|
||||
|
||||
with Tensor.train():
|
||||
conv = nn.Conv2d(3, 16, 3)
|
||||
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
|
||||
bn = nn.BatchNorm2d(16, num_devices=len(GPUS))
|
||||
|
||||
for p in get_parameters([conv, bn]):
|
||||
if not isinstance(p.lazydata, MultiLazyBuffer):
|
||||
@@ -588,14 +587,13 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
bn_ts[0].cat(*bn_ts[1:]).numpy()
|
||||
|
||||
def test_synced_vs_unsynced_bn(self):
|
||||
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
||||
from tinygrad.nn import BatchNorm2d
|
||||
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
||||
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
|
||||
|
||||
with Tensor.train():
|
||||
synced_bn = BatchNorm2d(8)
|
||||
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
|
||||
unsynced_bn = BatchNorm2d(8, num_devices=len(devices))
|
||||
|
||||
for p in get_parameters([synced_bn, unsynced_bn]):
|
||||
p.shard_(devices)
|
||||
|
||||
Reference in New Issue
Block a user