All devices are equal! (#196)

* Update all devices to be tested

ANE, CPU and OCL all now support all tests.

However tests are not currently passing on GPU and I cannot test on CPU.

Failing GPU test are not an issue caused by this update. Tests have not
been passing due to a missing "six" required installation.

OpenCL Tests have not been run since commit: 1a1c63a08b

devices have 3 types and are handle by a new DeviceTypes enum. (The goal
is to revert to Tensor.<type>, but this current setup allows for keyword
argument defaults: `device=DeviceType.CPU`)

All references to Tensor.GPU/CPU/ANE as been converted to the
corresponding `DeviceTypes` enum.

Refactor of the conversion code to allow for any device to any device
conversion.

* Add six dependency in requirements.txt

* Resolve failure to run tests

Move six into gpu required installs. Remove six from standard
installation.

* Remove repeated data conversion

* Refactor method names

Also reduce code with .to and .to_

* Dynamic device handlers

* Refactor DeviceTypes -> Device

* Add mem copy profiling back

* test_backward_pass_diamond_model passing

* Resolve Sum issue on GPU

* Revert batchnorm2d tests

* Update README with upadated API

* ANE testing with

* Last minute line gains
This commit is contained in:
Liam
2020-12-16 08:44:08 +01:00
committed by GitHub
parent 78210b5e40
commit bcf1518309
15 changed files with 246 additions and 181 deletions

View File

@@ -84,7 +84,7 @@ tinygrad supports GPUs through PyOpenCL.
```python
from tinygrad.tensor import Tensor
(Tensor.ones(4,4).cuda() + Tensor.ones(4,4).cuda()).cpu()
(Tensor.ones(4,4).gpu() + Tensor.ones(4,4).gpu()).cpu()
```
### ANE Support?!

View File

@@ -2,21 +2,22 @@ import os
import numpy as np
from tqdm import trange
from extra.utils import get_parameters
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, GPU, Device
def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=False, lossfn = lambda out,y: out.mul(y).mean()):
if gpu is True: [x.cuda_() for x in get_parameters([model, optim])]
def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, device=Device.CPU, lossfn = lambda out,y: out.mul(y).mean()):
if device == Device.GPU: [x.gpu_() for x in get_parameters([model, optim])]
elif device == Device.ANE: [x.ane_() for x in get_parameters([model, optim])]
if num_classes is None: num_classes = Y_train.max().astype(int)+1
losses, accuracies = [], []
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), device=device)
Y = Y_train[samp]
y = np.zeros((len(samp),num_classes), np.float32)
# correct loss for NLL, torch NLL loss returns one per row
y[range(y.shape[0]),Y] = -1.0*num_classes
y = Tensor(y, gpu=gpu)
y = Tensor(y, device=device)
# network
out = model.forward(x)
@@ -36,11 +37,11 @@ def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=F
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
def evaluate(model, X_test, Y_test, num_classes=None, gpu=False, BS=128):
def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128):
def numpy_eval(num_classes):
Y_test_preds_out = np.zeros((len(Y_test),num_classes))
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)).cpu().data
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), device=device)).cpu().data
Y_test_preds = np.argmax(Y_test_preds_out, axis=1)
return (Y_test == Y_test_preds).mean()

View File

