support symbolic reshape with non-contiguous (#4844)

* support symbolic reshape with non-contiguous

pre-requisite for symbolic arange (make symbolic ones that can be folded).

* test cases

* typo

* shorter
This commit is contained in:
chenyu
2024-06-05 16:01:19 -04:00
committed by GitHub
parent a352b6d9ce
commit 99e7a1d5e9
3 changed files with 51 additions and 5 deletions

View File

@@ -182,8 +182,7 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
# TODO: without contiguous, the CONST shape are different in jit
t = Tensor.ones(i).contiguous()
t = Tensor.ones(i)
symbolic = jf(t.reshape(vi)).item()
expected = f(t).item()
np.testing.assert_equal(symbolic, expected)

View File

@@ -102,7 +102,7 @@ class TestShapeTrackerUnbind(unittest.TestCase):
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
assert var_val == {v: 2}
class TestSymbolicReshape(unittest.TestCase):
class TestSymbolicReshapeFromContiguous(unittest.TestCase):
def test_reshape_into_symbols_simple(self):
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
@@ -151,6 +151,41 @@ class TestSymbolicReshape(unittest.TestCase):
new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
assert view.reshape(new_shape) is None
class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
def test_reshape_from_const(self):
vi = Variable("i", 1, 5).bind(4)
t = Tensor.ones(3, 4).reshape(3, vi)
assert t.shape == (3, vi)
assert not t.lazydata.st.contiguous
assert len(t.lazydata.st.views) == 1
def test_reshape_not_allowed(self):
vi = Variable("i", 1, 5).bind(4)
with self.assertRaises(ValueError):
# different shape length # TODO: cases where contractions matched might be fine
Tensor.ones(3, 4, 1).reshape(3, vi)
with self.assertRaises(ValueError):
# size matched, but dimensions do not match
Tensor.ones(4, 3).reshape(3, vi)
def test_reshape_from_padded(self):
vi = Variable("i", 1, 5).bind(4)
t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
st = t.lazydata.st
assert len(st.views) == 1
view = st.views[0]
assert view.shape == (4, 3, 2)
t = t.reshape(vi, 3, 2)
st2 = t.lazydata.st
assert len(st2.views) == 1
view2 = st2.views[0]
# check only shape changed. strides, offset, mask, contiguous remained the same
assert view2.shape == (vi, 3, 2)
assert view.strides == view2.strides == (0, 4, 1)
assert view.offset == view2.offset == 1
assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2))
assert not view.contiguous and not view2.contiguous
class TestSymbolicExpand(unittest.TestCase):
def test_expand_into_symbols(self):
vi = Variable("i", 1, 5).bind(3)