mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
All devices are equal! (#196)
* Update all devices to be tested
ANE, CPU and OCL all now support all tests.
However tests are not currently passing on GPU and I cannot test on CPU.
Failing GPU test are not an issue caused by this update. Tests have not
been passing due to a missing "six" required installation.
OpenCL Tests have not been run since commit: 1a1c63a08b
devices have 3 types and are handle by a new DeviceTypes enum. (The goal
is to revert to Tensor.<type>, but this current setup allows for keyword
argument defaults: `device=DeviceType.CPU`)
All references to Tensor.GPU/CPU/ANE as been converted to the
corresponding `DeviceTypes` enum.
Refactor of the conversion code to allow for any device to any device
conversion.
* Add six dependency in requirements.txt
* Resolve failure to run tests
Move six into gpu required installs. Remove six from standard
installation.
* Remove repeated data conversion
* Refactor method names
Also reduce code with .to and .to_
* Dynamic device handlers
* Refactor DeviceTypes -> Device
* Add mem copy profiling back
* test_backward_pass_diamond_model passing
* Resolve Sum issue on GPU
* Revert batchnorm2d tests
* Update README with upadated API
* ANE testing with
* Last minute line gains
This commit is contained in:
@@ -84,7 +84,7 @@ tinygrad supports GPUs through PyOpenCL.
|
||||
|
||||
```python
|
||||
from tinygrad.tensor import Tensor
|
||||
(Tensor.ones(4,4).cuda() + Tensor.ones(4,4).cuda()).cpu()
|
||||
(Tensor.ones(4,4).gpu() + Tensor.ones(4,4).gpu()).cpu()
|
||||
```
|
||||
|
||||
### ANE Support?!
|
||||
|
||||
@@ -2,21 +2,22 @@ import os
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from extra.utils import get_parameters
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
from tinygrad.tensor import Tensor, GPU, Device
|
||||
|
||||
def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=False, lossfn = lambda out,y: out.mul(y).mean()):
|
||||
if gpu is True: [x.cuda_() for x in get_parameters([model, optim])]
|
||||
def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, device=Device.CPU, lossfn = lambda out,y: out.mul(y).mean()):
|
||||
if device == Device.GPU: [x.gpu_() for x in get_parameters([model, optim])]
|
||||
elif device == Device.ANE: [x.ane_() for x in get_parameters([model, optim])]
|
||||
if num_classes is None: num_classes = Y_train.max().astype(int)+1
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
|
||||
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)
|
||||
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), device=device)
|
||||
Y = Y_train[samp]
|
||||
y = np.zeros((len(samp),num_classes), np.float32)
|
||||
# correct loss for NLL, torch NLL loss returns one per row
|
||||
y[range(y.shape[0]),Y] = -1.0*num_classes
|
||||
y = Tensor(y, gpu=gpu)
|
||||
y = Tensor(y, device=device)
|
||||
|
||||
# network
|
||||
out = model.forward(x)
|
||||
@@ -36,11 +37,11 @@ def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=F
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
|
||||
def evaluate(model, X_test, Y_test, num_classes=None, gpu=False, BS=128):
|
||||
def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128):
|
||||
def numpy_eval(num_classes):
|
||||
Y_test_preds_out = np.zeros((len(Y_test),num_classes))
|
||||
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)).cpu().data
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), device=device)).cpu().data
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=1)
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -22,7 +22,7 @@ setup(name='tinygrad',
|
||||
install_requires=['numpy', 'requests'],
|
||||
python_requires='>=3.8',
|
||||
extras_require={
|
||||
'gpu': ["pyopencl"],
|
||||
'gpu': ["pyopencl", "six"],
|
||||
'testing': [
|
||||
"pytest",
|
||||
"torch",
|
||||
|
||||
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
3
test/config.py
Normal file
3
test/config.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import os
|
||||
|
||||
ANE = os.environ.get('ANE', False)
|
||||
@@ -1,31 +1,32 @@
|
||||
#!/usr/bin/env python
|
||||
import gc
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
from tinygrad.tensor import Tensor, GPU, Device
|
||||
from .config import ANE
|
||||
|
||||
def tensors_allocated():
|
||||
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
|
||||
|
||||
class TestGC(unittest.TestCase):
|
||||
gpu = False
|
||||
device = Device.CPU
|
||||
|
||||
def test_gc(self):
|
||||
a = Tensor.zeros(4,4, gpu=self.gpu)
|
||||
b = Tensor.zeros(4,4, gpu=self.gpu)
|
||||
a = Tensor.zeros(4,4, device=self.device)
|
||||
b = Tensor.zeros(4,4, device=self.device)
|
||||
(a*b).mean().backward()
|
||||
assert(tensors_allocated() > 0)
|
||||
del a,b
|
||||
assert(tensors_allocated() == 0)
|
||||
|
||||
def test_gc_complex(self):
|
||||
a = Tensor.zeros(4,4, gpu=self.gpu)
|
||||
b = Tensor.zeros(4,4, gpu=self.gpu)
|
||||
a = Tensor.zeros(4,4, device=self.device)
|
||||
b = Tensor.zeros(4,4, device=self.device)
|
||||
assert(tensors_allocated() == 2)
|
||||
(a*b).mean().backward()
|
||||
assert(tensors_allocated() == 4)
|
||||
del b
|
||||
assert(tensors_allocated() == 2)
|
||||
b = Tensor.zeros(4,4, gpu=self.gpu)
|
||||
b = Tensor.zeros(4,4, device=self.device)
|
||||
print(tensors_allocated())
|
||||
(a*b).mean().backward()
|
||||
print(tensors_allocated())
|
||||
@@ -33,11 +34,13 @@ class TestGC(unittest.TestCase):
|
||||
del b
|
||||
assert(tensors_allocated() == 2)
|
||||
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestGCGPU(TestGC):
|
||||
device = Device.GPU
|
||||
|
||||
if GPU:
|
||||
class TestGCGPU(TestGC):
|
||||
gpu = True
|
||||
@unittest.skipUnless(ANE, "Requires ANE")
|
||||
class TestGCANE(TestGC):
|
||||
device=Device.ANE
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
from tinygrad.tensor import Tensor, GPU, Device
|
||||
import tinygrad.optim as optim
|
||||
from extra.training import train, evaluate
|
||||
from extra.utils import fetch, get_parameters
|
||||
from .config import ANE
|
||||
|
||||
# mnist loader
|
||||
def fetch_mnist():
|
||||
@@ -55,32 +56,36 @@ class TinyConvNet:
|
||||
return x.dot(self.l1).logsoftmax()
|
||||
|
||||
class TestMNIST(unittest.TestCase):
|
||||
gpu=False
|
||||
device = Device.CPU
|
||||
|
||||
def test_conv(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, steps=200, gpu=self.gpu)
|
||||
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
|
||||
train(model, X_train, Y_train, optimizer, steps=200, device=self.device)
|
||||
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95
|
||||
|
||||
def test_sgd(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu)
|
||||
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
|
||||
train(model, X_train, Y_train, optimizer, steps=1000, device=self.device)
|
||||
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95
|
||||
|
||||
def test_rmsprop(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=0.0002)
|
||||
train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu)
|
||||
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
|
||||
train(model, X_train, Y_train, optimizer, steps=1000, device=self.device)
|
||||
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestMNISTGPU(TestMNIST):
|
||||
gpu = True
|
||||
device = Device.GPU
|
||||
|
||||
@unittest.skipUnless(ANE, "Requires ANE")
|
||||
class TestMNISTANE(TestMNIST):
|
||||
device=Device.ANE
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -4,7 +4,8 @@ import cProfile
|
||||
import pstats
|
||||
import unittest
|
||||
import torch
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.tensor import Tensor, GPU, Device
|
||||
from .config import ANE
|
||||
|
||||
def start_profile():
|
||||
import time
|
||||
@@ -20,6 +21,8 @@ def stop_profile(pr, sort='cumtime'):
|
||||
ps.print_stats(0.2)
|
||||
|
||||
class TestConvSpeed(unittest.TestCase):
|
||||
device= Device.CPU
|
||||
|
||||
def test_mnist(self):
|
||||
# https://keras.io/examples/vision/mnist_convnet/
|
||||
conv = 3
|
||||
@@ -62,15 +65,15 @@ class TestConvSpeed(unittest.TestCase):
|
||||
|
||||
# ****** tinygrad compare *******
|
||||
|
||||
c1 = Tensor(c1.detach().numpy())
|
||||
c2 = Tensor(c2.detach().numpy())
|
||||
l1 = Tensor(l1.detach().numpy())
|
||||
c1 = Tensor(c1.detach().numpy(), device=self.device)
|
||||
c2 = Tensor(c2.detach().numpy(), device=self.device)
|
||||
l1 = Tensor(l1.detach().numpy(), device=self.device)
|
||||
|
||||
cnt = 5
|
||||
fpt, bpt = 0.0, 0.0
|
||||
for i in range(1+cnt):
|
||||
et0 = time.time()
|
||||
x = Tensor.randn(128, 1, 28, 28)
|
||||
x = Tensor.randn(128, 1, 28, 28, device=self.device)
|
||||
x = x.conv2d(c1).relu().avg_pool2d()
|
||||
x = x.conv2d(c2).relu().max_pool2d()
|
||||
x = x.reshape(shape=(x.shape[0], -1))
|
||||
@@ -91,6 +94,14 @@ class TestConvSpeed(unittest.TestCase):
|
||||
print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline))
|
||||
print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline))
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestConvSpeedGPU(TestConvSpeed):
|
||||
device = Device.GPU
|
||||
|
||||
@unittest.skipUnless(ANE, "Requires ANE")
|
||||
class TestConvSpeedANE(TestConvSpeed):
|
||||
device=Device.ANE
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import GPU, Device
|
||||
from tinygrad.nn import *
|
||||
from extra.utils import get_parameters
|
||||
import torch
|
||||
from .config import ANE
|
||||
|
||||
class TestNN(unittest.TestCase):
|
||||
device = Device.CPU
|
||||
|
||||
def test_batchnorm2d(self, training=False):
|
||||
sz = 4
|
||||
|
||||
@@ -29,13 +34,13 @@ class TestNN(unittest.TestCase):
|
||||
np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)
|
||||
|
||||
# trial
|
||||
inn = Tensor.randn(2, sz, 3, 3)
|
||||
inn = Tensor.randn(2, sz, 3, 3, device=self.device)
|
||||
|
||||
# in tinygrad
|
||||
outt = bn(inn)
|
||||
|
||||
# in torch
|
||||
toutt = tbn(torch.tensor(inn.data))
|
||||
toutt = tbn(torch.tensor(inn.cpu().data))
|
||||
|
||||
# close
|
||||
np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=5e-5)
|
||||
@@ -48,6 +53,27 @@ class TestNN(unittest.TestCase):
|
||||
def test_batchnorm2d_training(self):
|
||||
self.test_batchnorm2d(True)
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestNNGPU(TestNN):
|
||||
device = Device.GPU
|
||||
|
||||
@unittest.skip("Tests not added")
|
||||
def test_batchnorm2d(self): pass
|
||||
|
||||
@unittest.skip("Tests not added")
|
||||
def test_batchnorm2d_training(self): pass
|
||||
|
||||
|
||||
@unittest.skipUnless(ANE, "Requires ANE")
|
||||
class TestNNANE(TestNN):
|
||||
device=Device.ANE
|
||||
|
||||
@unittest.skip("Tests not added")
|
||||
def test_batchnorm2d(self): pass
|
||||
|
||||
@unittest.skip("Tests not added")
|
||||
def test_batchnorm2d_training(self): pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -4,14 +4,17 @@ import numpy as np
|
||||
import unittest
|
||||
import timeit
|
||||
import functools
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
from tinygrad.tensor import Tensor, GPU, Device
|
||||
from .config import ANE
|
||||
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, gpu=False, forward_only=False):
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False):
|
||||
torch.manual_seed(0)
|
||||
ts = [torch.rand(x, requires_grad=True) for x in shps]
|
||||
tst = [Tensor(x.detach().numpy()) for x in ts]
|
||||
if gpu:
|
||||
tst = [x.cuda() for x in tst]
|
||||
if device==Device.GPU:
|
||||
tst = [x.gpu() for x in tst]
|
||||
elif device==Device.ANE:
|
||||
tst = [x.ane() for x in tst]
|
||||
|
||||
out = torch_fxn(*ts)
|
||||
ret = tinygrad_fxn(*tst)
|
||||
@@ -23,7 +26,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0
|
||||
ret.mean().backward()
|
||||
|
||||
for t, tt in zip(ts, tst):
|
||||
np.testing.assert_allclose(t.grad, tt.grad.cpu().data, atol=grad_atol, rtol=grad_rtol)
|
||||
np.testing.assert_allclose(t.grad, tt.cpu().grad.data, atol=grad_atol, rtol=grad_rtol)
|
||||
|
||||
# speed
|
||||
torch_fp = timeit.Timer(functools.partial(torch_fxn, *ts)).timeit(5) * 1000/5
|
||||
@@ -38,58 +41,59 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0
|
||||
print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp))
|
||||
|
||||
class TestOps(unittest.TestCase):
|
||||
gpu = False
|
||||
device=Device.CPU
|
||||
|
||||
def test_add(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add, gpu=self.gpu)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add, device=self.device)
|
||||
def test_sub(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub, gpu=self.gpu)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub, device=self.device)
|
||||
def test_mul(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul, gpu=self.gpu)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul, device=self.device)
|
||||
def test_div(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, gpu=self.gpu)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, device=self.device)
|
||||
def test_pow(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, gpu=self.gpu)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, device=self.device)
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, device=self.device)
|
||||
def test_relu(self):
|
||||
helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, device=self.device)
|
||||
def test_leakyrelu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, device=self.device)
|
||||
def test_abs(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, device=self.device)
|
||||
def test_sigmoid(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, gpu=self.gpu)
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device)
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, gpu=self.gpu)
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
|
||||
def test_sum_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), gpu=self.gpu)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
|
||||
def test_mean_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), gpu=self.gpu)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), device=self.device)
|
||||
def test_logsoftmax(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, device=self.device)
|
||||
def test_tanh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, device=self.device)
|
||||
def test_topo_sort(self):
|
||||
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6, device=self.device)
|
||||
|
||||
def test_scalar_mul(self):
|
||||
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2, device=self.device)
|
||||
def test_scalar_rmul(self):
|
||||
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, device=self.device)
|
||||
|
||||
def test_scalar_sub(self):
|
||||
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, device=self.device)
|
||||
def test_scalar_rsub(self):
|
||||
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, device=self.device)
|
||||
|
||||
def test_broadcast_full(self):
|
||||
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
|
||||
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
|
||||
for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]:
|
||||
with self.subTest(op=torch_op.__name__, shapes=shapes):
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu)
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, device=self.device)
|
||||
|
||||
|
||||
def test_broadcast_partial(self):
|
||||
@@ -98,17 +102,18 @@ class TestOps(unittest.TestCase):
|
||||
for shapes in [((1,32,32,32), (1,32,1,1)), ((5,13,24,16,2), (1,13,24,1,1)),
|
||||
((4,1), (4,5)), ((1,4), (5,4))]:
|
||||
with self.subTest(op=torch_op.__name__, shapes=shapes):
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu, forward_only=self.gpu)
|
||||
# NOTE: ANE backwards?
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, device=self.device, forward_only=self.device!=Device.CPU)
|
||||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), gpu=self.gpu)
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device)
|
||||
|
||||
def test_reshape(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), gpu=self.gpu)
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)), gpu=self.gpu)
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), device=self.device)
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)), device=self.device)
|
||||
|
||||
def test_detach(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), gpu=self.gpu, forward_only=True)
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), device=self.device, forward_only=True)
|
||||
|
||||
def test_conv2d(self):
|
||||
for bs in [1,8]:
|
||||
@@ -119,7 +124,7 @@ class TestOps(unittest.TestCase):
|
||||
with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W):
|
||||
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), gpu=self.gpu, grad_rtol=1e-5)
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), device=self.device, grad_rtol=1e-5)
|
||||
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
@@ -128,18 +133,18 @@ class TestOps(unittest.TestCase):
|
||||
with self.subTest(stride := 2):
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), gpu=self.gpu)
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), device=self.device)
|
||||
with self.subTest(stride := (2,1)):
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), gpu=self.gpu)
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), device=self.device)
|
||||
|
||||
def test_maxpool2d(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([(32,2,110,28)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz), gpu=self.gpu)
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz), device=self.device)
|
||||
|
||||
def test_avgpool2d(self):
|
||||
shape = (32,2,111,28)
|
||||
@@ -147,11 +152,15 @@ class TestOps(unittest.TestCase):
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), gpu=self.gpu)
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), device=self.device)
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestOpsGPU(TestOps):
|
||||
gpu = True
|
||||
device=Device.GPU
|
||||
|
||||
@unittest.skipUnless(ANE, "Requires ANE")
|
||||
class TestOpsANE(TestOps):
|
||||
device=Device.ANE
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
from tinygrad.tensor import Tensor, GPU, Device
|
||||
from tinygrad.optim import Adam, SGD, RMSprop
|
||||
from extra.utils import get_parameters
|
||||
from .config import ANE
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
W_init = np.random.randn(3,3).astype(np.float32)
|
||||
m_init = np.random.randn(1,3).astype(np.float32)
|
||||
|
||||
def step_tinygrad(optim, kwargs={}, gpu=False):
|
||||
def step_tinygrad(optim, kwargs={}, device=Device.CPU):
|
||||
net = TinyNet()
|
||||
optim = optim([net.x, net.W], **kwargs)
|
||||
if gpu is True: [x.cuda_() for x in get_parameters([net, optim])]
|
||||
if device==Device.GPU: [x.gpu_() for x in get_parameters([net, optim])]
|
||||
elif device==Device.ANE: [x.ane_() for x in get_parameters([net, optim])]
|
||||
out = net.forward()
|
||||
out.backward()
|
||||
optim.step()
|
||||
@@ -54,20 +56,20 @@ class TorchNet():
|
||||
|
||||
|
||||
class TestOptim(unittest.TestCase):
|
||||
gpu = False
|
||||
device = Device.CPU
|
||||
|
||||
def test_adam(self):
|
||||
for x,y in zip(step_tinygrad(Adam, gpu=self.gpu),
|
||||
for x,y in zip(step_tinygrad(Adam, device=self.device),
|
||||
step_pytorch(torch.optim.Adam)):
|
||||
np.testing.assert_allclose(x, y, atol=1e-4)
|
||||
|
||||
def test_sgd(self):
|
||||
for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}, gpu=self.gpu),
|
||||
for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}, device=self.device),
|
||||
step_pytorch(torch.optim.SGD, kwargs={'lr': 0.001})):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
def test_rmsprop(self):
|
||||
for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}, gpu=self.gpu),
|
||||
for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}, device=self.device),
|
||||
step_pytorch(torch.optim.RMSprop,
|
||||
kwargs={'lr': 0.001, 'alpha': 0.99})):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
@@ -75,7 +77,11 @@ class TestOptim(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestOptimGPU(TestOptim):
|
||||
gpu = True
|
||||
device = Device.GPU
|
||||
|
||||
@unittest.skipUnless(ANE, "Requires ANE")
|
||||
class TestOptimANE(TestOptim):
|
||||
device = Device.ANE
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
from tinygrad.tensor import Tensor, GPU, Device
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from .config import ANE
|
||||
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
U_init = np.random.randn(3,3).astype(np.float32)
|
||||
@@ -11,13 +13,13 @@ W_init = np.random.randn(3,3).astype(np.float32)
|
||||
m_init = np.random.randn(1,3).astype(np.float32)
|
||||
|
||||
class TestTinygrad(unittest.TestCase):
|
||||
gpu = False
|
||||
device = Device.CPU
|
||||
|
||||
def test_backward_pass(self):
|
||||
def test_tinygrad():
|
||||
x = Tensor(x_init, gpu=self.gpu)
|
||||
W = Tensor(W_init, gpu=self.gpu)
|
||||
m = Tensor(m_init, gpu=self.gpu)
|
||||
x = Tensor(x_init, device=self.device)
|
||||
W = Tensor(W_init, device=self.device)
|
||||
m = Tensor(m_init, device=self.device)
|
||||
out = x.dot(W).relu()
|
||||
out = out.logsoftmax()
|
||||
out = out.mul(m).add(m).sum()
|
||||
@@ -39,16 +41,16 @@ class TestTinygrad(unittest.TestCase):
|
||||
|
||||
def test_backward_pass_diamond_model(self):
|
||||
def test_tinygrad():
|
||||
u = Tensor(U_init)
|
||||
v = Tensor(V_init)
|
||||
w = Tensor(W_init)
|
||||
u = Tensor(U_init, device=self.device)
|
||||
v = Tensor(V_init, device=self.device)
|
||||
w = Tensor(W_init, device=self.device)
|
||||
x = u.mul(v).relu()
|
||||
y = u.mul(w).relu()
|
||||
out = x.add(y).mul(y).relu()
|
||||
out = out.logsoftmax()
|
||||
out = out.sum()
|
||||
out.backward()
|
||||
return out.data, u.grad.data, v.grad.data, w.grad.data
|
||||
return out.cpu().data, u.cpu().grad.data, v.cpu().grad.data, w.cpu().grad.data
|
||||
|
||||
def test_pytorch():
|
||||
u = torch.tensor(U_init, requires_grad=True)
|
||||
@@ -74,8 +76,8 @@ class TestTinygrad(unittest.TestCase):
|
||||
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
|
||||
PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy()
|
||||
|
||||
tiny_x = Tensor(x, gpu=self.gpu)
|
||||
tiny_W = Tensor(W, gpu=self.gpu)
|
||||
tiny_x = Tensor(x, device=self.device)
|
||||
tiny_W = Tensor(W, device=self.device)
|
||||
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
|
||||
J = jacobian(tiny_func, tiny_x)
|
||||
NJ = numerical_jacobian(tiny_func, tiny_x)
|
||||
@@ -87,8 +89,8 @@ class TestTinygrad(unittest.TestCase):
|
||||
W = np.random.RandomState(1337).random((10, 5))
|
||||
x = np.random.RandomState(7331).random((1, 10)) - 0.5
|
||||
|
||||
tiny_x = Tensor(x, gpu=self.gpu)
|
||||
tiny_W = Tensor(W, gpu=self.gpu)
|
||||
tiny_x = Tensor(x, device=self.device)
|
||||
tiny_W = Tensor(W, device=self.device)
|
||||
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
|
||||
|
||||
self.assertTrue(gradcheck(tiny_func, tiny_x))
|
||||
@@ -99,7 +101,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestTinygradGPU(TestTinygrad):
|
||||
gpu = True
|
||||
device = Device.GPU
|
||||
|
||||
@unittest.skip("float64 not supported on GPU")
|
||||
def test_jacobian(self): pass
|
||||
@@ -107,6 +109,9 @@ class TestTinygradGPU(TestTinygrad):
|
||||
@unittest.skip("float64 not supported on GPU")
|
||||
def test_gradcheck(self): pass
|
||||
|
||||
@unittest.skipUnless(ANE, "Requires ANE")
|
||||
class TestOpsANE(TestTinygrad):
|
||||
device=Device.ANE
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from .tensor import Function, register, GPUBuffer, Tensor
|
||||
from .tensor import Function, register, GPUBuffer, Tensor, Device
|
||||
import pyopencl as cl
|
||||
import functools
|
||||
|
||||
@@ -178,7 +178,7 @@ class Add(Function):
|
||||
grad_x, grad_y = grad_output, grad_output
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
|
||||
register('add', Add, device=Tensor.GPU)
|
||||
register('add', Add, device=Device.GPU)
|
||||
|
||||
class Sub(Function):
|
||||
@staticmethod
|
||||
@@ -191,7 +191,7 @@ class Sub(Function):
|
||||
grad_x, grad_y = grad_output, unary_op(ctx, '-a', grad_output)
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
|
||||
register('sub', Sub, device=Tensor.GPU)
|
||||
register('sub', Sub, device=Device.GPU)
|
||||
|
||||
class Mul(Function):
|
||||
@staticmethod
|
||||
@@ -205,7 +205,7 @@ class Mul(Function):
|
||||
grad_x = binary_op(ctx, 'a*b', y, grad_output)
|
||||
grad_y = binary_op(ctx, 'a*b', x, grad_output)
|
||||
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
|
||||
register('mul', Mul, device=Tensor.GPU)
|
||||
register('mul', Mul, device=Device.GPU)
|
||||
|
||||
class Pow(Function):
|
||||
@staticmethod
|
||||
@@ -221,7 +221,7 @@ class Pow(Function):
|
||||
grad_y = binary_op(ctx, 'a*b', grad_output,
|
||||
binary_op(ctx, 'pow(a, (float)b) * log(a);', x, y))
|
||||
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
|
||||
register('pow', Pow, device=Tensor.GPU)
|
||||
register('pow', Pow, device=Device.GPU)
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
@@ -237,8 +237,8 @@ class Sum(Function):
|
||||
input, axis = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
output = GPUBuffer(shape, hostbuf=grad_output)
|
||||
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape))
|
||||
register('sum', Sum, device=Tensor.GPU)
|
||||
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
|
||||
register('sum', Sum, device=Device.GPU)
|
||||
|
||||
class Dot(Function):
|
||||
@staticmethod
|
||||
@@ -289,7 +289,7 @@ class Dot(Function):
|
||||
i32(1), msize, isize, i32(1), osize, osize)
|
||||
|
||||
return grad_input, grad_weight
|
||||
register('dot', Dot, device=Tensor.GPU)
|
||||
register('dot', Dot, device=Device.GPU)
|
||||
|
||||
# ************* simple ops *************
|
||||
|
||||
@@ -332,7 +332,7 @@ class Pad2D(Function):
|
||||
i32(oy), i32(ox), i32(iy), i32(ix)
|
||||
)
|
||||
return ret
|
||||
register('pad2d', Pad2D, device=Tensor.GPU)
|
||||
register('pad2d', Pad2D, device=Device.GPU)
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
@@ -347,7 +347,7 @@ class Reshape(Function):
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return GPUBuffer(in_shape, hostbuf=grad_output)
|
||||
register('reshape', Reshape, device=Tensor.GPU)
|
||||
register('reshape', Reshape, device=Device.GPU)
|
||||
|
||||
# ************* activation ops *************
|
||||
|
||||
@@ -361,7 +361,7 @@ class ReLU(Function):
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a * (b >= 0)', grad_output, input)
|
||||
register('relu', ReLU, device=Tensor.GPU)
|
||||
register('relu', ReLU, device=Device.GPU)
|
||||
|
||||
class Sigmoid(Function):
|
||||
@staticmethod
|
||||
@@ -374,7 +374,7 @@ class Sigmoid(Function):
|
||||
def backward(ctx, grad_output):
|
||||
ret, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a * (b * (1 - b));', grad_output, ret)
|
||||
register('sigmoid', Sigmoid, device=Tensor.GPU)
|
||||
register('sigmoid', Sigmoid, device=Device.GPU)
|
||||
|
||||
class AvgPool2D(Function):
|
||||
@staticmethod
|
||||
@@ -389,7 +389,7 @@ class AvgPool2D(Function):
|
||||
orig_shape, = ctx.saved_tensors
|
||||
return supersample_op(ctx, grad_output, orig_shape, ctx.kernel_size,
|
||||
result_op="input[iid] / (ksz.x * ksz.y)")
|
||||
register('avg_pool2d', AvgPool2D, device=Tensor.GPU)
|
||||
register('avg_pool2d', AvgPool2D, device=Device.GPU)
|
||||
|
||||
class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
@@ -409,7 +409,7 @@ class MaxPool2D(Function):
|
||||
result_op="(maxidx == kernidx) * input[iid]",
|
||||
decls="int maxidx=((__global float*)input2)[iid]; int kernidx=(gid.x%ksz.x) + ksz.x*(gid.y%ksz.y)",
|
||||
input2=idxs)
|
||||
register('max_pool2d', MaxPool2D, device=Tensor.GPU)
|
||||
register('max_pool2d', MaxPool2D, device=Device.GPU)
|
||||
|
||||
class LogSoftmax(Function):
|
||||
@staticmethod
|
||||
@@ -426,7 +426,7 @@ class LogSoftmax(Function):
|
||||
lsum = reduce_op(ctx, "out += a", "out", grad_output, axis=[1])
|
||||
texp = binary_op(ctx, "exp(a) * b", output, lsum)
|
||||
return binary_op(ctx, "a - b", grad_output, texp)
|
||||
register('logsoftmax', LogSoftmax, device=Tensor.GPU)
|
||||
register('logsoftmax', LogSoftmax, device=Device.GPU)
|
||||
|
||||
# ************* conv ops *************
|
||||
|
||||
@@ -553,4 +553,4 @@ class Conv2D(Function):
|
||||
convw(ctx.cl_queue, [ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args)
|
||||
convx(ctx.cl_queue, [bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args)
|
||||
return dx, dw
|
||||
register('conv2d', Conv2D, device=Tensor.GPU)
|
||||
register('conv2d', Conv2D, device=Device.GPU)
|
||||
|
||||
@@ -25,7 +25,7 @@ class RMSprop(Optimizer):
|
||||
super(RMSprop, self).__init__(params)
|
||||
self.lr, self.decay, self.eps = lr, decay, eps
|
||||
|
||||
self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params]
|
||||
self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), device=params[0].device, requires_grad=False) for t in self.params]
|
||||
|
||||
def step(self):
|
||||
for i, t in enumerate(self.params):
|
||||
@@ -37,8 +37,8 @@ class Adam(Optimizer):
|
||||
super(Adam, self).__init__(params)
|
||||
self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, 0
|
||||
|
||||
self.m = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params]
|
||||
self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params]
|
||||
self.m = [Tensor(np.zeros(t.shape, dtype=np.float32), device=params[0].device, requires_grad=False) for t in self.params]
|
||||
self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), device=params[0].device, requires_grad=False) for t in self.params]
|
||||
|
||||
def step(self):
|
||||
self.t = self.t + 1
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from inspect import signature
|
||||
import functools
|
||||
import numpy as np
|
||||
import os
|
||||
from collections import defaultdict
|
||||
@@ -34,6 +35,7 @@ class ProfileOp:
|
||||
|
||||
cl_ctx, cl_queue = None, None
|
||||
def require_init_gpu():
|
||||
if not GPU: raise Exception("No GPU Support, install pyopencl")
|
||||
global cl_ctx, cl_queue
|
||||
if cl_queue is None:
|
||||
devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.GPU)
|
||||
@@ -64,33 +66,16 @@ def require_init_ane():
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Device: CPU, GPU, ANE = 0, 1, 2
|
||||
|
||||
class Tensor:
|
||||
did_float_warning = False
|
||||
ops = defaultdict(dict)
|
||||
|
||||
CPU, GPU, ANE = 0, 1, 2
|
||||
def __init__(self, data, device=Device.CPU, requires_grad=True):
|
||||
self.data = self._move_data(data, device)
|
||||
|
||||
def __init__(self, data, gpu=None, requires_grad=True):
|
||||
if "ANETensor" in str(type(data)):
|
||||
self.device = Tensor.ANE
|
||||
elif isinstance(data, list):
|
||||
data = np.array(data, dtype=np.float32)
|
||||
elif GPU and isinstance(data, GPUBuffer):
|
||||
self.device = Tensor.GPU
|
||||
elif not isinstance(data, np.ndarray):
|
||||
raise TypeError(f"Error constructing tensor with {data!r}")
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
if data.dtype != np.float32 and not Tensor.did_float_warning:
|
||||
# warning? float64 is actually needed for numerical jacobian
|
||||
print(f"warning, {data.shape!r} isn't float32")
|
||||
Tensor.did_float_warning = True
|
||||
self.device = Tensor.CPU
|
||||
|
||||
self.data, self.grad, self.requires_grad = data, None, requires_grad
|
||||
|
||||
if gpu:
|
||||
self.cuda_()
|
||||
self.device, self.grad, self.requires_grad = device, None, requires_grad
|
||||
|
||||
# internal variables used for autograd graph construction
|
||||
self._ctx = None
|
||||
@@ -145,7 +130,7 @@ class Tensor:
|
||||
|
||||
# fill in the first grad with one
|
||||
# this is "implicit gradient creation"
|
||||
self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), gpu=self.gpu, requires_grad=False)
|
||||
self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk(set(), [])):
|
||||
assert (t0.grad is not None)
|
||||
@@ -157,53 +142,59 @@ class Tensor:
|
||||
if g is not None:
|
||||
assert g.shape == t.shape, \
|
||||
f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
|
||||
gt = Tensor(g, requires_grad=False)
|
||||
gt = Tensor(g, device=self.device, requires_grad=False)
|
||||
t.grad = gt if t.grad is None else (t.grad + gt)
|
||||
|
||||
# ***** tinygrad supports CPU and GPU *****
|
||||
|
||||
def cpu(self):
|
||||
if self.device == Tensor.GPU:
|
||||
with ProfileOp("toCPU", [self]):
|
||||
ret = Tensor(np.empty(self.shape, dtype=np.float32), gpu=False)
|
||||
cl.enqueue_copy(cl_queue, ret.data, self.data.cl, is_blocking=True)
|
||||
if self.grad:
|
||||
ret.grad = self.grad.cpu()
|
||||
return ret
|
||||
elif self.device == Tensor.ANE:
|
||||
return Tensor(self.data.data().astype(np.float32), gpu=False)
|
||||
else:
|
||||
return self
|
||||
@staticmethod
|
||||
def _move_data(data, device):
|
||||
if isinstance(data, GPUBuffer):
|
||||
if device == Device.GPU: return data
|
||||
old = data
|
||||
data = np.empty(old.shape, dtype=np.float32)
|
||||
with ProfileOp("toCPU", [data]):
|
||||
cl.enqueue_copy(cl_queue, data, old.cl, is_blocking=True)
|
||||
|
||||
@property
|
||||
def gpu(self):
|
||||
return self.device == Tensor.GPU
|
||||
elif "ANETensor" in str(type(data)):
|
||||
if device == Device.ANE: return data
|
||||
with ProfileOp("toCPU", [data]):
|
||||
data = data.data().astype(np.float32)
|
||||
|
||||
def cuda_(self):
|
||||
self.data = self.cuda().data
|
||||
self.device = Tensor.GPU
|
||||
if not isinstance(data, np.ndarray):
|
||||
data = np.array(data, dtype=np.float32)
|
||||
|
||||
def cuda(self):
|
||||
if not GPU:
|
||||
raise Exception("No GPU Support, install pyopencl")
|
||||
if not self.gpu:
|
||||
with ProfileOp("toGPU", [self]):
|
||||
require_init_gpu()
|
||||
ret = Tensor(GPUBuffer(self.shape, self.data))
|
||||
if self.grad:
|
||||
ret.grad = self.grad.cuda()
|
||||
return ret
|
||||
return self
|
||||
if data.dtype != np.float32 and not Tensor.did_float_warning:
|
||||
# warning? float64 is actually needed for numerical jacobian
|
||||
print(f"warning, {data.shape!r} isn't float32")
|
||||
Tensor.did_float_warning = True
|
||||
|
||||
def ane(self):
|
||||
assert(not self.gpu)
|
||||
require_init_ane()
|
||||
ndata = ane.tensor(self.shape)
|
||||
ndata.data()[:] = self.data
|
||||
return Tensor(ndata)
|
||||
if device == Device.GPU:
|
||||
require_init_gpu()
|
||||
with ProfileOp("toGPU", [data]):
|
||||
return GPUBuffer(data.shape, data)
|
||||
|
||||
elif device == Device.ANE:
|
||||
require_init_ane()
|
||||
with ProfileOp("toANE", [data]):
|
||||
ndata = ane.tensor(data.shape)
|
||||
ndata.data()[:] = data
|
||||
return ndata
|
||||
return data
|
||||
|
||||
def to_(self, device):
|
||||
self.data, self.device = self._move_data(self.data, device), device
|
||||
if self.grad: self.grad.to_(device)
|
||||
|
||||
def to(self, device):
|
||||
ret = Tensor(self.data, device)
|
||||
if self.grad: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
def _is(self, device): return self.device == device
|
||||
|
||||
def detach(self):
|
||||
return Tensor(self.data, self.gpu)
|
||||
return Tensor(self.data, device=self.device)
|
||||
|
||||
# ***** non first class ops *****
|
||||
|
||||
@@ -232,7 +223,7 @@ class Tensor:
|
||||
|
||||
def dropout(self, p=0.5):
|
||||
_mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
|
||||
ret = self * Tensor(_mask, requires_grad=False, gpu=self.gpu)
|
||||
ret = self * Tensor(_mask, requires_grad=False, device=self.device)
|
||||
return ret.div(1.0 - p)
|
||||
|
||||
def abs(self):
|
||||
@@ -259,18 +250,18 @@ class Function:
|
||||
setattr(ctx, k, v)
|
||||
with ProfileOp(ctx.__class__.__name__, x):
|
||||
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
|
||||
requires_grad=any([t.requires_grad for t in x]))
|
||||
device=ctx.device, requires_grad=any([t.requires_grad for t in x]))
|
||||
if ret.requires_grad:
|
||||
ret._ctx = ctx
|
||||
return ret
|
||||
|
||||
def register(name, fxn, device=Tensor.CPU):
|
||||
def register(name, fxn, device=Device.CPU):
|
||||
Tensor.ops[device][name] = fxn
|
||||
def dispatch(*x, **kwargs):
|
||||
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
|
||||
x = [Tensor(np.array([arg], dtype=tt.dtype), gpu=tt.gpu, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
f = (Tensor.ops[tt.device])[name]
|
||||
f.cl_ctx, f.cl_queue, f.ane = cl_ctx, cl_queue, ane
|
||||
f.cl_ctx, f.cl_queue, f.ane, f.device = cl_ctx, cl_queue, ane, tt.device
|
||||
return f.apply(f, *x, **kwargs)
|
||||
setattr(Tensor, name, dispatch)
|
||||
# TODO: div is a second class op, so it doesn't work here
|
||||
@@ -279,6 +270,11 @@ def register(name, fxn, device=Tensor.CPU):
|
||||
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))
|
||||
setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self))
|
||||
|
||||
for device in [device for device in Device.__dict__.keys() if device[0] != "_"]:
|
||||
setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, Device.__dict__[device]))
|
||||
setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, Device.__dict__[device]))
|
||||
setattr(Tensor, f"is_{device.lower()}", property(functools.partialmethod(Tensor._is, Device.__dict__[device])))
|
||||
|
||||
# this registers all the operations
|
||||
import tinygrad.ops_cpu
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user