temp fix for symbolic shape view add [pr] (#8337)

something is still wrong with symbolic shape shrink, but it should not recurse forever
This commit is contained in:
chenyu
2024-12-19 16:10:42 -05:00
committed by GitHub
parent 791a80a1c7
commit 2bf47b75da
2 changed files with 9 additions and 2 deletions

View File

@@ -31,7 +31,13 @@ class TestSymbolic(unittest.TestCase):
def test_merge_view_recursion_err(self):
vm2 = View(shape=(Variable('j', 1, 10),), strides=(0,), offset=0, mask=None, contiguous=False)
vm1 = View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True)
vm2.__add__(vm1)
self.assertEqual(vm2+vm1, vm1)
def test_merge_view_recursion_err2(self):
vm2 = View(shape=(Variable('a', 1, 10).bind(4),), strides=(0,), offset=0, mask=None, contiguous=False)
vm1 = View(shape=(Variable('a', 1, 10).bind(4),), strides=(1,), offset=0, mask=((0, Variable('a', 1, 10).bind(4)),), contiguous=False)
# TODO: this should not be None?
self.assertEqual(vm2+vm1, None)
def test_cat_dim0_strides(self):
i = Variable("i", 1, 5).bind(3)