mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 23:25:04 -05:00
Make ShapeTracker Immutable (#1909)
* ugh * ops test pass * fix shapetracker tests * sym shapetracker * shapetracker is a tuple of views now * from_shape * fix has variable shape * key isn't needed * post init assert
This commit is contained in:
@@ -14,42 +14,51 @@ def shapetracker_getitem(st, val):
|
||||
|
||||
class CheckingShapeTracker:
|
||||
def __init__(self, shape):
|
||||
self.st = ShapeTracker(shape)
|
||||
self.st = ShapeTracker.from_shape(shape)
|
||||
self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.t.shape
|
||||
|
||||
def simplify(self): self.st.simplify()
|
||||
def simplify(self):
|
||||
self.st = self.st.simplify()
|
||||
return self
|
||||
|
||||
def reshape(self, new_shape):
|
||||
self.st.reshape(new_shape)
|
||||
self.st = self.st.reshape(new_shape)
|
||||
self.t = self.t.reshape(new_shape)
|
||||
return self
|
||||
|
||||
def permute(self, axis):
|
||||
self.st.permute(axis)
|
||||
self.st = self.st.permute(axis)
|
||||
self.t = np.transpose(self.t, axis)
|
||||
return self
|
||||
|
||||
def expand(self, new_shape):
|
||||
self.st.expand(new_shape)
|
||||
self.st = self.st.expand(new_shape)
|
||||
self.t = np.broadcast_to(self.t, new_shape)
|
||||
return self
|
||||
|
||||
def flip(self, axis):
|
||||
self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
|
||||
self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
|
||||
self.t = np.flip(self.t, axis)
|
||||
return self
|
||||
|
||||
def shrink(self, arg):
|
||||
self.st.shrink(arg)
|
||||
self.st = self.st.shrink(arg)
|
||||
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
|
||||
return self
|
||||
|
||||
def pad(self, arg):
|
||||
self.st.pad(arg)
|
||||
self.st = self.st.pad(arg)
|
||||
self.t = np.pad(self.t, arg, constant_values=-1)
|
||||
return self
|
||||
|
||||
def stride(self, arg):
|
||||
self.st.stride(arg)
|
||||
self.st = self.st.stride(arg)
|
||||
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
|
||||
return self
|
||||
|
||||
def __getitem__(self, val):
|
||||
return self.t.flatten()[val]
|
||||
@@ -70,7 +79,7 @@ class CheckingShapeTracker:
|
||||
|
||||
class TestRealIssues(unittest.TestCase):
|
||||
def test_reshape_doesnt_multiview(self):
|
||||
self.st = ShapeTracker((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), views=[View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None)])
|
||||
self.st = ShapeTracker((View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None),))
|
||||
self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2))
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
@@ -78,27 +87,27 @@ class TestRealDoesntSimplify(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
print(st)
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
assert len(self.st.views) != 1
|
||||
assert None in st
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((8, 6, 11), views=[
|
||||
self.st = ShapeTracker((
|
||||
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
|
||||
View.create((8, 6, 11), (66, 11, 1), 0, None)])
|
||||
View.create((8, 6, 11), (66, 11, 1), 0, None)))
|
||||
assert self.st.real_strides() == (33, None, 1)
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((4, 4, 3, 3), views=[
|
||||
self.st = ShapeTracker((
|
||||
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
|
||||
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)])
|
||||
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
|
||||
assert self.st.real_strides() == (None, 18, -3, -1)
|
||||
|
||||
class TestRealStrides(unittest.TestCase):
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((16, 32, 4), views=[
|
||||
self.st = ShapeTracker((
|
||||
View.create((2048,), (1,), 0, ((0, 512),)),
|
||||
View.create((16, 32, 4), (128, 4, 1), 0, None)])
|
||||
View.create((16, 32, 4), (128, 4, 1), 0, None)))
|
||||
st = self.st.real_strides()
|
||||
print(self.st, st)
|
||||
assert st == (None, 4, 1)
|
||||
@@ -106,27 +115,27 @@ class TestRealStrides(unittest.TestCase):
|
||||
class TestRealSimplifies(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
assert len(self.st.views) == 1
|
||||
print(self.st.views[-1].strides, st)
|
||||
assert self.st.views[-1].strides == st
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((1, 3, 2, 11, 26, 1, 1, 3), views=[
|
||||
self.st = ShapeTracker((
|
||||
View.create((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None),
|
||||
View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)])
|
||||
View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)))
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((8, 1, 6, 10, 28, 3, 2, 1), views=[
|
||||
self.st = ShapeTracker((
|
||||
View.create((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
|
||||
View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
|
||||
View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)))
|
||||
|
||||
class TestIndexExpressions2d(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
|
||||
offsets = [0, 1, 15, 28, 10000]
|
||||
self.sts = [ShapeTracker(base_shape, [View.create(base_shape, offset=offset)]) for base_shape in shapes for offset in offsets]
|
||||
self.sts = [ShapeTracker((View.create(base_shape, offset=offset),)) for base_shape in shapes for offset in offsets]
|
||||
self.offset = [Variable.num(offset) for base_shape in shapes for offset in offsets]
|
||||
self.shapes = [shape for shape in shapes for offset in offsets]
|
||||
self.node_exprs = []
|
||||
@@ -171,36 +180,52 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset)
|
||||
|
||||
def test_permute(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
st = st.permute((1, 0))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_reshape(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
st = st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_reshape_expand(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
st.expand((base_shape[0], base_shape[1], base_shape[1]))
|
||||
st = st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
st = st.expand((base_shape[0], base_shape[1], base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_permute_reshape_1(self): # This tests multiple views
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
|
||||
st = st.permute((1, 0))
|
||||
st = st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_permute_reshape_2(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
st.reshape((1, base_shape[0]//5, base_shape[1]*5))
|
||||
st = st.permute((1, 0))
|
||||
st = st.reshape((1, base_shape[0]//5, base_shape[1]*5))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -211,14 +236,14 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
|
||||
# multiview simplify
|
||||
def test_expand_contract_simple(self):
|
||||
self.st.expand((10, 10))
|
||||
self.st.reshape((100,))
|
||||
self.st = self.st.expand((10, 10))
|
||||
self.st = self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
self.st.reshape((10, 10))
|
||||
self.st = self.st.reshape((10, 10))
|
||||
print(self.st.views)
|
||||
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 1)
|
||||
|
||||
@@ -231,7 +256,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
self.st.reshape((2, 5, 2, 5))
|
||||
print(self.st.views)
|
||||
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 1)
|
||||
|
||||
@@ -243,7 +268,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
assert(len(self.st.views) == 2)
|
||||
self.st.reshape((5, 20))
|
||||
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
|
||||
@@ -387,7 +412,7 @@ class TestShapeTrackerFuzzFailures(unittest.TestCase):
|
||||
self.st.reshape((1, 4))
|
||||
self.st.shrink(((0, 1), (1, 3)))
|
||||
print(self.st.st)
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.st)
|
||||
def test_case_2(self):
|
||||
self.st.stride( (1, 1, -2) )
|
||||
|
||||
Reference in New Issue
Block a user