From 4eabe677ed7944682ab82102fb52067a0a96ee9c Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 15 Jan 2022 20:21:02 -0800 Subject: [PATCH] fixup resnet --- examples/train_resnet.py | 2 +- models/resnet.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/examples/train_resnet.py b/examples/train_resnet.py index 8f8a25eff4..a73a98e40d 100755 --- a/examples/train_resnet.py +++ b/examples/train_resnet.py @@ -7,7 +7,7 @@ from PIL import Image from tinygrad.tensor import Device from extra.utils import get_parameters from extra.training import train, evaluate -from models.resnet import ResNet18, ResNet34, ResNet50 +from models.resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 from tinygrad.optim import Adam from test.test_mnist import fetch_mnist diff --git a/models/resnet.py b/models/resnet.py index b023c88542..9eebf5e42a 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -52,7 +52,7 @@ class Bottleneck: return out class ResNet: - def __init__(self, block, num_blocks, num_classes=10, url=None): + def __init__(self, block, num_blocks, num_classes=10, url=None, pretrained=False): self.url = url self.in_planes = 64 @@ -64,6 +64,9 @@ class ResNet: self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.fc = {"weight": Tensor.uniform(512 * block.expansion, num_classes), "bias": Tensor.zeros(num_classes)} + if pretrained: + self.load_from_pretrained() + def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks-1) layers = [] @@ -92,12 +95,13 @@ class ResNet: for k, v in state_dict.items(): obj = get_child(self, k) dat = v.detach().numpy().T if "fc.weight" in k else v.detach().numpy() - assert obj.shape == dat.shape - obj.assign(dat) + assert obj.shape == dat.shape or k.startswith("fc.") + if obj.shape == dat.shape: + obj.assign(dat) -ResNet18 = lambda: ResNet(BasicBlock, [2,2,2,2], 1000, 'https://download.pytorch.org/models/resnet18-5c106cde.pth') -ResNet34 = lambda: ResNet(BasicBlock, [3,4,6,3], 1000, 'https://download.pytorch.org/models/resnet34-333f7ec4.pth') -ResNet50 = lambda: ResNet(Bottleneck, [3,4,6,3], 1000, 'https://download.pytorch.org/models/resnet50-19c8e357.pth') -ResNet101 = lambda: ResNet(Bottleneck, [3,4,23,3], 1000, 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') -ResNet101 = lambda: ResNet(Bottleneck, [3,8,36,3], 1000, 'https://download.pytorch.org/models/resnet152-b121ed2d.pth') +ResNet18 = lambda num_classes=1000, pretrained=False: ResNet(BasicBlock, [2,2,2,2], num_classes, 'https://download.pytorch.org/models/resnet18-5c106cde.pth', pretrained=pretrained) +ResNet34 = lambda num_classes=1000, pretrained=False: ResNet(BasicBlock, [3,4,6,3], num_classes, 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', pretrained=pretrained) +ResNet50 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,4,6,3], num_classes, 'https://download.pytorch.org/models/resnet50-19c8e357.pth', pretrained=pretrained) +ResNet101 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,4,23,3], num_classes, 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', pretrained=pretrained) +ResNet152 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,8,36,3], num_classes, 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', pretrained=pretrained)