Shrink instead of reshape to unregister symbolic (#12241)

* Slice to unbind symbolic

* use vmax for now

* assert shape in reshape is valid

* update test_symbolic_ops to use shrink instead of reshape

* remove infer_with_bound_values for npw

* symbolic output doesnt have symbolic strides

* symbolic jit tests use shrink to unregister symbolic

* update test

* update more tests

* wrap vmax in int()

* only create a new st if the store is not an assigne

* unwrap st

* comments
This commit is contained in:
Sieds Lykles
2025-09-19 06:04:35 +02:00
committed by GitHub
parent a531a649fb
commit cc038b31b6
9 changed files with 99 additions and 69 deletions

View File

@@ -2,6 +2,7 @@ import unittest
from test.helpers import assert_jit_cache_len
from tinygrad import Variable, Tensor, TinyJit
from tinygrad.helpers import RANGEIFY
import numpy as np
class TestSymbolicJit(unittest.TestCase):
@@ -11,7 +12,7 @@ class TestSymbolicJit(unittest.TestCase):
a = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(a[:, :vi]).reshape(3, i).numpy()
symbolic = jf(a[:, :vi])[:3, :i].numpy()
expected = f(a[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -26,7 +27,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a[:, :vi]).numpy()
expected = f(a[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 2) # one add and one pad, can be one kernel?
assert_jit_cache_len(jf, 1 if RANGEIFY else 2) # one add and one pad, can be one kernel?
def test_add(self):
def f(a, b): return (a+b).realize()
@@ -35,7 +36,8 @@ class TestSymbolicJit(unittest.TestCase):
b = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(a[:, :vi], b[:, :vi]).reshape(3, i).numpy()
symbolic = jf(a[:, :vi], b[:, :vi])
symbolic = symbolic[:3, :i].numpy()
expected = f(a[:, :i], b[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -75,10 +77,10 @@ class TestSymbolicJit(unittest.TestCase):
v = Tensor.rand(2, 10, 4, 8)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(q, k[:, :vi], v[:, :vi]).reshape(2, 4, 1, 8).numpy()
symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy()
expected = f(q, k[:, :i], v[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 5)
assert_jit_cache_len(jf, 4 if RANGEIFY else 5)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
@@ -87,7 +89,7 @@ class TestSymbolicJit(unittest.TestCase):
b = Tensor.rand(2, 3)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(a[:vi], b).reshape(i+2, 3).numpy()
symbolic = jf(a[:vi], b)[:i+2, :3].numpy()
expected = f(a[:i], b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -99,7 +101,7 @@ class TestSymbolicJit(unittest.TestCase):
b = Tensor.rand(3, 2)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(a[:, :vi], b).reshape(3, i+2).numpy()
symbolic = jf(a[:, :vi], b)[:3, :i+2].numpy()
expected = f(a[:, :i], b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -113,7 +115,7 @@ class TestSymbolicJit(unittest.TestCase):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:vi], b[:vj]).reshape(i+j, 3).numpy()
symbolic = jf(a[:vi], b[:vj])[:i+j, :3].numpy()
expected = f(a[:i], b[:j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -127,7 +129,7 @@ class TestSymbolicJit(unittest.TestCase):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
symbolic = jf(a[:, :vi], b[:, :vj])[:3, :i+j].numpy()
expected = f(a[:, :i], b[:, :j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -141,7 +143,7 @@ class TestSymbolicJit(unittest.TestCase):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
symbolic = jf(a[:vi, :], b[:, :vj])[:i, :j].numpy()
expected = f(a[:i, :], b[:, :j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -155,7 +157,7 @@ class TestSymbolicJit(unittest.TestCase):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
symbolic = jf(a[:vj, :], b[:, :vi])[:j, :i].numpy()
expected = f(a[:j, :], b[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -207,8 +209,8 @@ class TestSymbolicJit(unittest.TestCase):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.ones(vi, 11).contiguous()
symbolic = a[:, 1:2]
symbolic = jf(symbolic).reshape(i, 1).numpy()
expected = f(a.reshape(i, 11)[:, 1:2]).numpy()
symbolic = jf(symbolic)[:i, :1].numpy()
expected = f(a[:i, :][:, 1:2]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -243,7 +245,7 @@ class TestSymbolicJit(unittest.TestCase):
expected = b[:i].mean(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 1
symbolic = jf1(c[:vi]).reshape(i).numpy()
symbolic = jf1(c[:vi])[:i].numpy()
expected = c[:i].mean(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -266,11 +268,11 @@ class TestSymbolicJit(unittest.TestCase):
expected = a[:i, :j].mean().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 0
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
symbolic = jf0(b[:vi, :vj])[:j].numpy()
expected = b[:i, :j].mean(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 1
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
symbolic = jf1(c[:vi, :vj])[:i].numpy()
expected = c[:i, :j].mean(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -295,7 +297,7 @@ class TestSymbolicJit(unittest.TestCase):
expected = b[:i].var(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 1
symbolic = jf1(c[:vi]).reshape(i).numpy()
symbolic = jf1(c[:vi])[:i].numpy()
expected = c[:i].var(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -318,11 +320,11 @@ class TestSymbolicJit(unittest.TestCase):
expected = a[:i, :j].var().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 0
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
symbolic = jf0(b[:vi, :vj])[:j].numpy()
expected = b[:i, :j].var(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 1
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
symbolic = jf1(c[:vi, :vj])[:i].numpy()
expected = c[:i, :j].var(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)