diff --git a/extra/models/resnet.py b/extra/models/resnet.py index 6d33928578..016f1d0759 100644 --- a/extra/models/resnet.py +++ b/extra/models/resnet.py @@ -137,7 +137,13 @@ class ResNet: self.url = model_urls[(self.num, self.groups, self.base_width)] for k, dat in torch_load(fetch(self.url)).items(): - obj: Tensor = get_child(self, k) + try: + obj: Tensor = get_child(self, k) + except AttributeError as e: + if 'fc.' in k and self.fc is None: + continue + + raise e if 'fc.' in k and obj.shape != dat.shape: print("skipping fully connected layer") diff --git a/test/models/test_resnet.py b/test/models/test_resnet.py index f6eb0401aa..f71d689056 100644 --- a/test/models/test_resnet.py +++ b/test/models/test_resnet.py @@ -9,6 +9,13 @@ class TestResnet(unittest.TestCase): model = resnet.ResNeXt50_32X4D() model.load_from_pretrained() + def test_model_load_no_fc_layer(self): + model = resnet.ResNet18(num_classes=None) + model.load_from_pretrained() + + model = resnet.ResNeXt50_32X4D(num_classes=None) + model.load_from_pretrained() + if __name__ == '__main__': unittest.main() \ No newline at end of file