mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
support all the enet sizes
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
import os
|
||||
GPU = os.getenv("GPU", None) is not None
|
||||
import sys
|
||||
import math
|
||||
import io
|
||||
import time
|
||||
import numpy as np
|
||||
@@ -63,7 +64,34 @@ class MBConvBlock:
|
||||
return x
|
||||
|
||||
class EfficientNet:
|
||||
def __init__(self):
|
||||
def __init__(self, number=0):
|
||||
self.number = number
|
||||
global_params = [
|
||||
# width, depth
|
||||
(1.0, 1.0), # b0
|
||||
(1.0, 1.1), # b1
|
||||
(1.1, 1.2), # b2
|
||||
(1.2, 1.4), # b3
|
||||
(1.4, 1.8), # b4
|
||||
(1.6, 2.2), # b5
|
||||
(1.8, 2.6), # b6
|
||||
(2.0, 3.1), # b7
|
||||
(2.2, 3.6), # b8
|
||||
(4.3, 5.3), # l2
|
||||
][number]
|
||||
|
||||
def round_filters(filters):
|
||||
multiplier = global_params[0]
|
||||
divisor = 8
|
||||
filters *= multiplier
|
||||
new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
|
||||
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
||||
new_filters += divisor
|
||||
return int(new_filters)
|
||||
|
||||
def round_repeats(repeats):
|
||||
return int(math.ceil(global_params[1] * repeats))
|
||||
|
||||
self._conv_stem = Tensor.zeros(32, 3, 3, 3)
|
||||
self._bn0 = BatchNorm2D(32)
|
||||
blocks_args = [
|
||||
@@ -79,13 +107,18 @@ class EfficientNet:
|
||||
# num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio
|
||||
for b in blocks_args:
|
||||
args = b[1:]
|
||||
for n in range(b[0]):
|
||||
args[3] = round_filters(args[3])
|
||||
args[4] = round_filters(args[4])
|
||||
for n in range(round_repeats(b[0])):
|
||||
self._blocks.append(MBConvBlock(*args))
|
||||
args[3] = args[4]
|
||||
args[1] = (1,1)
|
||||
self._conv_head = Tensor.zeros(1280, 320, 1, 1)
|
||||
self._bn1 = BatchNorm2D(1280)
|
||||
self._fc = Tensor.zeros(1280, 1000)
|
||||
|
||||
in_channels = round_filters(320)
|
||||
out_channels = round_filters(1280)
|
||||
self._conv_head = Tensor.zeros(out_channels, in_channels, 1, 1)
|
||||
self._bn1 = BatchNorm2D(out_channels)
|
||||
self._fc = Tensor.zeros(out_channels, 1000)
|
||||
self._fc_bias = Tensor.zeros(1000)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -96,14 +129,19 @@ class EfficientNet:
|
||||
x = block(x)
|
||||
x = swish(self._bn1(x.conv2d(self._conv_head)))
|
||||
x = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x = x.reshape(shape=(-1, 1280))
|
||||
x = x.reshape(shape=(-1, x.shape[1]))
|
||||
#x = x.dropout(0.2)
|
||||
return x.dot(self._fc).add(self._fc_bias.reshape(shape=[1,-1]))
|
||||
|
||||
def load_weights_from_torch(self):
|
||||
# load b0
|
||||
import torch
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
if self.number == 0:
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
elif self.number == 2:
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth")
|
||||
else:
|
||||
raise Exception("no pretrained weights")
|
||||
b0 = torch.load(io.BytesIO(b0))
|
||||
|
||||
for k,v in b0.items():
|
||||
@@ -162,7 +200,7 @@ def infer(model, img):
|
||||
|
||||
if __name__ == "__main__":
|
||||
# instantiate my net
|
||||
model = EfficientNet()
|
||||
model = EfficientNet(0)
|
||||
model.load_weights_from_torch()
|
||||
|
||||
# category labels
|
||||
|
||||
Reference in New Issue
Block a user