mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix train script
This commit is contained in:
@@ -423,8 +423,8 @@ def train_retinanet():
|
||||
|
||||
# ** model setup **
|
||||
backbone = resnet.ResNeXt50_32X4D(num_classes=None)
|
||||
loaded_keys = backbone.load_from_pretrained()
|
||||
_freeze_backbone_layers(backbone, 3, loaded_keys)
|
||||
backbone.load_from_pretrained()
|
||||
_freeze_backbone_layers(backbone, 3)
|
||||
|
||||
model = RetinaNet(backbone, num_classes=NUM_CLASSES)
|
||||
params = get_parameters(model)
|
||||
|
||||
Reference in New Issue
Block a user