Make ShapeTracker Immutable (#1909)

* ugh

* ops test pass

* fix shapetracker tests

* sym shapetracker

* shapetracker is a tuple of views now

* from_shape

* fix has variable shape

* key isn't needed

* post init assert
This commit is contained in:
George Hotz
2023-09-24 21:09:03 +08:00
committed by GitHub
parent 45f02393f0
commit 20059dc55b
11 changed files with 143 additions and 128 deletions

View File

@@ -14,42 +14,51 @@ def shapetracker_getitem(st, val):
class CheckingShapeTracker:
def __init__(self, shape):
self.st = ShapeTracker(shape)
self.st = ShapeTracker.from_shape(shape)
self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape)
@property
def shape(self):
return self.t.shape
def simplify(self): self.st.simplify()
def simplify(self):
self.st = self.st.simplify()
return self
def reshape(self, new_shape):
self.st.reshape(new_shape)
self.st = self.st.reshape(new_shape)
self.t = self.t.reshape(new_shape)
return self
def permute(self, axis):
self.st.permute(axis)
self.st = self.st.permute(axis)
self.t = np.transpose(self.t, axis)
return self
def expand(self, new_shape):
self.st.expand(new_shape)
self.st = self.st.expand(new_shape)
self.t = np.broadcast_to(self.t, new_shape)
return self
def flip(self, axis):
self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
self.t = np.flip(self.t, axis)
return self
def shrink(self, arg):
self.st.shrink(arg)
self.st = self.st.shrink(arg)
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
return self
def pad(self, arg):
self.st.pad(arg)
self.st = self.st.pad(arg)
self.t = np.pad(self.t, arg, constant_values=-1)
return self
def stride(self, arg):
self.st.stride(arg)
self.st = self.st.stride(arg)
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
return self
def __getitem__(self, val):
return self.t.flatten()[val]
@@ -70,7 +79,7 @@ class CheckingShapeTracker:
class TestRealIssues(unittest.TestCase):
def test_reshape_doesnt_multiview(self):
self.st = ShapeTracker((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), views=[View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None)])
self.st = ShapeTracker((View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None),))
self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2))
assert len(self.st.views) == 1
@@ -78,27 +87,27 @@ class TestRealDoesntSimplify(unittest.TestCase):
def tearDown(self):
st = self.st.real_strides()
print(st)
self.st.simplify()
self.st = self.st.simplify()
assert len(self.st.views) != 1
assert None in st
def test_1(self):
self.st = ShapeTracker((8, 6, 11), views=[
self.st = ShapeTracker((
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
View.create((8, 6, 11), (66, 11, 1), 0, None)])
View.create((8, 6, 11), (66, 11, 1), 0, None)))
assert self.st.real_strides() == (33, None, 1)
def test_2(self):
self.st = ShapeTracker((4, 4, 3, 3), views=[
self.st = ShapeTracker((
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)])
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
assert self.st.real_strides() == (None, 18, -3, -1)
class TestRealStrides(unittest.TestCase):
def test_1(self):
self.st = ShapeTracker((16, 32, 4), views=[
self.st = ShapeTracker((
View.create((2048,), (1,), 0, ((0, 512),)),
View.create((16, 32, 4), (128, 4, 1), 0, None)])
View.create((16, 32, 4), (128, 4, 1), 0, None)))
st = self.st.real_strides()
print(self.st, st)
assert st == (None, 4, 1)
@@ -106,27 +115,27 @@ class TestRealStrides(unittest.TestCase):
class TestRealSimplifies(unittest.TestCase):
def tearDown(self):
st = self.st.real_strides()
self.st.simplify()
self.st = self.st.simplify()
assert len(self.st.views) == 1
print(self.st.views[-1].strides, st)
assert self.st.views[-1].strides == st
def test_1(self):
self.st = ShapeTracker((1, 3, 2, 11, 26, 1, 1, 3), views=[
self.st = ShapeTracker((
View.create((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None),
View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)])
View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)))
def test_2(self):
self.st = ShapeTracker((8, 1, 6, 10, 28, 3, 2, 1), views=[
self.st = ShapeTracker((
View.create((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)))
class TestIndexExpressions2d(unittest.TestCase):
def setUp(self):
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
offsets = [0, 1, 15, 28, 10000]
self.sts = [ShapeTracker(base_shape, [View.create(base_shape, offset=offset)]) for base_shape in shapes for offset in offsets]
self.sts = [ShapeTracker((View.create(base_shape, offset=offset),)) for base_shape in shapes for offset in offsets]
self.offset = [Variable.num(offset) for base_shape in shapes for offset in offsets]
self.shapes = [shape for shape in shapes for offset in offsets]
self.node_exprs = []
@@ -171,36 +180,52 @@ class TestIndexExpressions2d(unittest.TestCase):
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset)
def test_permute(self):
new_st = []
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
st = st.permute((1, 0))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset)
new_st.append(st)
self.sts = new_st
def test_reshape(self):
new_st = []
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.reshape((base_shape[0], 1, base_shape[1]))
st = st.reshape((base_shape[0], 1, base_shape[1]))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
new_st.append(st)
self.sts = new_st
def test_reshape_expand(self):
new_st = []
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.reshape((base_shape[0], 1, base_shape[1]))
st.expand((base_shape[0], base_shape[1], base_shape[1]))
st = st.reshape((base_shape[0], 1, base_shape[1]))
st = st.expand((base_shape[0], base_shape[1], base_shape[1]))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
new_st.append(st)
self.sts = new_st
def test_permute_reshape_1(self): # This tests multiple views
new_st = []
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
st = st.permute((1, 0))
st = st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
new_st.append(st)
self.sts = new_st
def test_permute_reshape_2(self):
new_st = []
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
st.reshape((1, base_shape[0]//5, base_shape[1]*5))
st = st.permute((1, 0))
st = st.reshape((1, base_shape[0]//5, base_shape[1]*5))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
new_st.append(st)
self.sts = new_st
class TestSimplifyingShapeTracker(unittest.TestCase):
def setUp(self):
@@ -211,14 +236,14 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
# multiview simplify
def test_expand_contract_simple(self):
self.st.expand((10, 10))
self.st.reshape((100,))
self.st = self.st.expand((10, 10))
self.st = self.st.reshape((100,))
print(self.st.views)
assert(len(self.st.views) == 2)
self.st.reshape((10, 10))
self.st = self.st.reshape((10, 10))
print(self.st.views)
self.st.simplify()
self.st = self.st.simplify()
print(self.st.views)
assert(len(self.st.views) == 1)
@@ -231,7 +256,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
self.st.reshape((2, 5, 2, 5))
print(self.st.views)
self.st.simplify()
self.st = self.st.simplify()
print(self.st.views)
assert(len(self.st.views) == 1)
@@ -243,7 +268,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
assert(len(self.st.views) == 2)
self.st.reshape((5, 20))
self.st.simplify()
self.st = self.st.simplify()
print(self.st.views)
assert(len(self.st.views) == 2)
@@ -387,7 +412,7 @@ class TestShapeTrackerFuzzFailures(unittest.TestCase):
self.st.reshape((1, 4))
self.st.shrink(((0, 1), (1, 3)))
print(self.st.st)
self.st.simplify()
self.st = self.st.simplify()
print(self.st.st)
def test_case_2(self):
self.st.stride( (1, 1, -2) )