@@ -22,7 +22,7 @@ setup(name='tinygrad',
install_requires=['numpy', 'requests'],
python_requires='>=3.8',
extras_require={
'gpu': ["pyopencl"],
'gpu': ["pyopencl", "six"],
'testing': [
"pytest",
"torch",

0
test/__init__.py Normal file
View File

3
test/config.py Normal file
View File

@@ -0,0 +1,3 @@
import os
ANE = os.environ.get('ANE', False)

View File

@@ -1,31 +1,32 @@
#!/usr/bin/env python
import gc
import unittest
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, GPU, Device
from .config import ANE
def tensors_allocated():
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
class TestGC(unittest.TestCase):
gpu = False
device = Device.CPU
def test_gc(self):
a = Tensor.zeros(4,4, gpu=self.gpu)
b = Tensor.zeros(4,4, gpu=self.gpu)
a = Tensor.zeros(4,4, device=self.device)
b = Tensor.zeros(4,4, device=self.device)
(a*b).mean().backward()
assert(tensors_allocated() > 0)
del a,b
assert(tensors_allocated() == 0)
def test_gc_complex(self):
a = Tensor.zeros(4,4, gpu=self.gpu)
b = Tensor.zeros(4,4, gpu=self.gpu)
a = Tensor.zeros(4,4, device=self.device)
b = Tensor.zeros(4,4, device=self.device)
assert(tensors_allocated() == 2)
(a*b).mean().backward()
assert(tensors_allocated() == 4)
del b
assert(tensors_allocated() == 2)
b = Tensor.zeros(4,4, gpu=self.gpu)
b = Tensor.zeros(4,4, device=self.device)
print(tensors_allocated())
(a*b).mean().backward()
print(tensors_allocated())
@@ -33,11 +34,13 @@ class TestGC(unittest.TestCase):
del b
assert(tensors_allocated() == 2)
@unittest.skipUnless(GPU, "Requires GPU")
class TestGCGPU(TestGC):
device = Device.GPU
if GPU:
class TestGCGPU(TestGC):
gpu = True
@unittest.skipUnless(ANE, "Requires ANE")
class TestGCANE(TestGC):
device=Device.ANE
if __name__ == '__main__':
unittest.main()

View File

@@ -2,10 +2,11 @@
import os
import unittest
import numpy as np
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, GPU, Device
import tinygrad.optim as optim
from extra.training import train, evaluate
from extra.utils import fetch, get_parameters
from .config import ANE
# mnist loader
def fetch_mnist():
@@ -55,32 +56,36 @@ class TinyConvNet:
return x.dot(self.l1).logsoftmax()
class TestMNIST(unittest.TestCase):
gpu=False
device = Device.CPU
def test_conv(self):
np.random.seed(1337)
model = TinyConvNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=200, gpu=self.gpu)
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
train(model, X_train, Y_train, optimizer, steps=200, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95
def test_sgd(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.SGD(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu)
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
train(model, X_train, Y_train, optimizer, steps=1000, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95
def test_rmsprop(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.RMSprop(model.parameters(), lr=0.0002)
train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu)
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
train(model, X_train, Y_train, optimizer, steps=1000, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95
@unittest.skipUnless(GPU, "Requires GPU")
class TestMNISTGPU(TestMNIST):
gpu = True
device = Device.GPU
@unittest.skipUnless(ANE, "Requires ANE")
class TestMNISTANE(TestMNIST):
device=Device.ANE
if __name__ == '__main__':
unittest.main()

View File

@@ -4,7 +4,8 @@ import cProfile
import pstats
import unittest
import torch
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, GPU, Device
from .config import ANE
def start_profile():
import time
@@ -20,6 +21,8 @@ def stop_profile(pr, sort='cumtime'):
ps.print_stats(0.2)
class TestConvSpeed(unittest.TestCase):
device= Device.CPU
def test_mnist(self):
# https://keras.io/examples/vision/mnist_convnet/
conv = 3
@@ -62,15 +65,15 @@ class TestConvSpeed(unittest.TestCase):
# ****** tinygrad compare *******
c1 = Tensor(c1.detach().numpy())
c2 = Tensor(c2.detach().numpy())
l1 = Tensor(l1.detach().numpy())
c1 = Tensor(c1.detach().numpy(), device=self.device)
c2 = Tensor(c2.detach().numpy(), device=self.device)
l1 = Tensor(l1.detach().numpy(), device=self.device)
cnt = 5
fpt, bpt = 0.0, 0.0
for i in range(1+cnt):
et0 = time.time()
x = Tensor.randn(128, 1, 28, 28)
x = Tensor.randn(128, 1, 28, 28, device=self.device)
x = x.conv2d(c1).relu().avg_pool2d()
x = x.conv2d(c2).relu().max_pool2d()
x = x.reshape(shape=(x.shape[0], -1))
@@ -91,6 +94,14 @@ class TestConvSpeed(unittest.TestCase):
print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline))
print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline))
@unittest.skipUnless(GPU, "Requires GPU")
class TestConvSpeedGPU(TestConvSpeed):
device = Device.GPU
@unittest.skipUnless(ANE, "Requires ANE")
class TestConvSpeedANE(TestConvSpeed):
device=Device.ANE
if __name__ == '__main__':
unittest.main()

View File

@@ -1,10 +1,15 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import GPU, Device
from tinygrad.nn import *
from extra.utils import get_parameters
import torch
from .config import ANE
class TestNN(unittest.TestCase):
device = Device.CPU
def test_batchnorm2d(self, training=False):
sz = 4
@@ -29,13 +34,13 @@ class TestNN(unittest.TestCase):
np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)
# trial
inn = Tensor.randn(2, sz, 3, 3)
inn = Tensor.randn(2, sz, 3, 3, device=self.device)
# in tinygrad
outt = bn(inn)
# in torch
toutt = tbn(torch.tensor(inn.data))
toutt = tbn(torch.tensor(inn.cpu().data))
# close
np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=5e-5)
@@ -48,6 +53,27 @@ class TestNN(unittest.TestCase):
def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)
@unittest.skipUnless(GPU, "Requires GPU")
class TestNNGPU(TestNN):
device = Device.GPU
@unittest.skip("Tests not added")
def test_batchnorm2d(self): pass
@unittest.skip("Tests not added")
def test_batchnorm2d_training(self): pass
@unittest.skipUnless(ANE, "Requires ANE")
class TestNNANE(TestNN):
device=Device.ANE
@unittest.skip("Tests not added")
def test_batchnorm2d(self): pass
@unittest.skip("Tests not added")
def test_batchnorm2d_training(self): pass
if __name__ == '__main__':
unittest.main()

