mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
[bounty] Remove using reshape to register symbolic shape (#11771)
* Modify tests and start work towards removing symbolic reshape * Refactor symbolic reshape * fix small error * much cleaner + fix more tests * Can remove this now * Update test_symbolic_ops and test_tiny * Couple more tests * Unused import * More tests and add EXPAND to Tensor.empty * Fix test beam search * all int * Fix rangeify by adding shrink * Remove OOB check and so fix test_symbolic_jit * test_symbolic_jit doesn't need OOB Context anymore either * Should remove that test now * Cleanups part 1 * fix linters * Final cleanups * Don't reassign inside for loop --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -16,9 +16,9 @@ class TestBeamSearch(unittest.TestCase):
|
||||
BEAM.value = self.old_beam
|
||||
|
||||
def test_variable_ast_beam(self):
|
||||
with Context(IGNORE_OOB=1):
|
||||
a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
|
||||
a = (a+1).realize()
|
||||
vi = Variable("a", 1, 10).bind(3)
|
||||
a = rand(10, 3)[:vi]
|
||||
a = (a+1).realize()
|
||||
|
||||
def test_big_prime_number(self):
|
||||
a = rand(367, 367)
|
||||
@@ -42,18 +42,16 @@ class TestBeamSearch(unittest.TestCase):
|
||||
|
||||
def test_variable_big_prime_number(self):
|
||||
v = Variable("v", 1, 400).bind(367)
|
||||
a = rand(367, 367)
|
||||
b = rand(367, 367)
|
||||
with Context(IGNORE_OOB=1):
|
||||
c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
|
||||
a = rand(367, 400)
|
||||
b = rand(400, 367)
|
||||
c = (a[:, :v] @ b[:v, :]).realize()
|
||||
np.testing.assert_allclose(c.numpy(), a[:, :367].numpy() @ b[:367, :].numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_variable_shrink_prime_number(self):
|
||||
v = Variable("v", 1, 400).bind(367)
|
||||
a = rand(400, 367)
|
||||
with Context(IGNORE_OOB=1):
|
||||
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
|
||||
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_no_mutate_rawbuffers(self):
|
||||
a = rand(3, 3).realize()
|
||||
|
||||
Reference in New Issue
Block a user