fix train script

This commit is contained in:
Francis Lata
2025-02-19 20:43:02 +00:00
parent fc36f09b1e
commit 7dba815c47

View File

@@ -423,8 +423,8 @@ def train_retinanet():
# ** model setup **
backbone = resnet.ResNeXt50_32X4D(num_classes=None)
loaded_keys = backbone.load_from_pretrained()
_freeze_backbone_layers(backbone, 3, loaded_keys)
backbone.load_from_pretrained()
_freeze_backbone_layers(backbone, 3)
model = RetinaNet(backbone, num_classes=NUM_CLASSES)
params = get_parameters(model)