mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
fix nll loss in example
This commit is contained in:
@@ -12,6 +12,7 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
# train a model
|
||||
|
||||
np.random.seed(1337)
|
||||
def layer_init(m, h):
|
||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||
return ret.astype(np.float32)
|
||||
@@ -27,7 +28,7 @@ class TinyBobNet:
|
||||
# optimizer
|
||||
|
||||
model = TinyBobNet()
|
||||
optim = optim.SGD([model.l1, model.l2], lr=0.01)
|
||||
optim = optim.SGD([model.l1, model.l2], lr=0.001)
|
||||
#optim = optim.Adam([model.l1, model.l2], lr=0.001)
|
||||
|
||||
BS = 128
|
||||
@@ -38,7 +39,8 @@ for i in (t := trange(1000)):
|
||||
x = Tensor(X_train[samp].reshape((-1, 28*28)))
|
||||
Y = Y_train[samp]
|
||||
y = np.zeros((len(samp),10), np.float32)
|
||||
y[range(y.shape[0]),Y] = -1.0
|
||||
# correct loss for NLL, torch NLL loss returns one per row
|
||||
y[range(y.shape[0]),Y] = -10.0
|
||||
y = Tensor(y)
|
||||
|
||||
# network
|
||||
|
||||
Reference in New Issue
Block a user