diff --git a/test/test_nn.py b/test/test_nn.py index b5bea0dd4d..0ffd7f3d28 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -481,6 +481,20 @@ class TestNN(unittest.TestCase): np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy()) np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy()) + #https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/nn/modules/module.py#L2425 + def test_load_conv_num_batches_tracked(self): + layer = BatchNorm(sz=1, track_running_stats=False) + state_dict = { + 'weight': Tensor.ones(1), + 'bias': Tensor.ones(1), + 'num_batches_tracked': Tensor.ones(1), + } + load_state_dict(layer, state_dict) + state_dict['num_batches_tracked'] = Tensor.empty() + load_state_dict(layer, state_dict) + layer.num_batches_tracked = Tensor.ones(1) + load_state_dict(layer, state_dict) + @needs_second_gpu @unittest.skipIf(not_support_multi_device(), "no multi") def test_load_state_dict_sharded_model(self): diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index af30f445ef..effd7f1674 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -151,7 +151,8 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr if DEBUG >= 1: print(f"WARNING: not loading {k}") continue if v.shape != state_dict[k].shape: - raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.') + if {(), (1,)} == {state_dict[k].shape, v.shape}: state_dict[k] = state_dict[k].reshape(v.shape) + else: raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.') if isinstance(v.device, tuple): if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]) else: v.replace(state_dict[k].shard(v.device, v.uop.axis))