mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
* tinygrad.nn * flake8 * working on pylint * more pylint * more pylint * pylint passes * networkx * mypy can't infer that type * junk
246 lines
6.7 KiB
Python
246 lines
6.7 KiB
Python
#!/usr/bin/env python
|
|
import unittest
|
|
import numpy as np
|
|
from tinygrad.helpers import prod
|
|
from tinygrad.shapetracker import ShapeTracker
|
|
|
|
class DumbShapeTracker:
|
|
def __init__(self, shape):
|
|
self.t = np.arange(prod(shape), dtype=np.uint8).reshape(shape)
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.t.shape
|
|
|
|
def reshape(self, *new_shape):
|
|
self.t = self.t.reshape(new_shape)
|
|
|
|
def permute(self, *axis):
|
|
self.t = np.transpose(self.t, axis)
|
|
|
|
def expand(self, *new_shape):
|
|
self.t = np.broadcast_to(self.t, new_shape)
|
|
|
|
def flip(self, *axis):
|
|
self.t = np.flip(self.t, axis)
|
|
|
|
def shrink(self, *arg):
|
|
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
|
|
|
|
def stride(self, *arg):
|
|
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
|
|
|
|
def __getitem__(self, val):
|
|
return self.t.flatten()[val]
|
|
|
|
# Tensor.zeros(2, 4).permute(1,0).reshape(2, 4)
|
|
# (d1*4 + d0%4), d1=x//4, d0=x%4 = ((x//4)*4) + (x%4)%4
|
|
|
|
class TestComplexShapeTracker(unittest.TestCase):
|
|
def test_add_1s(self):
|
|
self.st = ShapeTracker((4, 4))
|
|
self.st.permute(1,0)
|
|
self.st.reshape(1,4,1,4,1)
|
|
assert not self.st.contiguous
|
|
self.st.permute(0,3,2,1,4)
|
|
assert self.st.contiguous
|
|
|
|
def test_permute_1s_simple(self):
|
|
self.st = ShapeTracker((1, 16, 9,9))
|
|
self.st.permute(1,0,2,3)
|
|
assert self.st.contiguous
|
|
self.st = ShapeTracker((2, 16, 9,9))
|
|
self.st.permute(1,0,2,3)
|
|
assert not self.st.contiguous
|
|
|
|
def test_remove_1s_simple(self):
|
|
self.st = ShapeTracker((1, 16, 1, 1))
|
|
self.st.reshape(16,)
|
|
assert self.st.contiguous
|
|
|
|
def test_remove_1s(self):
|
|
self.st = ShapeTracker((1, 4, 1, 4, 1))
|
|
self.st.permute(0,3,2,1,4)
|
|
self.st.reshape(4,4)
|
|
assert not self.st.contiguous
|
|
self.st.permute(1,0)
|
|
assert self.st.contiguous
|
|
|
|
@unittest.skip("reshape is even more complex")
|
|
def test_super_complex(self):
|
|
self.st = ShapeTracker((4, 4))
|
|
self.st.permute(1,0)
|
|
self.st.reshape(2, 2, 2, 2)
|
|
self.st.permute(2,3,0,1)
|
|
assert self.st.contiguous
|
|
|
|
def test_work(self):
|
|
self.st = ShapeTracker((64, 1024, 4))
|
|
self.st.reshape(1, 64, 128, 32)
|
|
self.st.permute(0, 3, 1, 2)
|
|
self.st.reshape(1, 32, 1, 64, 128)
|
|
self.st.permute(0, 3, 4, 1, 2)
|
|
assert self.st.contiguous
|
|
|
|
def test_work2(self):
|
|
self.st = ShapeTracker((64, 1024, 4))
|
|
self.st.reshape(1, 64, 128, 32)
|
|
self.st.permute(0, 3, 1, 2)
|
|
self.st.reshape(1, 1, 32, 64, 128)
|
|
self.st.permute(0, 3, 4, 1, 2)
|
|
self.st.reshape(64, 1024, 4)
|
|
print(self.st.views)
|
|
assert self.st.contiguous
|
|
|
|
class TestSingleShapeTracker(unittest.TestCase):
|
|
def setUp(self):
|
|
self.st = ShapeTracker((7,4))
|
|
|
|
def test_reshape(self):
|
|
self.st.reshape(7,1,4)
|
|
assert self.st.contiguous
|
|
|
|
def test_permute(self):
|
|
self.st.permute(1,0)
|
|
assert not self.st.contiguous
|
|
|
|
def test_shrink(self):
|
|
self.st.shrink((1,2), (0,4))
|
|
assert not self.st.contiguous
|
|
|
|
def test_double_permute(self):
|
|
self.st.permute(1,0)
|
|
self.st.permute(1,0)
|
|
assert self.st.contiguous
|
|
|
|
def test_reshape_permute(self):
|
|
self.st.reshape(7,1,4)
|
|
self.st.permute(0,1,2)
|
|
assert self.st.contiguous
|
|
|
|
def test_reshape_permute_yes(self):
|
|
self.st.reshape(7,1,4)
|
|
self.st.permute(0,2,1)
|
|
assert self.st.contiguous
|
|
|
|
def test_reshape_permute_no(self):
|
|
self.st.reshape(4,7)
|
|
self.st.permute(1,0)
|
|
assert not self.st.contiguous
|
|
|
|
def shapetracker_getitem(st, val):
|
|
locals = {"idx": val, "valid": 1}
|
|
exec(st.expr(), None, locals)
|
|
return locals["idx"] if locals["valid"] else -1
|
|
|
|
class TestShapeTracker(unittest.TestCase):
|
|
def setUp(self):
|
|
self.st = ShapeTracker((7,4))
|
|
self.dt = DumbShapeTracker((7,4))
|
|
self.apply = lambda fxn: [fxn(x) for x in [self.st, self.dt]]
|
|
|
|
def tearDown(self):
|
|
x = [shapetracker_getitem(self.st, i) for i in range(prod(self.st.shape))]
|
|
y = [self.dt[i] for i in range(prod(self.dt.shape))]
|
|
print(x,y, self.st.shape, self.dt.shape, self.st.expr())
|
|
assert self.st.shape == self.dt.shape
|
|
assert x == y
|
|
|
|
def test_noop(self):
|
|
pass
|
|
|
|
def test_simple_split(self):
|
|
self.test_permute()
|
|
self.apply(lambda x: x.reshape(prod(self.st.shape)))
|
|
|
|
def test_reshape(self):
|
|
assert self.st.shape == self.dt.shape
|
|
new_shape = self.st.shape[::-1]
|
|
self.apply(lambda x: x.reshape(*new_shape))
|
|
|
|
def test_permute(self):
|
|
assert self.st.shape == self.dt.shape
|
|
if len(self.st.shape) == 2: self.apply(lambda x: x.permute(1,0))
|
|
elif len(self.st.shape) == 3: self.apply(lambda x: x.permute(2,0,1))
|
|
|
|
def test_reshape_with_1(self):
|
|
assert self.st.shape == self.dt.shape
|
|
new_shape = [self.st.shape[0], 1, self.st.shape[1]]
|
|
self.apply(lambda x: x.reshape(*new_shape))
|
|
|
|
def test_expand(self):
|
|
self.test_reshape_with_1()
|
|
new_shape = list(self.st.shape)
|
|
new_shape[1] = 2
|
|
self.apply(lambda x: x.expand(*new_shape))
|
|
|
|
def test_flip_0(self):
|
|
self.apply(lambda x: x.flip(0))
|
|
|
|
def test_flip_1(self):
|
|
self.apply(lambda x: x.flip(1))
|
|
|
|
def test_flip_01(self):
|
|
self.apply(lambda x: x.flip(0,1))
|
|
|
|
def test_slice_0(self):
|
|
self.apply(lambda x: x.shrink((1, x.shape[0]), (0, x.shape[1])))
|
|
|
|
def test_slice_1(self):
|
|
self.apply(lambda x: x.shrink((0, x.shape[0]), (1, x.shape[1])))
|
|
|
|
def test_slice_1c1(self):
|
|
self.apply(lambda x: x.shrink((0, 1), (0, 1)))
|
|
|
|
def test_slice_1c2(self):
|
|
self.apply(lambda x: x.shrink((1, 2), (1, 2)))
|
|
|
|
def test_double_permute(self):
|
|
self.apply(lambda x: x.permute(1, 0))
|
|
self.apply(lambda x: x.permute(1, 0))
|
|
|
|
def test_slice_permute(self):
|
|
self.apply(lambda x: x.shrink((0, 2), (2, 4)))
|
|
self.apply(lambda x: x.permute(1, 0))
|
|
|
|
def test_slice_expand(self):
|
|
self.apply(lambda x: x.shrink((0, 2), (3, 4)))
|
|
self.apply(lambda x: x.expand(2, 10))
|
|
|
|
def test_double_stride(self):
|
|
self.apply(lambda x: x.stride(1, 2))
|
|
self.apply(lambda x: x.stride(2, 1))
|
|
|
|
def test_stride(self): self.apply(lambda x: x.stride(2,1))
|
|
def test_stride_int(self): self.apply(lambda x: x.stride(1,2))
|
|
def test_stride_2(self): self.apply(lambda x: x.stride(2,2))
|
|
def test_stride_n(self): self.apply(lambda x: x.stride(-2,1))
|
|
def test_stride_int_n(self): self.apply(lambda x: x.stride(-1,2))
|
|
def test_stride_2_n(self): self.apply(lambda x: x.stride(-2,-2))
|
|
|
|
def test_reshape_then_permute(self):
|
|
self.test_reshape()
|
|
self.test_permute()
|
|
|
|
def test_reshape_then_expand(self):
|
|
self.test_reshape()
|
|
self.test_expand()
|
|
|
|
def test_permute_then_reshape(self):
|
|
self.test_permute()
|
|
self.test_reshape()
|
|
|
|
def test_expand_then_reshape(self):
|
|
self.test_expand()
|
|
self.test_reshape()
|
|
|
|
def test_combo(self):
|
|
self.test_permute()
|
|
self.test_reshape()
|
|
self.test_slice_1()
|
|
self.test_expand()
|
|
self.test_permute()
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|