diff --git a/tensor.py b/tensor.py index 1eb222a93b..e672241bfd 100644 --- a/tensor.py +++ b/tensor.py @@ -78,6 +78,15 @@ class Mul(Function): return y*grad_output, x*grad_output register('mul', Mul) +class Add(Function): + @staticmethod + def forward(ctx, x, y): + return x+y + + @staticmethod + def backward(ctx, grad_output): + return grad_output, grad_output +register('add', Add) class ReLU(Function): @staticmethod diff --git a/test.py b/test.py index 31d4da3390..558d3794e8 100644 --- a/test.py +++ b/test.py @@ -14,7 +14,8 @@ def test_tinygrad(): outr = out.relu() outl = outr.logsoftmax() outm = outl.mul(m) - outx = outm.sum() + outa = outm.add(m) + outx = outa.sum() outx.backward() return outx.data, x.grad, W.grad @@ -26,7 +27,8 @@ def test_pytorch(): outr = out.relu() outl = torch.nn.functional.log_softmax(outr, dim=1) outm = outl.mul(m) - outx = outm.sum() + outa = outm.add(m) + outx = outa.sum() outx.backward() return outx.detach().numpy(), x.grad, W.grad