fix and test loading num_batches_tracked (#13538)

* fix and test loading num_batches_tracked

* add failing reverse case

* try reshape state dict if mismatch

* reshape for () and (1,)

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Rory Clear
2025-12-04 09:22:49 +00:00
committed by GitHub
parent 877a7fdd61
commit 6eab756578
2 changed files with 16 additions and 1 deletions

View File

@@ -481,6 +481,20 @@ class TestNN(unittest.TestCase):
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
#https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/nn/modules/module.py#L2425
def test_load_conv_num_batches_tracked(self):
layer = BatchNorm(sz=1, track_running_stats=False)
state_dict = {
'weight': Tensor.ones(1),
'bias': Tensor.ones(1),
'num_batches_tracked': Tensor.ones(1),
}
load_state_dict(layer, state_dict)
state_dict['num_batches_tracked'] = Tensor.empty()
load_state_dict(layer, state_dict)
layer.num_batches_tracked = Tensor.ones(1)
load_state_dict(layer, state_dict)
@needs_second_gpu
@unittest.skipIf(not_support_multi_device(), "no multi")
def test_load_state_dict_sharded_model(self):