mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix nll loss in example
This commit is contained in:
@@ -12,6 +12,7 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
|
|||||||
|
|
||||||
# train a model
|
# train a model
|
||||||
|
|
||||||
|
np.random.seed(1337)
|
||||||
def layer_init(m, h):
|
def layer_init(m, h):
|
||||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||||
return ret.astype(np.float32)
|
return ret.astype(np.float32)
|
||||||
@@ -27,7 +28,7 @@ class TinyBobNet:
|
|||||||
# optimizer
|
# optimizer
|
||||||
|
|
||||||
model = TinyBobNet()
|
model = TinyBobNet()
|
||||||
optim = optim.SGD([model.l1, model.l2], lr=0.01)
|
optim = optim.SGD([model.l1, model.l2], lr=0.001)
|
||||||
#optim = optim.Adam([model.l1, model.l2], lr=0.001)
|
#optim = optim.Adam([model.l1, model.l2], lr=0.001)
|
||||||
|
|
||||||
BS = 128
|
BS = 128
|
||||||
@@ -38,7 +39,8 @@ for i in (t := trange(1000)):
|
|||||||
x = Tensor(X_train[samp].reshape((-1, 28*28)))
|
x = Tensor(X_train[samp].reshape((-1, 28*28)))
|
||||||
Y = Y_train[samp]
|
Y = Y_train[samp]
|
||||||
y = np.zeros((len(samp),10), np.float32)
|
y = np.zeros((len(samp),10), np.float32)
|
||||||
y[range(y.shape[0]),Y] = -1.0
|
# correct loss for NLL, torch NLL loss returns one per row
|
||||||
|
y[range(y.shape[0]),Y] = -10.0
|
||||||
y = Tensor(y)
|
y = Tensor(y)
|
||||||
|
|
||||||
# network
|
# network
|
||||||
|
|||||||
Reference in New Issue
Block a user