fix enet init

This commit is contained in:
George Hotz
2020-12-06 13:52:07 -08:00
parent 3b982f2f7a
commit da514c2918
2 changed files with 10 additions and 10 deletions

View File

@@ -73,7 +73,7 @@ class MBConvBlock:
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se):
oup = expand_ratio * input_filters
if expand_ratio != 1:
self._expand_conv = Tensor.zeros(oup, input_filters, 1, 1)
self._expand_conv = Tensor.uniform(oup, input_filters, 1, 1)
self._bn0 = BatchNorm2D(oup)
else:
self._expand_conv = None
@@ -84,18 +84,18 @@ class MBConvBlock:
else:
self.pad = [(kernel_size-1)//2]*4
self._depthwise_conv = Tensor.zeros(oup, 1, kernel_size, kernel_size)
self._depthwise_conv = Tensor.uniform(oup, 1, kernel_size, kernel_size)
self._bn1 = BatchNorm2D(oup)
self.has_se = has_se
if self.has_se:
num_squeezed_channels = max(1, int(input_filters * se_ratio))
self._se_reduce = Tensor.zeros(num_squeezed_channels, oup, 1, 1)
self._se_reduce = Tensor.uniform(num_squeezed_channels, oup, 1, 1)
self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
self._se_expand = Tensor.zeros(oup, num_squeezed_channels, 1, 1)
self._se_expand = Tensor.uniform(oup, num_squeezed_channels, 1, 1)
self._se_expand_bias = Tensor.zeros(oup)
self._project_conv = Tensor.zeros(output_filters, oup, 1, 1)
self._project_conv = Tensor.uniform(output_filters, oup, 1, 1)
self._bn2 = BatchNorm2D(output_filters)
def __call__(self, inputs):
@@ -148,7 +148,7 @@ class EfficientNet:
return int(math.ceil(global_params[1] * repeats))
out_channels = round_filters(32)
self._conv_stem = Tensor.zeros(out_channels, 3, 3, 3)
self._conv_stem = Tensor.uniform(out_channels, 3, 3, 3)
self._bn0 = BatchNorm2D(out_channels)
blocks_args = [
[1, 3, (1,1), 1, 32, 16, 0.25],
@@ -172,9 +172,9 @@ class EfficientNet:
in_channels = round_filters(320)
out_channels = round_filters(1280)
self._conv_head = Tensor.zeros(out_channels, in_channels, 1, 1)
self._conv_head = Tensor.uniform(out_channels, in_channels, 1, 1)
self._bn1 = BatchNorm2D(out_channels)
self._fc = Tensor.zeros(out_channels, classes)
self._fc = Tensor.uniform(out_channels, classes)
self._fc_bias = Tensor.zeros(classes)
def forward(self, x):

View File

@@ -3,12 +3,12 @@ from tinygrad.tensor import Tensor
class BatchNorm2D:
def __init__(self, sz, eps=0.001):
self.eps = eps
self.weight = Tensor.zeros(sz)
self.weight = Tensor.ones(sz)
self.bias = Tensor.zeros(sz)
# TODO: need running_mean and running_var
self.running_mean = Tensor.zeros(sz)
self.running_var = Tensor.zeros(sz)
self.running_var = Tensor.ones(sz)
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
def __call__(self, x):