From cc038b31b605ec94d4c8b96b1ca304737b70b70b Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Fri, 19 Sep 2025 06:04:35 +0200 Subject: [PATCH] Shrink instead of reshape to unregister symbolic (#12241) * Slice to unbind symbolic * use vmax for now * assert shape in reshape is valid * update test_symbolic_ops to use shrink instead of reshape * remove infer_with_bound_values for npw * symbolic output doesnt have symbolic strides * symbolic jit tests use shrink to unregister symbolic * update test * update more tests * wrap vmax in int() * only create a new st if the store is not an assigne * unwrap st * comments --- extra/optimization/test_beam_search.py | 2 +- test/test_symbolic_jit.py | 40 ++++++------ test/test_symbolic_ops.py | 82 +++++++++++++++---------- test/test_tensor_variable.py | 10 +-- test/test_tiny.py | 2 +- test/unit/test_symbolic_shapetracker.py | 2 +- tinygrad/schedule/kernelize.py | 17 +++-- tinygrad/shape/view.py | 5 +- tinygrad/tensor.py | 8 ++- 9 files changed, 99 insertions(+), 69 deletions(-) diff --git a/extra/optimization/test_beam_search.py b/extra/optimization/test_beam_search.py index 7042ea914e..f493ec48eb 100644 --- a/extra/optimization/test_beam_search.py +++ b/extra/optimization/test_beam_search.py @@ -50,7 +50,7 @@ class TestBeamSearch(unittest.TestCase): def test_variable_shrink_prime_number(self): v = Variable("v", 1, 400).bind(367) a = rand(400, 367) - b = (a.shrink(((0,v), None))+1).reshape(367,367).realize() + b = (a.shrink(((0,v), None))+1)[:367,:367].realize() np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4) def test_no_mutate_rawbuffers(self): diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index f312539e7a..f28d274dcc 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -2,6 +2,7 @@ import unittest from test.helpers import assert_jit_cache_len from tinygrad import Variable, Tensor, TinyJit +from tinygrad.helpers import RANGEIFY import numpy as np class TestSymbolicJit(unittest.TestCase): @@ -11,7 +12,7 @@ class TestSymbolicJit(unittest.TestCase): a = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - symbolic = jf(a[:, :vi]).reshape(3, i).numpy() + symbolic = jf(a[:, :vi])[:3, :i].numpy() expected = f(a[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -26,7 +27,7 @@ class TestSymbolicJit(unittest.TestCase): symbolic = jf(a[:, :vi]).numpy() expected = f(a[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 2) # one add and one pad, can be one kernel? + assert_jit_cache_len(jf, 1 if RANGEIFY else 2) # one add and one pad, can be one kernel? def test_add(self): def f(a, b): return (a+b).realize() @@ -35,7 +36,8 @@ class TestSymbolicJit(unittest.TestCase): b = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - symbolic = jf(a[:, :vi], b[:, :vi]).reshape(3, i).numpy() + symbolic = jf(a[:, :vi], b[:, :vi]) + symbolic = symbolic[:3, :i].numpy() expected = f(a[:, :i], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -75,10 +77,10 @@ class TestSymbolicJit(unittest.TestCase): v = Tensor.rand(2, 10, 4, 8) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - symbolic = jf(q, k[:, :vi], v[:, :vi]).reshape(2, 4, 1, 8).numpy() + symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy() expected = f(q, k[:, :i], v[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 5) + assert_jit_cache_len(jf, 4 if RANGEIFY else 5) def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() @@ -87,7 +89,7 @@ class TestSymbolicJit(unittest.TestCase): b = Tensor.rand(2, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - symbolic = jf(a[:vi], b).reshape(i+2, 3).numpy() + symbolic = jf(a[:vi], b)[:i+2, :3].numpy() expected = f(a[:i], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -99,7 +101,7 @@ class TestSymbolicJit(unittest.TestCase): b = Tensor.rand(3, 2) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - symbolic = jf(a[:, :vi], b).reshape(3, i+2).numpy() + symbolic = jf(a[:, :vi], b)[:3, :i+2].numpy() expected = f(a[:, :i], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -113,7 +115,7 @@ class TestSymbolicJit(unittest.TestCase): for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - symbolic = jf(a[:vi], b[:vj]).reshape(i+j, 3).numpy() + symbolic = jf(a[:vi], b[:vj])[:i+j, :3].numpy() expected = f(a[:i], b[:j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -127,7 +129,7 @@ class TestSymbolicJit(unittest.TestCase): for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - symbolic = jf(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy() + symbolic = jf(a[:, :vi], b[:, :vj])[:3, :i+j].numpy() expected = f(a[:, :i], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -141,7 +143,7 @@ class TestSymbolicJit(unittest.TestCase): for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - symbolic = jf(a[:vi, :], b[:, :vj]).reshape(i, j).numpy() + symbolic = jf(a[:vi, :], b[:, :vj])[:i, :j].numpy() expected = f(a[:i, :], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -155,7 +157,7 @@ class TestSymbolicJit(unittest.TestCase): for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - symbolic = jf(a[:vj, :], b[:, :vi]).reshape(j, i).numpy() + symbolic = jf(a[:vj, :], b[:, :vi])[:j, :i].numpy() expected = f(a[:j, :], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -207,8 +209,8 @@ class TestSymbolicJit(unittest.TestCase): vi = Variable("i", 1, 10).bind(i) a = Tensor.ones(vi, 11).contiguous() symbolic = a[:, 1:2] - symbolic = jf(symbolic).reshape(i, 1).numpy() - expected = f(a.reshape(i, 11)[:, 1:2]).numpy() + symbolic = jf(symbolic)[:i, :1].numpy() + expected = f(a[:i, :][:, 1:2]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -243,7 +245,7 @@ class TestSymbolicJit(unittest.TestCase): expected = b[:i].mean(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) # axis = 1 - symbolic = jf1(c[:vi]).reshape(i).numpy() + symbolic = jf1(c[:vi])[:i].numpy() expected = c[:i].mean(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -266,11 +268,11 @@ class TestSymbolicJit(unittest.TestCase): expected = a[:i, :j].mean().numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) # axis = 0 - symbolic = jf0(b[:vi, :vj]).reshape(j).numpy() + symbolic = jf0(b[:vi, :vj])[:j].numpy() expected = b[:i, :j].mean(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) # axis = 1 - symbolic = jf1(c[:vi, :vj]).reshape(i).numpy() + symbolic = jf1(c[:vi, :vj])[:i].numpy() expected = c[:i, :j].mean(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -295,7 +297,7 @@ class TestSymbolicJit(unittest.TestCase): expected = b[:i].var(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) # axis = 1 - symbolic = jf1(c[:vi]).reshape(i).numpy() + symbolic = jf1(c[:vi])[:i].numpy() expected = c[:i].var(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -318,11 +320,11 @@ class TestSymbolicJit(unittest.TestCase): expected = a[:i, :j].var().numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) # axis = 0 - symbolic = jf0(b[:vi, :vj]).reshape(j).numpy() + symbolic = jf0(b[:vi, :vj])[:j].numpy() expected = b[:i, :j].var(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) # axis = 1 - symbolic = jf1(c[:vi, :vj]).reshape(i).numpy() + symbolic = jf1(c[:vi, :vj])[:i].numpy() expected = c[:i, :j].var(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 885953891c..991a9dcc93 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -13,7 +13,7 @@ class TestSymbolicOps(unittest.TestCase): a = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - symbolic = f(a[:, :vi]).reshape(3, i).numpy() + symbolic = f(a[:, :vi])[:3, :i].numpy() expected = f(a[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -32,7 +32,7 @@ class TestSymbolicOps(unittest.TestCase): b = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - symbolic = f(a[:, :vi], b[:, :vi]).reshape(3, i).numpy() + symbolic = f(a[:, :vi], b[:, :vi])[:, :i].numpy() expected = f(a[:, :i], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -55,7 +55,7 @@ class TestSymbolicOps(unittest.TestCase): vi = Variable("i", 1, 10).bind(i) if use_symbolic else i Tensor.realize(q, k, v) GlobalCounters.reset() - symbolic = f(q, k[:, :vi, :, :], v[:, :vi, :, :]).reshape(2, 4, 1, 8).numpy() + symbolic = f(q, k[:, :vi, :, :], v[:, :vi, :, :])[:2, :4, :1, :8].numpy() expected = f(q, k[:, :i, :, :], v[:, :i, :, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -94,7 +94,7 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) b = Tensor.rand(2, 3) - symbolic = f(a[:vi, :], b).reshape(i+2, 3).numpy() + symbolic = f(a[:vi, :], b)[:i+2, :3].numpy() expected = f(a[:i, :], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -104,7 +104,7 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) b = Tensor.rand(3, 2) - symbolic = f(a[:, :vi], b).reshape(3, i+2).numpy() + symbolic = f(a[:, :vi], b)[:3, :i+2].numpy() expected = f(a[:, :i], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -116,7 +116,7 @@ class TestSymbolicOps(unittest.TestCase): for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - symbolic = f(a[:vi, :], b[:vj, :]).reshape(i+j, 3).numpy() + symbolic = f(a[:vi, :], b[:vj, :])[:i+j, :3].numpy() expected = f(a[:i, :], b[:j, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -128,50 +128,41 @@ class TestSymbolicOps(unittest.TestCase): for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - symbolic = f(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy() + symbolic = f(a[:, :vi], b[:, :vj])[:3, :i+j].numpy() expected = f(a[:, :i], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_two_vars_plus1_ij(self): def f(a, b): return (a@b+1).realize() - a = Tensor.rand(10, 3) - b = Tensor.rand(3, 10) + a = Tensor.rand(10, 3).realize() + b = Tensor.rand(3, 10).realize() for i in range(2, 5): for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - symbolic = f(a[:vi, :], b[:, :vj]).reshape(i, j).numpy() + symbolic = f(a[:vi, :], b[:, :vj])[:i, :j].numpy() expected = f(a[:i, :], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_two_vars_plus1_ji(self): # reverse the order of variables def f(a, b): return (a@b+1).realize() - a = Tensor.rand(10, 3) - b = Tensor.rand(3, 10) + a = Tensor.rand(10, 3).realize() + b = Tensor.rand(3, 10).realize() for i in range(2, 5): for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - symbolic = f(a[:vj, :], b[:, :vi]).reshape(j, i).numpy() + symbolic = f(a[:vj, :], b[:, :vi])[:j, :i].numpy() expected = f(a[:j, :], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - def test_reshape_from_symbolic(self): - a = Tensor.rand(30) - for i in range(3, 5): - vi = Variable("i", 3, 10).bind(i) - symbolic = a[:vi*3].reshape((3, 3)).numpy() - # To match symbolic reshape (potential implicit shrink), we need a shrink - expected = a[:i*3].shrink(((0, 9),)).reshape((3, 3)).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - def test_invalid_symbolic_reshape(self): a = Tensor.rand(30) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) # Cannot reshape into symbolic from non-symbolic - with self.assertRaises(AssertionError): a.reshape((3, vi)) + with self.assertRaises(ValueError): a.reshape((3, vi)) def test_shrink(self): for i in range(1, 5): @@ -187,6 +178,7 @@ class TestSymbolicOps(unittest.TestCase): vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(7, 11) symbolic = a[3:5, vi:vi+2] + print(symbolic.shape) symbolic = symbolic.numpy() expected = a[3:5, i:i+2].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -195,7 +187,7 @@ class TestSymbolicOps(unittest.TestCase): a = Tensor.rand(7, 11) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - symbolic = a[3:5, :vi:1].reshape(2, i).numpy() + symbolic = a[3:5, :vi:1][:2, :i].numpy() expected = a[3:5, :i:1].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -203,7 +195,7 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) a = Tensor(1).unsqueeze(0).pad((0, 1)).unsqueeze(0) - symbolic = a.expand(vi, 2).reshape(i, 2).numpy() + symbolic = a.expand(vi, 2)[:i, :2].numpy() expected = a.expand(i, 2).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -211,8 +203,8 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) a = Tensor.ones(vi, 11).contiguous() - symbolic = a[:, 1:2].reshape(i, 1).numpy() - expected = a.reshape(i, 11)[:, 1:2].numpy() + symbolic = a[:, 1:2][:i, :1].numpy() + expected = Tensor.ones(i, 11)[:, 1:2].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_ones_sum(self): @@ -229,7 +221,11 @@ class TestSymbolicOps(unittest.TestCase): vi = Variable("i", 1, 10).bind(i) for axis in [None, 0, 1]: expected = a[:i].mean(axis).numpy() - symbolic = a[:vi].mean(axis).reshape(expected.shape).numpy() + symbolic = a[:vi].mean(axis) + if axis is None: + symbolic = symbolic.numpy() + else: + symbolic = symbolic[:expected.shape[0]].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_mean_2d(self): @@ -240,7 +236,11 @@ class TestSymbolicOps(unittest.TestCase): vj = Variable("j", 1, 10).bind(j) for axis in [None, 0, 1]: expected = a[:i, :j].mean(axis).numpy() - symbolic = a[:vi, :vj].mean(axis).reshape(expected.shape).numpy() + symbolic = a[:vi, :vj].mean(axis) + if axis is None: + symbolic = symbolic.numpy() + else: + symbolic = symbolic[:expected.shape[0]].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_var(self): @@ -249,7 +249,11 @@ class TestSymbolicOps(unittest.TestCase): vi = Variable("i", 1, 10).bind(i) for axis in [None, 0, 1]: expected = a[:i].var(axis).numpy() - symbolic = a[:vi].var(axis).reshape(expected.shape).numpy() + symbolic = a[:vi].var(axis) + if axis is None: + symbolic = symbolic.numpy() + else: + symbolic = symbolic[:expected.shape[0]].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_var_2d(self): @@ -260,7 +264,11 @@ class TestSymbolicOps(unittest.TestCase): vj = Variable("j", 1, 10).bind(j) for axis in [None, 0, 1]: expected = a[:i, :j].var(axis).numpy() - symbolic = a[:vi, :vj].var(axis).reshape(expected.shape).numpy() + symbolic_result = a[:vi, :vj].var(axis) + if axis is None: + symbolic = symbolic_result.numpy() + else: + symbolic = symbolic_result[:expected.shape[0]].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_bitcast_down(self): @@ -268,7 +276,11 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) expected = a[:i].bitcast(dtypes.uint8).numpy() - symbolic = a[:vi].bitcast(dtypes.uint8).reshape(expected.shape).numpy() + symbolic_result = a[:vi].bitcast(dtypes.uint8) + if len(expected.shape) == 2: + symbolic = symbolic_result[:expected.shape[0], :expected.shape[1]].numpy() + else: + symbolic = symbolic_result[:].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0) @unittest.skipUnless(is_dtype_supported(dtypes.uint64), "no uint64") @@ -277,7 +289,11 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) expected = a[:i].bitcast(dtypes.uint64).numpy() - symbolic = a[:vi].bitcast(dtypes.uint64).reshape(expected.shape).numpy() + symbolic_result = a[:vi].bitcast(dtypes.uint64) + if len(expected.shape) == 2: + symbolic = symbolic_result[:expected.shape[0], :expected.shape[1]].numpy() + else: + symbolic = symbolic_result[:].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0) @unittest.expectedFailure diff --git a/test/test_tensor_variable.py b/test/test_tensor_variable.py index 0fa165e462..a046555d1b 100644 --- a/test/test_tensor_variable.py +++ b/test/test_tensor_variable.py @@ -38,7 +38,7 @@ class TestTensorVariable(unittest.TestCase): vv = Variable("a", 1, 10).bind(2) vv2 = Variable("b", 1, 10).bind(2) t = Tensor.ones(10, 10).contiguous()[:vv2, :vv] - ret = t.mean(axis=1).reshape(2, 1).numpy() + ret = t.mean(axis=1)[:2].reshape(2, 1).numpy() assert np.all(ret == 1) def test_symbolic_mean_2d_add(self): @@ -66,25 +66,25 @@ class TestTensorVariable(unittest.TestCase): def test_symbolic_arange(self): vv = Variable("a", 1, 10) ret = Tensor.arange(0, vv.bind(4)) - self.assertListEqual(ret.reshape(4).tolist(), [0,1,2,3]) + self.assertListEqual(ret[:4].tolist(), [0,1,2,3]) def test_symbolic_arange_sym_start(self): vv = Variable("a", 1, 6) ret = Tensor.arange(vv.bind(4), 7) - self.assertListEqual(ret.reshape(3).tolist(), [4,5,6]) + self.assertListEqual(ret[:3].tolist(), [4,5,6]) # TODO: add vmin/vmax pattern for symbolic denominator @unittest.expectedFailure def test_symbolic_arange_sym_step(self): vv = Variable("step", 1, 3) ret = Tensor.arange(0, 10, vv.bind(2)) - self.assertListEqual(ret.reshape(5).tolist(), [0,2,4,6,8]) + self.assertListEqual(ret[:5].tolist(), [0,2,4,6,8]) def test_symbolic_arange_two_vars(self): begin = Variable("b", 1, 5) end = Variable("e", 6, 10) ret = Tensor.arange(begin.bind(4), end.bind(7)) - self.assertListEqual(ret.reshape(3).tolist(), [4,5,6]) + self.assertListEqual(ret[:3].tolist(), [4,5,6]) def test_variable_empty(self): v = Variable("i", 1, 10) diff --git a/test/test_tiny.py b/test/test_tiny.py index a767749eb3..31bb84f595 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -95,7 +95,7 @@ class TestTiny(unittest.TestCase): ones = Tensor.ones(10).contiguous() for s in [2,5]: ret = ones[:i.bind(s)] + 1 - self.assertListEqual(ret.contiguous().reshape(s).tolist(), [2.0]*s) + self.assertListEqual(ret.contiguous()[:s].tolist(), [2.0]*s) def test_symbolic_reduce(self): i = Variable('i', 1, 10) diff --git a/test/unit/test_symbolic_shapetracker.py b/test/unit/test_symbolic_shapetracker.py index 565408cc62..c89419a2f9 100644 --- a/test/unit/test_symbolic_shapetracker.py +++ b/test/unit/test_symbolic_shapetracker.py @@ -197,7 +197,7 @@ class TestSymbolicPad(unittest.TestCase): def test_pad(self): v = Variable("v", 1, 100).bind(5) t = Tensor.ones(100)[:v].pad(((4, 0),)) - t = t.reshape(9) + t = t[:9] assert t.tolist() == [0,0,0,0,1,1,1,1,1] diff --git a/tinygrad/schedule/kernelize.py b/tinygrad/schedule/kernelize.py index cf4db55cb9..ec2f7c046a 100644 --- a/tinygrad/schedule/kernelize.py +++ b/tinygrad/schedule/kernelize.py @@ -120,7 +120,8 @@ def create_kernel(x:UOp, b:UOp|None=None): if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype) kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ())) buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset)) - return buffer.assign(kernel).shrink(((0, prod(x.shape)),)).reshape(x.shape) + # we have to shrink the buffer back to the symbolic shape + return buffer.assign(kernel).reshape(tuple(d.vmax if isinstance(d, UOp) else d for d in x.shape)).shrink(tuple((0, d) for d in x.shape)) DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND} def append_to_kernel(x:UOp): @@ -148,6 +149,16 @@ create_kernels = PatternMatcher([ lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)), ]) +def add_stores(ctx, sink: UOp): + stores = [] + for i,x in enumerate(sink.src): + gbl = UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i) + # if this is an assign then we already have a buffer with a view that should be the target of the store + if x.op is Ops.ASSIGN: stores.append(UOp.store(gbl.view(unwrap(s.st)), s)) + # otherwise we have to create the shapetracker and shrink it to the correct symbolic shape + else: stores.append( + UOp.store(gbl.reshape(tuple(int(d.vmax) if isinstance(d,UOp) else d for d in s.shape)).shrink(tuple((0,d) for d in s.shape)),s)) + return UOp.sink(*stores, arg=sink.arg) # **** fix kernel AST def unbind_view(x:UOp): @@ -168,9 +179,7 @@ replace_buffers = PatternMatcher([ # no SINK for meta ops (UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x), # STORE (except for meta ops) - (UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink: - UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)], - arg=sink.arg)), + (UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), add_stores), # remove CONTIGUOUS/DEVICE from kernel AST (UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x), (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 22f2661585..37da15642c 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -312,10 +312,7 @@ class View: if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}") # check for the same size - if all_int(self.shape): - # reshapes cannot introduce symbolic shape - assert all_int(new_shape), f"{self.shape=} -> {new_shape=} contains non int dims" - if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") + if resolve(prod(self.shape) != prod(new_shape), True): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") if 0 in self.shape: return View.create(new_shape) if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 20cb713fdb..a1b0c6ac76 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -8,7 +8,8 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION from tinygrad.gradient import compute_gradient -from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int, sint_to_uop +from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int, sint_to_uop, \ + srender from tinygrad.uop.spec import tensor_uop_spec, type_verify from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule @@ -994,6 +995,8 @@ class Tensor(MathTrait): # resolve -1 if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) + if resolve(prod(self.shape) != prod(new_shape), True): + raise ValueError(f"size mismatch, can't reshape ({', '.join(srender(d) for d in self.shape)}) -> ({', '.join(srender(d) for d in new_shape)})") return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self def expand(self, shape, *args) -> Tensor: @@ -1174,6 +1177,9 @@ class Tensor(MathTrait): boundary, stride = [start, stop], step if all(isinstance(s, int) for s in (start,stop,step)): # handle int slicing + # if we're slicing a symbolic dimension into a int dimension, we can slice untill the bind size + # TODO: right now this is using vmax instead of the bind size because jit doesnt update the bound value of the returned tensor + if isinstance(size, UOp): size = int(size.vmax) *boundary, stride = index.indices(cast(SupportsIndex, size)) if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0] elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]