mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
if you wait 24 seconds, that gets 98%
This commit is contained in:
@@ -25,14 +25,15 @@ class TinyBobNet:
|
||||
class TinyConvNet:
|
||||
def __init__(self):
|
||||
conv = 7
|
||||
chans = 4
|
||||
chans = 16
|
||||
self.c1 = Tensor(layer_init_uniform(chans,1,conv,conv))
|
||||
self.l1 = Tensor(layer_init_uniform(((28-conv+1)**2)*chans, 128))
|
||||
self.l2 = Tensor(layer_init_uniform(128, 10))
|
||||
|
||||
def forward(self, x):
|
||||
x.data = x.data.reshape((-1, 1, 28, 28)) # hacks
|
||||
x = x.conv2d(self.c1).reshape(Tensor(np.array((x.shape[0], -1)))).relu()
|
||||
x = x.conv2d(self.c1).relu()
|
||||
x = x.reshape(Tensor(np.array((x.shape[0], -1))))
|
||||
return x.dot(self.l1).relu().dot(self.l2).logsoftmax()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user