Test split (#231)

* Split tests

Split tests into "Test CPU" and "Test GPU".

Add test flag "TEST_DEVICES" which is a comma separated list of devices:
CPU,GPU,ANE

* Run tests based on provided TEST_DEVICES flag

By default will run all "CPU,GPU,ANE"

* fix bad quote

* Revert changes and use GPU=1

This is done through setting the default Tensor Device to Device.CPU of
GPU=1 is set.

Run GPU tests: GPU=1 pytest -s -v
This commit is contained in:
Liam
2021-01-01 15:19:03 +01:00
committed by GitHub
parent 4a7cf2e420
commit ebd72ff437
10 changed files with 137 additions and 213 deletions

View File

@@ -15,24 +15,6 @@ jobs:
- name: Check <1000 lines
run: sloccount tinygrad test examples; if [ $(sloccount tinygrad | sed -n 's/.*Total Physical Source Lines of Code (SLOC)[ ]*= \([^ ]*\).*/\1/p' | tr -d ',') -gt 1000 ]; then exit 1; fi
test:
name: Test
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v2
- name: Install OpenCL
run: sudo apt-get install pocl-opencl-icd
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install Dependencies
run: pip install -e '.[gpu,testing]'
- name: Run Pytest
run: python -m pytest -s -v
linter:
name: Indentation Linter
runs-on: ubuntu-latest
@@ -53,3 +35,41 @@ jobs:
run: |
python -m pylint --disable=all -e W0311 --jobs=0 --indent-string=' ' **/*.py
testcpu:
name: CPU Tests
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v2
- name: Install OpenCL
run: sudo apt-get install pocl-opencl-icd
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install Dependencies
run: pip install -e '.[testing]'
- name: Run Pytest
run: python -m pytest -s -v
testgpu:
name: GPU Tests
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v2
- name: Install OpenCL
run: sudo apt-get install pocl-opencl-icd
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install Dependencies
run: pip install -e '.[gpu,testing]'
- name: Run Pytest
run: GPU=1 python -m pytest -s -v

View File

@@ -11,18 +11,16 @@ def sparse_categorical_crossentropy(out, Y):
# correct loss for NLL, torch NLL loss returns one per row
y[range(y.shape[0]),YY] = -1.0*num_classes
y = y.reshape(list(Y.shape)+[num_classes])
y = Tensor(y, device=out.device)
y = Tensor(y)
return out.mul(y).mean()
def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, lossfn=sparse_categorical_crossentropy):
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categorical_crossentropy):
Tensor.training = True
if device == Device.GPU: [x.gpu_() for x in get_parameters([model, optim])]
elif device == Device.ANE: [x.ane_() for x in get_parameters([model, optim])]
losses, accuracies = [], []
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(X_train[samp], device=device)
x = Tensor(X_train[samp])
y = Y_train[samp]
# network
@@ -42,12 +40,12 @@ def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, loss
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128):
def evaluate(model, X_test, Y_test, num_classes=None, BS=128):
Tensor.training = False
def numpy_eval(num_classes):
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS], device=device)).cpu().data
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS])).cpu().data
Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
return (Y_test == Y_test_preds).mean()

View File

@@ -1,31 +1,30 @@
#!/usr/bin/env python
import gc
import unittest
from tinygrad.tensor import Tensor, GPU, ANE, Device
from tinygrad.tensor import Tensor
def tensors_allocated():
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
class TestGC(unittest.TestCase):
device = Device.CPU
def test_gc(self):
a = Tensor.zeros(4,4, device=self.device)
b = Tensor.zeros(4,4, device=self.device)
a = Tensor.zeros(4,4)
b = Tensor.zeros(4,4)
(a*b).mean().backward()
assert(tensors_allocated() > 0)
del a,b
assert(tensors_allocated() == 0)
def test_gc_complex(self):
a = Tensor.zeros(4,4, device=self.device)
b = Tensor.zeros(4,4, device=self.device)
a = Tensor.zeros(4,4)
b = Tensor.zeros(4,4)
assert(tensors_allocated() == 2)
(a*b).mean().backward()
assert(tensors_allocated() == 4)
del b
assert(tensors_allocated() == 2)
b = Tensor.zeros(4,4, device=self.device)
b = Tensor.zeros(4,4)
print(tensors_allocated())
(a*b).mean().backward()
print(tensors_allocated())
@@ -33,13 +32,5 @@ class TestGC(unittest.TestCase):
del b
assert(tensors_allocated() == 2)
@unittest.skipUnless(GPU, "Requires GPU")
class TestGCGPU(TestGC):
device = Device.GPU
@unittest.skipUnless(ANE, "Requires ANE")
class TestGCANE(TestGC):
device=Device.ANE
if __name__ == '__main__':
unittest.main()

View File

@@ -2,7 +2,7 @@
import os
import unittest
import numpy as np
from tinygrad.tensor import Tensor, GPU, ANE, Device
from tinygrad.tensor import Tensor
import tinygrad.optim as optim
from extra.training import train, evaluate
from extra.utils import fetch, get_parameters
@@ -55,36 +55,27 @@ class TinyConvNet:
return x.dot(self.l1).logsoftmax()
class TestMNIST(unittest.TestCase):
device = Device.CPU
def test_conv(self):
np.random.seed(1337)
model = TinyConvNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=200, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95
train(model, X_train, Y_train, optimizer, steps=200)
assert evaluate(model, X_test, Y_test) > 0.95
def test_sgd(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.SGD(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=1000, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95
train(model, X_train, Y_train, optimizer, steps=1000)
assert evaluate(model, X_test, Y_test) > 0.95
def test_rmsprop(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.RMSprop(model.parameters(), lr=0.0002)
train(model, X_train, Y_train, optimizer, steps=1000, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95
@unittest.skipUnless(GPU, "Requires GPU")
class TestMNISTGPU(TestMNIST):
device = Device.GPU
@unittest.skipUnless(ANE, "Requires ANE")
class TestMNISTANE(TestMNIST):
device=Device.ANE
train(model, X_train, Y_train, optimizer, steps=1000)
assert evaluate(model, X_test, Y_test) > 0.95
if __name__ == '__main__':
unittest.main()

View File

@@ -4,7 +4,7 @@ import cProfile
import pstats
import unittest
import torch
from tinygrad.tensor import Tensor, GPU, ANE, Device
from tinygrad.tensor import Tensor
def start_profile():
import time
@@ -20,7 +20,6 @@ def stop_profile(pr, sort='cumtime'):
ps.print_stats(0.2)
class TestConvSpeed(unittest.TestCase):
device= Device.CPU
def test_mnist(self):
# https://keras.io/examples/vision/mnist_convnet/
@@ -64,15 +63,15 @@ class TestConvSpeed(unittest.TestCase):
# ****** tinygrad compare *******
c1 = Tensor(c1.detach().numpy(), device=self.device)
c2 = Tensor(c2.detach().numpy(), device=self.device)
l1 = Tensor(l1.detach().numpy(), device=self.device)
c1 = Tensor(c1.detach().numpy())
c2 = Tensor(c2.detach().numpy())
l1 = Tensor(l1.detach().numpy())
cnt = 5
fpt, bpt = 0.0, 0.0
for i in range(1+cnt):
et0 = time.time()
x = Tensor.randn(128, 1, 28, 28, device=self.device)
x = Tensor.randn(128, 1, 28, 28)
x = x.conv2d(c1).relu().avg_pool2d()
x = x.conv2d(c2).relu().max_pool2d()
x = x.reshape(shape=(x.shape[0], -1))
@@ -93,14 +92,6 @@ class TestConvSpeed(unittest.TestCase):
print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline))
print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline))
@unittest.skipUnless(GPU, "Requires GPU")
class TestConvSpeedGPU(TestConvSpeed):
device = Device.GPU
@unittest.skipUnless(ANE, "Requires ANE")
class TestConvSpeedANE(TestConvSpeed):
device=Device.ANE
if __name__ == '__main__':
unittest.main()

View File

@@ -1,13 +1,13 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import GPU, ANE, Device
from tinygrad.tensor import Tensor, DEFAULT_DEVICE
from tinygrad.nn import *
from extra.utils import get_parameters
import torch
@unittest.skipUnless(not DEFAULT_DEVICE, "Not Implemented")
class TestNN(unittest.TestCase):
device = Device.CPU
def test_batchnorm2d(self, training=False):
sz = 4
@@ -33,7 +33,7 @@ class TestNN(unittest.TestCase):
np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)
# trial
inn = Tensor.randn(2, sz, 3, 3, device=self.device)
inn = Tensor.randn(2, sz, 3, 3)
# in tinygrad
outt = bn(inn)
@@ -52,27 +52,5 @@ class TestNN(unittest.TestCase):
def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)
@unittest.skipUnless(GPU, "Requires GPU")
class TestNNGPU(TestNN):
device = Device.GPU
@unittest.skip("Tests not added")
def test_batchnorm2d(self): pass
@unittest.skip("Tests not added")
def test_batchnorm2d_training(self): pass
@unittest.skipUnless(ANE, "Requires ANE")
class TestNNANE(TestNN):
device=Device.ANE
@unittest.skip("Tests not added")
def test_batchnorm2d(self): pass
@unittest.skip("Tests not added")
def test_batchnorm2d_training(self): pass
if __name__ == '__main__':
unittest.main()

View File

@@ -4,20 +4,16 @@ import numpy as np
import unittest
import timeit
import functools
from tinygrad.tensor import Tensor, GPU, ANE, Device
from tinygrad.tensor import Tensor, DEFAULT_DEVICE, Device
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False, vals=None):
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, forward_only=False, vals=None):
torch.manual_seed(0)
if shps is None:
ts = [torch.tensor(x, requires_grad=True) for x in vals]
else:
ts = [torch.rand(x, requires_grad=True) for x in shps]
tst = [Tensor(x.detach().numpy()) for x in ts]
if device==Device.GPU:
tst = [x.gpu() for x in tst]
elif device==Device.ANE:
tst = [x.ane() for x in tst]
tst = [Tensor(x.detach().numpy()) for x in ts]
out = torch_fxn(*ts)
ret = tinygrad_fxn(*tst)
@@ -42,82 +38,74 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0
print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp))
# TODO: everywhere you see this, make the op work on GPU
def cpu_only(func):
def wrapper(self):
if self.device == Device.CPU:
func(self)
return wrapper
class TestOps(unittest.TestCase):
device=Device.CPU
def test_add(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add, device=self.device)
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
def test_sub(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub, device=self.device)
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub)
def test_mul(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul, device=self.device)
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul)
def test_div(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, device=self.device)
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
def test_pow(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, device=self.device)
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow)
def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, device=self.device)
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt)
def test_relu(self):
helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, device=self.device)
helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu)
def test_leakyrelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, device=self.device)
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
def test_abs(self):
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, device=self.device)
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs)
def test_log(self):
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log, device=self.device)
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log)
def test_exp(self):
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp, device=self.device)
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp)
def test_sigmoid(self):
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device)
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid)
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device)
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot)
def test_multidot(self):
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot)
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot)
def test_sum(self):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1), device=self.device)
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
def test_max(self):
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device)
helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device,
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max)
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5))
helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5),
vals=[
[[1.0,1.0,0.0,1.0]],
])
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1))
def test_mean_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)))
def test_logsoftmax(self):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, device=self.device)
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7)
def test_tanh(self):
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, device=self.device)
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
def test_topo_sort(self):
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6, device=self.device)
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
def test_scalar_mul(self):
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2, device=self.device)
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2)
def test_scalar_rmul(self):
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, device=self.device)
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x)
def test_scalar_sub(self):
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, device=self.device)
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2)
def test_scalar_rsub(self):
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, device=self.device)
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x)
def test_broadcast_full(self):
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]:
with self.subTest(op=torch_op.__name__, shapes=shapes):
helper_test_op(shapes, torch_op, tinygrad_op, device=self.device)
helper_test_op(shapes, torch_op, tinygrad_op)
def test_broadcast_partial(self):
@@ -127,28 +115,28 @@ class TestOps(unittest.TestCase):
((4,1), (4,5)), ((1,4), (5,4))]:
with self.subTest(op=torch_op.__name__, shapes=shapes):
# NOTE: ANE backwards?
helper_test_op(shapes, torch_op, tinygrad_op, device=self.device, forward_only=self.device!=Device.CPU)
helper_test_op(shapes, torch_op, tinygrad_op, forward_only=DEFAULT_DEVICE!=Device.CPU)
def test_slice(self):
helper_test_op([(3,3,3,3)], lambda x: x[1:2], lambda x: x[1:2], device=self.device)
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2], device=self.device)
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1], device=self.device)
helper_test_op([(3,3,3,3)], lambda x: x[1:2], lambda x: x[1:2])
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1])
def test_pad2d(self):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device)
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)))
def test_transpose(self):
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device)
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)))
# This is failing on GPU because the dim is too large
#helper_test_op([(21,22,23,24)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1)), device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0)), device=self.device)
#helper_test_op([(21,22,23,24)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1)))
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0)))
def test_reshape(self):
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), device=self.device)
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)), device=self.device)
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)))
def test_detach(self):
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), device=self.device, forward_only=True)
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
def test_conv2d(self):
for bs in [1,8]:
@@ -159,7 +147,7 @@ class TestOps(unittest.TestCase):
with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W):
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), device=self.device, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5)
def test_strided_conv2d(self):
bs = 4
@@ -168,18 +156,18 @@ class TestOps(unittest.TestCase):
with self.subTest(stride := 2):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), device=self.device)
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu())
with self.subTest(stride := (2,1)):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), device=self.device)
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu())
def test_maxpool2d(self):
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
with self.subTest(kernel_size=ksz):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz),
lambda x: Tensor.max_pool2d(x, kernel_size=ksz), device=self.device)
lambda x: Tensor.max_pool2d(x, kernel_size=ksz))
def test_avgpool2d(self):
shape = (32,2,111,28)
@@ -187,15 +175,7 @@ class TestOps(unittest.TestCase):
with self.subTest(kernel_size=ksz):
helper_test_op([shape],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), device=self.device, rtol=1e-5)
@unittest.skipUnless(GPU, "Requires GPU")
class TestOpsGPU(TestOps):
device=Device.GPU
@unittest.skipUnless(ANE, "Requires ANE")
class TestOpsANE(TestOps):
device=Device.ANE
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -1,7 +1,7 @@
import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor, GPU, ANE, Device
from tinygrad.tensor import Tensor
from tinygrad.optim import Adam, SGD, RMSprop
from extra.utils import get_parameters
@@ -9,11 +9,9 @@ x_init = np.random.randn(1,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)
def step_tinygrad(optim, kwargs={}, device=Device.CPU):
def step_tinygrad(optim, kwargs={}):
net = TinyNet()
optim = optim([net.x, net.W], **kwargs)
if device==Device.GPU: [x.gpu_() for x in get_parameters([net, optim])]
elif device==Device.ANE: [x.ane_() for x in get_parameters([net, optim])]
out = net.forward()
out.backward()
optim.step()
@@ -55,33 +53,22 @@ class TorchNet():
class TestOptim(unittest.TestCase):
device = Device.CPU
def test_adam(self):
for x,y in zip(step_tinygrad(Adam, device=self.device),
for x,y in zip(step_tinygrad(Adam),
step_pytorch(torch.optim.Adam)):
np.testing.assert_allclose(x, y, atol=1e-4)
def test_sgd(self):
for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}, device=self.device),
for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}),
step_pytorch(torch.optim.SGD, kwargs={'lr': 0.001})):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_rmsprop(self):
for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}, device=self.device),
for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}),
step_pytorch(torch.optim.RMSprop,
kwargs={'lr': 0.001, 'alpha': 0.99})):
np.testing.assert_allclose(x, y, atol=1e-5)
@unittest.skipUnless(GPU, "Requires GPU")
class TestOptimGPU(TestOptim):
device = Device.GPU
@unittest.skipUnless(ANE, "Requires ANE")
class TestOptimANE(TestOptim):
device = Device.ANE
if __name__ == '__main__':
unittest.main()

View File

@@ -1,7 +1,7 @@
import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor, GPU, ANE, Device
from tinygrad.tensor import Tensor, DEFAULT_DEVICE
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
x_init = np.random.randn(1,3).astype(np.float32)
@@ -11,13 +11,12 @@ W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)
class TestTinygrad(unittest.TestCase):
device = Device.CPU
def test_backward_pass(self):
def test_tinygrad():
x = Tensor(x_init, device=self.device)
W = Tensor(W_init, device=self.device)
m = Tensor(m_init, device=self.device)
x = Tensor(x_init)
W = Tensor(W_init)
m = Tensor(m_init)
out = x.dot(W).relu()
out = out.logsoftmax()
out = out.mul(m).add(m).sum()
@@ -39,9 +38,9 @@ class TestTinygrad(unittest.TestCase):
def test_backward_pass_diamond_model(self):
def test_tinygrad():
u = Tensor(U_init, device=self.device)
v = Tensor(V_init, device=self.device)
w = Tensor(W_init, device=self.device)
u = Tensor(U_init)
v = Tensor(V_init)
w = Tensor(W_init)
x = u.mul(v).relu()
y = u.mul(w).relu()
out = x.add(y).mul(y).relu()
@@ -65,6 +64,7 @@ class TestTinygrad(unittest.TestCase):
for x,y in zip(test_tinygrad(), test_pytorch()):
np.testing.assert_allclose(x, y, atol=1e-5)
@unittest.skipUnless(not DEFAULT_DEVICE, "float64 not supported on GPU")
def test_jacobian(self):
W = np.random.RandomState(1337).random((10, 5))
x = np.random.RandomState(7331).random((1, 10)) - 0.5
@@ -74,8 +74,8 @@ class TestTinygrad(unittest.TestCase):
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy()
tiny_x = Tensor(x, device=self.device)
tiny_W = Tensor(W, device=self.device)
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
J = jacobian(tiny_func, tiny_x)
NJ = numerical_jacobian(tiny_func, tiny_x)
@@ -83,12 +83,13 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(PJ, J, atol = 1e-5)
np.testing.assert_allclose(PJ, NJ, atol = 1e-5)
@unittest.skipUnless(not DEFAULT_DEVICE, "float64 not supported on GPU")
def test_gradcheck(self):
W = np.random.RandomState(1337).random((10, 5))
x = np.random.RandomState(7331).random((1, 10)) - 0.5
tiny_x = Tensor(x, device=self.device)
tiny_W = Tensor(W, device=self.device)
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
self.assertTrue(gradcheck(tiny_func, tiny_x))
@@ -96,20 +97,5 @@ class TestTinygrad(unittest.TestCase):
# coarse approx. since a "big" eps and the non-linearities of the model
self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1))
@unittest.skipUnless(GPU, "Requires GPU")
class TestTinygradGPU(TestTinygrad):
device = Device.GPU
@unittest.skip("float64 not supported on GPU")
def test_jacobian(self): pass
@unittest.skip("float64 not supported on GPU")
def test_gradcheck(self): pass
@unittest.skipUnless(ANE, "Requires ANE")
class TestOpsANE(TestTinygrad):
device=Device.ANE
if __name__ == '__main__':
unittest.main()

View File

@@ -67,15 +67,17 @@ def require_init_ane():
class Device: CPU, GPU, ANE = 0, 1, 2
DEFAULT_DEVICE = Device.CPU if os.environ.get("GPU", 0) != "1" else Device.GPU
class Tensor:
did_float_warning = False
training = True
ops = defaultdict(dict)
def __init__(self, data, device=Device.CPU, requires_grad=True):
self.data = self._move_data(data, device)
def __init__(self, data, device=DEFAULT_DEVICE, requires_grad=True):
self.device, self.data = device, self._move_data(data, device)
self.device, self.grad, self.requires_grad = device, None, requires_grad
self.grad, self.requires_grad = None, requires_grad
# internal variables used for autograd graph construction
self._ctx = None