mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
losing lines (#678)
* losing lines * FLIP -> STRIDE * shapetracker refactor
This commit is contained in:
@@ -23,31 +23,31 @@ class CheckingShapeTracker:
|
||||
def simplify(self): self.st.simplify()
|
||||
|
||||
def reshape(self, new_shape):
|
||||
self.st.reshape(new_shape)
|
||||
self.st._reshape(new_shape)
|
||||
self.t = self.t.reshape(new_shape)
|
||||
|
||||
def permute(self, axis):
|
||||
self.st.permute(axis)
|
||||
self.st._permute(axis)
|
||||
self.t = np.transpose(self.t, axis)
|
||||
|
||||
def expand(self, new_shape):
|
||||
self.st.expand(new_shape)
|
||||
self.st._expand(new_shape)
|
||||
self.t = np.broadcast_to(self.t, new_shape)
|
||||
|
||||
def flip(self, axis):
|
||||
self.st.flip(axis)
|
||||
self.st._stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
|
||||
self.t = np.flip(self.t, axis)
|
||||
|
||||
def shrink(self, arg):
|
||||
self.st.shrink(arg)
|
||||
self.st._shrink(arg)
|
||||
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
|
||||
|
||||
def pad(self, arg):
|
||||
self.st.pad(arg)
|
||||
self.st._pad(arg)
|
||||
self.t = np.pad(self.t, arg, constant_values=-1)
|
||||
|
||||
def stride(self, arg):
|
||||
self.st.stride(arg)
|
||||
self.st._stride(arg)
|
||||
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
|
||||
|
||||
def __getitem__(self, val):
|
||||
@@ -148,7 +148,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
|
||||
class TestComplexShapeTracker(unittest.TestCase):
|
||||
def test_add_1s(self):
|
||||
self.st = ShapeTracker((4, 4))
|
||||
self.st = CheckingShapeTracker((4, 4))
|
||||
self.st.permute((1,0))
|
||||
self.st.reshape((1,4,1,4,1))
|
||||
assert not self.st.contiguous
|
||||
@@ -156,20 +156,20 @@ class TestComplexShapeTracker(unittest.TestCase):
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_permute_1s_simple(self):
|
||||
self.st = ShapeTracker((1, 16, 9,9))
|
||||
self.st = CheckingShapeTracker((1, 16, 9,9))
|
||||
self.st.permute((1,0,2,3))
|
||||
assert self.st.contiguous
|
||||
self.st = ShapeTracker((2, 16, 9,9))
|
||||
self.st = CheckingShapeTracker((2, 16, 9,9))
|
||||
self.st.permute((1,0,2,3))
|
||||
assert not self.st.contiguous
|
||||
|
||||
def test_remove_1s_simple(self):
|
||||
self.st = ShapeTracker((1, 16, 1, 1))
|
||||
self.st = CheckingShapeTracker((1, 16, 1, 1))
|
||||
self.st.reshape((16,))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_remove_1s(self):
|
||||
self.st = ShapeTracker((1, 4, 1, 4, 1))
|
||||
self.st = CheckingShapeTracker((1, 4, 1, 4, 1))
|
||||
self.st.permute((0,3,2,1,4))
|
||||
self.st.reshape((4,4))
|
||||
assert not self.st.contiguous
|
||||
@@ -177,46 +177,46 @@ class TestComplexShapeTracker(unittest.TestCase):
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_permute_reshape(self):
|
||||
self.st = ShapeTracker((4, 4))
|
||||
self.st = CheckingShapeTracker((4, 4))
|
||||
self.st.permute((1,0))
|
||||
self.st.reshape((2, 2, 2, 2))
|
||||
# TODO: should also be tested by test_super_complex
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
def test_factorize_split(self):
|
||||
self.st = ShapeTracker((4, 4))
|
||||
self.st = CheckingShapeTracker((4, 4))
|
||||
self.st.permute((1,0))
|
||||
self.st.reshape((2, 2, 2, 2))
|
||||
self.st.permute((2,3,0,1))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_factorize_combine(self):
|
||||
self.st = ShapeTracker((4, 4, 4))
|
||||
self.st = CheckingShapeTracker((4, 4, 4))
|
||||
self.st.permute((2, 0, 1))
|
||||
self.st.reshape((4, 16))
|
||||
self.st.permute((1, 0))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_factorize_combine_add_ones(self):
|
||||
self.st = ShapeTracker((4, 4, 4))
|
||||
self.st = CheckingShapeTracker((4, 4, 4))
|
||||
self.st.permute((2, 0, 1))
|
||||
self.st.reshape((4, 16, 1, 1))
|
||||
self.st.permute((1, 0, 2, 3))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_fancy_factorize(self):
|
||||
self.st = ShapeTracker((32, 3, 3, 1))
|
||||
self.st = CheckingShapeTracker((32, 3, 3, 1))
|
||||
self.st.reshape((8, 4, 3, 3))
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
def test_super_complex_2_fail(self):
|
||||
self.st = ShapeTracker((4, 4, 4))
|
||||
self.st = CheckingShapeTracker((4, 4, 4))
|
||||
self.st.permute((2, 0, 1))
|
||||
self.st.reshape((16, 4))
|
||||
assert len(self.st.views) != 1
|
||||
|
||||
def test_work(self):
|
||||
self.st = ShapeTracker((64, 1024, 4))
|
||||
self.st = CheckingShapeTracker((64, 1024, 4))
|
||||
self.st.reshape((1, 64, 128, 32))
|
||||
self.st.permute((0, 3, 1, 2))
|
||||
self.st.reshape((1, 32, 1, 64, 128))
|
||||
@@ -224,7 +224,7 @@ class TestComplexShapeTracker(unittest.TestCase):
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_work2(self):
|
||||
self.st = ShapeTracker((64, 1024, 4))
|
||||
self.st = CheckingShapeTracker((64, 1024, 4))
|
||||
self.st.reshape((1, 64, 128, 32))
|
||||
self.st.permute((0, 3, 1, 2))
|
||||
self.st.reshape((1, 1, 32, 64, 128))
|
||||
|
||||
Reference in New Issue
Block a user