mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
simple tests, repr not str
This commit is contained in:
@@ -3,17 +3,13 @@ import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.utils import layer_init_uniform, fetch_mnist
|
||||
import tinygrad.optim as optim
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
# load the mnist dataset
|
||||
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
# train a model
|
||||
|
||||
np.random.seed(1337)
|
||||
|
||||
# load the mnist dataset
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
# create a model
|
||||
class TinyBobNet:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor(layer_init_uniform(784, 128))
|
||||
@@ -22,8 +18,6 @@ class TinyBobNet:
|
||||
def forward(self, x):
|
||||
return x.dot(self.l1).relu().dot(self.l2).logsoftmax()
|
||||
|
||||
# optimizer
|
||||
|
||||
model = TinyBobNet()
|
||||
optim = optim.SGD([model.l1, model.l2], lr=0.001)
|
||||
#optim = optim.Adam([model.l1, model.l2], lr=0.001)
|
||||
|
||||
28
test/test.py
28
test/test.py
@@ -10,31 +10,25 @@ def test_tinygrad():
|
||||
x = Tensor(x_init)
|
||||
W = Tensor(W_init)
|
||||
m = Tensor(m_init)
|
||||
out = x.dot(W)
|
||||
outr = out.relu()
|
||||
outl = outr.logsoftmax()
|
||||
outm = outl.mul(m)
|
||||
outa = outm.add(m)
|
||||
outx = outa.sum()
|
||||
outx.backward()
|
||||
return outx.data, x.grad, W.grad
|
||||
out = x.dot(W).relu()
|
||||
out = out.logsoftmax()
|
||||
out = out.mul(m).add(m).sum()
|
||||
out.backward()
|
||||
return out.data, x.grad, W.grad
|
||||
|
||||
def test_pytorch():
|
||||
x = torch.tensor(x_init, requires_grad=True)
|
||||
W = torch.tensor(W_init, requires_grad=True)
|
||||
m = torch.tensor(m_init)
|
||||
out = x.matmul(W)
|
||||
outr = out.relu()
|
||||
outl = torch.nn.functional.log_softmax(outr, dim=1)
|
||||
outm = outl.mul(m)
|
||||
outa = outm.add(m)
|
||||
outx = outa.sum()
|
||||
outx.backward()
|
||||
return outx.detach().numpy(), x.grad, W.grad
|
||||
out = x.matmul(W).relu()
|
||||
out = torch.nn.functional.log_softmax(out, dim=1)
|
||||
out = out.mul(m).add(m).sum()
|
||||
out.backward()
|
||||
return out.detach().numpy(), x.grad, W.grad
|
||||
|
||||
for x,y in zip(test_tinygrad(), test_pytorch()):
|
||||
print(x,y)
|
||||
np.testing.assert_allclose(x, y, atol=1e-6)
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class Tensor:
|
||||
# internal variables used for autograd graph construction
|
||||
self._ctx = None
|
||||
|
||||
def __str__(self):
|
||||
def __repr__(self):
|
||||
return "Tensor %r with grad %r" % (self.data, self.grad)
|
||||
|
||||
def backward(self, allow_fill=True):
|
||||
|
||||
Reference in New Issue
Block a user