improved shapetracker

This commit is contained in:
George Hotz
2022-06-21 19:16:55 -07:00
parent c53c91f949
commit c833886bf5
2 changed files with 33 additions and 2 deletions

View File

@@ -36,8 +36,31 @@ class DumbShapeTracker:
# Tensor.zeros(2, 4).permute(1,0).reshape(2, 4)
# (d1*4 + d0%4), d1=x//4, d0=x%4 = ((x//4)*4) + (x%4)%4
@unittest.skip("reshape is more complex")
class TestComplexShapeTracker(unittest.TestCase):
def test_add_1s(self):
self.st = ShapeTracker((4, 4))
self.st.permute(1,0)
self.st.reshape(1,4,1,4,1)
assert not self.st.contiguous
self.st.permute(0,3,2,1,4)
assert self.st.contiguous
def test_remove_1s(self):
self.st = ShapeTracker((1, 4, 1, 4, 1))
self.st.permute(0,3,2,1,4)
self.st.reshape(4,4)
assert not self.st.contiguous
self.st.permute(1,0)
assert self.st.contiguous
@unittest.skip("reshape is even more complex")
def test_super_complex(self):
self.st = ShapeTracker((4, 4))
self.st.permute(1,0)
self.st.reshape(2, 2, 2, 2)
self.st.permute(2,3,0,1)
assert self.st.contiguous
def test_work(self):
self.st = ShapeTracker((64, 1024, 4))
self.st.reshape(1, 64, 128, 32)

View File

@@ -99,8 +99,16 @@ class ShapeTracker:
assert all([isinstance(x, int) for x in new_shape])
assert prod(self.shape) == prod(new_shape)
if self.shape == new_shape: return
# check if this is adding or removing 1s (only)
if tuple([x for x in self.shape if x != 1]) == tuple([x for x in new_shape if x != 1]):
old_strides = [y for x,y in zip(self.shape, self.strides) if x != 1]
new_strides = [0 if x == 1 else old_strides.pop(0) for x in new_shape]
self.views[-1] = View(new_shape, new_strides, self.offset)
return
view = View(new_shape, strides_for_shape(new_shape))
if self.contiguous: self.views[-1] = view
if self.contiguous: self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
else: self.views.append(view)
def permute(self, *axis):