trainer works with CIFAR

This commit is contained in:
George Hotz
2020-12-06 12:20:14 -08:00
parent 80a9c777ba
commit 609d11e699
2 changed files with 58 additions and 20 deletions

View File

@@ -116,7 +116,7 @@ class MBConvBlock:
return x
class EfficientNet:
def __init__(self, number=0):
def __init__(self, number=0, classes=1000):
self.number = number
global_params = [
# width, depth
@@ -171,8 +171,8 @@ class EfficientNet:
out_channels = round_filters(1280)
self._conv_head = Tensor.zeros(out_channels, in_channels, 1, 1)
self._bn1 = BatchNorm2D(out_channels)
self._fc = Tensor.zeros(out_channels, 1000)
self._fc_bias = Tensor.zeros(1000)
self._fc = Tensor.zeros(out_channels, classes)
self._fc_bias = Tensor.zeros(classes)
def forward(self, x):
x = x.pad2d(padding=(0,1,0,1))