diff --git a/ane/2_compile/simple/gemm.plist b/ane/2_compile/simple/gemm.plist index ebeadf4619..fff1d3ebd1 100644 --- a/ane/2_compile/simple/gemm.plist +++ b/ane/2_compile/simple/gemm.plist @@ -31,6 +31,8 @@ image + BatchSize + 512 InputChannels 512 InputHeight diff --git a/ane/lib/testconv.py b/ane/lib/testconv.py index c5709df4a1..3b8542d581 100755 --- a/ane/lib/testconv.py +++ b/ane/lib/testconv.py @@ -15,7 +15,7 @@ def benchmark(ane): ret = ane.run(comp, tin, tout) et = time.time() ts = (et-st) - ops = 1000*512*512*512*2 + ops = 1000*512*512*2 print("%.2f ms, %.2f gigaops/sec" % (ts*1000, ops*1e-9/ts)) diff --git a/examples/transformer.py b/examples/transformer.py new file mode 100755 index 0000000000..f83c5d3063 --- /dev/null +++ b/examples/transformer.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +import numpy as np +import random +from tinygrad.tensor import Tensor + +# dataset idea from https://github.com/karpathy/minGPT/blob/master/play_math.ipynb +def make_dataset(): + ds = [] + for i in range(100): + for j in range(100): + s = i+j + ds.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10]) + random.shuffle(ds) + ds = np.array(ds) + ds_X = ds[:, 0:6] + ds_Y = np.copy(ds[:, 1:]) + ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:] + ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:] + + return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test + +#X_train, Y_train, X_test, Y_test = make_dataset() + +class TransformerBlock: + def __init__(self, embed_dim, num_heads): + # Multi-Head Attention + self.num_heads = num_heads + self.projection_dim = embed_dim // num_heads + assert self.projection_dim * self.num_heads == embed_dim + + # looks like bias is useless + self.query_dense = Tensor.uniform(embed_dim, embed_dim) + self.key_dense = Tensor.uniform(embed_dim, embed_dim) + self.value_dense = Tensor.uniform(embed_dim, embed_dim) + + self.ff1 = Tensor.uniform(embed_dim, embed_dim) + self.ff2 = Tensor.uniform(embed_dim, embed_dim) + + def __call__(self, x): + bs = x.shape[0] + x = x.reshape(shape=(-1, self.num_heads * self.projection_dim)) + + # run multi head attention + qkv = [x.dot(y) \ + .reshape(shape=(bs, -1, self.num_heads, self.projection_dim)) \ + .transpose(order=(0,2,1,3)) \ + for y in [self.query_dense, self.key_dense, self.value_dense]] + + print(qkv[0].shape) + + #query = self.query_dense(x).reshape((bs, -1, self.num_heads, self.projection_dim)) + #key = self.key_dense(x).reshape((bs, -1, self.num_heads, self.projection_dim)) + #value = self.value_dense(x).reshape((bs, -1, self.num_heads, self.projection_dim)) + + #x = self.ff2(self.ff1(x).relu()) + #return x + +if __name__ == "__main__": + tb = TransformerBlock(128, 4) + tmp = Tensor.zeros(20, 10, 128) + ret = tb(tmp) + ret.backward() + print(ret) + diff --git a/test/config.py b/test/config.py deleted file mode 100644 index ab20e8b39e..0000000000 --- a/test/config.py +++ /dev/null @@ -1,3 +0,0 @@ -import os - -ANE = os.environ.get('ANE', False) diff --git a/test/test_gc.py b/test/test_gc.py index 17e86a9114..56b5af66c2 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -1,8 +1,7 @@ #!/usr/bin/env python import gc import unittest -from tinygrad.tensor import Tensor, GPU, Device -from .config import ANE +from tinygrad.tensor import Tensor, GPU, ANE, Device def tensors_allocated(): return sum([isinstance(x, Tensor) for x in gc.get_objects()]) diff --git a/test/test_mnist.py b/test/test_mnist.py index 8edae37f4d..3c38b7683b 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -2,11 +2,10 @@ import os import unittest import numpy as np -from tinygrad.tensor import Tensor, GPU, Device +from tinygrad.tensor import Tensor, GPU, ANE, Device import tinygrad.optim as optim from extra.training import train, evaluate from extra.utils import fetch, get_parameters -from .config import ANE # mnist loader def fetch_mnist(): diff --git a/test/test_net_speed.py b/test/test_net_speed.py index 99827d3e96..38a42cfed3 100644 --- a/test/test_net_speed.py +++ b/test/test_net_speed.py @@ -4,8 +4,7 @@ import cProfile import pstats import unittest import torch -from tinygrad.tensor import Tensor, GPU, Device -from .config import ANE +from tinygrad.tensor import Tensor, GPU, ANE, Device def start_profile(): import time diff --git a/test/test_nn.py b/test/test_nn.py index ba00c7340b..adc55d008d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1,11 +1,10 @@ #!/usr/bin/env python import unittest import numpy as np -from tinygrad.tensor import GPU, Device +from tinygrad.tensor import GPU, ANE, Device from tinygrad.nn import * from extra.utils import get_parameters import torch -from .config import ANE class TestNN(unittest.TestCase): device = Device.CPU diff --git a/test/test_ops.py b/test/test_ops.py index 294993dc94..4c55c43899 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -4,8 +4,7 @@ import numpy as np import unittest import timeit import functools -from tinygrad.tensor import Tensor, GPU, Device -from .config import ANE +from tinygrad.tensor import Tensor, GPU, ANE, Device def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False): torch.manual_seed(0) @@ -108,6 +107,12 @@ class TestOps(unittest.TestCase): def test_pad2d(self): helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device) + def test_transpose(self): + # TODO: transpose for GPU + if self.device == Device.GPU: + return + helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device) + def test_reshape(self): helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), device=self.device) helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)), device=self.device) diff --git a/test/test_optim.py b/test/test_optim.py index 99ddcff57c..be93e1b60f 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1,10 +1,9 @@ import numpy as np import torch import unittest -from tinygrad.tensor import Tensor, GPU, Device +from tinygrad.tensor import Tensor, GPU, ANE, Device from tinygrad.optim import Adam, SGD, RMSprop from extra.utils import get_parameters -from .config import ANE x_init = np.random.randn(1,3).astype(np.float32) W_init = np.random.randn(3,3).astype(np.float32) diff --git a/test/test_tensor.py b/test/test_tensor.py index f54527b084..413a4acf13 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,10 +1,8 @@ import numpy as np import torch import unittest -from tinygrad.tensor import Tensor, GPU, Device +from tinygrad.tensor import Tensor, GPU, ANE, Device from extra.gradcheck import numerical_jacobian, jacobian, gradcheck -from .config import ANE - x_init = np.random.randn(1,3).astype(np.float32) U_init = np.random.randn(3,3).astype(np.float32) diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index 38756a7a29..7b8258eec9 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -113,6 +113,16 @@ class Reshape(Function): return grad_output.reshape(in_shape) register('reshape', Reshape) +class Transpose(Function): + @staticmethod + def forward(ctx, x, order): + ctx.save_for_backward(order) + return np.transpose(x, order) + + @staticmethod + def backward(ctx, x): + return np.transpose(x, np.argsort(ctx.order)) +register('transpose', Transpose) # ************* activation ops ************* diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 88929dce91..b24cf11e4b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -284,3 +284,4 @@ try: except ImportError: # no GPU support GPU = False +ANE = False