mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Shrink instead of reshape to unregister symbolic (#12241)
* Slice to unbind symbolic * use vmax for now * assert shape in reshape is valid * update test_symbolic_ops to use shrink instead of reshape * remove infer_with_bound_values for npw * symbolic output doesnt have symbolic strides * symbolic jit tests use shrink to unregister symbolic * update test * update more tests * wrap vmax in int() * only create a new st if the store is not an assigne * unwrap st * comments
This commit is contained in:
@@ -2,6 +2,7 @@ import unittest
|
||||
|
||||
from test.helpers import assert_jit_cache_len
|
||||
from tinygrad import Variable, Tensor, TinyJit
|
||||
from tinygrad.helpers import RANGEIFY
|
||||
import numpy as np
|
||||
|
||||
class TestSymbolicJit(unittest.TestCase):
|
||||
@@ -11,7 +12,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
a = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(a[:, :vi]).reshape(3, i).numpy()
|
||||
symbolic = jf(a[:, :vi])[:3, :i].numpy()
|
||||
expected = f(a[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -26,7 +27,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a[:, :vi]).numpy()
|
||||
expected = f(a[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 2) # one add and one pad, can be one kernel?
|
||||
assert_jit_cache_len(jf, 1 if RANGEIFY else 2) # one add and one pad, can be one kernel?
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
@@ -35,7 +36,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(a[:, :vi], b[:, :vi]).reshape(3, i).numpy()
|
||||
symbolic = jf(a[:, :vi], b[:, :vi])
|
||||
symbolic = symbolic[:3, :i].numpy()
|
||||
expected = f(a[:, :i], b[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -75,10 +77,10 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
v = Tensor.rand(2, 10, 4, 8)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(q, k[:, :vi], v[:, :vi]).reshape(2, 4, 1, 8).numpy()
|
||||
symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy()
|
||||
expected = f(q, k[:, :i], v[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 5)
|
||||
assert_jit_cache_len(jf, 4 if RANGEIFY else 5)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
@@ -87,7 +89,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
b = Tensor.rand(2, 3)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(a[:vi], b).reshape(i+2, 3).numpy()
|
||||
symbolic = jf(a[:vi], b)[:i+2, :3].numpy()
|
||||
expected = f(a[:i], b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -99,7 +101,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
b = Tensor.rand(3, 2)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(a[:, :vi], b).reshape(3, i+2).numpy()
|
||||
symbolic = jf(a[:, :vi], b)[:3, :i+2].numpy()
|
||||
expected = f(a[:, :i], b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -113,7 +115,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = jf(a[:vi], b[:vj]).reshape(i+j, 3).numpy()
|
||||
symbolic = jf(a[:vi], b[:vj])[:i+j, :3].numpy()
|
||||
expected = f(a[:i], b[:j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -127,7 +129,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = jf(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
|
||||
symbolic = jf(a[:, :vi], b[:, :vj])[:3, :i+j].numpy()
|
||||
expected = f(a[:, :i], b[:, :j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -141,7 +143,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = jf(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
|
||||
symbolic = jf(a[:vi, :], b[:, :vj])[:i, :j].numpy()
|
||||
expected = f(a[:i, :], b[:, :j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -155,7 +157,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = jf(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
|
||||
symbolic = jf(a[:vj, :], b[:, :vi])[:j, :i].numpy()
|
||||
expected = f(a[:j, :], b[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -207,8 +209,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.ones(vi, 11).contiguous()
|
||||
symbolic = a[:, 1:2]
|
||||
symbolic = jf(symbolic).reshape(i, 1).numpy()
|
||||
expected = f(a.reshape(i, 11)[:, 1:2]).numpy()
|
||||
symbolic = jf(symbolic)[:i, :1].numpy()
|
||||
expected = f(a[:i, :][:, 1:2]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
@@ -243,7 +245,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
expected = b[:i].mean(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi]).reshape(i).numpy()
|
||||
symbolic = jf1(c[:vi])[:i].numpy()
|
||||
expected = c[:i].mean(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -266,11 +268,11 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
expected = a[:i, :j].mean().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 0
|
||||
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
|
||||
symbolic = jf0(b[:vi, :vj])[:j].numpy()
|
||||
expected = b[:i, :j].mean(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
|
||||
symbolic = jf1(c[:vi, :vj])[:i].numpy()
|
||||
expected = c[:i, :j].mean(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -295,7 +297,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
expected = b[:i].var(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi]).reshape(i).numpy()
|
||||
symbolic = jf1(c[:vi])[:i].numpy()
|
||||
expected = c[:i].var(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -318,11 +320,11 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
expected = a[:i, :j].var().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 0
|
||||
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
|
||||
symbolic = jf0(b[:vi, :vj])[:j].numpy()
|
||||
expected = b[:i, :j].var(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
|
||||
symbolic = jf1(c[:vi, :vj])[:i].numpy()
|
||||
expected = c[:i, :j].var(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user