From c1a18e13ef65381502cbf4c39bb71ee13cddd102 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Mon, 23 Dec 2024 21:48:55 +0000 Subject: [PATCH] make training work --- examples/mlperf/losses.py | 2 +- examples/mlperf/model_train.py | 2 +- extra/models/resnet.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/losses.py b/examples/mlperf/losses.py index b5e0917275..136071bdf8 100644 --- a/examples/mlperf/losses.py +++ b/examples/mlperf/losses.py @@ -16,7 +16,7 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float if alpha >= 0: alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt) - loss *= alpha_t + loss = loss * alpha_t if reduction == "mean": loss = loss.mean() elif reduction == "sum": loss = loss.sum() diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 5426e3c710..de456c3f9c 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -408,7 +408,7 @@ def train_retinanet(): resnet.BatchNorm = FrozenBatchNorm2d # ** model setup ** - backbone = resnet.ResNeXt50_32X4D(num_classes=NUM_CLASSES) + backbone = resnet.ResNeXt50_32X4D(num_classes=None) loaded_keys = backbone.load_from_pretrained() _freeze_backbone_layers(backbone, 3, loaded_keys) diff --git a/extra/models/resnet.py b/extra/models/resnet.py index 60b7cdedcb..b29654a276 100644 --- a/extra/models/resnet.py +++ b/extra/models/resnet.py @@ -138,6 +138,9 @@ class ResNet: self.url = model_urls[(self.num, self.groups, self.base_width)] loaded_keys = [] for k, dat in torch_load(fetch(self.url)).items(): + if 'fc.' in k and self.fc is None: + continue + obj: Tensor = get_child(self, k) if 'fc.' in k and obj.shape != dat.shape: