add support for add

This commit is contained in:
George Hotz
2020-10-18 10:33:12 -07:00
parent 1ea9ab3e9c
commit 472e4592d0
2 changed files with 13 additions and 2 deletions

View File

@@ -78,6 +78,15 @@ class Mul(Function):
return y*grad_output, x*grad_output
register('mul', Mul)
class Add(Function):
@staticmethod
def forward(ctx, x, y):
return x+y
@staticmethod
def backward(ctx, grad_output):
return grad_output, grad_output
register('add', Add)
class ReLU(Function):
@staticmethod

View File

@@ -14,7 +14,8 @@ def test_tinygrad():
outr = out.relu()
outl = outr.logsoftmax()
outm = outl.mul(m)
outx = outm.sum()
outa = outm.add(m)
outx = outa.sum()
outx.backward()
return outx.data, x.grad, W.grad
@@ -26,7 +27,8 @@ def test_pytorch():
outr = out.relu()
outl = torch.nn.functional.log_softmax(outr, dim=1)
outm = outl.mul(m)
outx = outm.sum()
outa = outm.add(m)
outx = outa.sum()
outx.backward()
return outx.detach().numpy(), x.grad, W.grad