mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
we have a test
This commit is contained in:
15
tensor.py
15
tensor.py
@@ -61,6 +61,19 @@ def register(name, fxn):
|
||||
setattr(Tensor, name, partialmethod(fxn.apply, fxn))
|
||||
|
||||
# **** implement a few functions ****
|
||||
|
||||
class Mul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return x*y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
return y*grad_output, x*grad_output
|
||||
register('mul', Mul)
|
||||
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
@@ -108,7 +121,7 @@ class LogSoftmax(Function):
|
||||
def logsumexp(x):
|
||||
c = x.max(axis=1)
|
||||
return c + np.log(np.exp(x-c.reshape((-1, 1))).sum(axis=1))
|
||||
output = input - logsumexp(input)
|
||||
output = input - logsumexp(input).reshape((-1, 1))
|
||||
ctx.save_for_backward(output)
|
||||
return output
|
||||
|
||||
|
||||
38
test.py
38
test.py
@@ -1,10 +1,38 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensor import Tensor
|
||||
|
||||
x = np.random.randn(1,3)
|
||||
W = np.random.randn(3,3)
|
||||
out = x.dot(W)
|
||||
print(out)
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
W_init = np.random.randn(3,3).astype(np.float32)
|
||||
m_init = np.random.randn(1,3).astype(np.float32)
|
||||
|
||||
def test_tinygrad():
|
||||
x = Tensor(x_init)
|
||||
W = Tensor(W_init)
|
||||
m = Tensor(m_init)
|
||||
out = x.dot(W)
|
||||
outr = out.relu()
|
||||
outl = outr.logsoftmax()
|
||||
outm = outl.mul(m)
|
||||
outx = outm.sum()
|
||||
outx.backward()
|
||||
return outx.data, x.grad, W.grad
|
||||
|
||||
def test_pytorch():
|
||||
x = torch.tensor(x_init, requires_grad=True)
|
||||
W = torch.tensor(W_init, requires_grad=True)
|
||||
m = torch.tensor(m_init)
|
||||
out = x.matmul(W)
|
||||
outr = out.relu()
|
||||
outl = torch.nn.functional.log_softmax(outr, dim=1)
|
||||
outm = outl.mul(m)
|
||||
outx = outm.sum()
|
||||
outx.backward()
|
||||
return outx.detach().numpy(), x.grad, W.grad
|
||||
|
||||
for x,y in zip(test_tinygrad(), test_pytorch()):
|
||||
print(x,y)
|
||||
np.testing.assert_allclose(x, y, atol=1e-6)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user