fix nll loss in example

This commit is contained in:
George Hotz
2020-10-18 14:27:29 -07:00
parent 28c9d31e49
commit a139f34bb6

View File

@@ -12,6 +12,7 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
# train a model
np.random.seed(1337)
def layer_init(m, h):
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
return ret.astype(np.float32)
@@ -27,7 +28,7 @@ class TinyBobNet:
# optimizer
model = TinyBobNet()
optim = optim.SGD([model.l1, model.l2], lr=0.01)
optim = optim.SGD([model.l1, model.l2], lr=0.001)
#optim = optim.Adam([model.l1, model.l2], lr=0.001)
BS = 128
@@ -38,7 +39,8 @@ for i in (t := trange(1000)):
x = Tensor(X_train[samp].reshape((-1, 28*28)))
Y = Y_train[samp]
y = np.zeros((len(samp),10), np.float32)
y[range(y.shape[0]),Y] = -1.0
# correct loss for NLL, torch NLL loss returns one per row
y[range(y.shape[0]),Y] = -10.0
y = Tensor(y)
# network