View File

@@ -4,14 +4,17 @@ import numpy as np
import unittest
import timeit
import functools
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, GPU, Device
from .config import ANE
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, gpu=False, forward_only=False):
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False):
torch.manual_seed(0)
ts = [torch.rand(x, requires_grad=True) for x in shps]
tst = [Tensor(x.detach().numpy()) for x in ts]
if gpu:
tst = [x.cuda() for x in tst]
if device==Device.GPU:
tst = [x.gpu() for x in tst]
elif device==Device.ANE:
tst = [x.ane() for x in tst]
out = torch_fxn(*ts)
ret = tinygrad_fxn(*tst)
@@ -23,7 +26,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0
ret.mean().backward()
for t, tt in zip(ts, tst):
np.testing.assert_allclose(t.grad, tt.grad.cpu().data, atol=grad_atol, rtol=grad_rtol)
np.testing.assert_allclose(t.grad, tt.cpu().grad.data, atol=grad_atol, rtol=grad_rtol)
# speed
torch_fp = timeit.Timer(functools.partial(torch_fxn, *ts)).timeit(5) * 1000/5
@@ -38,58 +41,59 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0
print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp))
class TestOps(unittest.TestCase):
gpu = False
device=Device.CPU
def test_add(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add, gpu=self.gpu)
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add, device=self.device)
def test_sub(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub, gpu=self.gpu)
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub, device=self.device)
def test_mul(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul, gpu=self.gpu)
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul, device=self.device)
def test_div(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, gpu=self.gpu)
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, device=self.device)
def test_pow(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, gpu=self.gpu)
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, device=self.device)
def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, device=self.device)
def test_relu(self):
helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, device=self.device)
def test_leakyrelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, device=self.device)
def test_abs(self):
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, device=self.device)
def test_sigmoid(self):
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device)
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, gpu=self.gpu)
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device)
def test_sum(self):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, gpu=self.gpu)
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
def test_sum_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), gpu=self.gpu)
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
def test_mean_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), gpu=self.gpu)
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), device=self.device)
def test_logsoftmax(self):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, device=self.device)
def test_tanh(self):
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, device=self.device)
def test_topo_sort(self):
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6, device=self.device)
def test_scalar_mul(self):
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2, device=self.device)
def test_scalar_rmul(self):
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, device=self.device)
def test_scalar_sub(self):
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, device=self.device)
def test_scalar_rsub(self):
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, gpu=self.gpu)
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, device=self.device)
def test_broadcast_full(self):
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]:
with self.subTest(op=torch_op.__name__, shapes=shapes):
helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu)
helper_test_op(shapes, torch_op, tinygrad_op, device=self.device)
def test_broadcast_partial(self):
@@ -98,17 +102,18 @@ class TestOps(unittest.TestCase):
for shapes in [((1,32,32,32), (1,32,1,1)), ((5,13,24,16,2), (1,13,24,1,1)),
((4,1), (4,5)), ((1,4), (5,4))]:
with self.subTest(op=torch_op.__name__, shapes=shapes):
helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu, forward_only=self.gpu)
# NOTE: ANE backwards?
helper_test_op(shapes, torch_op, tinygrad_op, device=self.device, forward_only=self.device!=Device.CPU)
def test_pad2d(self):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), gpu=self.gpu)
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device)
def test_reshape(self):
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), gpu=self.gpu)
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)), gpu=self.gpu)
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), device=self.device)
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)), device=self.device)
def test_detach(self):
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), gpu=self.gpu, forward_only=True)
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), device=self.device, forward_only=True)
def test_conv2d(self):
for bs in [1,8]:
@@ -119,7 +124,7 @@ class TestOps(unittest.TestCase):
with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W):
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), gpu=self.gpu, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), device=self.device, grad_rtol=1e-5)
def test_strided_conv2d(self):
bs = 4
@@ -128,18 +133,18 @@ class TestOps(unittest.TestCase):
with self.subTest(stride := 2):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), gpu=self.gpu)
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), device=self.device)
with self.subTest(stride := (2,1)):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), gpu=self.gpu)
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), device=self.device)
def test_maxpool2d(self):
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
with self.subTest(kernel_size=ksz):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz),
lambda x: Tensor.max_pool2d(x, kernel_size=ksz), gpu=self.gpu)
lambda x: Tensor.max_pool2d(x, kernel_size=ksz), device=self.device)
def test_avgpool2d(self):
shape = (32,2,111,28)
@@ -147,11 +152,15 @@ class TestOps(unittest.TestCase):
with self.subTest(kernel_size=ksz):
helper_test_op([shape],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), gpu=self.gpu)
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), device=self.device)
@unittest.skipUnless(GPU, "Requires GPU")
class TestOpsGPU(TestOps):
gpu = True
device=Device.GPU
@unittest.skipUnless(ANE, "Requires ANE")
class TestOpsANE(TestOps):
device=Device.ANE
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -1,18 +1,20 @@
import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, GPU, Device
from tinygrad.optim import Adam, SGD, RMSprop
from extra.utils import get_parameters
from .config import ANE
x_init = np.random.randn(1,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)
def step_tinygrad(optim, kwargs={}, gpu=False):
def step_tinygrad(optim, kwargs={}, device=Device.CPU):
net = TinyNet()
optim = optim([net.x, net.W], **kwargs)
if gpu is True: [x.cuda_() for x in get_parameters([net, optim])]
if device==Device.GPU: [x.gpu_() for x in get_parameters([net, optim])]
elif device==Device.ANE: [x.ane_() for x in get_parameters([net, optim])]
out = net.forward()
out.backward()
optim.step()
@@ -54,20 +56,20 @@ class TorchNet():
class TestOptim(unittest.TestCase):
gpu = False
device = Device.CPU
def test_adam(self):
for x,y in zip(step_tinygrad(Adam, gpu=self.gpu),
for x,y in zip(step_tinygrad(Adam, device=self.device),
step_pytorch(torch.optim.Adam)):
np.testing.assert_allclose(x, y, atol=1e-4)
def test_sgd(self):
for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}, gpu=self.gpu),
for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}, device=self.device),
step_pytorch(torch.optim.SGD, kwargs={'lr': 0.001})):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_rmsprop(self):
for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}, gpu=self.gpu),
for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}, device=self.device),
step_pytorch(torch.optim.RMSprop,
kwargs={'lr': 0.001, 'alpha': 0.99})):
np.testing.assert_allclose(x, y, atol=1e-5)
@@ -75,7 +77,11 @@ class TestOptim(unittest.TestCase):
@unittest.skipUnless(GPU, "Requires GPU")
class TestOptimGPU(TestOptim):
gpu = True
device = Device.GPU
@unittest.skipUnless(ANE, "Requires ANE")
class TestOptimANE(TestOptim):
device = Device.ANE
if __name__ == '__main__':

