mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Shrink instead of reshape to unregister symbolic (#12241)
* Slice to unbind symbolic * use vmax for now * assert shape in reshape is valid * update test_symbolic_ops to use shrink instead of reshape * remove infer_with_bound_values for npw * symbolic output doesnt have symbolic strides * symbolic jit tests use shrink to unregister symbolic * update test * update more tests * wrap vmax in int() * only create a new st if the store is not an assigne * unwrap st * comments
This commit is contained in:
@@ -50,7 +50,7 @@ class TestBeamSearch(unittest.TestCase):
|
||||
def test_variable_shrink_prime_number(self):
|
||||
v = Variable("v", 1, 400).bind(367)
|
||||
a = rand(400, 367)
|
||||
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
|
||||
b = (a.shrink(((0,v), None))+1)[:367,:367].realize()
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_no_mutate_rawbuffers(self):
|
||||
|
||||
@@ -2,6 +2,7 @@ import unittest
|
||||
|
||||
from test.helpers import assert_jit_cache_len
|
||||
from tinygrad import Variable, Tensor, TinyJit
|
||||
from tinygrad.helpers import RANGEIFY
|
||||
import numpy as np
|
||||
|
||||
class TestSymbolicJit(unittest.TestCase):
|
||||
@@ -11,7 +12,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
a = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(a[:, :vi]).reshape(3, i).numpy()
|
||||
symbolic = jf(a[:, :vi])[:3, :i].numpy()
|
||||
expected = f(a[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -26,7 +27,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
symbolic = jf(a[:, :vi]).numpy()
|
||||
expected = f(a[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 2) # one add and one pad, can be one kernel?
|
||||
assert_jit_cache_len(jf, 1 if RANGEIFY else 2) # one add and one pad, can be one kernel?
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
@@ -35,7 +36,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(a[:, :vi], b[:, :vi]).reshape(3, i).numpy()
|
||||
symbolic = jf(a[:, :vi], b[:, :vi])
|
||||
symbolic = symbolic[:3, :i].numpy()
|
||||
expected = f(a[:, :i], b[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -75,10 +77,10 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
v = Tensor.rand(2, 10, 4, 8)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(q, k[:, :vi], v[:, :vi]).reshape(2, 4, 1, 8).numpy()
|
||||
symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy()
|
||||
expected = f(q, k[:, :i], v[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 5)
|
||||
assert_jit_cache_len(jf, 4 if RANGEIFY else 5)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
@@ -87,7 +89,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
b = Tensor.rand(2, 3)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(a[:vi], b).reshape(i+2, 3).numpy()
|
||||
symbolic = jf(a[:vi], b)[:i+2, :3].numpy()
|
||||
expected = f(a[:i], b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -99,7 +101,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
b = Tensor.rand(3, 2)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(a[:, :vi], b).reshape(3, i+2).numpy()
|
||||
symbolic = jf(a[:, :vi], b)[:3, :i+2].numpy()
|
||||
expected = f(a[:, :i], b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -113,7 +115,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = jf(a[:vi], b[:vj]).reshape(i+j, 3).numpy()
|
||||
symbolic = jf(a[:vi], b[:vj])[:i+j, :3].numpy()
|
||||
expected = f(a[:i], b[:j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -127,7 +129,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = jf(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
|
||||
symbolic = jf(a[:, :vi], b[:, :vj])[:3, :i+j].numpy()
|
||||
expected = f(a[:, :i], b[:, :j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -141,7 +143,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = jf(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
|
||||
symbolic = jf(a[:vi, :], b[:, :vj])[:i, :j].numpy()
|
||||
expected = f(a[:i, :], b[:, :j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -155,7 +157,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = jf(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
|
||||
symbolic = jf(a[:vj, :], b[:, :vi])[:j, :i].numpy()
|
||||
expected = f(a[:j, :], b[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
@@ -207,8 +209,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.ones(vi, 11).contiguous()
|
||||
symbolic = a[:, 1:2]
|
||||
symbolic = jf(symbolic).reshape(i, 1).numpy()
|
||||
expected = f(a.reshape(i, 11)[:, 1:2]).numpy()
|
||||
symbolic = jf(symbolic)[:i, :1].numpy()
|
||||
expected = f(a[:i, :][:, 1:2]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
@@ -243,7 +245,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
expected = b[:i].mean(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi]).reshape(i).numpy()
|
||||
symbolic = jf1(c[:vi])[:i].numpy()
|
||||
expected = c[:i].mean(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -266,11 +268,11 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
expected = a[:i, :j].mean().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 0
|
||||
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
|
||||
symbolic = jf0(b[:vi, :vj])[:j].numpy()
|
||||
expected = b[:i, :j].mean(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
|
||||
symbolic = jf1(c[:vi, :vj])[:i].numpy()
|
||||
expected = c[:i, :j].mean(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -295,7 +297,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
expected = b[:i].var(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi]).reshape(i).numpy()
|
||||
symbolic = jf1(c[:vi])[:i].numpy()
|
||||
expected = c[:i].var(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -318,11 +320,11 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
expected = a[:i, :j].var().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 0
|
||||
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
|
||||
symbolic = jf0(b[:vi, :vj])[:j].numpy()
|
||||
expected = b[:i, :j].var(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
|
||||
symbolic = jf1(c[:vi, :vj])[:i].numpy()
|
||||
expected = c[:i, :j].var(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
a = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = f(a[:, :vi]).reshape(3, i).numpy()
|
||||
symbolic = f(a[:, :vi])[:3, :i].numpy()
|
||||
expected = f(a[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = f(a[:, :vi], b[:, :vi]).reshape(3, i).numpy()
|
||||
symbolic = f(a[:, :vi], b[:, :vi])[:, :i].numpy()
|
||||
expected = f(a[:, :i], b[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -55,7 +55,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
vi = Variable("i", 1, 10).bind(i) if use_symbolic else i
|
||||
Tensor.realize(q, k, v)
|
||||
GlobalCounters.reset()
|
||||
symbolic = f(q, k[:, :vi, :, :], v[:, :vi, :, :]).reshape(2, 4, 1, 8).numpy()
|
||||
symbolic = f(q, k[:, :vi, :, :], v[:, :vi, :, :])[:2, :4, :1, :8].numpy()
|
||||
expected = f(q, k[:, :i, :, :], v[:, :i, :, :]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -94,7 +94,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
b = Tensor.rand(2, 3)
|
||||
symbolic = f(a[:vi, :], b).reshape(i+2, 3).numpy()
|
||||
symbolic = f(a[:vi, :], b)[:i+2, :3].numpy()
|
||||
expected = f(a[:i, :], b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -104,7 +104,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
b = Tensor.rand(3, 2)
|
||||
symbolic = f(a[:, :vi], b).reshape(3, i+2).numpy()
|
||||
symbolic = f(a[:, :vi], b)[:3, :i+2].numpy()
|
||||
expected = f(a[:, :i], b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -116,7 +116,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = f(a[:vi, :], b[:vj, :]).reshape(i+j, 3).numpy()
|
||||
symbolic = f(a[:vi, :], b[:vj, :])[:i+j, :3].numpy()
|
||||
expected = f(a[:i, :], b[:j, :]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -128,50 +128,41 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = f(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
|
||||
symbolic = f(a[:, :vi], b[:, :vj])[:3, :i+j].numpy()
|
||||
expected = f(a[:, :i], b[:, :j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_two_vars_plus1_ij(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(3, 10)
|
||||
a = Tensor.rand(10, 3).realize()
|
||||
b = Tensor.rand(3, 10).realize()
|
||||
for i in range(2, 5):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = f(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
|
||||
symbolic = f(a[:vi, :], b[:, :vj])[:i, :j].numpy()
|
||||
expected = f(a[:i, :], b[:, :j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_two_vars_plus1_ji(self):
|
||||
# reverse the order of variables
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(3, 10)
|
||||
a = Tensor.rand(10, 3).realize()
|
||||
b = Tensor.rand(3, 10).realize()
|
||||
for i in range(2, 5):
|
||||
for j in range(2, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
symbolic = f(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
|
||||
symbolic = f(a[:vj, :], b[:, :vi])[:j, :i].numpy()
|
||||
expected = f(a[:j, :], b[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_reshape_from_symbolic(self):
|
||||
a = Tensor.rand(30)
|
||||
for i in range(3, 5):
|
||||
vi = Variable("i", 3, 10).bind(i)
|
||||
symbolic = a[:vi*3].reshape((3, 3)).numpy()
|
||||
# To match symbolic reshape (potential implicit shrink), we need a shrink
|
||||
expected = a[:i*3].shrink(((0, 9),)).reshape((3, 3)).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_invalid_symbolic_reshape(self):
|
||||
a = Tensor.rand(30)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
# Cannot reshape into symbolic from non-symbolic
|
||||
with self.assertRaises(AssertionError): a.reshape((3, vi))
|
||||
with self.assertRaises(ValueError): a.reshape((3, vi))
|
||||
|
||||
def test_shrink(self):
|
||||
for i in range(1, 5):
|
||||
@@ -187,6 +178,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a[3:5, vi:vi+2]
|
||||
print(symbolic.shape)
|
||||
symbolic = symbolic.numpy()
|
||||
expected = a[3:5, i:i+2].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
@@ -195,7 +187,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
a = Tensor.rand(7, 11)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = a[3:5, :vi:1].reshape(2, i).numpy()
|
||||
symbolic = a[3:5, :vi:1][:2, :i].numpy()
|
||||
expected = a[3:5, :i:1].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -203,7 +195,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor(1).unsqueeze(0).pad((0, 1)).unsqueeze(0)
|
||||
symbolic = a.expand(vi, 2).reshape(i, 2).numpy()
|
||||
symbolic = a.expand(vi, 2)[:i, :2].numpy()
|
||||
expected = a.expand(i, 2).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -211,8 +203,8 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.ones(vi, 11).contiguous()
|
||||
symbolic = a[:, 1:2].reshape(i, 1).numpy()
|
||||
expected = a.reshape(i, 11)[:, 1:2].numpy()
|
||||
symbolic = a[:, 1:2][:i, :1].numpy()
|
||||
expected = Tensor.ones(i, 11)[:, 1:2].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_ones_sum(self):
|
||||
@@ -229,7 +221,11 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
for axis in [None, 0, 1]:
|
||||
expected = a[:i].mean(axis).numpy()
|
||||
symbolic = a[:vi].mean(axis).reshape(expected.shape).numpy()
|
||||
symbolic = a[:vi].mean(axis)
|
||||
if axis is None:
|
||||
symbolic = symbolic.numpy()
|
||||
else:
|
||||
symbolic = symbolic[:expected.shape[0]].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_mean_2d(self):
|
||||
@@ -240,7 +236,11 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
for axis in [None, 0, 1]:
|
||||
expected = a[:i, :j].mean(axis).numpy()
|
||||
symbolic = a[:vi, :vj].mean(axis).reshape(expected.shape).numpy()
|
||||
symbolic = a[:vi, :vj].mean(axis)
|
||||
if axis is None:
|
||||
symbolic = symbolic.numpy()
|
||||
else:
|
||||
symbolic = symbolic[:expected.shape[0]].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_var(self):
|
||||
@@ -249,7 +249,11 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
for axis in [None, 0, 1]:
|
||||
expected = a[:i].var(axis).numpy()
|
||||
symbolic = a[:vi].var(axis).reshape(expected.shape).numpy()
|
||||
symbolic = a[:vi].var(axis)
|
||||
if axis is None:
|
||||
symbolic = symbolic.numpy()
|
||||
else:
|
||||
symbolic = symbolic[:expected.shape[0]].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_var_2d(self):
|
||||
@@ -260,7 +264,11 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
for axis in [None, 0, 1]:
|
||||
expected = a[:i, :j].var(axis).numpy()
|
||||
symbolic = a[:vi, :vj].var(axis).reshape(expected.shape).numpy()
|
||||
symbolic_result = a[:vi, :vj].var(axis)
|
||||
if axis is None:
|
||||
symbolic = symbolic_result.numpy()
|
||||
else:
|
||||
symbolic = symbolic_result[:expected.shape[0]].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_bitcast_down(self):
|
||||
@@ -268,7 +276,11 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
expected = a[:i].bitcast(dtypes.uint8).numpy()
|
||||
symbolic = a[:vi].bitcast(dtypes.uint8).reshape(expected.shape).numpy()
|
||||
symbolic_result = a[:vi].bitcast(dtypes.uint8)
|
||||
if len(expected.shape) == 2:
|
||||
symbolic = symbolic_result[:expected.shape[0], :expected.shape[1]].numpy()
|
||||
else:
|
||||
symbolic = symbolic_result[:].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), "no uint64")
|
||||
@@ -277,7 +289,11 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
expected = a[:i].bitcast(dtypes.uint64).numpy()
|
||||
symbolic = a[:vi].bitcast(dtypes.uint64).reshape(expected.shape).numpy()
|
||||
symbolic_result = a[:vi].bitcast(dtypes.uint64)
|
||||
if len(expected.shape) == 2:
|
||||
symbolic = symbolic_result[:expected.shape[0], :expected.shape[1]].numpy()
|
||||
else:
|
||||
symbolic = symbolic_result[:].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestTensorVariable(unittest.TestCase):
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
vv2 = Variable("b", 1, 10).bind(2)
|
||||
t = Tensor.ones(10, 10).contiguous()[:vv2, :vv]
|
||||
ret = t.mean(axis=1).reshape(2, 1).numpy()
|
||||
ret = t.mean(axis=1)[:2].reshape(2, 1).numpy()
|
||||
assert np.all(ret == 1)
|
||||
|
||||
def test_symbolic_mean_2d_add(self):
|
||||
@@ -66,25 +66,25 @@ class TestTensorVariable(unittest.TestCase):
|
||||
def test_symbolic_arange(self):
|
||||
vv = Variable("a", 1, 10)
|
||||
ret = Tensor.arange(0, vv.bind(4))
|
||||
self.assertListEqual(ret.reshape(4).tolist(), [0,1,2,3])
|
||||
self.assertListEqual(ret[:4].tolist(), [0,1,2,3])
|
||||
|
||||
def test_symbolic_arange_sym_start(self):
|
||||
vv = Variable("a", 1, 6)
|
||||
ret = Tensor.arange(vv.bind(4), 7)
|
||||
self.assertListEqual(ret.reshape(3).tolist(), [4,5,6])
|
||||
self.assertListEqual(ret[:3].tolist(), [4,5,6])
|
||||
|
||||
# TODO: add vmin/vmax pattern for symbolic denominator
|
||||
@unittest.expectedFailure
|
||||
def test_symbolic_arange_sym_step(self):
|
||||
vv = Variable("step", 1, 3)
|
||||
ret = Tensor.arange(0, 10, vv.bind(2))
|
||||
self.assertListEqual(ret.reshape(5).tolist(), [0,2,4,6,8])
|
||||
self.assertListEqual(ret[:5].tolist(), [0,2,4,6,8])
|
||||
|
||||
def test_symbolic_arange_two_vars(self):
|
||||
begin = Variable("b", 1, 5)
|
||||
end = Variable("e", 6, 10)
|
||||
ret = Tensor.arange(begin.bind(4), end.bind(7))
|
||||
self.assertListEqual(ret.reshape(3).tolist(), [4,5,6])
|
||||
self.assertListEqual(ret[:3].tolist(), [4,5,6])
|
||||
|
||||
def test_variable_empty(self):
|
||||
v = Variable("i", 1, 10)
|
||||
|
||||
@@ -95,7 +95,7 @@ class TestTiny(unittest.TestCase):
|
||||
ones = Tensor.ones(10).contiguous()
|
||||
for s in [2,5]:
|
||||
ret = ones[:i.bind(s)] + 1
|
||||
self.assertListEqual(ret.contiguous().reshape(s).tolist(), [2.0]*s)
|
||||
self.assertListEqual(ret.contiguous()[:s].tolist(), [2.0]*s)
|
||||
|
||||
def test_symbolic_reduce(self):
|
||||
i = Variable('i', 1, 10)
|
||||
|
||||
@@ -197,7 +197,7 @@ class TestSymbolicPad(unittest.TestCase):
|
||||
def test_pad(self):
|
||||
v = Variable("v", 1, 100).bind(5)
|
||||
t = Tensor.ones(100)[:v].pad(((4, 0),))
|
||||
t = t.reshape(9)
|
||||
t = t[:9]
|
||||
assert t.tolist() == [0,0,0,0,1,1,1,1,1]
|
||||
|
||||
|
||||
|
||||
@@ -120,7 +120,8 @@ def create_kernel(x:UOp, b:UOp|None=None):
|
||||
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
|
||||
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
||||
return buffer.assign(kernel).shrink(((0, prod(x.shape)),)).reshape(x.shape)
|
||||
# we have to shrink the buffer back to the symbolic shape
|
||||
return buffer.assign(kernel).reshape(tuple(d.vmax if isinstance(d, UOp) else d for d in x.shape)).shrink(tuple((0, d) for d in x.shape))
|
||||
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND}
|
||||
def append_to_kernel(x:UOp):
|
||||
@@ -148,6 +149,16 @@ create_kernels = PatternMatcher([
|
||||
lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
|
||||
])
|
||||
|
||||
def add_stores(ctx, sink: UOp):
|
||||
stores = []
|
||||
for i,x in enumerate(sink.src):
|
||||
gbl = UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i)
|
||||
# if this is an assign then we already have a buffer with a view that should be the target of the store
|
||||
if x.op is Ops.ASSIGN: stores.append(UOp.store(gbl.view(unwrap(s.st)), s))
|
||||
# otherwise we have to create the shapetracker and shrink it to the correct symbolic shape
|
||||
else: stores.append(
|
||||
UOp.store(gbl.reshape(tuple(int(d.vmax) if isinstance(d,UOp) else d for d in s.shape)).shrink(tuple((0,d) for d in s.shape)),s))
|
||||
return UOp.sink(*stores, arg=sink.arg)
|
||||
# **** fix kernel AST
|
||||
|
||||
def unbind_view(x:UOp):
|
||||
@@ -168,9 +179,7 @@ replace_buffers = PatternMatcher([
|
||||
# no SINK for meta ops
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
||||
# STORE (except for meta ops)
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)],
|
||||
arg=sink.arg)),
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), add_stores),
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
|
||||
@@ -312,10 +312,7 @@ class View:
|
||||
|
||||
if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")
|
||||
# check for the same size
|
||||
if all_int(self.shape):
|
||||
# reshapes cannot introduce symbolic shape
|
||||
assert all_int(new_shape), f"{self.shape=} -> {new_shape=} contains non int dims"
|
||||
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
if resolve(prod(self.shape) != prod(new_shape), True): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
|
||||
if 0 in self.shape: return View.create(new_shape)
|
||||
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
||||
|
||||
@@ -8,7 +8,8 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int, sint_to_uop
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int, sint_to_uop, \
|
||||
srender
|
||||
from tinygrad.uop.spec import tensor_uop_spec, type_verify
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -994,6 +995,8 @@ class Tensor(MathTrait):
|
||||
# resolve -1
|
||||
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
||||
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
||||
if resolve(prod(self.shape) != prod(new_shape), True):
|
||||
raise ValueError(f"size mismatch, can't reshape ({', '.join(srender(d) for d in self.shape)}) -> ({', '.join(srender(d) for d in new_shape)})")
|
||||
return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
|
||||
|
||||
def expand(self, shape, *args) -> Tensor:
|
||||
@@ -1174,6 +1177,9 @@ class Tensor(MathTrait):
|
||||
boundary, stride = [start, stop], step
|
||||
if all(isinstance(s, int) for s in (start,stop,step)):
|
||||
# handle int slicing
|
||||
# if we're slicing a symbolic dimension into a int dimension, we can slice untill the bind size
|
||||
# TODO: right now this is using vmax instead of the bind size because jit doesnt update the bound value of the returned tensor
|
||||
if isinstance(size, UOp): size = int(size.vmax)
|
||||
*boundary, stride = index.indices(cast(SupportsIndex, size))
|
||||
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
|
||||
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
|
||||
|
||||
Reference in New Issue
Block a user