mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fixup resnet
This commit is contained in:
@@ -7,7 +7,7 @@ from PIL import Image
|
||||
from tinygrad.tensor import Device
|
||||
from extra.utils import get_parameters
|
||||
from extra.training import train, evaluate
|
||||
from models.resnet import ResNet18, ResNet34, ResNet50
|
||||
from models.resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
|
||||
from tinygrad.optim import Adam
|
||||
from test.test_mnist import fetch_mnist
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class Bottleneck:
|
||||
return out
|
||||
|
||||
class ResNet:
|
||||
def __init__(self, block, num_blocks, num_classes=10, url=None):
|
||||
def __init__(self, block, num_blocks, num_classes=10, url=None, pretrained=False):
|
||||
self.url = url
|
||||
self.in_planes = 64
|
||||
|
||||
@@ -64,6 +64,9 @@ class ResNet:
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.fc = {"weight": Tensor.uniform(512 * block.expansion, num_classes), "bias": Tensor.zeros(num_classes)}
|
||||
|
||||
if pretrained:
|
||||
self.load_from_pretrained()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks-1)
|
||||
layers = []
|
||||
@@ -92,12 +95,13 @@ class ResNet:
|
||||
for k, v in state_dict.items():
|
||||
obj = get_child(self, k)
|
||||
dat = v.detach().numpy().T if "fc.weight" in k else v.detach().numpy()
|
||||
assert obj.shape == dat.shape
|
||||
obj.assign(dat)
|
||||
assert obj.shape == dat.shape or k.startswith("fc.")
|
||||
if obj.shape == dat.shape:
|
||||
obj.assign(dat)
|
||||
|
||||
ResNet18 = lambda: ResNet(BasicBlock, [2,2,2,2], 1000, 'https://download.pytorch.org/models/resnet18-5c106cde.pth')
|
||||
ResNet34 = lambda: ResNet(BasicBlock, [3,4,6,3], 1000, 'https://download.pytorch.org/models/resnet34-333f7ec4.pth')
|
||||
ResNet50 = lambda: ResNet(Bottleneck, [3,4,6,3], 1000, 'https://download.pytorch.org/models/resnet50-19c8e357.pth')
|
||||
ResNet101 = lambda: ResNet(Bottleneck, [3,4,23,3], 1000, 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
|
||||
ResNet101 = lambda: ResNet(Bottleneck, [3,8,36,3], 1000, 'https://download.pytorch.org/models/resnet152-b121ed2d.pth')
|
||||
ResNet18 = lambda num_classes=1000, pretrained=False: ResNet(BasicBlock, [2,2,2,2], num_classes, 'https://download.pytorch.org/models/resnet18-5c106cde.pth', pretrained=pretrained)
|
||||
ResNet34 = lambda num_classes=1000, pretrained=False: ResNet(BasicBlock, [3,4,6,3], num_classes, 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', pretrained=pretrained)
|
||||
ResNet50 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,4,6,3], num_classes, 'https://download.pytorch.org/models/resnet50-19c8e357.pth', pretrained=pretrained)
|
||||
ResNet101 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,4,23,3], num_classes, 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', pretrained=pretrained)
|
||||
ResNet152 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,8,36,3], num_classes, 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', pretrained=pretrained)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user