Fix FC layer ResNet load_from_pretrained error (#8387)

* validate that FC exists before loading pretrained weights

* add test case for ResNet pretrained model without FC layer

* remove extra newline

* rename test case

* reraise exception if not handled by check
This commit is contained in:
Francis Lata
2024-12-26 18:11:27 -05:00
committed by GitHub
parent 90f1f0c9d5
commit 5755ac1f72
2 changed files with 14 additions and 1 deletions

View File

@@ -137,7 +137,13 @@ class ResNet:
self.url = model_urls[(self.num, self.groups, self.base_width)]
for k, dat in torch_load(fetch(self.url)).items():
obj: Tensor = get_child(self, k)
try:
obj: Tensor = get_child(self, k)
except AttributeError as e:
if 'fc.' in k and self.fc is None:
continue
raise e
if 'fc.' in k and obj.shape != dat.shape:
print("skipping fully connected layer")

View File

@@ -9,6 +9,13 @@ class TestResnet(unittest.TestCase):
model = resnet.ResNeXt50_32X4D()
model.load_from_pretrained()
def test_model_load_no_fc_layer(self):
model = resnet.ResNet18(num_classes=None)
model.load_from_pretrained()
model = resnet.ResNeXt50_32X4D(num_classes=None)
model.load_from_pretrained()
if __name__ == '__main__':
unittest.main()