[bounty] Remove using reshape to register symbolic shape (#11771)

* Modify tests and start work towards removing symbolic reshape

* Refactor symbolic reshape

* fix small error

* much cleaner + fix more tests

* Can remove this now

* Update test_symbolic_ops and test_tiny

* Couple more tests

* Unused import

* More tests and add EXPAND to Tensor.empty

* Fix test beam search

* all int

* Fix rangeify by adding shrink

* Remove OOB check and so fix test_symbolic_jit

* test_symbolic_jit doesn't need OOB Context anymore either

* Should remove that test now

* Cleanups part 1

* fix linters

* Final cleanups

* Don't reassign inside for loop

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Ben Waldron
2025-08-28 16:30:49 +00:00
committed by GitHub
parent 53853ae49b
commit ea1be2e4cd
13 changed files with 281 additions and 358 deletions

View File

@@ -16,9 +16,9 @@ class TestBeamSearch(unittest.TestCase):
BEAM.value = self.old_beam
def test_variable_ast_beam(self):
with Context(IGNORE_OOB=1):
a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
a = (a+1).realize()
vi = Variable("a", 1, 10).bind(3)
a = rand(10, 3)[:vi]
a = (a+1).realize()
def test_big_prime_number(self):
a = rand(367, 367)
@@ -42,18 +42,16 @@ class TestBeamSearch(unittest.TestCase):
def test_variable_big_prime_number(self):
v = Variable("v", 1, 400).bind(367)
a = rand(367, 367)
b = rand(367, 367)
with Context(IGNORE_OOB=1):
c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
a = rand(367, 400)
b = rand(400, 367)
c = (a[:, :v] @ b[:v, :]).realize()
np.testing.assert_allclose(c.numpy(), a[:, :367].numpy() @ b[:367, :].numpy(), atol=1e-4, rtol=1e-4)
def test_variable_shrink_prime_number(self):
v = Variable("v", 1, 400).bind(367)
a = rand(400, 367)
with Context(IGNORE_OOB=1):
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
def test_no_mutate_rawbuffers(self):
a = rand(3, 3).realize()