fixup resnet

This commit is contained in:
George Hotz
2022-01-15 20:21:02 -08:00
parent e0bef0bd01
commit 4eabe677ed
2 changed files with 13 additions and 9 deletions

View File

@@ -7,7 +7,7 @@ from PIL import Image
from tinygrad.tensor import Device
from extra.utils import get_parameters
from extra.training import train, evaluate
from models.resnet import ResNet18, ResNet34, ResNet50
from models.resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from tinygrad.optim import Adam
from test.test_mnist import fetch_mnist

View File

@@ -52,7 +52,7 @@ class Bottleneck:
return out
class ResNet:
def __init__(self, block, num_blocks, num_classes=10, url=None):
def __init__(self, block, num_blocks, num_classes=10, url=None, pretrained=False):
self.url = url
self.in_planes = 64
@@ -64,6 +64,9 @@ class ResNet:
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.fc = {"weight": Tensor.uniform(512 * block.expansion, num_classes), "bias": Tensor.zeros(num_classes)}
if pretrained:
self.load_from_pretrained()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks-1)
layers = []
@@ -92,12 +95,13 @@ class ResNet:
for k, v in state_dict.items():
obj = get_child(self, k)
dat = v.detach().numpy().T if "fc.weight" in k else v.detach().numpy()
assert obj.shape == dat.shape
obj.assign(dat)
assert obj.shape == dat.shape or k.startswith("fc.")
if obj.shape == dat.shape:
obj.assign(dat)
ResNet18 = lambda: ResNet(BasicBlock, [2,2,2,2], 1000, 'https://download.pytorch.org/models/resnet18-5c106cde.pth')
ResNet34 = lambda: ResNet(BasicBlock, [3,4,6,3], 1000, 'https://download.pytorch.org/models/resnet34-333f7ec4.pth')
ResNet50 = lambda: ResNet(Bottleneck, [3,4,6,3], 1000, 'https://download.pytorch.org/models/resnet50-19c8e357.pth')
ResNet101 = lambda: ResNet(Bottleneck, [3,4,23,3], 1000, 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
ResNet101 = lambda: ResNet(Bottleneck, [3,8,36,3], 1000, 'https://download.pytorch.org/models/resnet152-b121ed2d.pth')
ResNet18 = lambda num_classes=1000, pretrained=False: ResNet(BasicBlock, [2,2,2,2], num_classes, 'https://download.pytorch.org/models/resnet18-5c106cde.pth', pretrained=pretrained)
ResNet34 = lambda num_classes=1000, pretrained=False: ResNet(BasicBlock, [3,4,6,3], num_classes, 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', pretrained=pretrained)
ResNet50 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,4,6,3], num_classes, 'https://download.pytorch.org/models/resnet50-19c8e357.pth', pretrained=pretrained)
ResNet101 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,4,23,3], num_classes, 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', pretrained=pretrained)
ResNet152 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,8,36,3], num_classes, 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', pretrained=pretrained)