mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Fix FC layer ResNet load_from_pretrained error (#8387)
* validate that FC exists before loading pretrained weights * add test case for ResNet pretrained model without FC layer * remove extra newline * rename test case * reraise exception if not handled by check
This commit is contained in:
@@ -137,7 +137,13 @@ class ResNet:
|
||||
|
||||
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
||||
for k, dat in torch_load(fetch(self.url)).items():
|
||||
obj: Tensor = get_child(self, k)
|
||||
try:
|
||||
obj: Tensor = get_child(self, k)
|
||||
except AttributeError as e:
|
||||
if 'fc.' in k and self.fc is None:
|
||||
continue
|
||||
|
||||
raise e
|
||||
|
||||
if 'fc.' in k and obj.shape != dat.shape:
|
||||
print("skipping fully connected layer")
|
||||
|
||||
@@ -9,6 +9,13 @@ class TestResnet(unittest.TestCase):
|
||||
model = resnet.ResNeXt50_32X4D()
|
||||
model.load_from_pretrained()
|
||||
|
||||
def test_model_load_no_fc_layer(self):
|
||||
model = resnet.ResNet18(num_classes=None)
|
||||
model.load_from_pretrained()
|
||||
|
||||
model = resnet.ResNeXt50_32X4D(num_classes=None)
|
||||
model.load_from_pretrained()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user