mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
make training work
This commit is contained in:
@@ -16,7 +16,7 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
|
||||
loss *= alpha_t
|
||||
loss = loss * alpha_t
|
||||
|
||||
if reduction == "mean": loss = loss.mean()
|
||||
elif reduction == "sum": loss = loss.sum()
|
||||
|
||||
@@ -408,7 +408,7 @@ def train_retinanet():
|
||||
resnet.BatchNorm = FrozenBatchNorm2d
|
||||
|
||||
# ** model setup **
|
||||
backbone = resnet.ResNeXt50_32X4D(num_classes=NUM_CLASSES)
|
||||
backbone = resnet.ResNeXt50_32X4D(num_classes=None)
|
||||
loaded_keys = backbone.load_from_pretrained()
|
||||
_freeze_backbone_layers(backbone, 3, loaded_keys)
|
||||
|
||||
|
||||
@@ -138,6 +138,9 @@ class ResNet:
|
||||
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
||||
loaded_keys = []
|
||||
for k, dat in torch_load(fetch(self.url)).items():
|
||||
if 'fc.' in k and self.fc is None:
|
||||
continue
|
||||
|
||||
obj: Tensor = get_child(self, k)
|
||||
|
||||
if 'fc.' in k and obj.shape != dat.shape:
|
||||
|
||||
Reference in New Issue
Block a user