we have a test

This commit is contained in:
George Hotz
2020-10-18 10:05:53 -07:00
parent fb1004103d
commit 5939427795
2 changed files with 47 additions and 6 deletions

View File

@@ -61,6 +61,19 @@ def register(name, fxn):
setattr(Tensor, name, partialmethod(fxn.apply, fxn))
# **** implement a few functions ****
class Mul(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x*y
@staticmethod
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
return y*grad_output, x*grad_output
register('mul', Mul)
class ReLU(Function):
@staticmethod
@@ -108,7 +121,7 @@ class LogSoftmax(Function):
def logsumexp(x):
c = x.max(axis=1)
return c + np.log(np.exp(x-c.reshape((-1, 1))).sum(axis=1))
output = input - logsumexp(input)
output = input - logsumexp(input).reshape((-1, 1))
ctx.save_for_backward(output)
return output

38
test.py
View File

@@ -1,10 +1,38 @@
import numpy as np
import torch
from tensor import Tensor
x = np.random.randn(1,3)
W = np.random.randn(3,3)
out = x.dot(W)
print(out)
x_init = np.random.randn(1,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)
def test_tinygrad():
x = Tensor(x_init)
W = Tensor(W_init)
m = Tensor(m_init)
out = x.dot(W)
outr = out.relu()
outl = outr.logsoftmax()
outm = outl.mul(m)
outx = outm.sum()
outx.backward()
return outx.data, x.grad, W.grad
def test_pytorch():
x = torch.tensor(x_init, requires_grad=True)
W = torch.tensor(W_init, requires_grad=True)
m = torch.tensor(m_init)
out = x.matmul(W)
outr = out.relu()
outl = torch.nn.functional.log_softmax(outr, dim=1)
outm = outl.mul(m)
outx = outm.sum()
outx.backward()
return outx.detach().numpy(), x.grad, W.grad
for x,y in zip(test_tinygrad(), test_pytorch()):
print(x,y)
np.testing.assert_allclose(x, y, atol=1e-6)