diff --git a/tensor.py b/tensor.py index d69028c04e..2a1ba6541c 100644 --- a/tensor.py +++ b/tensor.py @@ -61,6 +61,19 @@ def register(name, fxn): setattr(Tensor, name, partialmethod(fxn.apply, fxn)) # **** implement a few functions **** + +class Mul(Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return x*y + + @staticmethod + def backward(ctx, grad_output): + x,y = ctx.saved_tensors + return y*grad_output, x*grad_output +register('mul', Mul) + class ReLU(Function): @staticmethod @@ -108,7 +121,7 @@ class LogSoftmax(Function): def logsumexp(x): c = x.max(axis=1) return c + np.log(np.exp(x-c.reshape((-1, 1))).sum(axis=1)) - output = input - logsumexp(input) + output = input - logsumexp(input).reshape((-1, 1)) ctx.save_for_backward(output) return output diff --git a/test.py b/test.py index c324238566..31d4da3390 100644 --- a/test.py +++ b/test.py @@ -1,10 +1,38 @@ +import numpy as np import torch from tensor import Tensor -x = np.random.randn(1,3) -W = np.random.randn(3,3) -out = x.dot(W) -print(out) - +x_init = np.random.randn(1,3).astype(np.float32) +W_init = np.random.randn(3,3).astype(np.float32) +m_init = np.random.randn(1,3).astype(np.float32) + +def test_tinygrad(): + x = Tensor(x_init) + W = Tensor(W_init) + m = Tensor(m_init) + out = x.dot(W) + outr = out.relu() + outl = outr.logsoftmax() + outm = outl.mul(m) + outx = outm.sum() + outx.backward() + return outx.data, x.grad, W.grad + +def test_pytorch(): + x = torch.tensor(x_init, requires_grad=True) + W = torch.tensor(W_init, requires_grad=True) + m = torch.tensor(m_init) + out = x.matmul(W) + outr = out.relu() + outl = torch.nn.functional.log_softmax(outr, dim=1) + outm = outl.mul(m) + outx = outm.sum() + outx.backward() + return outx.detach().numpy(), x.grad, W.grad + +for x,y in zip(test_tinygrad(), test_pytorch()): + print(x,y) + np.testing.assert_allclose(x, y, atol=1e-6) +