From ea1be2e4cdee9b44cdd4ba0e1265eaa570cd638a Mon Sep 17 00:00:00 2001 From: Ben Waldron <140399313+ben-waldron-1@users.noreply.github.com> Date: Thu, 28 Aug 2025 16:30:49 +0000 Subject: [PATCH] [bounty] Remove using reshape to register symbolic shape (#11771) * Modify tests and start work towards removing symbolic reshape * Refactor symbolic reshape * fix small error * much cleaner + fix more tests * Can remove this now * Update test_symbolic_ops and test_tiny * Couple more tests * Unused import * More tests and add EXPAND to Tensor.empty * Fix test beam search * all int * Fix rangeify by adding shrink * Remove OOB check and so fix test_symbolic_jit * test_symbolic_jit doesn't need OOB Context anymore either * Should remove that test now * Cleanups part 1 * fix linters * Final cleanups * Don't reassign inside for loop --------- Co-authored-by: chenyu --- extra/optimization/test_beam_search.py | 20 +- test/test_symbolic_jit.py | 213 ++++++++++---------- test/test_symbolic_ops.py | 149 +++++++------- test/test_tensor_variable.py | 54 +++-- test/test_uops.py | 2 +- test/unit/test_shapetracker.py | 39 ++-- test/unit/test_shapetracker_math.py | 14 -- test/unit/test_symbolic_shapetracker.py | 121 ++++------- test/unit/test_tensor_uop_representation.py | 2 +- tinygrad/schedule/kernelize.py | 2 +- tinygrad/schedule/rangeify.py | 3 +- tinygrad/shape/view.py | 18 +- tinygrad/tensor.py | 2 +- 13 files changed, 281 insertions(+), 358 deletions(-) diff --git a/extra/optimization/test_beam_search.py b/extra/optimization/test_beam_search.py index 24c3f943b7..7042ea914e 100644 --- a/extra/optimization/test_beam_search.py +++ b/extra/optimization/test_beam_search.py @@ -16,9 +16,9 @@ class TestBeamSearch(unittest.TestCase): BEAM.value = self.old_beam def test_variable_ast_beam(self): - with Context(IGNORE_OOB=1): - a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3)) - a = (a+1).realize() + vi = Variable("a", 1, 10).bind(3) + a = rand(10, 3)[:vi] + a = (a+1).realize() def test_big_prime_number(self): a = rand(367, 367) @@ -42,18 +42,16 @@ class TestBeamSearch(unittest.TestCase): def test_variable_big_prime_number(self): v = Variable("v", 1, 400).bind(367) - a = rand(367, 367) - b = rand(367, 367) - with Context(IGNORE_OOB=1): - c = (a.reshape(367, v) @ b.reshape(v, 367)).realize() - np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4) + a = rand(367, 400) + b = rand(400, 367) + c = (a[:, :v] @ b[:v, :]).realize() + np.testing.assert_allclose(c.numpy(), a[:, :367].numpy() @ b[:367, :].numpy(), atol=1e-4, rtol=1e-4) def test_variable_shrink_prime_number(self): v = Variable("v", 1, 400).bind(367) a = rand(400, 367) - with Context(IGNORE_OOB=1): - b = (a.shrink(((0,v), None))+1).reshape(367,367).realize() - np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4) + b = (a.shrink(((0,v), None))+1).reshape(367,367).realize() + np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4) def test_no_mutate_rawbuffers(self): a = rand(3, 3).realize() diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 881ce33489..f983f027cd 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -2,50 +2,41 @@ import unittest from test.helpers import assert_jit_cache_len from tinygrad import Variable, Tensor, TinyJit -from tinygrad.helpers import Context import numpy as np class TestSymbolicJit(unittest.TestCase): - def setUp(self): - # A lot of these test are out of bounds, so we ignore the bounds check - self.context = Context(IGNORE_OOB=1) - self.context.__enter__() - - def tearDown(self): - self.context.__exit__(None, None, None) - def test_plus1(self): def f(a): return (a+1).realize() jf = TinyJit(f) + a = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a).numpy() + symbolic = jf(a[:, :vi]).reshape(3, i).numpy() + expected = f(a[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_add(self): def f(a, b): return (a+b).realize() jf = TinyJit(f) + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(3, i) - symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:, :vi], b[:, :vi]).reshape(3, i).numpy() + expected = f(a[:, :i], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_matmul(self): def f(a, b): return (a@b).realize() jf = TinyJit(f) + a = Tensor.rand(3, 10) + b = Tensor.rand(10, 5) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(i, 5) - symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:, :vi], b[:vi, :]).numpy() + expected = f(a[:, :i], b[:i, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -55,119 +46,119 @@ class TestSymbolicJit(unittest.TestCase): s = (s+s).realize() # this one does not have symbols in input return s jf = TinyJit(f) + a = Tensor.rand(3, 10) + b = Tensor.rand(10, 5) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(i, 5) - symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:, :vi], b[:vi, :]).numpy() + expected = f(a[:, :i], b[:i, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 2) def test_attention(self): def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize() jf = TinyJit(f) + q = Tensor.rand(2, 1, 4, 8) + k = Tensor.rand(2, 10, 4, 8) + v = Tensor.rand(2, 10, 4, 8) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - q = Tensor.rand(2, 1, 4, 8) - k = Tensor.rand(2, i, 4, 8) - v = Tensor.rand(2, i, 4, 8) - symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy() - expected = f(q, k, v).numpy() + symbolic = jf(q, k[:, :vi], v[:, :vi]).reshape(2, 4, 1, 8).numpy() + expected = f(q, k[:, :i], v[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 5) def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() jf = TinyJit(f) + a = Tensor.rand(10, 3) + b = Tensor.rand(2, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(i, 3) - b = Tensor.rand(2, 3) - symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:vi], b).reshape(i+2, 3).numpy() + expected = f(a[:i], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_cat_dim1(self): def f(a, b): return a.cat(b, dim=1).realize() jf = TinyJit(f) + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 2) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(3, 2) - symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:, :vi], b).reshape(3, i+2).numpy() + expected = f(a[:, :i], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_cat_dim0_two_vars(self): def f(a, b): return a.cat(b, dim=0).realize() jf = TinyJit(f) + a = Tensor.rand(10, 3) + b = Tensor.rand(10, 3) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(j, 3) - symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:vi], b[:vj]).reshape(i+j, 3).numpy() + expected = f(a[:i], b[:j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_cat_dim1_two_vars(self): def f(a, b): return a.cat(b, dim=1).realize() jf = TinyJit(f) + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 10) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(3, i) - b = Tensor.rand(3, j) - symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy() + expected = f(a[:, :i], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_two_vars_plus1_ij(self): def f(a, b): return (a@b+1).realize() jf = TinyJit(f) + a = Tensor.rand(10, 3) + b = Tensor.rand(3, 10) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(3, j) - symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:vi, :], b[:, :vj]).reshape(i, j).numpy() + expected = f(a[:i, :], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_two_vars_plus1_ji(self): def f(a, b): return (a@b+1).realize() jf = TinyJit(f) + a = Tensor.rand(10, 3) + b = Tensor.rand(3, 10) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(j, 3) - b = Tensor.rand(3, i) - symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:vj, :], b[:, :vi]).reshape(j, i).numpy() + expected = f(a[:j, :], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_jit_symbolic_shape_mismatch(self): @TinyJit def add(a, b): return (a+b).realize() + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i).reshape(3, vi) - b = Tensor.rand(3, i).reshape(3, vi) - add(a, b) + add(a[:, :vi], b[:, :vi]) vi2 = Variable("i", 1, 10).bind(7) - a = Tensor.rand(3, 7).reshape(3, vi2) - bad = Tensor.rand(4, 7).reshape(4, vi2) + a = Tensor.rand(3, 7)[:, :vi2] + bad = Tensor.rand(4, 7)[:, :vi2] with self.assertRaises(AssertionError): add(a, bad) @@ -175,9 +166,9 @@ class TestSymbolicJit(unittest.TestCase): # shrink is a movement, so we pair it with a simple function to test the JIT interaction def f(a): return (a+1).realize() jf = TinyJit(f) + a = Tensor.rand(7, 11) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(7, 11) symbolic = a.shrink(((3,5),(vi,vi+2))) symbolic = jf(symbolic).numpy() expected = f(a.shrink(((3,5),(i,i+2)))).numpy() @@ -188,9 +179,9 @@ class TestSymbolicJit(unittest.TestCase): # slice is a movement, so we pair it with a simple function to test the JIT interaction def f(a): return (a+1).realize() jf = TinyJit(f) + a = Tensor.rand(7, 11) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(7, 11) symbolic = a[3:5, vi:vi+2] symbolic = jf(symbolic).numpy() expected = f(a[3:5, i:i+2]).numpy() @@ -212,11 +203,11 @@ class TestSymbolicJit(unittest.TestCase): def test_ones_sum(self): def f(a): return a.sum().realize() jf = TinyJit(f) + t = Tensor.ones(10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - t = Tensor.ones(i) - symbolic = jf(t.reshape(vi)).item() - expected = f(t).item() + symbolic = jf(t[:vi]).item() + expected = f(t[:i]).item() np.testing.assert_equal(symbolic, expected) def test_mean(self): @@ -226,22 +217,22 @@ class TestSymbolicJit(unittest.TestCase): jf = TinyJit(f) jf0 = TinyJit(f0) jf1 = TinyJit(f1) + a = Tensor.rand(10, 3) + b = Tensor.rand(10, 3) + c = Tensor.rand(10, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - # aixs = None - a = Tensor.rand(i, 3) - symbolic = jf(a.reshape(vi, 3)).numpy() - expected = a.mean().numpy() + # axis = None + symbolic = jf(a[:vi]).numpy() + expected = a[:i].mean().numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 0 - a = Tensor.rand(i, 3) - symbolic = jf0(a.reshape(vi, 3)).numpy() - expected = a.mean(0).numpy() + # axis = 0 + symbolic = jf0(b[:vi]).numpy() + expected = b[:i].mean(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 1 - a = Tensor.rand(i, 3) - symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy() - expected = a.mean(1).numpy() + # axis = 1 + symbolic = jf1(c[:vi]).reshape(i).numpy() + expected = c[:i].mean(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_mean_2d(self): @@ -251,24 +242,24 @@ class TestSymbolicJit(unittest.TestCase): jf = TinyJit(f) jf0 = TinyJit(f0) jf1 = TinyJit(f1) + a = Tensor.rand(10, 10) + b = Tensor.rand(10, 10) + c = Tensor.rand(10, 10) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - # aixs = None - a = Tensor.rand(i, j) - symbolic = jf(a.reshape(vi, vj)).numpy() - expected = a.mean().numpy() + # axis = None + symbolic = jf(a[:vi, :vj]).numpy() + expected = a[:i, :j].mean().numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 0 - a = Tensor.rand(i, j) - symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy() - expected = a.mean(0).numpy() + # axis = 0 + symbolic = jf0(b[:vi, :vj]).reshape(j).numpy() + expected = b[:i, :j].mean(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 1 - a = Tensor.rand(i, j) - symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy() - expected = a.mean(1).numpy() + # axis = 1 + symbolic = jf1(c[:vi, :vj]).reshape(i).numpy() + expected = c[:i, :j].mean(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_var(self): @@ -278,22 +269,22 @@ class TestSymbolicJit(unittest.TestCase): jf = TinyJit(f) jf0 = TinyJit(f0) jf1 = TinyJit(f1) + a = Tensor.rand(10, 3) + b = Tensor.rand(10, 3) + c = Tensor.rand(10, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - # aixs = None - a = Tensor.rand(i, 3) - symbolic = jf(a.reshape(vi, 3)).numpy() - expected = a.var().numpy() + # axis = None + symbolic = jf(a[:vi]).numpy() + expected = a[:i].var().numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 0 - a = Tensor.rand(i, 3) - symbolic = jf0(a.reshape(vi, 3)).numpy() - expected = a.var(0).numpy() + # axis = 0 + symbolic = jf0(b[:vi]).numpy() + expected = b[:i].var(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 1 - a = Tensor.rand(i, 3) - symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy() - expected = a.var(1).numpy() + # axis = 1 + symbolic = jf1(c[:vi]).reshape(i).numpy() + expected = c[:i].var(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_var_2d(self): @@ -303,24 +294,24 @@ class TestSymbolicJit(unittest.TestCase): jf = TinyJit(f) jf0 = TinyJit(f0) jf1 = TinyJit(f1) + a = Tensor.rand(10, 10) + b = Tensor.rand(10, 10) + c = Tensor.rand(10, 10) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - # aixs = None - a = Tensor.rand(i, j) - symbolic = jf(a.reshape(vi, vj)).numpy() - expected = a.var().numpy() + # axis = None + symbolic = jf(a[:vi, :vj]).numpy() + expected = a[:i, :j].var().numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 0 - a = Tensor.rand(i, j) - symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy() - expected = a.var(0).numpy() + # axis = 0 + symbolic = jf0(b[:vi, :vj]).reshape(j).numpy() + expected = b[:i, :j].var(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 1 - a = Tensor.rand(i, j) - symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy() - expected = a.var(1).numpy() + # axis = 1 + symbolic = jf1(c[:vi, :vj]).reshape(i).numpy() + expected = c[:i, :j].var(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) if __name__ == '__main__': diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index f626446315..8f4074da7d 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -1,7 +1,7 @@ import unittest from tinygrad import Tensor, Variable from tinygrad.shape.shapetracker import View -from tinygrad.helpers import Context, GlobalCounters +from tinygrad.helpers import GlobalCounters from tinygrad.uop.ops import sym_infer from tinygrad.dtype import dtypes from tinygrad.device import Device @@ -9,54 +9,46 @@ from examples.gpt2 import Attention import numpy as np class TestSymbolicOps(unittest.TestCase): - def setUp(self): - # A lot of these test are out of bounds, so we ignore the bounds check - self.context = Context(IGNORE_OOB=1) - self.context.__enter__() - - def tearDown(self): - self.context.__exit__(None, None, None) - def test_plus1(self): def f(a): return (a+1).realize() + a = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - symbolic = f(a.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a).numpy() + symbolic = f(a[:, :vi]).reshape(3, i).numpy() + expected = f(a[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_add(self): def f(a, b): return (a+b).realize() + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(3, i) - symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:, :vi], b[:, :vi]).reshape(3, i).numpy() + expected = f(a[:, :i], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_matmul(self): def f(a, b): return (a@b).realize() + a = Tensor.rand(3, 10) + b = Tensor.rand(10, 5) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(i, 5) - symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:, :vi], b[:vi, :]).numpy() + expected = f(a[:, :i], b[:i, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_attention(self, dropout_p=0.0, imin=1, imax=5, use_symbolic=True): def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize() + q = Tensor.rand(2, 1, 4, 8) + k = Tensor.rand(2, 10, 4, 8) + v = Tensor.rand(2, 10, 4, 8) for i in range(imin, imax): vi = Variable("i", 1, 10).bind(i) if use_symbolic else i - q = Tensor.rand(2, 1, 4, 8) - k = Tensor.rand(2, i, 4, 8) - v = Tensor.rand(2, i, 4, 8) Tensor.realize(q, k, v) GlobalCounters.reset() - symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy() - expected = f(q, k, v).numpy() + symbolic = f(q, k[:, :vi, :, :], v[:, :vi, :, :]).reshape(2, 4, 1, 8).numpy() + expected = f(q, k[:, :i, :, :], v[:, :i, :, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_attention_cmp_symbolic(self): @@ -90,73 +82,89 @@ class TestSymbolicOps(unittest.TestCase): def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() + a = Tensor.rand(10, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(i, 3) b = Tensor.rand(2, 3) - symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:vi, :], b).reshape(i+2, 3).numpy() + expected = f(a[:i, :], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_cat_dim1(self): def f(a, b): return a.cat(b, dim=1).realize() + a = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) b = Tensor.rand(3, 2) - symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:, :vi], b).reshape(3, i+2).numpy() + expected = f(a[:, :i], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_cat_dim0_two_vars(self): def f(a, b): return a.cat(b, dim=0).realize() + a = Tensor.rand(10, 3) + b = Tensor.rand(10, 3) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(j, 3) - symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:vi, :], b[:vj, :]).reshape(i+j, 3).numpy() + expected = f(a[:i, :], b[:j, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_cat_dim1_two_vars(self): def f(a, b): return a.cat(b, dim=1).realize() + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 10) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(3, i) - b = Tensor.rand(3, j) - symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy() + expected = f(a[:, :i], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_two_vars_plus1_ij(self): def f(a, b): return (a@b+1).realize() + a = Tensor.rand(10, 3) + b = Tensor.rand(3, 10) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(3, j) - symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:vi, :], b[:, :vj]).reshape(i, j).numpy() + expected = f(a[:i, :], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_two_vars_plus1_ji(self): # reverse the order of variables def f(a, b): return (a@b+1).realize() + a = Tensor.rand(10, 3) + b = Tensor.rand(3, 10) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(j, 3) - b = Tensor.rand(3, i) - symbolic = f(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:vj, :], b[:, :vi]).reshape(j, i).numpy() + expected = f(a[:j, :], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_reshape_from_symbolic(self): + a = Tensor.rand(30) + for i in range(3, 5): + vi = Variable("i", 3, 10).bind(i) + symbolic = a[:vi*3].reshape((3, 3)).numpy() + # To match symbolic reshape (potential implicit shrink), we need a shrink + expected = a[:i*3].shrink(((0, 9),)).reshape((3, 3)).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_invalid_symbolic_reshape(self): + a = Tensor.rand(30) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + # Cannot reshape into symbolic from non-symbolic + with self.assertRaises(AssertionError): a.reshape((3, vi)) + def test_shrink(self): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) @@ -176,11 +184,10 @@ class TestSymbolicOps(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_slice_no_start(self): + a = Tensor.rand(7, 11) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(7, 11) - symbolic = a[3:5, :vi:1].reshape(2,i) - symbolic = symbolic.numpy() + symbolic = a[3:5, :vi:1].reshape(2, i).numpy() expected = a[3:5, :i:1].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -201,31 +208,31 @@ class TestSymbolicOps(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_ones_sum(self): + t = Tensor.ones(10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - t = Tensor.ones(i) - symbolic = t.reshape(vi).sum().item() - expected = t.sum().item() + symbolic = t[:vi].sum().item() + expected = t[:i].sum().item() np.testing.assert_equal(symbolic, expected) def test_mean(self): + a = Tensor.rand(10, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) for axis in [None, 0, 1]: - a = Tensor.rand(i, 3) - expected = a.mean(axis).numpy() - symbolic = a.reshape(vi, 3).mean(axis).reshape(expected.shape).numpy() + expected = a[:i].mean(axis).numpy() + symbolic = a[:vi].mean(axis).reshape(expected.shape).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_mean_2d(self): + a = Tensor.rand(10, 10) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) for axis in [None, 0, 1]: - a = Tensor.rand(i, j) - expected = a.mean(axis).numpy() - symbolic = a.reshape(vi, vj).mean(axis).reshape(expected.shape).numpy() + expected = a[:i, :j].mean(axis).numpy() + symbolic = a[:vi, :vj].mean(axis).reshape(expected.shape).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_var(self): @@ -233,43 +240,43 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) for axis in [None, 0, 1]: - expected = a[:i, :].var(axis).numpy() - symbolic = a[:vi, :].var(axis).reshape(expected.shape).numpy() + expected = a[:i].var(axis).numpy() + symbolic = a[:vi].var(axis).reshape(expected.shape).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_var_2d(self): + a = Tensor.rand(10, 10) for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) for axis in [None, 0, 1]: - a = Tensor.rand(i, j) - expected = a.var(axis).numpy() - symbolic = a.reshape(vi, vj).var(axis).reshape(expected.shape).numpy() + expected = a[:i, :j].var(axis).numpy() + symbolic = a[:vi, :vj].var(axis).reshape(expected.shape).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_bitcast_down(self): + a = Tensor.rand(10, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(i, 3) - expected = a.bitcast(dtypes.uint8).numpy() - symbolic = a.reshape(vi, 3).bitcast(dtypes.uint8).reshape(expected.shape).numpy() + expected = a[:i].bitcast(dtypes.uint8).numpy() + symbolic = a[:vi].bitcast(dtypes.uint8).reshape(expected.shape).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0) @unittest.skipIf(Device.DEFAULT == "WEBGPU", "no uint64") def test_bitcast_up(self): + a = Tensor.rand(10, 4) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(i, 4) - expected = a.bitcast(dtypes.uint64).numpy() - symbolic = a.reshape(vi, 4).bitcast(dtypes.uint64).reshape(expected.shape).numpy() + expected = a[:i].bitcast(dtypes.uint64).numpy() + symbolic = a[:vi].bitcast(dtypes.uint64).reshape(expected.shape).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0) @unittest.expectedFailure def test_conv2d_ceildiv_edge_case(self): v = Variable('v', 11, 50_000) val = 39601 - x = Tensor.randn(1, 22, 39601).reshape(1, 22, v.bind(val)) + x = Tensor.randn(1, 22, 50_000)[:, :, :v.bind(val)] weight = Tensor.randn(256, 22, 12) result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3)) diff --git a/test/test_tensor_variable.py b/test/test_tensor_variable.py index 85ef0c5104..0fa165e462 100644 --- a/test/test_tensor_variable.py +++ b/test/test_tensor_variable.py @@ -1,7 +1,6 @@ import unittest import numpy as np from tinygrad import Tensor, Variable -from tinygrad.helpers import Context class TestTensorVariable(unittest.TestCase): def test_add_tvar(self): @@ -23,43 +22,38 @@ class TestTensorVariable(unittest.TestCase): assert (Tensor(3) * (vv * 4)).item() == 24 def test_symbolic_mean(self): - with Context(IGNORE_OOB=1): - vv = Variable("a", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(2, vv) - ret = t.mean().item() - assert ret == 1 + vv = Variable("a", 1, 10).bind(2) + t = Tensor.ones(2, 10).contiguous()[:, :vv] + ret = t.mean().item() + assert ret == 1 def test_symbolic_mean_2d(self): - with Context(IGNORE_OOB=1): - vv = Variable("a", 1, 10).bind(2) - vv2 = Variable("b", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv) - ret = t.mean().item() - assert ret == 1 + vv = Variable("a", 1, 10).bind(2) + vv2 = Variable("b", 1, 10).bind(2) + t = Tensor.ones(10, 10).contiguous()[:vv2, :vv] + ret = t.mean().item() + assert ret == 1 def test_symbolic_mean_2d_axis_1(self): - with Context(IGNORE_OOB=1): - vv = Variable("a", 1, 10).bind(2) - vv2 = Variable("b", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv) - ret = t.mean(axis=1).reshape(2, 1).numpy() - assert np.all(ret == 1) + vv = Variable("a", 1, 10).bind(2) + vv2 = Variable("b", 1, 10).bind(2) + t = Tensor.ones(10, 10).contiguous()[:vv2, :vv] + ret = t.mean(axis=1).reshape(2, 1).numpy() + assert np.all(ret == 1) def test_symbolic_mean_2d_add(self): - with Context(IGNORE_OOB=1): - add_term = Variable("c", 0, 10).bind(1) - vv = Variable("a", 1, 10).bind(1) - vv2 = Variable("b", 1, 10).bind(1) - t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term) - ret = t.mean().item() - assert ret == 1 + add_term = Variable("c", 0, 10).bind(1) + vv = Variable("a", 1, 10).bind(1) + vv2 = Variable("b", 1, 10).bind(1) + t = Tensor.ones(20, 20).contiguous()[:vv2+add_term, :vv+add_term] + ret = t.mean().item() + assert ret == 1 def test_symbolic_var(self): - with Context(IGNORE_OOB=1): - vv = Variable("a", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(2, vv) - ret = t.var().item() - assert ret == 0 + vv = Variable("a", 1, 10).bind(2) + t = Tensor.ones(2, 10).contiguous()[:, :vv] + ret = t.var().item() + assert ret == 0 def test_symbolic_pad(self): vv = Variable("a", 1, 10).bind(2) diff --git a/test/test_uops.py b/test/test_uops.py index dfd6a7f967..7fefc02273 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -447,7 +447,7 @@ class TestUOpMethod(unittest.TestCase): def test_uop_variables(self): a = UOp.variable("a", 1, 10) uop_var = Tensor(a.bind(1)) - st_var = Tensor.empty((2, 1)).reshape((2, a.bind(1))) + st_var = Tensor.empty((2, 10))[:, :a.bind(1)] _, var_vals = (uop_var+st_var).schedule_with_vars() self.assertEqual(len(var_vals), 1) self.assertEqual(list(var_vals)[0], a) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index bf2bf3d36c..a3cfc80687 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -839,25 +839,22 @@ class TestRender(unittest.TestCase): self.assertEqual(idx.render(), "((ridx0*3)+ridx1)") self.assertEqual(valid.render(), "(ridx0<2)") -class TestVariableReshape(unittest.TestCase): - def test_reshape(self): - st = ShapeTracker.from_shape((3,)) - st = st.reshape((Variable("i", 1, 10),)) +class TestVariableShrink(unittest.TestCase): + def test_shrink(self): + st = ShapeTracker.from_shape((10,)) + st = st.shrink(((0, Variable("i", 1, 10)),)) assert len(st.views) == 1 - def test_reshape_stride_0(self): - st = ShapeTracker.from_shape((3,), (0,)) - st = st.reshape((Variable("i", 1, 10).bind(3),)) - assert len(st.views) == 1, f"multiview {st}" - - def test_reshape_bound(self): - st = ShapeTracker.from_shape((3,)) - st = st.reshape((Variable("i", 1, 10).bind(3),)) + def test_shrink_bound(self): + st = ShapeTracker.from_shape((10,)) + st = st.shrink(((0, Variable("i", 1, 10).bind(3)),)) assert len(st.views) == 1 - def test_add(self): - st1 = ShapeTracker.from_shape((3,)) - st2 = ShapeTracker.from_shape((Variable("i", 1, 10),)) +class TestVariableMerge(unittest.TestCase): + def test_add_reshape(self): + vi = Variable("i", 1, 10) + st1 = ShapeTracker.from_shape((vi,)) + st2 = ShapeTracker.from_shape((1, vi,)) st = st1+st2 assert len(st.views) == 1 @@ -867,15 +864,17 @@ class TestVariableReshape(unittest.TestCase): st = st1+st2 assert len(st.views) == 1, f"multiview {st}" - def test_add_bound(self): - st1 = ShapeTracker.from_shape((3,)) - st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),)) + def test_add_reshape_bound(self): + vi = Variable("i", 1, 10).bind(3) + st1 = ShapeTracker.from_shape((vi,)) + st2 = ShapeTracker.from_shape((1, vi,)) st = st1+st2 assert len(st.views) == 1 def test_simplify(self): - st1 = ShapeTracker.from_shape((3,)) - st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),)) + vi = Variable("i", 1, 10).bind(3) + st1 = ShapeTracker.from_shape((vi,)) + st2 = ShapeTracker.from_shape((1, vi,)) st = ShapeTracker((st1.views[0], st2.views[0])) st = st.simplify() assert len(st.views) == 1 diff --git a/test/unit/test_shapetracker_math.py b/test/unit/test_shapetracker_math.py index efd017f509..3a74ae30b1 100644 --- a/test/unit/test_shapetracker_math.py +++ b/test/unit/test_shapetracker_math.py @@ -87,20 +87,6 @@ class TestShapeTrackerAdd(unittest.TestCase): assert not (st_equal(st1, st2)) class TestShapeTrackerAddVariable(unittest.TestCase): - def test_self_add(self): - j = Variable("j", 0, 20).bind(10) - a = ShapeTracker.from_shape((10,10)) - x = a.reshape((10, j)) - out = x + x - assert out == x - - def test_self_add_reshape(self): - j = Variable("j", 0, 20).bind(10) - a = ShapeTracker.from_shape((10,10)) - x = a.reshape((10, j)) - out = x.reshape((5, 2, j)) + x - assert out == x - def test_merge_symbolic_views(self): var_i = Variable('i', 1, 10) var_j = Variable('i', 1, 10) diff --git a/test/unit/test_symbolic_shapetracker.py b/test/unit/test_symbolic_shapetracker.py index ed065d0576..0af8820708 100644 --- a/test/unit/test_symbolic_shapetracker.py +++ b/test/unit/test_symbolic_shapetracker.py @@ -48,11 +48,11 @@ class TestSymbolic(unittest.TestCase): i = Variable("i", 1, 5).bind(3) j = Variable("j", 1, 5).bind(3) k = Variable("k", 1, 5).bind(3) - t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0) + t = Tensor.rand(5, 4)[:i].cat(Tensor.rand(5, 4)[:j], dim=0).cat(Tensor.rand(5, 4)[:k], dim=0) st = t.uop.st self.assert_tuple_equal(st.shape, (i+j+k, 4)) assert st.real_strides() == (4, 1) - t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0) + t = Tensor.rand(5, 3)[:i].cat(Tensor.rand(5, 3)[:i], dim=0).cat(Tensor.rand(3, 3), dim=0) st = t.uop.st self.assert_tuple_equal(st.shape, (2*i+3, 3)) assert st.real_strides() == (3, 1) @@ -61,7 +61,7 @@ class TestSymbolic(unittest.TestCase): i = Variable("i", 1, 5).bind(4) j = Variable("j", 1, 5).bind(4) k = Variable("k", 1, 5).bind(4) - t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1) + t = Tensor.rand(3, 5)[:, :i].cat(Tensor.rand(3, 5)[:, :j], dim=1).cat(Tensor.rand(3, 5)[:, :k], dim=1) st = t.uop.st self.assert_tuple_equal(st.shape, (3, i+j+k)) self.assert_tuple_equal(st.real_strides(), (i+j+k, 1)) @@ -109,60 +109,44 @@ class TestShapeTrackerUnbind(unittest.TestCase): assert unbound_view == View.create(shape=(v, 4)) assert var_val == {v: 3} - def test_reshape_unbind(self): - v = Variable("v", 1, 100) - bv = Variable("v", 1, 100).bind(3) - t = Tensor.rand(3, 4).reshape(bv, 4) - unbound_st, var_val = t.uop.st.unbind() - assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),)) - assert var_val == {v: 3} - def test_shrink_unbind(self): v = Variable("v", 1, 100) bv = Variable("v", 1, 100).bind(2) + t = Tensor.rand(3, 4).shrink(((0,bv),(0,4))) + unbound_st, var_val = t.uop.st.unbind() + assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),)) + assert var_val == {v: 2} t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4))) unbound_st, var_val = t.uop.st.unbind() assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),)) assert var_val == {v: 2} -class TestSymbolicReshapeFromContiguous(unittest.TestCase): - def test_reshape_into_symbols_simple(self): +class TestSymbolicReshape(unittest.TestCase): + def test_reshape(self): + a = Tensor.rand(5, 4) + b = Tensor.rand(5, 6) for i in range(1, 6): vi = Variable("i", 1, 5).bind(i) - t = Tensor.rand(i, 4).reshape(vi, 4) - assert t.shape == (vi, 4) - t = Tensor.rand(i, 6).reshape(vi, 2, 3) - assert t.shape == (vi, 2, 3) - - def test_reshape_symbols_reshape_ints(self): - for i in range(1, 6): - vi = Variable("i", 1, 5).bind(i) - t = Tensor.rand(i, 4).reshape(vi, 4) - assert t.shape == (vi, 4) - t = t.reshape(i, 4) - assert t.shape == (i, 4) - - @unittest.skip("works now") - def test_reshape_into_symbols_bad_shape(self): - vi = Variable("i", 1, 10).bind(4) - # TODO: this never actually worked, it relied on lazy - #with self.assertRaises(ValueError): - # Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape - with self.assertRaises(AssertionError): - Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node + ret = a[:vi] + ret = ret.reshape((vi, 4)) + assert ret.shape == (vi, 4) + ret = b[:vi] + ret = ret.reshape((vi, 2, 3)) + assert ret.shape == (vi, 2, 3) def test_two_symbol_reshape(self): + t = Tensor.rand(5, 5) for i in range(1, 6): for j in range(1, 6): vi = Variable("i", 1, 5).bind(i) vj = Variable("j", 1, 5).bind(j) - t = Tensor.rand(i, j).reshape(vi, vj) - assert t.shape == (vi, vj) - # NOTE: this is currently not allowed - # t = t.reshape(1, vi*vj) - # assert t.shape == (1, vi*vj) - t = t.reshape(vj, vi) - assert t.shape == (vj, vi) + ret = t[:vi, :vj] + ret = ret.reshape(vj, vi) + assert ret.shape == (vj, vi) + ret = ret.reshape(vi, vj) + assert ret.shape == (vi, vj) + ret = ret.reshape(1, vi*vj) + assert ret.shape == (1, vi*vj) def test_symbolic_mask(self): # taken from gpt2 single kvcache @@ -175,41 +159,6 @@ class TestSymbolicReshapeFromContiguous(unittest.TestCase): new_shape = (2, (Variable('start_pos', 1, 128)+1), 16, 64) assert view.reshape(new_shape) is None -class TestSymbolicReshapeFromNonContiguous(unittest.TestCase): - def test_reshape_from_const(self): - vi = Variable("i", 1, 5).bind(4) - t = Tensor.ones(3, 4).reshape(3, vi) - assert t.shape == (3, vi) - assert not t.uop.st.contiguous - assert len(t.uop.st.views) == 1 - - def test_reshape_not_allowed(self): - vi = Variable("i", 1, 5).bind(4) - with self.assertRaises(ValueError): - # different shape length # TODO: cases where contractions matched might be fine - Tensor.ones(3, 4, 1).reshape(3, vi) - with self.assertRaises(ValueError): - # size matched, but dimensions do not match - Tensor.ones(4, 3).reshape(3, vi) - - def test_reshape_from_padded(self): - vi = Variable("i", 1, 5).bind(4) - t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3))) - st = t.uop.st - assert len(st.views) == 1 - view = st.views[0] - assert view.shape == (4, 3, 2) - t = t.reshape(vi, 3, 2) - st2 = t.uop.st - assert len(st2.views) == 1 - view2 = st2.views[0] - # check only shape changed. strides, offset, mask, contiguous remained the same - assert view2.shape == (vi, 3, 2) - assert view.strides == view2.strides == (0, 4, 1) - assert view.offset == view2.offset == 1 - assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2)) - assert not view.contiguous and not view2.contiguous - class TestSymbolicExpand(unittest.TestCase): def test_expand_into_symbols(self): vi = Variable("i", 1, 5).bind(3) @@ -220,11 +169,12 @@ class TestSymbolicExpand(unittest.TestCase): assert a.shape == (3, vi, vj) def test_plus_expands_constant(self): + a = Tensor.rand(3, 5) for i in range(1, 6): vi = Variable("i", 1, 5).bind(i) - a = Tensor.rand(3, i).reshape(3, vi) - a = a + 1 - self.assertTupleEqual(a.shape, (3, vi)) + ret = a[:, :vi] + ret = ret + 1 + self.assertTupleEqual(ret.shape, (3, vi)) def test_pad_then_expand_into_symbols(self): vi = Variable("i", 1, 10).bind(3) @@ -234,6 +184,11 @@ class TestSymbolicExpand(unittest.TestCase): self.assertEqual(a.reshape(vi*25).shape, (vi*25,)) class TestSymbolicShrink(unittest.TestCase): + def test_shrink_symbols_simple(self): + vi = Variable("i", 1, 5) + t = Tensor.rand(5, 5).shrink(((0, 5),(0,vi))) + assert t.shape == (5, vi) + def test_shrink_symbols(self): vi = Variable("i", 1, 5) t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1))) @@ -242,10 +197,10 @@ class TestSymbolicShrink(unittest.TestCase): class TestSymbolicPad(unittest.TestCase): def test_pad(self): v = Variable("v", 1, 100).bind(5) - t = Tensor.ones(5).reshape(v).pad(((4, 0),)).reshape(9) - assert t.shape == (9,) - st = t.uop.st - print(st) + t = Tensor.ones(100)[:v].pad(((4, 0),)) + t = t.reshape(9) + assert t.tolist() == [0,0,0,0,1,1,1,1,1] + if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index 59d28a4345..a1b2f0526d 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -97,7 +97,7 @@ class TestTensorUopRepresentation(unittest.TestCase): is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))) vi = UOp.variable("i", 1, 3).bind(1) a = Tensor.empty(3, vi) - is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))) + is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.SHRINK, src=(UPat(Ops.BUFFER),))),)) self.assertEqual(a.uop.base.buffer.size, 9) if __name__ == '__main__': diff --git a/tinygrad/schedule/kernelize.py b/tinygrad/schedule/kernelize.py index d02ec8fc0e..b1e90bc60d 100644 --- a/tinygrad/schedule/kernelize.py +++ b/tinygrad/schedule/kernelize.py @@ -119,7 +119,7 @@ def create_kernel(x:UOp, b:UOp|None=None): if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype) kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ())) buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset)) - return buffer.assign(kernel).reshape(x.shape) + return buffer.assign(kernel).shrink(((0, prod(x.shape)),)).reshape(x.shape) DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND} def append_to_kernel(x:UOp): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 8db1ed4f8b..83578efb71 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -196,7 +196,8 @@ def map_contiguous(ctx:RangeifyContext, x:UOp): ranges = [] for s in x.shape[len(x.src)-1:]: ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0)) - return x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device).forced_reshape(x.shape) + ret = x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device) + return ret.shrink(((0, prod(x.shape)),)).forced_reshape(x.shape) def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp): rngs = list(idx.src[1:]) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 0475fc6506..fcef6265ae 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -3,7 +3,7 @@ import functools, operator, itertools from dataclasses import dataclass from typing import cast, Sequence from tinygrad.dtype import dtypes -from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify +from tinygrad.uop.ops import resolve, UOp, Variable, sint, smax, smin, sint_to_uop, Ops, ssimplify from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape @@ -311,9 +311,10 @@ class View: if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}") # check for the same size - if (self_all_int := all_int(self.shape)): - assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" - if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") + if all_int(self.shape): + # reshapes cannot introduce symbolic shape + assert all_int(new_shape), f"{self.shape=} -> {new_shape=} contains non int dims" + if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") if 0 in self.shape: return View.create(new_shape) if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None @@ -321,15 +322,6 @@ class View: # after the asserts, it's okay to check contiguous if self.contiguous: return View.create(new_shape) - # if it's not contiguous and new shape is symbolic, check if it's directly replaceable - if self_all_int and not all_int(new_shape): - if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") - for si, so in zip(self.shape, new_shape): - if not isinstance(so, int): so = sym_infer(so, dict([v.unbind() for v in so.vars()])) - if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") - # all dimensions matched, return the new view directly - return View(new_shape, self.strides, self.offset, self.mask, self.contiguous) - r_strides, r_new_shape = [], reversed(new_shape) for merged_size, new_stride, real_size in reversed(merge_dims(self.shape, self.strides, self.mask)): # TODO: write with get_contraction diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1b422aa1cc..0754af30dd 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -442,7 +442,7 @@ class Tensor(MathTrait): if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}") # TODO: add test for multidevice tensor device = tuple(Device.canonicalize(d) for d in device) if isinstance(device, tuple) else Device.canonicalize(device) - return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).reshape(shape) + return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape) @staticmethod def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor: