Shrink instead of reshape to unregister symbolic (#12241)

* Slice to unbind symbolic

* use vmax for now

* assert shape in reshape is valid

* update test_symbolic_ops to use shrink instead of reshape

* remove infer_with_bound_values for npw

* symbolic output doesnt have symbolic strides

* symbolic jit tests use shrink to unregister symbolic

* update test

* update more tests

* wrap vmax in int()

* only create a new st if the store is not an assigne

* unwrap st

* comments
This commit is contained in:
Sieds Lykles
2025-09-19 06:04:35 +02:00
committed by GitHub
parent a531a649fb
commit cc038b31b6
9 changed files with 99 additions and 69 deletions

View File

@@ -50,7 +50,7 @@ class TestBeamSearch(unittest.TestCase):
def test_variable_shrink_prime_number(self):
v = Variable("v", 1, 400).bind(367)
a = rand(400, 367)
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
b = (a.shrink(((0,v), None))+1)[:367,:367].realize()
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
def test_no_mutate_rawbuffers(self):

View File

@@ -2,6 +2,7 @@ import unittest
from test.helpers import assert_jit_cache_len
from tinygrad import Variable, Tensor, TinyJit
from tinygrad.helpers import RANGEIFY
import numpy as np
class TestSymbolicJit(unittest.TestCase):
@@ -11,7 +12,7 @@ class TestSymbolicJit(unittest.TestCase):
a = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(a[:, :vi]).reshape(3, i).numpy()
symbolic = jf(a[:, :vi])[:3, :i].numpy()
expected = f(a[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -26,7 +27,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a[:, :vi]).numpy()
expected = f(a[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 2) # one add and one pad, can be one kernel?
assert_jit_cache_len(jf, 1 if RANGEIFY else 2) # one add and one pad, can be one kernel?
def test_add(self):
def f(a, b): return (a+b).realize()
@@ -35,7 +36,8 @@ class TestSymbolicJit(unittest.TestCase):
b = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(a[:, :vi], b[:, :vi]).reshape(3, i).numpy()
symbolic = jf(a[:, :vi], b[:, :vi])
symbolic = symbolic[:3, :i].numpy()
expected = f(a[:, :i], b[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -75,10 +77,10 @@ class TestSymbolicJit(unittest.TestCase):
v = Tensor.rand(2, 10, 4, 8)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(q, k[:, :vi], v[:, :vi]).reshape(2, 4, 1, 8).numpy()
symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy()
expected = f(q, k[:, :i], v[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 5)
assert_jit_cache_len(jf, 4 if RANGEIFY else 5)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
@@ -87,7 +89,7 @@ class TestSymbolicJit(unittest.TestCase):
b = Tensor.rand(2, 3)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(a[:vi], b).reshape(i+2, 3).numpy()
symbolic = jf(a[:vi], b)[:i+2, :3].numpy()
expected = f(a[:i], b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -99,7 +101,7 @@ class TestSymbolicJit(unittest.TestCase):
b = Tensor.rand(3, 2)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(a[:, :vi], b).reshape(3, i+2).numpy()
symbolic = jf(a[:, :vi], b)[:3, :i+2].numpy()
expected = f(a[:, :i], b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -113,7 +115,7 @@ class TestSymbolicJit(unittest.TestCase):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:vi], b[:vj]).reshape(i+j, 3).numpy()
symbolic = jf(a[:vi], b[:vj])[:i+j, :3].numpy()
expected = f(a[:i], b[:j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -127,7 +129,7 @@ class TestSymbolicJit(unittest.TestCase):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
symbolic = jf(a[:, :vi], b[:, :vj])[:3, :i+j].numpy()
expected = f(a[:, :i], b[:, :j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -141,7 +143,7 @@ class TestSymbolicJit(unittest.TestCase):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
symbolic = jf(a[:vi, :], b[:, :vj])[:i, :j].numpy()
expected = f(a[:i, :], b[:, :j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -155,7 +157,7 @@ class TestSymbolicJit(unittest.TestCase):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = jf(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
symbolic = jf(a[:vj, :], b[:, :vi])[:j, :i].numpy()
expected = f(a[:j, :], b[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -207,8 +209,8 @@ class TestSymbolicJit(unittest.TestCase):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.ones(vi, 11).contiguous()
symbolic = a[:, 1:2]
symbolic = jf(symbolic).reshape(i, 1).numpy()
expected = f(a.reshape(i, 11)[:, 1:2]).numpy()
symbolic = jf(symbolic)[:i, :1].numpy()
expected = f(a[:i, :][:, 1:2]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -243,7 +245,7 @@ class TestSymbolicJit(unittest.TestCase):
expected = b[:i].mean(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 1
symbolic = jf1(c[:vi]).reshape(i).numpy()
symbolic = jf1(c[:vi])[:i].numpy()
expected = c[:i].mean(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -266,11 +268,11 @@ class TestSymbolicJit(unittest.TestCase):
expected = a[:i, :j].mean().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 0
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
symbolic = jf0(b[:vi, :vj])[:j].numpy()
expected = b[:i, :j].mean(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 1
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
symbolic = jf1(c[:vi, :vj])[:i].numpy()
expected = c[:i, :j].mean(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -295,7 +297,7 @@ class TestSymbolicJit(unittest.TestCase):
expected = b[:i].var(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 1
symbolic = jf1(c[:vi]).reshape(i).numpy()
symbolic = jf1(c[:vi])[:i].numpy()
expected = c[:i].var(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -318,11 +320,11 @@ class TestSymbolicJit(unittest.TestCase):
expected = a[:i, :j].var().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 0
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
symbolic = jf0(b[:vi, :vj])[:j].numpy()
expected = b[:i, :j].var(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# axis = 1
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
symbolic = jf1(c[:vi, :vj])[:i].numpy()
expected = c[:i, :j].var(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

View File

@@ -13,7 +13,7 @@ class TestSymbolicOps(unittest.TestCase):
a = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = f(a[:, :vi]).reshape(3, i).numpy()
symbolic = f(a[:, :vi])[:3, :i].numpy()
expected = f(a[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -32,7 +32,7 @@ class TestSymbolicOps(unittest.TestCase):
b = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = f(a[:, :vi], b[:, :vi]).reshape(3, i).numpy()
symbolic = f(a[:, :vi], b[:, :vi])[:, :i].numpy()
expected = f(a[:, :i], b[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -55,7 +55,7 @@ class TestSymbolicOps(unittest.TestCase):
vi = Variable("i", 1, 10).bind(i) if use_symbolic else i
Tensor.realize(q, k, v)
GlobalCounters.reset()
symbolic = f(q, k[:, :vi, :, :], v[:, :vi, :, :]).reshape(2, 4, 1, 8).numpy()
symbolic = f(q, k[:, :vi, :, :], v[:, :vi, :, :])[:2, :4, :1, :8].numpy()
expected = f(q, k[:, :i, :, :], v[:, :i, :, :]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -94,7 +94,7 @@ class TestSymbolicOps(unittest.TestCase):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
b = Tensor.rand(2, 3)
symbolic = f(a[:vi, :], b).reshape(i+2, 3).numpy()
symbolic = f(a[:vi, :], b)[:i+2, :3].numpy()
expected = f(a[:i, :], b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -104,7 +104,7 @@ class TestSymbolicOps(unittest.TestCase):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
b = Tensor.rand(3, 2)
symbolic = f(a[:, :vi], b).reshape(3, i+2).numpy()
symbolic = f(a[:, :vi], b)[:3, :i+2].numpy()
expected = f(a[:, :i], b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -116,7 +116,7 @@ class TestSymbolicOps(unittest.TestCase):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = f(a[:vi, :], b[:vj, :]).reshape(i+j, 3).numpy()
symbolic = f(a[:vi, :], b[:vj, :])[:i+j, :3].numpy()
expected = f(a[:i, :], b[:j, :]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -128,50 +128,41 @@ class TestSymbolicOps(unittest.TestCase):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = f(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
symbolic = f(a[:, :vi], b[:, :vj])[:3, :i+j].numpy()
expected = f(a[:, :i], b[:, :j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_two_vars_plus1_ij(self):
def f(a, b): return (a@b+1).realize()
a = Tensor.rand(10, 3)
b = Tensor.rand(3, 10)
a = Tensor.rand(10, 3).realize()
b = Tensor.rand(3, 10).realize()
for i in range(2, 5):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = f(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
symbolic = f(a[:vi, :], b[:, :vj])[:i, :j].numpy()
expected = f(a[:i, :], b[:, :j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_two_vars_plus1_ji(self):
# reverse the order of variables
def f(a, b): return (a@b+1).realize()
a = Tensor.rand(10, 3)
b = Tensor.rand(3, 10)
a = Tensor.rand(10, 3).realize()
b = Tensor.rand(3, 10).realize()
for i in range(2, 5):
for j in range(2, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
symbolic = f(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
symbolic = f(a[:vj, :], b[:, :vi])[:j, :i].numpy()
expected = f(a[:j, :], b[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_reshape_from_symbolic(self):
a = Tensor.rand(30)
for i in range(3, 5):
vi = Variable("i", 3, 10).bind(i)
symbolic = a[:vi*3].reshape((3, 3)).numpy()
# To match symbolic reshape (potential implicit shrink), we need a shrink
expected = a[:i*3].shrink(((0, 9),)).reshape((3, 3)).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_invalid_symbolic_reshape(self):
a = Tensor.rand(30)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
# Cannot reshape into symbolic from non-symbolic
with self.assertRaises(AssertionError): a.reshape((3, vi))
with self.assertRaises(ValueError): a.reshape((3, vi))
def test_shrink(self):
for i in range(1, 5):
@@ -187,6 +178,7 @@ class TestSymbolicOps(unittest.TestCase):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(7, 11)
symbolic = a[3:5, vi:vi+2]
print(symbolic.shape)
symbolic = symbolic.numpy()
expected = a[3:5, i:i+2].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -195,7 +187,7 @@ class TestSymbolicOps(unittest.TestCase):
a = Tensor.rand(7, 11)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = a[3:5, :vi:1].reshape(2, i).numpy()
symbolic = a[3:5, :vi:1][:2, :i].numpy()
expected = a[3:5, :i:1].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -203,7 +195,7 @@ class TestSymbolicOps(unittest.TestCase):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor(1).unsqueeze(0).pad((0, 1)).unsqueeze(0)
symbolic = a.expand(vi, 2).reshape(i, 2).numpy()
symbolic = a.expand(vi, 2)[:i, :2].numpy()
expected = a.expand(i, 2).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -211,8 +203,8 @@ class TestSymbolicOps(unittest.TestCase):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.ones(vi, 11).contiguous()
symbolic = a[:, 1:2].reshape(i, 1).numpy()
expected = a.reshape(i, 11)[:, 1:2].numpy()
symbolic = a[:, 1:2][:i, :1].numpy()
expected = Tensor.ones(i, 11)[:, 1:2].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_ones_sum(self):
@@ -229,7 +221,11 @@ class TestSymbolicOps(unittest.TestCase):
vi = Variable("i", 1, 10).bind(i)
for axis in [None, 0, 1]:
expected = a[:i].mean(axis).numpy()
symbolic = a[:vi].mean(axis).reshape(expected.shape).numpy()
symbolic = a[:vi].mean(axis)
if axis is None:
symbolic = symbolic.numpy()
else:
symbolic = symbolic[:expected.shape[0]].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_mean_2d(self):
@@ -240,7 +236,11 @@ class TestSymbolicOps(unittest.TestCase):
vj = Variable("j", 1, 10).bind(j)
for axis in [None, 0, 1]:
expected = a[:i, :j].mean(axis).numpy()
symbolic = a[:vi, :vj].mean(axis).reshape(expected.shape).numpy()
symbolic = a[:vi, :vj].mean(axis)
if axis is None:
symbolic = symbolic.numpy()
else:
symbolic = symbolic[:expected.shape[0]].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_var(self):
@@ -249,7 +249,11 @@ class TestSymbolicOps(unittest.TestCase):
vi = Variable("i", 1, 10).bind(i)
for axis in [None, 0, 1]:
expected = a[:i].var(axis).numpy()
symbolic = a[:vi].var(axis).reshape(expected.shape).numpy()
symbolic = a[:vi].var(axis)
if axis is None:
symbolic = symbolic.numpy()
else:
symbolic = symbolic[:expected.shape[0]].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_var_2d(self):
@@ -260,7 +264,11 @@ class TestSymbolicOps(unittest.TestCase):
vj = Variable("j", 1, 10).bind(j)
for axis in [None, 0, 1]:
expected = a[:i, :j].var(axis).numpy()
symbolic = a[:vi, :vj].var(axis).reshape(expected.shape).numpy()
symbolic_result = a[:vi, :vj].var(axis)
if axis is None:
symbolic = symbolic_result.numpy()
else:
symbolic = symbolic_result[:expected.shape[0]].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_bitcast_down(self):
@@ -268,7 +276,11 @@ class TestSymbolicOps(unittest.TestCase):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
expected = a[:i].bitcast(dtypes.uint8).numpy()
symbolic = a[:vi].bitcast(dtypes.uint8).reshape(expected.shape).numpy()
symbolic_result = a[:vi].bitcast(dtypes.uint8)
if len(expected.shape) == 2:
symbolic = symbolic_result[:expected.shape[0], :expected.shape[1]].numpy()
else:
symbolic = symbolic_result[:].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), "no uint64")
@@ -277,7 +289,11 @@ class TestSymbolicOps(unittest.TestCase):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
expected = a[:i].bitcast(dtypes.uint64).numpy()
symbolic = a[:vi].bitcast(dtypes.uint64).reshape(expected.shape).numpy()
symbolic_result = a[:vi].bitcast(dtypes.uint64)
if len(expected.shape) == 2:
symbolic = symbolic_result[:expected.shape[0], :expected.shape[1]].numpy()
else:
symbolic = symbolic_result[:].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
@unittest.expectedFailure

View File

@@ -38,7 +38,7 @@ class TestTensorVariable(unittest.TestCase):
vv = Variable("a", 1, 10).bind(2)
vv2 = Variable("b", 1, 10).bind(2)
t = Tensor.ones(10, 10).contiguous()[:vv2, :vv]
ret = t.mean(axis=1).reshape(2, 1).numpy()
ret = t.mean(axis=1)[:2].reshape(2, 1).numpy()
assert np.all(ret == 1)
def test_symbolic_mean_2d_add(self):
@@ -66,25 +66,25 @@ class TestTensorVariable(unittest.TestCase):
def test_symbolic_arange(self):
vv = Variable("a", 1, 10)
ret = Tensor.arange(0, vv.bind(4))
self.assertListEqual(ret.reshape(4).tolist(), [0,1,2,3])
self.assertListEqual(ret[:4].tolist(), [0,1,2,3])
def test_symbolic_arange_sym_start(self):
vv = Variable("a", 1, 6)
ret = Tensor.arange(vv.bind(4), 7)
self.assertListEqual(ret.reshape(3).tolist(), [4,5,6])
self.assertListEqual(ret[:3].tolist(), [4,5,6])
# TODO: add vmin/vmax pattern for symbolic denominator
@unittest.expectedFailure
def test_symbolic_arange_sym_step(self):
vv = Variable("step", 1, 3)
ret = Tensor.arange(0, 10, vv.bind(2))
self.assertListEqual(ret.reshape(5).tolist(), [0,2,4,6,8])
self.assertListEqual(ret[:5].tolist(), [0,2,4,6,8])
def test_symbolic_arange_two_vars(self):
begin = Variable("b", 1, 5)
end = Variable("e", 6, 10)
ret = Tensor.arange(begin.bind(4), end.bind(7))
self.assertListEqual(ret.reshape(3).tolist(), [4,5,6])
self.assertListEqual(ret[:3].tolist(), [4,5,6])
def test_variable_empty(self):
v = Variable("i", 1, 10)

View File

@@ -95,7 +95,7 @@ class TestTiny(unittest.TestCase):
ones = Tensor.ones(10).contiguous()
for s in [2,5]:
ret = ones[:i.bind(s)] + 1
self.assertListEqual(ret.contiguous().reshape(s).tolist(), [2.0]*s)
self.assertListEqual(ret.contiguous()[:s].tolist(), [2.0]*s)
def test_symbolic_reduce(self):
i = Variable('i', 1, 10)

View File

@@ -197,7 +197,7 @@ class TestSymbolicPad(unittest.TestCase):
def test_pad(self):
v = Variable("v", 1, 100).bind(5)
t = Tensor.ones(100)[:v].pad(((4, 0),))
t = t.reshape(9)
t = t[:9]
assert t.tolist() == [0,0,0,0,1,1,1,1,1]

View File

@@ -120,7 +120,8 @@ def create_kernel(x:UOp, b:UOp|None=None):
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
return buffer.assign(kernel).shrink(((0, prod(x.shape)),)).reshape(x.shape)
# we have to shrink the buffer back to the symbolic shape
return buffer.assign(kernel).reshape(tuple(d.vmax if isinstance(d, UOp) else d for d in x.shape)).shrink(tuple((0, d) for d in x.shape))
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND}
def append_to_kernel(x:UOp):
@@ -148,6 +149,16 @@ create_kernels = PatternMatcher([
lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
])
def add_stores(ctx, sink: UOp):
stores = []
for i,x in enumerate(sink.src):
gbl = UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i)
# if this is an assign then we already have a buffer with a view that should be the target of the store
if x.op is Ops.ASSIGN: stores.append(UOp.store(gbl.view(unwrap(s.st)), s))
# otherwise we have to create the shapetracker and shrink it to the correct symbolic shape
else: stores.append(
UOp.store(gbl.reshape(tuple(int(d.vmax) if isinstance(d,UOp) else d for d in s.shape)).shrink(tuple((0,d) for d in s.shape)),s))
return UOp.sink(*stores, arg=sink.arg)
# **** fix kernel AST
def unbind_view(x:UOp):
@@ -168,9 +179,7 @@ replace_buffers = PatternMatcher([
# no SINK for meta ops
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
# STORE (except for meta ops)
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)],
arg=sink.arg)),
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), add_stores),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),

View File

@@ -312,10 +312,7 @@ class View:
if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")
# check for the same size
if all_int(self.shape):
# reshapes cannot introduce symbolic shape
assert all_int(new_shape), f"{self.shape=} -> {new_shape=} contains non int dims"
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
if resolve(prod(self.shape) != prod(new_shape), True): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
if 0 in self.shape: return View.create(new_shape)
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None

View File

@@ -8,7 +8,8 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION
from tinygrad.gradient import compute_gradient
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int, sint_to_uop
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int, sint_to_uop, \
srender
from tinygrad.uop.spec import tensor_uop_spec, type_verify
from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
@@ -994,6 +995,8 @@ class Tensor(MathTrait):
# resolve -1
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
if resolve(prod(self.shape) != prod(new_shape), True):
raise ValueError(f"size mismatch, can't reshape ({', '.join(srender(d) for d in self.shape)}) -> ({', '.join(srender(d) for d in new_shape)})")
return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
def expand(self, shape, *args) -> Tensor:
@@ -1174,6 +1177,9 @@ class Tensor(MathTrait):
boundary, stride = [start, stop], step
if all(isinstance(s, int) for s in (start,stop,step)):
# handle int slicing
# if we're slicing a symbolic dimension into a int dimension, we can slice untill the bind size
# TODO: right now this is using vmax instead of the bind size because jit doesnt update the bound value of the returned tensor
if isinstance(size, UOp): size = int(size.vmax)
*boundary, stride = index.indices(cast(SupportsIndex, size))
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]