mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix and test loading num_batches_tracked (#13538)
* fix and test loading num_batches_tracked * add failing reverse case * try reshape state dict if mismatch * reshape for () and (1,) --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -481,6 +481,20 @@ class TestNN(unittest.TestCase):
|
||||
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
||||
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
||||
|
||||
#https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/nn/modules/module.py#L2425
|
||||
def test_load_conv_num_batches_tracked(self):
|
||||
layer = BatchNorm(sz=1, track_running_stats=False)
|
||||
state_dict = {
|
||||
'weight': Tensor.ones(1),
|
||||
'bias': Tensor.ones(1),
|
||||
'num_batches_tracked': Tensor.ones(1),
|
||||
}
|
||||
load_state_dict(layer, state_dict)
|
||||
state_dict['num_batches_tracked'] = Tensor.empty()
|
||||
load_state_dict(layer, state_dict)
|
||||
layer.num_batches_tracked = Tensor.ones(1)
|
||||
load_state_dict(layer, state_dict)
|
||||
|
||||
@needs_second_gpu
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
def test_load_state_dict_sharded_model(self):
|
||||
|
||||
Reference in New Issue
Block a user