write sgd class

This commit is contained in:
George Hotz
2020-10-18 13:27:59 -07:00
parent f4e0cb5945
commit 118c2eebe3
2 changed files with 17 additions and 7 deletions

View File

@@ -35,9 +35,20 @@ class TinyBobNet:
def forward(self, x):
return x.dot(self.l1).relu().dot(self.l2).logsoftmax()
model = TinyBobNet()
# optimizer
class SGD:
def __init__(self, tensors, lr):
self.tensors = tensors
self.lr = lr
def step(self):
for t in self.tensors:
t.data -= self.lr * t.grad
model = TinyBobNet()
optim = SGD([model.l1, model.l2], lr=0.01)
lr = 0.01
BS = 128
losses, accuracies = [], []
for i in (t := trange(1000)):
@@ -55,13 +66,11 @@ for i in (t := trange(1000)):
# NLL loss function
loss = outs.mul(y).mean()
loss.backward()
optim.step()
cat = np.argmax(outs.data, axis=1)
accuracy = (cat == Y).mean()
# SGD
model.l1.data = model.l1.data - lr*model.l1.grad
model.l2.data = model.l2.data - lr*model.l2.grad
# printing
loss = loss.data

View File

@@ -37,7 +37,8 @@ class Tensor:
grads = [grads]
for t,g in zip(self._ctx.parents, grads):
if g.shape != t.data.shape:
print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx, g.shape, t.data.shape))
print("grad shape must match tensor shape in %r, %r != %r" %
(self._ctx, g.shape, t.data.shape))
assert(False)
t.grad = g
t.backward(False)
@@ -46,7 +47,7 @@ class Tensor:
div = Tensor(np.array([1/self.data.size]))
return self.sum().mul(div)
# The Function is the Context
# An instantiation of the Function is the Context
class Function:
def __init__(self, *tensors):
self.parents = tensors