View File

@@ -1,8 +1,10 @@
import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, GPU, Device
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from .config import ANE
x_init = np.random.randn(1,3).astype(np.float32)
U_init = np.random.randn(3,3).astype(np.float32)
@@ -11,13 +13,13 @@ W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)
class TestTinygrad(unittest.TestCase):
gpu = False
device = Device.CPU
def test_backward_pass(self):
def test_tinygrad():
x = Tensor(x_init, gpu=self.gpu)
W = Tensor(W_init, gpu=self.gpu)
m = Tensor(m_init, gpu=self.gpu)
x = Tensor(x_init, device=self.device)
W = Tensor(W_init, device=self.device)
m = Tensor(m_init, device=self.device)
out = x.dot(W).relu()
out = out.logsoftmax()
out = out.mul(m).add(m).sum()
@@ -39,16 +41,16 @@ class TestTinygrad(unittest.TestCase):
def test_backward_pass_diamond_model(self):
def test_tinygrad():
u = Tensor(U_init)
v = Tensor(V_init)
w = Tensor(W_init)
u = Tensor(U_init, device=self.device)
v = Tensor(V_init, device=self.device)
w = Tensor(W_init, device=self.device)
x = u.mul(v).relu()
y = u.mul(w).relu()
out = x.add(y).mul(y).relu()
out = out.logsoftmax()
out = out.sum()
out.backward()
return out.data, u.grad.data, v.grad.data, w.grad.data
return out.cpu().data, u.cpu().grad.data, v.cpu().grad.data, w.cpu().grad.data
def test_pytorch():
u = torch.tensor(U_init, requires_grad=True)
@@ -74,8 +76,8 @@ class TestTinygrad(unittest.TestCase):
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy()
tiny_x = Tensor(x, gpu=self.gpu)
tiny_W = Tensor(W, gpu=self.gpu)
tiny_x = Tensor(x, device=self.device)
tiny_W = Tensor(W, device=self.device)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
J = jacobian(tiny_func, tiny_x)
NJ = numerical_jacobian(tiny_func, tiny_x)
@@ -87,8 +89,8 @@ class TestTinygrad(unittest.TestCase):
W = np.random.RandomState(1337).random((10, 5))
x = np.random.RandomState(7331).random((1, 10)) - 0.5
tiny_x = Tensor(x, gpu=self.gpu)
tiny_W = Tensor(W, gpu=self.gpu)
tiny_x = Tensor(x, device=self.device)
tiny_W = Tensor(W, device=self.device)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
self.assertTrue(gradcheck(tiny_func, tiny_x))
@@ -99,7 +101,7 @@ class TestTinygrad(unittest.TestCase):
@unittest.skipUnless(GPU, "Requires GPU")
class TestTinygradGPU(TestTinygrad):
gpu = True
device = Device.GPU
@unittest.skip("float64 not supported on GPU")
def test_jacobian(self): pass
@@ -107,6 +109,9 @@ class TestTinygradGPU(TestTinygrad):
@unittest.skip("float64 not supported on GPU")
def test_gradcheck(self): pass
@unittest.skipUnless(ANE, "Requires ANE")
class TestOpsANE(TestTinygrad):
device=Device.ANE
if __name__ == '__main__':
unittest.main()

