make training work

This commit is contained in:
Francis Lata
2024-12-23 21:48:55 +00:00
parent 96a7d1d442
commit c1a18e13ef
3 changed files with 5 additions and 2 deletions

View File

@@ -16,7 +16,7 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float
if alpha >= 0:
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
loss *= alpha_t
loss = loss * alpha_t
if reduction == "mean": loss = loss.mean()
elif reduction == "sum": loss = loss.sum()

View File

@@ -408,7 +408,7 @@ def train_retinanet():
resnet.BatchNorm = FrozenBatchNorm2d
# ** model setup **
backbone = resnet.ResNeXt50_32X4D(num_classes=NUM_CLASSES)
backbone = resnet.ResNeXt50_32X4D(num_classes=None)
loaded_keys = backbone.load_from_pretrained()
_freeze_backbone_layers(backbone, 3, loaded_keys)

View File

@@ -138,6 +138,9 @@ class ResNet:
self.url = model_urls[(self.num, self.groups, self.base_width)]
loaded_keys = []
for k, dat in torch_load(fetch(self.url)).items():
if 'fc.' in k and self.fc is None:
continue
obj: Tensor = get_child(self, k)
if 'fc.' in k and obj.shape != dat.shape: