mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
write sgd class
This commit is contained in:
@@ -35,9 +35,20 @@ class TinyBobNet:
|
||||
def forward(self, x):
|
||||
return x.dot(self.l1).relu().dot(self.l2).logsoftmax()
|
||||
|
||||
model = TinyBobNet()
|
||||
# optimizer
|
||||
|
||||
class SGD:
|
||||
def __init__(self, tensors, lr):
|
||||
self.tensors = tensors
|
||||
self.lr = lr
|
||||
|
||||
def step(self):
|
||||
for t in self.tensors:
|
||||
t.data -= self.lr * t.grad
|
||||
|
||||
model = TinyBobNet()
|
||||
optim = SGD([model.l1, model.l2], lr=0.01)
|
||||
|
||||
lr = 0.01
|
||||
BS = 128
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(1000)):
|
||||
@@ -55,13 +66,11 @@ for i in (t := trange(1000)):
|
||||
# NLL loss function
|
||||
loss = outs.mul(y).mean()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
cat = np.argmax(outs.data, axis=1)
|
||||
accuracy = (cat == Y).mean()
|
||||
|
||||
# SGD
|
||||
model.l1.data = model.l1.data - lr*model.l1.grad
|
||||
model.l2.data = model.l2.data - lr*model.l2.grad
|
||||
|
||||
# printing
|
||||
loss = loss.data
|
||||
|
||||
@@ -37,7 +37,8 @@ class Tensor:
|
||||
grads = [grads]
|
||||
for t,g in zip(self._ctx.parents, grads):
|
||||
if g.shape != t.data.shape:
|
||||
print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx, g.shape, t.data.shape))
|
||||
print("grad shape must match tensor shape in %r, %r != %r" %
|
||||
(self._ctx, g.shape, t.data.shape))
|
||||
assert(False)
|
||||
t.grad = g
|
||||
t.backward(False)
|
||||
@@ -46,7 +47,7 @@ class Tensor:
|
||||
div = Tensor(np.array([1/self.data.size]))
|
||||
return self.sum().mul(div)
|
||||
|
||||
# The Function is the Context
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
def __init__(self, *tensors):
|
||||
self.parents = tensors
|
||||
|
||||
Reference in New Issue
Block a user