View File

@@ -1,5 +1,5 @@
import numpy as np
from .tensor import Function, register, GPUBuffer, Tensor
from .tensor import Function, register, GPUBuffer, Tensor, Device
import pyopencl as cl
import functools
@@ -178,7 +178,7 @@ class Add(Function):
grad_x, grad_y = grad_output, grad_output
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
register('add', Add, device=Tensor.GPU)
register('add', Add, device=Device.GPU)
class Sub(Function):
@staticmethod
@@ -191,7 +191,7 @@ class Sub(Function):
grad_x, grad_y = grad_output, unary_op(ctx, '-a', grad_output)
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
register('sub', Sub, device=Tensor.GPU)
register('sub', Sub, device=Device.GPU)
class Mul(Function):
@staticmethod
@@ -205,7 +205,7 @@ class Mul(Function):
grad_x = binary_op(ctx, 'a*b', y, grad_output)
grad_y = binary_op(ctx, 'a*b', x, grad_output)
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
register('mul', Mul, device=Tensor.GPU)
register('mul', Mul, device=Device.GPU)
class Pow(Function):
@staticmethod
@@ -221,7 +221,7 @@ class Pow(Function):
grad_y = binary_op(ctx, 'a*b', grad_output,
binary_op(ctx, 'pow(a, (float)b) * log(a);', x, y))
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
register('pow', Pow, device=Tensor.GPU)
register('pow', Pow, device=Device.GPU)
class Sum(Function):
@staticmethod
@@ -237,8 +237,8 @@ class Sum(Function):
input, axis = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
output = GPUBuffer(shape, hostbuf=grad_output)
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape))
register('sum', Sum, device=Tensor.GPU)
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
register('sum', Sum, device=Device.GPU)
class Dot(Function):
@staticmethod
@@ -289,7 +289,7 @@ class Dot(Function):
i32(1), msize, isize, i32(1), osize, osize)
return grad_input, grad_weight
register('dot', Dot, device=Tensor.GPU)
register('dot', Dot, device=Device.GPU)
# ************* simple ops *************
@@ -332,7 +332,7 @@ class Pad2D(Function):
i32(oy), i32(ox), i32(iy), i32(ix)
)
return ret
register('pad2d', Pad2D, device=Tensor.GPU)
register('pad2d', Pad2D, device=Device.GPU)
class Reshape(Function):
@staticmethod
@@ -347,7 +347,7 @@ class Reshape(Function):
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return GPUBuffer(in_shape, hostbuf=grad_output)
register('reshape', Reshape, device=Tensor.GPU)
register('reshape', Reshape, device=Device.GPU)
# ************* activation ops *************
@@ -361,7 +361,7 @@ class ReLU(Function):
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return binary_op(ctx, 'a * (b >= 0)', grad_output, input)
register('relu', ReLU, device=Tensor.GPU)
register('relu', ReLU, device=Device.GPU)
class Sigmoid(Function):
@staticmethod
@@ -374,7 +374,7 @@ class Sigmoid(Function):
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return binary_op(ctx, 'a * (b * (1 - b));', grad_output, ret)
register('sigmoid', Sigmoid, device=Tensor.GPU)
register('sigmoid', Sigmoid, device=Device.GPU)
class AvgPool2D(Function):
@staticmethod
@@ -389,7 +389,7 @@ class AvgPool2D(Function):
orig_shape, = ctx.saved_tensors
return supersample_op(ctx, grad_output, orig_shape, ctx.kernel_size,
result_op="input[iid] / (ksz.x * ksz.y)")
register('avg_pool2d', AvgPool2D, device=Tensor.GPU)
register('avg_pool2d', AvgPool2D, device=Device.GPU)
class MaxPool2D(Function):
@staticmethod
@@ -409,7 +409,7 @@ class MaxPool2D(Function):
result_op="(maxidx == kernidx) * input[iid]",
decls="int maxidx=((__global float*)input2)[iid]; int kernidx=(gid.x%ksz.x) + ksz.x*(gid.y%ksz.y)",
input2=idxs)
register('max_pool2d', MaxPool2D, device=Tensor.GPU)
register('max_pool2d', MaxPool2D, device=Device.GPU)
class LogSoftmax(Function):
@staticmethod
@@ -426,7 +426,7 @@ class LogSoftmax(Function):
lsum = reduce_op(ctx, "out += a", "out", grad_output, axis=[1])
texp = binary_op(ctx, "exp(a) * b", output, lsum)
return binary_op(ctx, "a - b", grad_output, texp)
register('logsoftmax', LogSoftmax, device=Tensor.GPU)
register('logsoftmax', LogSoftmax, device=Device.GPU)
# ************* conv ops *************
@@ -553,4 +553,4 @@ class Conv2D(Function):
convw(ctx.cl_queue, [ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args)
convx(ctx.cl_queue, [bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args)
return dx, dw
register('conv2d', Conv2D, device=Tensor.GPU)
register('conv2d', Conv2D, device=Device.GPU)

View File

@@ -25,7 +25,7 @@ class RMSprop(Optimizer):
super(RMSprop, self).__init__(params)
self.lr, self.decay, self.eps = lr, decay, eps
self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params]
self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), device=params[0].device, requires_grad=False) for t in self.params]
def step(self):
for i, t in enumerate(self.params):
@@ -37,8 +37,8 @@ class Adam(Optimizer):
super(Adam, self).__init__(params)
self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, 0
self.m = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params]
self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params]
self.m = [Tensor(np.zeros(t.shape, dtype=np.float32), device=params[0].device, requires_grad=False) for t in self.params]
self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), device=params[0].device, requires_grad=False) for t in self.params]
def step(self):
self.t = self.t + 1

