mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
trainer works with CIFAR
This commit is contained in:
@@ -116,7 +116,7 @@ class MBConvBlock:
|
||||
return x
|
||||
|
||||
class EfficientNet:
|
||||
def __init__(self, number=0):
|
||||
def __init__(self, number=0, classes=1000):
|
||||
self.number = number
|
||||
global_params = [
|
||||
# width, depth
|
||||
@@ -171,8 +171,8 @@ class EfficientNet:
|
||||
out_channels = round_filters(1280)
|
||||
self._conv_head = Tensor.zeros(out_channels, in_channels, 1, 1)
|
||||
self._bn1 = BatchNorm2D(out_channels)
|
||||
self._fc = Tensor.zeros(out_channels, 1000)
|
||||
self._fc_bias = Tensor.zeros(1000)
|
||||
self._fc = Tensor.zeros(out_channels, classes)
|
||||
self._fc_bias = Tensor.zeros(classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.pad2d(padding=(0,1,0,1))
|
||||
|
||||
Reference in New Issue
Block a user