From ebd72ff437006fc5855a599a46d12cfebc69594a Mon Sep 17 00:00:00 2001 From: Liam <3579535@myuwc.ac.za> Date: Fri, 1 Jan 2021 15:19:03 +0100 Subject: [PATCH] Test split (#231) * Split tests Split tests into "Test CPU" and "Test GPU". Add test flag "TEST_DEVICES" which is a comma separated list of devices: CPU,GPU,ANE * Run tests based on provided TEST_DEVICES flag By default will run all "CPU,GPU,ANE" * fix bad quote * Revert changes and use GPU=1 This is done through setting the default Tensor Device to Device.CPU of GPU=1 is set. Run GPU tests: GPU=1 pytest -s -v --- .github/workflows/test.yml | 56 +++++++++++------ extra/training.py | 12 ++-- test/test_gc.py | 21 ++----- test/test_mnist.py | 23 +++---- test/test_net_speed.py | 19 ++---- test/test_nn.py | 28 +-------- test/test_ops.py | 120 ++++++++++++++++--------------------- test/test_optim.py | 23 ++----- test/test_tensor.py | 40 ++++--------- tinygrad/tensor.py | 8 ++- 10 files changed, 137 insertions(+), 213 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fab97e0da4..0457702310 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,24 +15,6 @@ jobs: - name: Check <1000 lines run: sloccount tinygrad test examples; if [ $(sloccount tinygrad | sed -n 's/.*Total Physical Source Lines of Code (SLOC)[ ]*= \([^ ]*\).*/\1/p' | tr -d ',') -gt 1000 ]; then exit 1; fi - test: - name: Test - runs-on: ubuntu-latest - - steps: - - name: Checkout Code - uses: actions/checkout@v2 - - name: Install OpenCL - run: sudo apt-get install pocl-opencl-icd - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - name: Install Dependencies - run: pip install -e '.[gpu,testing]' - - name: Run Pytest - run: python -m pytest -s -v - linter: name: Indentation Linter runs-on: ubuntu-latest @@ -53,3 +35,41 @@ jobs: run: | python -m pylint --disable=all -e W0311 --jobs=0 --indent-string=' ' **/*.py + testcpu: + name: CPU Tests + runs-on: ubuntu-latest + + steps: + - name: Checkout Code + uses: actions/checkout@v2 + - name: Install OpenCL + run: sudo apt-get install pocl-opencl-icd + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install Dependencies + run: pip install -e '.[testing]' + - name: Run Pytest + run: python -m pytest -s -v + + testgpu: + name: GPU Tests + runs-on: ubuntu-latest + + steps: + - name: Checkout Code + uses: actions/checkout@v2 + - name: Install OpenCL + run: sudo apt-get install pocl-opencl-icd + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install Dependencies + run: pip install -e '.[gpu,testing]' + - name: Run Pytest + run: GPU=1 python -m pytest -s -v + + + diff --git a/extra/training.py b/extra/training.py index d3462fa4b8..585bc083cf 100644 --- a/extra/training.py +++ b/extra/training.py @@ -11,18 +11,16 @@ def sparse_categorical_crossentropy(out, Y): # correct loss for NLL, torch NLL loss returns one per row y[range(y.shape[0]),YY] = -1.0*num_classes y = y.reshape(list(Y.shape)+[num_classes]) - y = Tensor(y, device=out.device) + y = Tensor(y) return out.mul(y).mean() -def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, lossfn=sparse_categorical_crossentropy): +def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categorical_crossentropy): Tensor.training = True - if device == Device.GPU: [x.gpu_() for x in get_parameters([model, optim])] - elif device == Device.ANE: [x.ane_() for x in get_parameters([model, optim])] losses, accuracies = [], [] for i in (t := trange(steps, disable=os.getenv('CI') is not None)): samp = np.random.randint(0, X_train.shape[0], size=(BS)) - x = Tensor(X_train[samp], device=device) + x = Tensor(X_train[samp]) y = Y_train[samp] # network @@ -42,12 +40,12 @@ def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, loss accuracies.append(accuracy) t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) -def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128): +def evaluate(model, X_test, Y_test, num_classes=None, BS=128): Tensor.training = False def numpy_eval(num_classes): Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes]) for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None): - Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS], device=device)).cpu().data + Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS])).cpu().data Y_test_preds = np.argmax(Y_test_preds_out, axis=-1) return (Y_test == Y_test_preds).mean() diff --git a/test/test_gc.py b/test/test_gc.py index 56b5af66c2..da76aeaaab 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -1,31 +1,30 @@ #!/usr/bin/env python import gc import unittest -from tinygrad.tensor import Tensor, GPU, ANE, Device +from tinygrad.tensor import Tensor def tensors_allocated(): return sum([isinstance(x, Tensor) for x in gc.get_objects()]) class TestGC(unittest.TestCase): - device = Device.CPU def test_gc(self): - a = Tensor.zeros(4,4, device=self.device) - b = Tensor.zeros(4,4, device=self.device) + a = Tensor.zeros(4,4) + b = Tensor.zeros(4,4) (a*b).mean().backward() assert(tensors_allocated() > 0) del a,b assert(tensors_allocated() == 0) def test_gc_complex(self): - a = Tensor.zeros(4,4, device=self.device) - b = Tensor.zeros(4,4, device=self.device) + a = Tensor.zeros(4,4) + b = Tensor.zeros(4,4) assert(tensors_allocated() == 2) (a*b).mean().backward() assert(tensors_allocated() == 4) del b assert(tensors_allocated() == 2) - b = Tensor.zeros(4,4, device=self.device) + b = Tensor.zeros(4,4) print(tensors_allocated()) (a*b).mean().backward() print(tensors_allocated()) @@ -33,13 +32,5 @@ class TestGC(unittest.TestCase): del b assert(tensors_allocated() == 2) -@unittest.skipUnless(GPU, "Requires GPU") -class TestGCGPU(TestGC): - device = Device.GPU - -@unittest.skipUnless(ANE, "Requires ANE") -class TestGCANE(TestGC): - device=Device.ANE - if __name__ == '__main__': unittest.main() diff --git a/test/test_mnist.py b/test/test_mnist.py index b6f089ea20..d3a33bccac 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -2,7 +2,7 @@ import os import unittest import numpy as np -from tinygrad.tensor import Tensor, GPU, ANE, Device +from tinygrad.tensor import Tensor import tinygrad.optim as optim from extra.training import train, evaluate from extra.utils import fetch, get_parameters @@ -55,36 +55,27 @@ class TinyConvNet: return x.dot(self.l1).logsoftmax() class TestMNIST(unittest.TestCase): - device = Device.CPU def test_conv(self): np.random.seed(1337) model = TinyConvNet() optimizer = optim.Adam(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, steps=200, device=self.device) - assert evaluate(model, X_test, Y_test, device=self.device) > 0.95 + train(model, X_train, Y_train, optimizer, steps=200) + assert evaluate(model, X_test, Y_test) > 0.95 def test_sgd(self): np.random.seed(1337) model = TinyBobNet() optimizer = optim.SGD(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, steps=1000, device=self.device) - assert evaluate(model, X_test, Y_test, device=self.device) > 0.95 + train(model, X_train, Y_train, optimizer, steps=1000) + assert evaluate(model, X_test, Y_test) > 0.95 def test_rmsprop(self): np.random.seed(1337) model = TinyBobNet() optimizer = optim.RMSprop(model.parameters(), lr=0.0002) - train(model, X_train, Y_train, optimizer, steps=1000, device=self.device) - assert evaluate(model, X_test, Y_test, device=self.device) > 0.95 - -@unittest.skipUnless(GPU, "Requires GPU") -class TestMNISTGPU(TestMNIST): - device = Device.GPU - -@unittest.skipUnless(ANE, "Requires ANE") -class TestMNISTANE(TestMNIST): - device=Device.ANE + train(model, X_train, Y_train, optimizer, steps=1000) + assert evaluate(model, X_test, Y_test) > 0.95 if __name__ == '__main__': unittest.main() diff --git a/test/test_net_speed.py b/test/test_net_speed.py index 38a42cfed3..eaf1f90a80 100644 --- a/test/test_net_speed.py +++ b/test/test_net_speed.py @@ -4,7 +4,7 @@ import cProfile import pstats import unittest import torch -from tinygrad.tensor import Tensor, GPU, ANE, Device +from tinygrad.tensor import Tensor def start_profile(): import time @@ -20,7 +20,6 @@ def stop_profile(pr, sort='cumtime'): ps.print_stats(0.2) class TestConvSpeed(unittest.TestCase): - device= Device.CPU def test_mnist(self): # https://keras.io/examples/vision/mnist_convnet/ @@ -64,15 +63,15 @@ class TestConvSpeed(unittest.TestCase): # ****** tinygrad compare ******* - c1 = Tensor(c1.detach().numpy(), device=self.device) - c2 = Tensor(c2.detach().numpy(), device=self.device) - l1 = Tensor(l1.detach().numpy(), device=self.device) + c1 = Tensor(c1.detach().numpy()) + c2 = Tensor(c2.detach().numpy()) + l1 = Tensor(l1.detach().numpy()) cnt = 5 fpt, bpt = 0.0, 0.0 for i in range(1+cnt): et0 = time.time() - x = Tensor.randn(128, 1, 28, 28, device=self.device) + x = Tensor.randn(128, 1, 28, 28) x = x.conv2d(c1).relu().avg_pool2d() x = x.conv2d(c2).relu().max_pool2d() x = x.reshape(shape=(x.shape[0], -1)) @@ -93,14 +92,6 @@ class TestConvSpeed(unittest.TestCase): print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline)) print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline)) -@unittest.skipUnless(GPU, "Requires GPU") -class TestConvSpeedGPU(TestConvSpeed): - device = Device.GPU - -@unittest.skipUnless(ANE, "Requires ANE") -class TestConvSpeedANE(TestConvSpeed): - device=Device.ANE - if __name__ == '__main__': unittest.main() diff --git a/test/test_nn.py b/test/test_nn.py index adc55d008d..c502f9f871 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1,13 +1,13 @@ #!/usr/bin/env python import unittest import numpy as np -from tinygrad.tensor import GPU, ANE, Device +from tinygrad.tensor import Tensor, DEFAULT_DEVICE from tinygrad.nn import * from extra.utils import get_parameters import torch +@unittest.skipUnless(not DEFAULT_DEVICE, "Not Implemented") class TestNN(unittest.TestCase): - device = Device.CPU def test_batchnorm2d(self, training=False): sz = 4 @@ -33,7 +33,7 @@ class TestNN(unittest.TestCase): np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5) # trial - inn = Tensor.randn(2, sz, 3, 3, device=self.device) + inn = Tensor.randn(2, sz, 3, 3) # in tinygrad outt = bn(inn) @@ -52,27 +52,5 @@ class TestNN(unittest.TestCase): def test_batchnorm2d_training(self): self.test_batchnorm2d(True) -@unittest.skipUnless(GPU, "Requires GPU") -class TestNNGPU(TestNN): - device = Device.GPU - - @unittest.skip("Tests not added") - def test_batchnorm2d(self): pass - - @unittest.skip("Tests not added") - def test_batchnorm2d_training(self): pass - - -@unittest.skipUnless(ANE, "Requires ANE") -class TestNNANE(TestNN): - device=Device.ANE - - @unittest.skip("Tests not added") - def test_batchnorm2d(self): pass - - @unittest.skip("Tests not added") - def test_batchnorm2d_training(self): pass - - if __name__ == '__main__': unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index 5946501e15..96eceb98f7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -4,20 +4,16 @@ import numpy as np import unittest import timeit import functools -from tinygrad.tensor import Tensor, GPU, ANE, Device +from tinygrad.tensor import Tensor, DEFAULT_DEVICE, Device -def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False, vals=None): +def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, forward_only=False, vals=None): torch.manual_seed(0) if shps is None: ts = [torch.tensor(x, requires_grad=True) for x in vals] else: ts = [torch.rand(x, requires_grad=True) for x in shps] - tst = [Tensor(x.detach().numpy()) for x in ts] - if device==Device.GPU: - tst = [x.gpu() for x in tst] - elif device==Device.ANE: - tst = [x.ane() for x in tst] + tst = [Tensor(x.detach().numpy()) for x in ts] out = torch_fxn(*ts) ret = tinygrad_fxn(*tst) @@ -42,82 +38,74 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0 print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp)) -# TODO: everywhere you see this, make the op work on GPU -def cpu_only(func): - def wrapper(self): - if self.device == Device.CPU: - func(self) - return wrapper - class TestOps(unittest.TestCase): - device=Device.CPU def test_add(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add, device=self.device) + helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add) def test_sub(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub, device=self.device) + helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub) def test_mul(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul, device=self.device) + helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul) def test_div(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, device=self.device) + helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div) def test_pow(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, device=self.device) + helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow) def test_sqrt(self): - helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, device=self.device) + helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt) def test_relu(self): - helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, device=self.device) + helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu) def test_leakyrelu(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, device=self.device) + helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu) def test_abs(self): - helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, device=self.device) + helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs) def test_log(self): - helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log, device=self.device) + helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log) def test_exp(self): - helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp, device=self.device) + helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp) def test_sigmoid(self): - helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device) + helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid) def test_dot(self): - helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device) + helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot) def test_multidot(self): - helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device) - helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device) + helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot) + helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot) def test_sum(self): - helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1), device=self.device) + helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2))) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1)) def test_max(self): - helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device) - helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device) - helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device, + helper_test_op([(45,3)], lambda x: x.max(), Tensor.max) + helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5)) + helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), vals=[ [[1.0,1.0,0.0,1.0]], ]) - helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device) + helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1)) def test_mean_axis(self): - helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), device=self.device) + helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2))) def test_logsoftmax(self): - helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, device=self.device) + helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7) def test_tanh(self): - helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, device=self.device) + helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) def test_topo_sort(self): - helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6, device=self.device) + helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6) def test_scalar_mul(self): - helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2, device=self.device) + helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2) def test_scalar_rmul(self): - helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, device=self.device) + helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x) def test_scalar_sub(self): - helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, device=self.device) + helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2) def test_scalar_rsub(self): - helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, device=self.device) + helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x) def test_broadcast_full(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]: with self.subTest(op=torch_op.__name__, shapes=shapes): - helper_test_op(shapes, torch_op, tinygrad_op, device=self.device) + helper_test_op(shapes, torch_op, tinygrad_op) def test_broadcast_partial(self): @@ -127,28 +115,28 @@ class TestOps(unittest.TestCase): ((4,1), (4,5)), ((1,4), (5,4))]: with self.subTest(op=torch_op.__name__, shapes=shapes): # NOTE: ANE backwards? - helper_test_op(shapes, torch_op, tinygrad_op, device=self.device, forward_only=self.device!=Device.CPU) + helper_test_op(shapes, torch_op, tinygrad_op, forward_only=DEFAULT_DEVICE!=Device.CPU) def test_slice(self): - helper_test_op([(3,3,3,3)], lambda x: x[1:2], lambda x: x[1:2], device=self.device) - helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2], device=self.device) - helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1], device=self.device) + helper_test_op([(3,3,3,3)], lambda x: x[1:2], lambda x: x[1:2]) + helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2]) + helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1]) def test_pad2d(self): - helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device) + helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4))) def test_transpose(self): - helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device) + helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1))) # This is failing on GPU because the dim is too large - #helper_test_op([(21,22,23,24)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1)), device=self.device) - helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0)), device=self.device) + #helper_test_op([(21,22,23,24)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1))) + helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0))) def test_reshape(self): - helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), device=self.device) - helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)), device=self.device) + helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6))) + helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6))) def test_detach(self): - helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), device=self.device, forward_only=True) + helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True) def test_conv2d(self): for bs in [1,8]: @@ -159,7 +147,7 @@ class TestOps(unittest.TestCase): with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W): helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), device=self.device, grad_rtol=1e-5) + lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5) def test_strided_conv2d(self): bs = 4 @@ -168,18 +156,18 @@ class TestOps(unittest.TestCase): with self.subTest(stride := 2): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(), - lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), device=self.device) + lambda x,w: Tensor.conv2d(x,w,stride=stride).relu()) with self.subTest(stride := (2,1)): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(), - lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), device=self.device) + lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu()) def test_maxpool2d(self): for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), - lambda x: Tensor.max_pool2d(x, kernel_size=ksz), device=self.device) + lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) def test_avgpool2d(self): shape = (32,2,111,28) @@ -187,15 +175,7 @@ class TestOps(unittest.TestCase): with self.subTest(kernel_size=ksz): helper_test_op([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz), - lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), device=self.device, rtol=1e-5) - -@unittest.skipUnless(GPU, "Requires GPU") -class TestOpsGPU(TestOps): - device=Device.GPU - -@unittest.skipUnless(ANE, "Requires ANE") -class TestOpsANE(TestOps): - device=Device.ANE + lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/test_optim.py b/test/test_optim.py index be93e1b60f..2bc1268286 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1,7 +1,7 @@ import numpy as np import torch import unittest -from tinygrad.tensor import Tensor, GPU, ANE, Device +from tinygrad.tensor import Tensor from tinygrad.optim import Adam, SGD, RMSprop from extra.utils import get_parameters @@ -9,11 +9,9 @@ x_init = np.random.randn(1,3).astype(np.float32) W_init = np.random.randn(3,3).astype(np.float32) m_init = np.random.randn(1,3).astype(np.float32) -def step_tinygrad(optim, kwargs={}, device=Device.CPU): +def step_tinygrad(optim, kwargs={}): net = TinyNet() optim = optim([net.x, net.W], **kwargs) - if device==Device.GPU: [x.gpu_() for x in get_parameters([net, optim])] - elif device==Device.ANE: [x.ane_() for x in get_parameters([net, optim])] out = net.forward() out.backward() optim.step() @@ -55,33 +53,22 @@ class TorchNet(): class TestOptim(unittest.TestCase): - device = Device.CPU def test_adam(self): - for x,y in zip(step_tinygrad(Adam, device=self.device), + for x,y in zip(step_tinygrad(Adam), step_pytorch(torch.optim.Adam)): np.testing.assert_allclose(x, y, atol=1e-4) def test_sgd(self): - for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}, device=self.device), + for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}), step_pytorch(torch.optim.SGD, kwargs={'lr': 0.001})): np.testing.assert_allclose(x, y, atol=1e-5) def test_rmsprop(self): - for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}, device=self.device), + for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}), step_pytorch(torch.optim.RMSprop, kwargs={'lr': 0.001, 'alpha': 0.99})): np.testing.assert_allclose(x, y, atol=1e-5) - -@unittest.skipUnless(GPU, "Requires GPU") -class TestOptimGPU(TestOptim): - device = Device.GPU - -@unittest.skipUnless(ANE, "Requires ANE") -class TestOptimANE(TestOptim): - device = Device.ANE - - if __name__ == '__main__': unittest.main() diff --git a/test/test_tensor.py b/test/test_tensor.py index 413a4acf13..06391114af 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,7 +1,7 @@ import numpy as np import torch import unittest -from tinygrad.tensor import Tensor, GPU, ANE, Device +from tinygrad.tensor import Tensor, DEFAULT_DEVICE from extra.gradcheck import numerical_jacobian, jacobian, gradcheck x_init = np.random.randn(1,3).astype(np.float32) @@ -11,13 +11,12 @@ W_init = np.random.randn(3,3).astype(np.float32) m_init = np.random.randn(1,3).astype(np.float32) class TestTinygrad(unittest.TestCase): - device = Device.CPU def test_backward_pass(self): def test_tinygrad(): - x = Tensor(x_init, device=self.device) - W = Tensor(W_init, device=self.device) - m = Tensor(m_init, device=self.device) + x = Tensor(x_init) + W = Tensor(W_init) + m = Tensor(m_init) out = x.dot(W).relu() out = out.logsoftmax() out = out.mul(m).add(m).sum() @@ -39,9 +38,9 @@ class TestTinygrad(unittest.TestCase): def test_backward_pass_diamond_model(self): def test_tinygrad(): - u = Tensor(U_init, device=self.device) - v = Tensor(V_init, device=self.device) - w = Tensor(W_init, device=self.device) + u = Tensor(U_init) + v = Tensor(V_init) + w = Tensor(W_init) x = u.mul(v).relu() y = u.mul(w).relu() out = x.add(y).mul(y).relu() @@ -65,6 +64,7 @@ class TestTinygrad(unittest.TestCase): for x,y in zip(test_tinygrad(), test_pytorch()): np.testing.assert_allclose(x, y, atol=1e-5) + @unittest.skipUnless(not DEFAULT_DEVICE, "float64 not supported on GPU") def test_jacobian(self): W = np.random.RandomState(1337).random((10, 5)) x = np.random.RandomState(7331).random((1, 10)) - 0.5 @@ -74,8 +74,8 @@ class TestTinygrad(unittest.TestCase): torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1) PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy() - tiny_x = Tensor(x, device=self.device) - tiny_W = Tensor(W, device=self.device) + tiny_x = Tensor(x) + tiny_W = Tensor(W) tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax() J = jacobian(tiny_func, tiny_x) NJ = numerical_jacobian(tiny_func, tiny_x) @@ -83,12 +83,13 @@ class TestTinygrad(unittest.TestCase): np.testing.assert_allclose(PJ, J, atol = 1e-5) np.testing.assert_allclose(PJ, NJ, atol = 1e-5) + @unittest.skipUnless(not DEFAULT_DEVICE, "float64 not supported on GPU") def test_gradcheck(self): W = np.random.RandomState(1337).random((10, 5)) x = np.random.RandomState(7331).random((1, 10)) - 0.5 - tiny_x = Tensor(x, device=self.device) - tiny_W = Tensor(W, device=self.device) + tiny_x = Tensor(x) + tiny_W = Tensor(W) tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax() self.assertTrue(gradcheck(tiny_func, tiny_x)) @@ -96,20 +97,5 @@ class TestTinygrad(unittest.TestCase): # coarse approx. since a "big" eps and the non-linearities of the model self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1)) - -@unittest.skipUnless(GPU, "Requires GPU") -class TestTinygradGPU(TestTinygrad): - device = Device.GPU - - @unittest.skip("float64 not supported on GPU") - def test_jacobian(self): pass - - @unittest.skip("float64 not supported on GPU") - def test_gradcheck(self): pass - -@unittest.skipUnless(ANE, "Requires ANE") -class TestOpsANE(TestTinygrad): - device=Device.ANE - if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0631381699..60b94dabee 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -67,15 +67,17 @@ def require_init_ane(): class Device: CPU, GPU, ANE = 0, 1, 2 +DEFAULT_DEVICE = Device.CPU if os.environ.get("GPU", 0) != "1" else Device.GPU + class Tensor: did_float_warning = False training = True ops = defaultdict(dict) - def __init__(self, data, device=Device.CPU, requires_grad=True): - self.data = self._move_data(data, device) + def __init__(self, data, device=DEFAULT_DEVICE, requires_grad=True): + self.device, self.data = device, self._move_data(data, device) - self.device, self.grad, self.requires_grad = device, None, requires_grad + self.grad, self.requires_grad = None, requires_grad # internal variables used for autograd graph construction self._ctx = None