View File

@@ -1,5 +1,6 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from inspect import signature
import functools
import numpy as np
import os
from collections import defaultdict
@@ -34,6 +35,7 @@ class ProfileOp:
cl_ctx, cl_queue = None, None
def require_init_gpu():
if not GPU: raise Exception("No GPU Support, install pyopencl")
global cl_ctx, cl_queue
if cl_queue is None:
devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.GPU)
@@ -64,33 +66,16 @@ def require_init_ane():
# **** start with two base classes, Tensor and Function ****
class Device: CPU, GPU, ANE = 0, 1, 2
class Tensor:
did_float_warning = False
ops = defaultdict(dict)
CPU, GPU, ANE = 0, 1, 2
def __init__(self, data, device=Device.CPU, requires_grad=True):
self.data = self._move_data(data, device)
def __init__(self, data, gpu=None, requires_grad=True):
if "ANETensor" in str(type(data)):
self.device = Tensor.ANE
elif isinstance(data, list):
data = np.array(data, dtype=np.float32)
elif GPU and isinstance(data, GPUBuffer):
self.device = Tensor.GPU
elif not isinstance(data, np.ndarray):
raise TypeError(f"Error constructing tensor with {data!r}")
if isinstance(data, np.ndarray):
if data.dtype != np.float32 and not Tensor.did_float_warning:
# warning? float64 is actually needed for numerical jacobian
print(f"warning, {data.shape!r} isn't float32")
Tensor.did_float_warning = True
self.device = Tensor.CPU
self.data, self.grad, self.requires_grad = data, None, requires_grad
if gpu:
self.cuda_()
self.device, self.grad, self.requires_grad = device, None, requires_grad
# internal variables used for autograd graph construction
self._ctx = None
@@ -145,7 +130,7 @@ class Tensor:
# fill in the first grad with one
# this is "implicit gradient creation"
self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), gpu=self.gpu, requires_grad=False)
self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk(set(), [])):
assert (t0.grad is not None)
@@ -157,53 +142,59 @@ class Tensor:
if g is not None:
assert g.shape == t.shape, \
f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
gt = Tensor(g, requires_grad=False)
gt = Tensor(g, device=self.device, requires_grad=False)
t.grad = gt if t.grad is None else (t.grad + gt)
# ***** tinygrad supports CPU and GPU *****
def cpu(self):
if self.device == Tensor.GPU:
with ProfileOp("toCPU", [self]):
ret = Tensor(np.empty(self.shape, dtype=np.float32), gpu=False)
cl.enqueue_copy(cl_queue, ret.data, self.data.cl, is_blocking=True)
if self.grad:
ret.grad = self.grad.cpu()
return ret
elif self.device == Tensor.ANE:
return Tensor(self.data.data().astype(np.float32), gpu=False)
else:
return self
@staticmethod
def _move_data(data, device):
if isinstance(data, GPUBuffer):
if device == Device.GPU: return data
old = data
data = np.empty(old.shape, dtype=np.float32)
with ProfileOp("toCPU", [data]):
cl.enqueue_copy(cl_queue, data, old.cl, is_blocking=True)
@property
def gpu(self):
return self.device == Tensor.GPU
elif "ANETensor" in str(type(data)):
if device == Device.ANE: return data
with ProfileOp("toCPU", [data]):
data = data.data().astype(np.float32)
def cuda_(self):
self.data = self.cuda().data
self.device = Tensor.GPU
if not isinstance(data, np.ndarray):
data = np.array(data, dtype=np.float32)
def cuda(self):
if not GPU:
raise Exception("No GPU Support, install pyopencl")
if not self.gpu:
with ProfileOp("toGPU", [self]):
require_init_gpu()
ret = Tensor(GPUBuffer(self.shape, self.data))
if self.grad:
ret.grad = self.grad.cuda()
return ret
return self
if data.dtype != np.float32 and not Tensor.did_float_warning:
# warning? float64 is actually needed for numerical jacobian
print(f"warning, {data.shape!r} isn't float32")
Tensor.did_float_warning = True
def ane(self):
assert(not self.gpu)
require_init_ane()
ndata = ane.tensor(self.shape)
ndata.data()[:] = self.data
return Tensor(ndata)
if device == Device.GPU:
require_init_gpu()
with ProfileOp("toGPU", [data]):
return GPUBuffer(data.shape, data)
elif device == Device.ANE:
require_init_ane()
with ProfileOp("toANE", [data]):
ndata = ane.tensor(data.shape)
ndata.data()[:] = data
return ndata
return data
def to_(self, device):
self.data, self.device = self._move_data(self.data, device), device
if self.grad: self.grad.to_(device)
def to(self, device):
ret = Tensor(self.data, device)
if self.grad: ret.grad = self.grad.to(device)
return ret
def _is(self, device): return self.device == device
def detach(self):
return Tensor(self.data, self.gpu)
return Tensor(self.data, device=self.device)
# ***** non first class ops *****
@@ -232,7 +223,7 @@ class Tensor:
def dropout(self, p=0.5):
_mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
ret = self * Tensor(_mask, requires_grad=False, gpu=self.gpu)
ret = self * Tensor(_mask, requires_grad=False, device=self.device)
return ret.div(1.0 - p)
def abs(self):
@@ -259,18 +250,18 @@ class Function:
setattr(ctx, k, v)
with ProfileOp(ctx.__class__.__name__, x):
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
requires_grad=any([t.requires_grad for t in x]))
device=ctx.device, requires_grad=any([t.requires_grad for t in x]))
if ret.requires_grad:
ret._ctx = ctx
return ret
def register(name, fxn, device=Tensor.CPU):
def register(name, fxn, device=Device.CPU):
Tensor.ops[device][name] = fxn
def dispatch(*x, **kwargs):
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
x = [Tensor(np.array([arg], dtype=tt.dtype), gpu=tt.gpu, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
f = (Tensor.ops[tt.device])[name]
f.cl_ctx, f.cl_queue, f.ane = cl_ctx, cl_queue, ane
f.cl_ctx, f.cl_queue, f.ane, f.device = cl_ctx, cl_queue, ane, tt.device
return f.apply(f, *x, **kwargs)
setattr(Tensor, name, dispatch)
# TODO: div is a second class op, so it doesn't work here
@@ -279,6 +270,11 @@ def register(name, fxn, device=Tensor.CPU):
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))
setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self))
for device in [device for device in Device.__dict__.keys() if device[0] != "_"]:
setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, Device.__dict__[device]))
setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, Device.__dict__[device]))
setattr(Tensor, f"is_{device.lower()}", property(functools.partialmethod(Tensor._is, Device.__dict__[device])))
# this registers all the operations
import tinygrad.ops_cpu
try: