support all the enet sizes

This commit is contained in:
George Hotz
2020-11-09 18:04:16 -08:00
parent 9db95ab942
commit 8b23033fa9

View File

@@ -5,6 +5,7 @@
import os
GPU = os.getenv("GPU", None) is not None
import sys
import math
import io
import time
import numpy as np
@@ -63,7 +64,34 @@ class MBConvBlock:
return x
class EfficientNet:
def __init__(self):
def __init__(self, number=0):
self.number = number
global_params = [
# width, depth
(1.0, 1.0), # b0
(1.0, 1.1), # b1
(1.1, 1.2), # b2
(1.2, 1.4), # b3
(1.4, 1.8), # b4
(1.6, 2.2), # b5
(1.8, 2.6), # b6
(2.0, 3.1), # b7
(2.2, 3.6), # b8
(4.3, 5.3), # l2
][number]
def round_filters(filters):
multiplier = global_params[0]
divisor = 8
filters *= multiplier
new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
new_filters += divisor
return int(new_filters)
def round_repeats(repeats):
return int(math.ceil(global_params[1] * repeats))
self._conv_stem = Tensor.zeros(32, 3, 3, 3)
self._bn0 = BatchNorm2D(32)
blocks_args = [
@@ -79,13 +107,18 @@ class EfficientNet:
# num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio
for b in blocks_args:
args = b[1:]
for n in range(b[0]):
args[3] = round_filters(args[3])
args[4] = round_filters(args[4])
for n in range(round_repeats(b[0])):
self._blocks.append(MBConvBlock(*args))
args[3] = args[4]
args[1] = (1,1)
self._conv_head = Tensor.zeros(1280, 320, 1, 1)
self._bn1 = BatchNorm2D(1280)
self._fc = Tensor.zeros(1280, 1000)
in_channels = round_filters(320)
out_channels = round_filters(1280)
self._conv_head = Tensor.zeros(out_channels, in_channels, 1, 1)
self._bn1 = BatchNorm2D(out_channels)
self._fc = Tensor.zeros(out_channels, 1000)
self._fc_bias = Tensor.zeros(1000)
def forward(self, x):
@@ -96,14 +129,19 @@ class EfficientNet:
x = block(x)
x = swish(self._bn1(x.conv2d(self._conv_head)))
x = x.avg_pool2d(kernel_size=x.shape[2:4])
x = x.reshape(shape=(-1, 1280))
x = x.reshape(shape=(-1, x.shape[1]))
#x = x.dropout(0.2)
return x.dot(self._fc).add(self._fc_bias.reshape(shape=[1,-1]))
def load_weights_from_torch(self):
# load b0
import torch
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
if self.number == 0:
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
elif self.number == 2:
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth")
else:
raise Exception("no pretrained weights")
b0 = torch.load(io.BytesIO(b0))
for k,v in b0.items():
@@ -162,7 +200,7 @@ def infer(model, img):
if __name__ == "__main__":
# instantiate my net
model = EfficientNet()
model = EfficientNet(0)
model.load_weights_from_torch()
# category labels