[bounty] Remove using reshape to register symbolic shape (#11771)

* Modify tests and start work towards removing symbolic reshape

* Refactor symbolic reshape

* fix small error

* much cleaner + fix more tests

* Can remove this now

* Update test_symbolic_ops and test_tiny

* Couple more tests

* Unused import

* More tests and add EXPAND to Tensor.empty

* Fix test beam search

* all int

* Fix rangeify by adding shrink

* Remove OOB check and so fix test_symbolic_jit

* test_symbolic_jit doesn't need OOB Context anymore either

* Should remove that test now

* Cleanups part 1

* fix linters

* Final cleanups

* Don't reassign inside for loop

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Ben Waldron
2025-08-28 16:30:49 +00:00
committed by GitHub
parent 53853ae49b
commit ea1be2e4cd
13 changed files with 281 additions and 358 deletions

View File

@@ -16,9 +16,9 @@ class TestBeamSearch(unittest.TestCase):
BEAM.value = self.old_beam
def test_variable_ast_beam(self):
with Context(IGNORE_OOB=1):
a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
a = (a+1).realize()
vi = Variable("a", 1, 10).bind(3)
a = rand(10, 3)[:vi]
a = (a+1).realize()
def test_big_prime_number(self):
a = rand(367, 367)
@@ -42,18 +42,16 @@ class TestBeamSearch(unittest.TestCase):
def test_variable_big_prime_number(self):
v = Variable("v", 1, 400).bind(367)
a = rand(367, 367)
b = rand(367, 367)
with Context(IGNORE_OOB=1):
c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
a = rand(367, 400)
b = rand(400, 367)
c = (a[:, :v] @ b[:v, :]).realize()
np.testing.assert_allclose(c.numpy(), a[:, :367].numpy() @ b[:367, :].numpy(), atol=1e-4, rtol=1e-4)
def test_variable_shrink_prime_number(self):
v = Variable("v", 1, 400).bind(367)
a = rand(400, 367)
with Context(IGNORE_OOB=1):
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
def test_no_mutate_rawbuffers(self):
a = rand(3, 3).realize()

View File

@@ -2,50 +2,41 @@ import unittest
from test.helpers import assert_jit_cache_len
from tinygrad import Variable, Tensor, TinyJit
from tinygrad.helpers import Context
import numpy as np
class TestSymbolicJit(unittest.TestCase):
def setUp(self):
# A lot of these test are out of bounds, so we ignore the bounds check
self.context = Context(IGNORE_OOB=1)
self.context.__enter__()
def tearDown(self):
self.context.__exit__(None, None, None)
def test_plus1(self):
def f(a): return (a+1).realize()
jf = TinyJit(f)
a = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy()
expected = f(a).numpy()
symbolic = jf(a[:, :vi]).reshape(3, i).numpy()
expected = f(a[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_add(self):
def f(a, b): return (a+b).realize()
jf = TinyJit(f)
a = Tensor.rand(3, 10)
b = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(3, i)
symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
expected = f(a, b).numpy()
symbolic = jf(a[:, :vi], b[:, :vi]).reshape(3, i).numpy()
expected = f(a[:, :i], b[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_matmul(self):
def f(a, b): return (a@b).realize()
jf = TinyJit(f)
a = Tensor.rand(3, 10)
b = Tensor.rand(10, 5)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
expected = f(a, b).numpy()
symbolic = jf(a[:, :vi], b[:vi, :]).numpy()
expected = f(a[:, :i], b[:i, :]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
@@ -55,119 +46,119 @@ class TestSymbolicJit(unittest.TestCase):
s = (s+s).realize() # this one does not have symbols in input
return s
jf = TinyJit(f)
a = Tensor.rand(3, 10)
b = Tensor.rand(10, 5)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
expected = f(a, b).numpy()
symbolic = jf(a[:, :vi], b[:vi, :]).numpy()
expected = f(a[:, :i], b[:i, :]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 2)
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
jf = TinyJit(f)
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, 10, 4, 8)
v = Tensor.rand(2, 10, 4, 8)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, i, 4, 8)
v = Tensor.rand(2, i, 4, 8)
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
expected = f(q, k, v).numpy()
symbolic = jf(q, k[:, :vi], v[:, :vi]).reshape(2, 4, 1, 8).numpy()
expected = f(q, k[:, :i], v[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 5)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
jf = TinyJit(f)
a = Tensor.rand(10, 3)
b = Tensor.rand(2, 3)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(i, 3)
b = Tensor.rand(2, 3)
symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
expected = f(a, b).numpy()
symbolic = jf(a[:vi], b).reshape(i+2, 3).numpy()
expected = f(a[:i], b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_cat_dim1(self):
def f(a, b): return a.cat(b, dim=1).realize()
jf = TinyJit(f)
a = Tensor.rand(3, 10)
b = Tensor.rand(3, 2)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(3, 2)
symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy()
expected = f(a, b).numpy()
symbolic = jf(a[:, :vi], b).reshape(3, i+2).numpy()
expected = f(a[:, :i], b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_cat_dim0_two_vars(self):
def f(a, b): return a.cat(b, dim=0).realize()
jf = TinyJit(f)
a = Tensor.rand(10, 3)
b = Tensor.rand(10, 3)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(i, 3)
b = Tensor.rand(j, 3)
symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
expected = f(a, b).numpy()
symbolic = jf(a[:vi], b[:vj]).reshape(i+j, 3).numpy()
expected = f(a[:i], b[:j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_cat_dim1_two_vars(self):
def f(a, b): return a.cat(b, dim=1).realize()
jf = TinyJit(f)
a = Tensor.rand(3, 10)
b = Tensor.rand(3, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(3, i)
b = Tensor.rand(3, j)
symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
expected = f(a, b).numpy()
symbolic = jf(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
expected = f(a[:, :i], b[:, :j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_two_vars_plus1_ij(self):
def f(a, b): return (a@b+1).realize()
jf = TinyJit(f)
a = Tensor.rand(10, 3)
b = Tensor.rand(3, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(i, 3)
b = Tensor.rand(3, j)
symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
expected = f(a, b).numpy()
symbolic = jf(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
expected = f(a[:i, :], b[:, :j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_two_vars_plus1_ji(self):
def f(a, b): return (a@b+1).realize()
jf = TinyJit(f)
a = Tensor.rand(10, 3)
b = Tensor.rand(3, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(j, 3)
b = Tensor.rand(3, i)
symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy()
expected = f(a, b).numpy()
symbolic = jf(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
expected = f(a[:j, :], b[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert_jit_cache_len(jf, 1)
def test_jit_symbolic_shape_mismatch(self):
@TinyJit
def add(a, b): return (a+b).realize()
a = Tensor.rand(3, 10)
b = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i).reshape(3, vi)
b = Tensor.rand(3, i).reshape(3, vi)
add(a, b)
add(a[:, :vi], b[:, :vi])
vi2 = Variable("i", 1, 10).bind(7)
a = Tensor.rand(3, 7).reshape(3, vi2)
bad = Tensor.rand(4, 7).reshape(4, vi2)
a = Tensor.rand(3, 7)[:, :vi2]
bad = Tensor.rand(4, 7)[:, :vi2]
with self.assertRaises(AssertionError):
add(a, bad)
@@ -175,9 +166,9 @@ class TestSymbolicJit(unittest.TestCase):
# shrink is a movement, so we pair it with a simple function to test the JIT interaction
def f(a): return (a+1).realize()
jf = TinyJit(f)
a = Tensor.rand(7, 11)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(7, 11)
symbolic = a.shrink(((3,5),(vi,vi+2)))
symbolic = jf(symbolic).numpy()
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
@@ -188,9 +179,9 @@ class TestSymbolicJit(unittest.TestCase):
# slice is a movement, so we pair it with a simple function to test the JIT interaction
def f(a): return (a+1).realize()
jf = TinyJit(f)
a = Tensor.rand(7, 11)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(7, 11)
symbolic = a[3:5, vi:vi+2]
symbolic = jf(symbolic).numpy()
expected = f(a[3:5, i:i+2]).numpy()
@@ -212,11 +203,11 @@ class TestSymbolicJit(unittest.TestCase):
def test_ones_sum(self):
def f(a): return a.sum().realize()
jf = TinyJit(f)
t = Tensor.ones(10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
t = Tensor.ones(i)
symbolic = jf(t.reshape(vi)).item()
expected = f(t).item()
symbolic = jf(t[:vi]).item()
expected = f(t[:i]).item()
np.testing.assert_equal(symbolic, expected)
def test_mean(self):
@@ -226,22 +217,22 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f)
jf0 = TinyJit(f0)
jf1 = TinyJit(f1)
a = Tensor.rand(10, 3)
b = Tensor.rand(10, 3)
c = Tensor.rand(10, 3)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
# aixs = None
a = Tensor.rand(i, 3)
symbolic = jf(a.reshape(vi, 3)).numpy()
expected = a.mean().numpy()
# axis = None
symbolic = jf(a[:vi]).numpy()
expected = a[:i].mean().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# aixs = 0
a = Tensor.rand(i, 3)
symbolic = jf0(a.reshape(vi, 3)).numpy()
expected = a.mean(0).numpy()
# axis = 0
symbolic = jf0(b[:vi]).numpy()
expected = b[:i].mean(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# aixs = 1
a = Tensor.rand(i, 3)
symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy()
expected = a.mean(1).numpy()
# axis = 1
symbolic = jf1(c[:vi]).reshape(i).numpy()
expected = c[:i].mean(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_mean_2d(self):
@@ -251,24 +242,24 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f)
jf0 = TinyJit(f0)
jf1 = TinyJit(f1)
a = Tensor.rand(10, 10)
b = Tensor.rand(10, 10)
c = Tensor.rand(10, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
# aixs = None
a = Tensor.rand(i, j)
symbolic = jf(a.reshape(vi, vj)).numpy()
expected = a.mean().numpy()
# axis = None
symbolic = jf(a[:vi, :vj]).numpy()
expected = a[:i, :j].mean().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# aixs = 0
a = Tensor.rand(i, j)
symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy()
expected = a.mean(0).numpy()
# axis = 0
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
expected = b[:i, :j].mean(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# aixs = 1
a = Tensor.rand(i, j)
symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy()
expected = a.mean(1).numpy()
# axis = 1
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
expected = c[:i, :j].mean(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_var(self):
@@ -278,22 +269,22 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f)
jf0 = TinyJit(f0)
jf1 = TinyJit(f1)
a = Tensor.rand(10, 3)
b = Tensor.rand(10, 3)
c = Tensor.rand(10, 3)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
# aixs = None
a = Tensor.rand(i, 3)
symbolic = jf(a.reshape(vi, 3)).numpy()
expected = a.var().numpy()
# axis = None
symbolic = jf(a[:vi]).numpy()
expected = a[:i].var().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# aixs = 0
a = Tensor.rand(i, 3)
symbolic = jf0(a.reshape(vi, 3)).numpy()
expected = a.var(0).numpy()
# axis = 0
symbolic = jf0(b[:vi]).numpy()
expected = b[:i].var(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# aixs = 1
a = Tensor.rand(i, 3)
symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy()
expected = a.var(1).numpy()
# axis = 1
symbolic = jf1(c[:vi]).reshape(i).numpy()
expected = c[:i].var(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_var_2d(self):
@@ -303,24 +294,24 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f)
jf0 = TinyJit(f0)
jf1 = TinyJit(f1)
a = Tensor.rand(10, 10)
b = Tensor.rand(10, 10)
c = Tensor.rand(10, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
# aixs = None
a = Tensor.rand(i, j)
symbolic = jf(a.reshape(vi, vj)).numpy()
expected = a.var().numpy()
# axis = None
symbolic = jf(a[:vi, :vj]).numpy()
expected = a[:i, :j].var().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# aixs = 0
a = Tensor.rand(i, j)
symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy()
expected = a.var(0).numpy()
# axis = 0
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
expected = b[:i, :j].var(0).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
# aixs = 1
a = Tensor.rand(i, j)
symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy()
expected = a.var(1).numpy()
# axis = 1
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
expected = c[:i, :j].var(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
if __name__ == '__main__':

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad import Tensor, Variable
from tinygrad.shape.shapetracker import View
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.uop.ops import sym_infer
from tinygrad.dtype import dtypes
from tinygrad.device import Device
@@ -9,54 +9,46 @@ from examples.gpt2 import Attention
import numpy as np
class TestSymbolicOps(unittest.TestCase):
def setUp(self):
# A lot of these test are out of bounds, so we ignore the bounds check
self.context = Context(IGNORE_OOB=1)
self.context.__enter__()
def tearDown(self):
self.context.__exit__(None, None, None)
def test_plus1(self):
def f(a): return (a+1).realize()
a = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
symbolic = f(a.reshape(3, vi)).reshape(3, i).numpy()
expected = f(a).numpy()
symbolic = f(a[:, :vi]).reshape(3, i).numpy()
expected = f(a[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_add(self):
def f(a, b): return (a+b).realize()
a = Tensor.rand(3, 10)
b = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(3, i)
symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
expected = f(a, b).numpy()
symbolic = f(a[:, :vi], b[:, :vi]).reshape(3, i).numpy()
expected = f(a[:, :i], b[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_matmul(self):
def f(a, b): return (a@b).realize()
a = Tensor.rand(3, 10)
b = Tensor.rand(10, 5)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
expected = f(a, b).numpy()
symbolic = f(a[:, :vi], b[:vi, :]).numpy()
expected = f(a[:, :i], b[:i, :]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_attention(self, dropout_p=0.0, imin=1, imax=5, use_symbolic=True):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize()
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, 10, 4, 8)
v = Tensor.rand(2, 10, 4, 8)
for i in range(imin, imax):
vi = Variable("i", 1, 10).bind(i) if use_symbolic else i
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, i, 4, 8)
v = Tensor.rand(2, i, 4, 8)
Tensor.realize(q, k, v)
GlobalCounters.reset()
symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
expected = f(q, k, v).numpy()
symbolic = f(q, k[:, :vi, :, :], v[:, :vi, :, :]).reshape(2, 4, 1, 8).numpy()
expected = f(q, k[:, :i, :, :], v[:, :i, :, :]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_attention_cmp_symbolic(self):
@@ -90,73 +82,89 @@ class TestSymbolicOps(unittest.TestCase):
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
a = Tensor.rand(10, 3)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(i, 3)
b = Tensor.rand(2, 3)
symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
expected = f(a, b).numpy()
symbolic = f(a[:vi, :], b).reshape(i+2, 3).numpy()
expected = f(a[:i, :], b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim1(self):
def f(a, b): return a.cat(b, dim=1).realize()
a = Tensor.rand(3, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(3, 2)
symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).numpy()
expected = f(a, b).numpy()
symbolic = f(a[:, :vi], b).reshape(3, i+2).numpy()
expected = f(a[:, :i], b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim0_two_vars(self):
def f(a, b): return a.cat(b, dim=0).realize()
a = Tensor.rand(10, 3)
b = Tensor.rand(10, 3)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(i, 3)
b = Tensor.rand(j, 3)
symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
expected = f(a, b).numpy()
symbolic = f(a[:vi, :], b[:vj, :]).reshape(i+j, 3).numpy()
expected = f(a[:i, :], b[:j, :]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim1_two_vars(self):
def f(a, b): return a.cat(b, dim=1).realize()
a = Tensor.rand(3, 10)
b = Tensor.rand(3, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(3, i)
b = Tensor.rand(3, j)
symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
expected = f(a, b).numpy()
symbolic = f(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
expected = f(a[:, :i], b[:, :j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_two_vars_plus1_ij(self):
def f(a, b): return (a@b+1).realize()
a = Tensor.rand(10, 3)
b = Tensor.rand(3, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(i, 3)
b = Tensor.rand(3, j)
symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
expected = f(a, b).numpy()
symbolic = f(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
expected = f(a[:i, :], b[:, :j]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_two_vars_plus1_ji(self):
# reverse the order of variables
def f(a, b): return (a@b+1).realize()
a = Tensor.rand(10, 3)
b = Tensor.rand(3, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(j, 3)
b = Tensor.rand(3, i)
symbolic = f(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy()
expected = f(a, b).numpy()
symbolic = f(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
expected = f(a[:j, :], b[:, :i]).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_reshape_from_symbolic(self):
a = Tensor.rand(30)
for i in range(3, 5):
vi = Variable("i", 3, 10).bind(i)
symbolic = a[:vi*3].reshape((3, 3)).numpy()
# To match symbolic reshape (potential implicit shrink), we need a shrink
expected = a[:i*3].shrink(((0, 9),)).reshape((3, 3)).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_invalid_symbolic_reshape(self):
a = Tensor.rand(30)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
# Cannot reshape into symbolic from non-symbolic
with self.assertRaises(AssertionError): a.reshape((3, vi))
def test_shrink(self):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
@@ -176,11 +184,10 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_slice_no_start(self):
a = Tensor.rand(7, 11)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(7, 11)
symbolic = a[3:5, :vi:1].reshape(2,i)
symbolic = symbolic.numpy()
symbolic = a[3:5, :vi:1].reshape(2, i).numpy()
expected = a[3:5, :i:1].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
@@ -201,31 +208,31 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_ones_sum(self):
t = Tensor.ones(10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
t = Tensor.ones(i)
symbolic = t.reshape(vi).sum().item()
expected = t.sum().item()
symbolic = t[:vi].sum().item()
expected = t[:i].sum().item()
np.testing.assert_equal(symbolic, expected)
def test_mean(self):
a = Tensor.rand(10, 3)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
for axis in [None, 0, 1]:
a = Tensor.rand(i, 3)
expected = a.mean(axis).numpy()
symbolic = a.reshape(vi, 3).mean(axis).reshape(expected.shape).numpy()
expected = a[:i].mean(axis).numpy()
symbolic = a[:vi].mean(axis).reshape(expected.shape).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_mean_2d(self):
a = Tensor.rand(10, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
for axis in [None, 0, 1]:
a = Tensor.rand(i, j)
expected = a.mean(axis).numpy()
symbolic = a.reshape(vi, vj).mean(axis).reshape(expected.shape).numpy()
expected = a[:i, :j].mean(axis).numpy()
symbolic = a[:vi, :vj].mean(axis).reshape(expected.shape).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_var(self):
@@ -233,43 +240,43 @@ class TestSymbolicOps(unittest.TestCase):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
for axis in [None, 0, 1]:
expected = a[:i, :].var(axis).numpy()
symbolic = a[:vi, :].var(axis).reshape(expected.shape).numpy()
expected = a[:i].var(axis).numpy()
symbolic = a[:vi].var(axis).reshape(expected.shape).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_var_2d(self):
a = Tensor.rand(10, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
for axis in [None, 0, 1]:
a = Tensor.rand(i, j)
expected = a.var(axis).numpy()
symbolic = a.reshape(vi, vj).var(axis).reshape(expected.shape).numpy()
expected = a[:i, :j].var(axis).numpy()
symbolic = a[:vi, :vj].var(axis).reshape(expected.shape).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_bitcast_down(self):
a = Tensor.rand(10, 3)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(i, 3)
expected = a.bitcast(dtypes.uint8).numpy()
symbolic = a.reshape(vi, 3).bitcast(dtypes.uint8).reshape(expected.shape).numpy()
expected = a[:i].bitcast(dtypes.uint8).numpy()
symbolic = a[:vi].bitcast(dtypes.uint8).reshape(expected.shape).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no uint64")
def test_bitcast_up(self):
a = Tensor.rand(10, 4)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(i, 4)
expected = a.bitcast(dtypes.uint64).numpy()
symbolic = a.reshape(vi, 4).bitcast(dtypes.uint64).reshape(expected.shape).numpy()
expected = a[:i].bitcast(dtypes.uint64).numpy()
symbolic = a[:vi].bitcast(dtypes.uint64).reshape(expected.shape).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
@unittest.expectedFailure
def test_conv2d_ceildiv_edge_case(self):
v = Variable('v', 11, 50_000)
val = 39601
x = Tensor.randn(1, 22, 39601).reshape(1, 22, v.bind(val))
x = Tensor.randn(1, 22, 50_000)[:, :, :v.bind(val)]
weight = Tensor.randn(256, 22, 12)
result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))

View File

@@ -1,7 +1,6 @@
import unittest
import numpy as np
from tinygrad import Tensor, Variable
from tinygrad.helpers import Context
class TestTensorVariable(unittest.TestCase):
def test_add_tvar(self):
@@ -23,43 +22,38 @@ class TestTensorVariable(unittest.TestCase):
assert (Tensor(3) * (vv * 4)).item() == 24
def test_symbolic_mean(self):
with Context(IGNORE_OOB=1):
vv = Variable("a", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
ret = t.mean().item()
assert ret == 1
vv = Variable("a", 1, 10).bind(2)
t = Tensor.ones(2, 10).contiguous()[:, :vv]
ret = t.mean().item()
assert ret == 1
def test_symbolic_mean_2d(self):
with Context(IGNORE_OOB=1):
vv = Variable("a", 1, 10).bind(2)
vv2 = Variable("b", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv)
ret = t.mean().item()
assert ret == 1
vv = Variable("a", 1, 10).bind(2)
vv2 = Variable("b", 1, 10).bind(2)
t = Tensor.ones(10, 10).contiguous()[:vv2, :vv]
ret = t.mean().item()
assert ret == 1
def test_symbolic_mean_2d_axis_1(self):
with Context(IGNORE_OOB=1):
vv = Variable("a", 1, 10).bind(2)
vv2 = Variable("b", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv)
ret = t.mean(axis=1).reshape(2, 1).numpy()
assert np.all(ret == 1)
vv = Variable("a", 1, 10).bind(2)
vv2 = Variable("b", 1, 10).bind(2)
t = Tensor.ones(10, 10).contiguous()[:vv2, :vv]
ret = t.mean(axis=1).reshape(2, 1).numpy()
assert np.all(ret == 1)
def test_symbolic_mean_2d_add(self):
with Context(IGNORE_OOB=1):
add_term = Variable("c", 0, 10).bind(1)
vv = Variable("a", 1, 10).bind(1)
vv2 = Variable("b", 1, 10).bind(1)
t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term)
ret = t.mean().item()
assert ret == 1
add_term = Variable("c", 0, 10).bind(1)
vv = Variable("a", 1, 10).bind(1)
vv2 = Variable("b", 1, 10).bind(1)
t = Tensor.ones(20, 20).contiguous()[:vv2+add_term, :vv+add_term]
ret = t.mean().item()
assert ret == 1
def test_symbolic_var(self):
with Context(IGNORE_OOB=1):
vv = Variable("a", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
ret = t.var().item()
assert ret == 0
vv = Variable("a", 1, 10).bind(2)
t = Tensor.ones(2, 10).contiguous()[:, :vv]
ret = t.var().item()
assert ret == 0
def test_symbolic_pad(self):
vv = Variable("a", 1, 10).bind(2)

View File

@@ -447,7 +447,7 @@ class TestUOpMethod(unittest.TestCase):
def test_uop_variables(self):
a = UOp.variable("a", 1, 10)
uop_var = Tensor(a.bind(1))
st_var = Tensor.empty((2, 1)).reshape((2, a.bind(1)))
st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
_, var_vals = (uop_var+st_var).schedule_with_vars()
self.assertEqual(len(var_vals), 1)
self.assertEqual(list(var_vals)[0], a)

View File

@@ -839,25 +839,22 @@ class TestRender(unittest.TestCase):
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
self.assertEqual(valid.render(), "(ridx0<2)")
class TestVariableReshape(unittest.TestCase):
def test_reshape(self):
st = ShapeTracker.from_shape((3,))
st = st.reshape((Variable("i", 1, 10),))
class TestVariableShrink(unittest.TestCase):
def test_shrink(self):
st = ShapeTracker.from_shape((10,))
st = st.shrink(((0, Variable("i", 1, 10)),))
assert len(st.views) == 1
def test_reshape_stride_0(self):
st = ShapeTracker.from_shape((3,), (0,))
st = st.reshape((Variable("i", 1, 10).bind(3),))
assert len(st.views) == 1, f"multiview {st}"
def test_reshape_bound(self):
st = ShapeTracker.from_shape((3,))
st = st.reshape((Variable("i", 1, 10).bind(3),))
def test_shrink_bound(self):
st = ShapeTracker.from_shape((10,))
st = st.shrink(((0, Variable("i", 1, 10).bind(3)),))
assert len(st.views) == 1
def test_add(self):
st1 = ShapeTracker.from_shape((3,))
st2 = ShapeTracker.from_shape((Variable("i", 1, 10),))
class TestVariableMerge(unittest.TestCase):
def test_add_reshape(self):
vi = Variable("i", 1, 10)
st1 = ShapeTracker.from_shape((vi,))
st2 = ShapeTracker.from_shape((1, vi,))
st = st1+st2
assert len(st.views) == 1
@@ -867,15 +864,17 @@ class TestVariableReshape(unittest.TestCase):
st = st1+st2
assert len(st.views) == 1, f"multiview {st}"
def test_add_bound(self):
st1 = ShapeTracker.from_shape((3,))
st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),))
def test_add_reshape_bound(self):
vi = Variable("i", 1, 10).bind(3)
st1 = ShapeTracker.from_shape((vi,))
st2 = ShapeTracker.from_shape((1, vi,))
st = st1+st2
assert len(st.views) == 1
def test_simplify(self):
st1 = ShapeTracker.from_shape((3,))
st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),))
vi = Variable("i", 1, 10).bind(3)
st1 = ShapeTracker.from_shape((vi,))
st2 = ShapeTracker.from_shape((1, vi,))
st = ShapeTracker((st1.views[0], st2.views[0]))
st = st.simplify()
assert len(st.views) == 1

View File

@@ -87,20 +87,6 @@ class TestShapeTrackerAdd(unittest.TestCase):
assert not (st_equal(st1, st2))
class TestShapeTrackerAddVariable(unittest.TestCase):
def test_self_add(self):
j = Variable("j", 0, 20).bind(10)
a = ShapeTracker.from_shape((10,10))
x = a.reshape((10, j))
out = x + x
assert out == x
def test_self_add_reshape(self):
j = Variable("j", 0, 20).bind(10)
a = ShapeTracker.from_shape((10,10))
x = a.reshape((10, j))
out = x.reshape((5, 2, j)) + x
assert out == x
def test_merge_symbolic_views(self):
var_i = Variable('i', 1, 10)
var_j = Variable('i', 1, 10)

View File

@@ -48,11 +48,11 @@ class TestSymbolic(unittest.TestCase):
i = Variable("i", 1, 5).bind(3)
j = Variable("j", 1, 5).bind(3)
k = Variable("k", 1, 5).bind(3)
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
t = Tensor.rand(5, 4)[:i].cat(Tensor.rand(5, 4)[:j], dim=0).cat(Tensor.rand(5, 4)[:k], dim=0)
st = t.uop.st
self.assert_tuple_equal(st.shape, (i+j+k, 4))
assert st.real_strides() == (4, 1)
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
t = Tensor.rand(5, 3)[:i].cat(Tensor.rand(5, 3)[:i], dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.uop.st
self.assert_tuple_equal(st.shape, (2*i+3, 3))
assert st.real_strides() == (3, 1)
@@ -61,7 +61,7 @@ class TestSymbolic(unittest.TestCase):
i = Variable("i", 1, 5).bind(4)
j = Variable("j", 1, 5).bind(4)
k = Variable("k", 1, 5).bind(4)
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
t = Tensor.rand(3, 5)[:, :i].cat(Tensor.rand(3, 5)[:, :j], dim=1).cat(Tensor.rand(3, 5)[:, :k], dim=1)
st = t.uop.st
self.assert_tuple_equal(st.shape, (3, i+j+k))
self.assert_tuple_equal(st.real_strides(), (i+j+k, 1))
@@ -109,60 +109,44 @@ class TestShapeTrackerUnbind(unittest.TestCase):
assert unbound_view == View.create(shape=(v, 4))
assert var_val == {v: 3}
def test_reshape_unbind(self):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(3)
t = Tensor.rand(3, 4).reshape(bv, 4)
unbound_st, var_val = t.uop.st.unbind()
assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),))
assert var_val == {v: 3}
def test_shrink_unbind(self):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(2)
t = Tensor.rand(3, 4).shrink(((0,bv),(0,4)))
unbound_st, var_val = t.uop.st.unbind()
assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),))
assert var_val == {v: 2}
t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4)))
unbound_st, var_val = t.uop.st.unbind()
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
assert var_val == {v: 2}
class TestSymbolicReshapeFromContiguous(unittest.TestCase):
def test_reshape_into_symbols_simple(self):
class TestSymbolicReshape(unittest.TestCase):
def test_reshape(self):
a = Tensor.rand(5, 4)
b = Tensor.rand(5, 6)
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
t = Tensor.rand(i, 4).reshape(vi, 4)
assert t.shape == (vi, 4)
t = Tensor.rand(i, 6).reshape(vi, 2, 3)
assert t.shape == (vi, 2, 3)
def test_reshape_symbols_reshape_ints(self):
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
t = Tensor.rand(i, 4).reshape(vi, 4)
assert t.shape == (vi, 4)
t = t.reshape(i, 4)
assert t.shape == (i, 4)
@unittest.skip("works now")
def test_reshape_into_symbols_bad_shape(self):
vi = Variable("i", 1, 10).bind(4)
# TODO: this never actually worked, it relied on lazy
#with self.assertRaises(ValueError):
# Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
with self.assertRaises(AssertionError):
Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
ret = a[:vi]
ret = ret.reshape((vi, 4))
assert ret.shape == (vi, 4)
ret = b[:vi]
ret = ret.reshape((vi, 2, 3))
assert ret.shape == (vi, 2, 3)
def test_two_symbol_reshape(self):
t = Tensor.rand(5, 5)
for i in range(1, 6):
for j in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
vj = Variable("j", 1, 5).bind(j)
t = Tensor.rand(i, j).reshape(vi, vj)
assert t.shape == (vi, vj)
# NOTE: this is currently not allowed
# t = t.reshape(1, vi*vj)
# assert t.shape == (1, vi*vj)
t = t.reshape(vj, vi)
assert t.shape == (vj, vi)
ret = t[:vi, :vj]
ret = ret.reshape(vj, vi)
assert ret.shape == (vj, vi)
ret = ret.reshape(vi, vj)
assert ret.shape == (vi, vj)
ret = ret.reshape(1, vi*vj)
assert ret.shape == (1, vi*vj)
def test_symbolic_mask(self):
# taken from gpt2 single kvcache
@@ -175,41 +159,6 @@ class TestSymbolicReshapeFromContiguous(unittest.TestCase):
new_shape = (2, (Variable('start_pos', 1, 128)+1), 16, 64)
assert view.reshape(new_shape) is None
class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
def test_reshape_from_const(self):
vi = Variable("i", 1, 5).bind(4)
t = Tensor.ones(3, 4).reshape(3, vi)
assert t.shape == (3, vi)
assert not t.uop.st.contiguous
assert len(t.uop.st.views) == 1
def test_reshape_not_allowed(self):
vi = Variable("i", 1, 5).bind(4)
with self.assertRaises(ValueError):
# different shape length # TODO: cases where contractions matched might be fine
Tensor.ones(3, 4, 1).reshape(3, vi)
with self.assertRaises(ValueError):
# size matched, but dimensions do not match
Tensor.ones(4, 3).reshape(3, vi)
def test_reshape_from_padded(self):
vi = Variable("i", 1, 5).bind(4)
t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
st = t.uop.st
assert len(st.views) == 1
view = st.views[0]
assert view.shape == (4, 3, 2)
t = t.reshape(vi, 3, 2)
st2 = t.uop.st
assert len(st2.views) == 1
view2 = st2.views[0]
# check only shape changed. strides, offset, mask, contiguous remained the same
assert view2.shape == (vi, 3, 2)
assert view.strides == view2.strides == (0, 4, 1)
assert view.offset == view2.offset == 1
assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2))
assert not view.contiguous and not view2.contiguous
class TestSymbolicExpand(unittest.TestCase):
def test_expand_into_symbols(self):
vi = Variable("i", 1, 5).bind(3)
@@ -220,11 +169,12 @@ class TestSymbolicExpand(unittest.TestCase):
assert a.shape == (3, vi, vj)
def test_plus_expands_constant(self):
a = Tensor.rand(3, 5)
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
a = Tensor.rand(3, i).reshape(3, vi)
a = a + 1
self.assertTupleEqual(a.shape, (3, vi))
ret = a[:, :vi]
ret = ret + 1
self.assertTupleEqual(ret.shape, (3, vi))
def test_pad_then_expand_into_symbols(self):
vi = Variable("i", 1, 10).bind(3)
@@ -234,6 +184,11 @@ class TestSymbolicExpand(unittest.TestCase):
self.assertEqual(a.reshape(vi*25).shape, (vi*25,))
class TestSymbolicShrink(unittest.TestCase):
def test_shrink_symbols_simple(self):
vi = Variable("i", 1, 5)
t = Tensor.rand(5, 5).shrink(((0, 5),(0,vi)))
assert t.shape == (5, vi)
def test_shrink_symbols(self):
vi = Variable("i", 1, 5)
t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1)))
@@ -242,10 +197,10 @@ class TestSymbolicShrink(unittest.TestCase):
class TestSymbolicPad(unittest.TestCase):
def test_pad(self):
v = Variable("v", 1, 100).bind(5)
t = Tensor.ones(5).reshape(v).pad(((4, 0),)).reshape(9)
assert t.shape == (9,)
st = t.uop.st
print(st)
t = Tensor.ones(100)[:v].pad(((4, 0),))
t = t.reshape(9)
assert t.tolist() == [0,0,0,0,1,1,1,1,1]
if __name__ == '__main__':
unittest.main()

View File

@@ -97,7 +97,7 @@ class TestTensorUopRepresentation(unittest.TestCase):
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
vi = UOp.variable("i", 1, 3).bind(1)
a = Tensor.empty(3, vi)
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.SHRINK, src=(UPat(Ops.BUFFER),))),))
self.assertEqual(a.uop.base.buffer.size, 9)
if __name__ == '__main__':

View File

@@ -119,7 +119,7 @@ def create_kernel(x:UOp, b:UOp|None=None):
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
return buffer.assign(kernel).reshape(x.shape)
return buffer.assign(kernel).shrink(((0, prod(x.shape)),)).reshape(x.shape)
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND}
def append_to_kernel(x:UOp):

View File

@@ -196,7 +196,8 @@ def map_contiguous(ctx:RangeifyContext, x:UOp):
ranges = []
for s in x.shape[len(x.src)-1:]:
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
return x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device).forced_reshape(x.shape)
ret = x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device)
return ret.shrink(((0, prod(x.shape)),)).forced_reshape(x.shape)
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
rngs = list(idx.src[1:])

View File

@@ -3,7 +3,7 @@ import functools, operator, itertools
from dataclasses import dataclass
from typing import cast, Sequence
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify
from tinygrad.uop.ops import resolve, UOp, Variable, sint, smax, smin, sint_to_uop, Ops, ssimplify
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
@@ -311,9 +311,10 @@ class View:
if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")
# check for the same size
if (self_all_int := all_int(self.shape)):
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
if all_int(self.shape):
# reshapes cannot introduce symbolic shape
assert all_int(new_shape), f"{self.shape=} -> {new_shape=} contains non int dims"
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
if 0 in self.shape: return View.create(new_shape)
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
@@ -321,15 +322,6 @@ class View:
# after the asserts, it's okay to check contiguous
if self.contiguous: return View.create(new_shape)
# if it's not contiguous and new shape is symbolic, check if it's directly replaceable
if self_all_int and not all_int(new_shape):
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
for si, so in zip(self.shape, new_shape):
if not isinstance(so, int): so = sym_infer(so, dict([v.unbind() for v in so.vars()]))
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
# all dimensions matched, return the new view directly
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
r_strides, r_new_shape = [], reversed(new_shape)
for merged_size, new_stride, real_size in reversed(merge_dims(self.shape, self.strides, self.mask)):
# TODO: write with get_contraction

View File

@@ -442,7 +442,7 @@ class Tensor(MathTrait):
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
# TODO: add test for multidevice tensor
device = tuple(Device.canonicalize(d) for d in device) if isinstance(device, tuple) else Device.canonicalize(device)
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).reshape(shape)
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
@staticmethod
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor: