mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
[bounty] Remove using reshape to register symbolic shape (#11771)
* Modify tests and start work towards removing symbolic reshape * Refactor symbolic reshape * fix small error * much cleaner + fix more tests * Can remove this now * Update test_symbolic_ops and test_tiny * Couple more tests * Unused import * More tests and add EXPAND to Tensor.empty * Fix test beam search * all int * Fix rangeify by adding shrink * Remove OOB check and so fix test_symbolic_jit * test_symbolic_jit doesn't need OOB Context anymore either * Should remove that test now * Cleanups part 1 * fix linters * Final cleanups * Don't reassign inside for loop --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -16,9 +16,9 @@ class TestBeamSearch(unittest.TestCase):
|
||||
BEAM.value = self.old_beam
|
||||
|
||||
def test_variable_ast_beam(self):
|
||||
with Context(IGNORE_OOB=1):
|
||||
a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
|
||||
a = (a+1).realize()
|
||||
vi = Variable("a", 1, 10).bind(3)
|
||||
a = rand(10, 3)[:vi]
|
||||
a = (a+1).realize()
|
||||
|
||||
def test_big_prime_number(self):
|
||||
a = rand(367, 367)
|
||||
@@ -42,18 +42,16 @@ class TestBeamSearch(unittest.TestCase):
|
||||
|
||||
def test_variable_big_prime_number(self):
|
||||
v = Variable("v", 1, 400).bind(367)
|
||||
a = rand(367, 367)
|
||||
b = rand(367, 367)
|
||||
with Context(IGNORE_OOB=1):
|
||||
c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
|
||||
a = rand(367, 400)
|
||||
b = rand(400, 367)
|
||||
c = (a[:, :v] @ b[:v, :]).realize()
|
||||
np.testing.assert_allclose(c.numpy(), a[:, :367].numpy() @ b[:367, :].numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_variable_shrink_prime_number(self):
|
||||
v = Variable("v", 1, 400).bind(367)
|
||||
a = rand(400, 367)
|
||||
with Context(IGNORE_OOB=1):
|
||||
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
|
||||
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_no_mutate_rawbuffers(self):
|
||||
a = rand(3, 3).realize()
|
||||
|
||||
@@ -2,50 +2,41 @@ import unittest
|
||||
|
||||
from test.helpers import assert_jit_cache_len
|
||||
from tinygrad import Variable, Tensor, TinyJit
|
||||
from tinygrad.helpers import Context
|
||||
import numpy as np
|
||||
|
||||
class TestSymbolicJit(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# A lot of these test are out of bounds, so we ignore the bounds check
|
||||
self.context = Context(IGNORE_OOB=1)
|
||||
self.context.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.context.__exit__(None, None, None)
|
||||
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
symbolic = jf(a[:, :vi]).reshape(3, i).numpy()
|
||||
expected = f(a[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(3, 10)
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, i)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = jf(a[:, :vi], b[:, :vi]).reshape(3, i).numpy()
|
||||
expected = f(a[:, :i], b[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_matmul(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(3, 10)
|
||||
b = Tensor.rand(10, 5)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = jf(a[:, :vi], b[:vi, :]).numpy()
|
||||
expected = f(a[:, :i], b[:i, :]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
@@ -55,119 +46,119 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
s = (s+s).realize() # this one does not have symbols in input
|
||||
return s
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(3, 10)
|
||||
b = Tensor.rand(10, 5)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = jf(a[:, :vi], b[:vi, :]).numpy()
|
||||
expected = f(a[:, :i], b[:i, :]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 2)
|
||||
|
||||
def test_attention(self):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
|
||||
jf = TinyJit(f)
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
k = Tensor.rand(2, 10, 4, 8)
|
||||
v = Tensor.rand(2, 10, 4, 8)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
k = Tensor.rand(2, i, 4, 8)
|
||||
v = Tensor.rand(2, i, 4, 8)
|
||||
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
|
||||
expected = f(q, k, v).numpy()
|
||||
symbolic = jf(q, k[:, :vi], v[:, :vi]).reshape(2, 4, 1, 8).numpy()
|
||||
expected = f(q, k[:, :i], v[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 5)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(2, 3)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(2, 3)
|
||||
symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = jf(a[:vi], b).reshape(i+2, 3).numpy()
|
||||
expected = f(a[:i], b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_cat_dim1(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(3, 10)
|
||||
b = Tensor.rand(3, 2)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, 2)
|
||||
symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = jf(a[:, :vi], b).reshape(3, i+2).numpy()
|
||||
expected = f(a[:, :i], b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_cat_dim0_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(10, 3)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(j, 3)
|
||||
symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = jf(a[:vi], b[:vj]).reshape(i+j, 3).numpy()
|
||||
expected = f(a[:i], b[:j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_cat_dim1_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(3, 10)
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = jf(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
|
||||
expected = f(a[:, :i], b[:, :j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_two_vars_plus1_ij(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = jf(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
|
||||
expected = f(a[:i, :], b[:, :j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_two_vars_plus1_ji(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(j, 3)
|
||||
b = Tensor.rand(3, i)
|
||||
symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = jf(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
|
||||
expected = f(a[:j, :], b[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_jit_symbolic_shape_mismatch(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
a = Tensor.rand(3, 10)
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i).reshape(3, vi)
|
||||
b = Tensor.rand(3, i).reshape(3, vi)
|
||||
add(a, b)
|
||||
add(a[:, :vi], b[:, :vi])
|
||||
vi2 = Variable("i", 1, 10).bind(7)
|
||||
a = Tensor.rand(3, 7).reshape(3, vi2)
|
||||
bad = Tensor.rand(4, 7).reshape(4, vi2)
|
||||
a = Tensor.rand(3, 7)[:, :vi2]
|
||||
bad = Tensor.rand(4, 7)[:, :vi2]
|
||||
with self.assertRaises(AssertionError):
|
||||
add(a, bad)
|
||||
|
||||
@@ -175,9 +166,9 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
# shrink is a movement, so we pair it with a simple function to test the JIT interaction
|
||||
def f(a): return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(7, 11)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a.shrink(((3,5),(vi,vi+2)))
|
||||
symbolic = jf(symbolic).numpy()
|
||||
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
|
||||
@@ -188,9 +179,9 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
# slice is a movement, so we pair it with a simple function to test the JIT interaction
|
||||
def f(a): return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
a = Tensor.rand(7, 11)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a[3:5, vi:vi+2]
|
||||
symbolic = jf(symbolic).numpy()
|
||||
expected = f(a[3:5, i:i+2]).numpy()
|
||||
@@ -212,11 +203,11 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_ones_sum(self):
|
||||
def f(a): return a.sum().realize()
|
||||
jf = TinyJit(f)
|
||||
t = Tensor.ones(10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
t = Tensor.ones(i)
|
||||
symbolic = jf(t.reshape(vi)).item()
|
||||
expected = f(t).item()
|
||||
symbolic = jf(t[:vi]).item()
|
||||
expected = f(t[:i]).item()
|
||||
np.testing.assert_equal(symbolic, expected)
|
||||
|
||||
def test_mean(self):
|
||||
@@ -226,22 +217,22 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
jf = TinyJit(f)
|
||||
jf0 = TinyJit(f0)
|
||||
jf1 = TinyJit(f1)
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(10, 3)
|
||||
c = Tensor.rand(10, 3)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
# aixs = None
|
||||
a = Tensor.rand(i, 3)
|
||||
symbolic = jf(a.reshape(vi, 3)).numpy()
|
||||
expected = a.mean().numpy()
|
||||
# axis = None
|
||||
symbolic = jf(a[:vi]).numpy()
|
||||
expected = a[:i].mean().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# aixs = 0
|
||||
a = Tensor.rand(i, 3)
|
||||
symbolic = jf0(a.reshape(vi, 3)).numpy()
|
||||
expected = a.mean(0).numpy()
|
||||
# axis = 0
|
||||
symbolic = jf0(b[:vi]).numpy()
|
||||
expected = b[:i].mean(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# aixs = 1
|
||||
a = Tensor.rand(i, 3)
|
||||
symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy()
|
||||
expected = a.mean(1).numpy()
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi]).reshape(i).numpy()
|
||||
expected = c[:i].mean(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_mean_2d(self):
|
||||
@@ -251,24 +242,24 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
jf = TinyJit(f)
|
||||
jf0 = TinyJit(f0)
|
||||
jf1 = TinyJit(f1)
|
||||
a = Tensor.rand(10, 10)
|
||||
b = Tensor.rand(10, 10)
|
||||
c = Tensor.rand(10, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
# aixs = None
|
||||
a = Tensor.rand(i, j)
|
||||
symbolic = jf(a.reshape(vi, vj)).numpy()
|
||||
expected = a.mean().numpy()
|
||||
# axis = None
|
||||
symbolic = jf(a[:vi, :vj]).numpy()
|
||||
expected = a[:i, :j].mean().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# aixs = 0
|
||||
a = Tensor.rand(i, j)
|
||||
symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy()
|
||||
expected = a.mean(0).numpy()
|
||||
# axis = 0
|
||||
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
|
||||
expected = b[:i, :j].mean(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# aixs = 1
|
||||
a = Tensor.rand(i, j)
|
||||
symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy()
|
||||
expected = a.mean(1).numpy()
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
|
||||
expected = c[:i, :j].mean(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_var(self):
|
||||
@@ -278,22 +269,22 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
jf = TinyJit(f)
|
||||
jf0 = TinyJit(f0)
|
||||
jf1 = TinyJit(f1)
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(10, 3)
|
||||
c = Tensor.rand(10, 3)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
# aixs = None
|
||||
a = Tensor.rand(i, 3)
|
||||
symbolic = jf(a.reshape(vi, 3)).numpy()
|
||||
expected = a.var().numpy()
|
||||
# axis = None
|
||||
symbolic = jf(a[:vi]).numpy()
|
||||
expected = a[:i].var().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# aixs = 0
|
||||
a = Tensor.rand(i, 3)
|
||||
symbolic = jf0(a.reshape(vi, 3)).numpy()
|
||||
expected = a.var(0).numpy()
|
||||
# axis = 0
|
||||
symbolic = jf0(b[:vi]).numpy()
|
||||
expected = b[:i].var(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# aixs = 1
|
||||
a = Tensor.rand(i, 3)
|
||||
symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy()
|
||||
expected = a.var(1).numpy()
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi]).reshape(i).numpy()
|
||||
expected = c[:i].var(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_var_2d(self):
|
||||
@@ -303,24 +294,24 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
jf = TinyJit(f)
|
||||
jf0 = TinyJit(f0)
|
||||
jf1 = TinyJit(f1)
|
||||
a = Tensor.rand(10, 10)
|
||||
b = Tensor.rand(10, 10)
|
||||
c = Tensor.rand(10, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
# aixs = None
|
||||
a = Tensor.rand(i, j)
|
||||
symbolic = jf(a.reshape(vi, vj)).numpy()
|
||||
expected = a.var().numpy()
|
||||
# axis = None
|
||||
symbolic = jf(a[:vi, :vj]).numpy()
|
||||
expected = a[:i, :j].var().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# aixs = 0
|
||||
a = Tensor.rand(i, j)
|
||||
symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy()
|
||||
expected = a.var(0).numpy()
|
||||
# axis = 0
|
||||
symbolic = jf0(b[:vi, :vj]).reshape(j).numpy()
|
||||
expected = b[:i, :j].var(0).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
# aixs = 1
|
||||
a = Tensor.rand(i, j)
|
||||
symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy()
|
||||
expected = a.var(1).numpy()
|
||||
# axis = 1
|
||||
symbolic = jf1(c[:vi, :vj]).reshape(i).numpy()
|
||||
expected = c[:i, :j].var(1).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Variable
|
||||
from tinygrad.shape.shapetracker import View
|
||||
from tinygrad.helpers import Context, GlobalCounters
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.uop.ops import sym_infer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Device
|
||||
@@ -9,54 +9,46 @@ from examples.gpt2 import Attention
|
||||
import numpy as np
|
||||
|
||||
class TestSymbolicOps(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# A lot of these test are out of bounds, so we ignore the bounds check
|
||||
self.context = Context(IGNORE_OOB=1)
|
||||
self.context.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.context.__exit__(None, None, None)
|
||||
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
a = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = f(a.reshape(3, vi)).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
symbolic = f(a[:, :vi]).reshape(3, i).numpy()
|
||||
expected = f(a[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
a = Tensor.rand(3, 10)
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, i)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = f(a[:, :vi], b[:, :vi]).reshape(3, i).numpy()
|
||||
expected = f(a[:, :i], b[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_matmul(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
a = Tensor.rand(3, 10)
|
||||
b = Tensor.rand(10, 5)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = f(a[:, :vi], b[:vi, :]).numpy()
|
||||
expected = f(a[:, :i], b[:i, :]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_attention(self, dropout_p=0.0, imin=1, imax=5, use_symbolic=True):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize()
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
k = Tensor.rand(2, 10, 4, 8)
|
||||
v = Tensor.rand(2, 10, 4, 8)
|
||||
for i in range(imin, imax):
|
||||
vi = Variable("i", 1, 10).bind(i) if use_symbolic else i
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
k = Tensor.rand(2, i, 4, 8)
|
||||
v = Tensor.rand(2, i, 4, 8)
|
||||
Tensor.realize(q, k, v)
|
||||
GlobalCounters.reset()
|
||||
symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
|
||||
expected = f(q, k, v).numpy()
|
||||
symbolic = f(q, k[:, :vi, :, :], v[:, :vi, :, :]).reshape(2, 4, 1, 8).numpy()
|
||||
expected = f(q, k[:, :i, :, :], v[:, :i, :, :]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_attention_cmp_symbolic(self):
|
||||
@@ -90,73 +82,89 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
a = Tensor.rand(10, 3)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(2, 3)
|
||||
symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = f(a[:vi, :], b).reshape(i+2, 3).numpy()
|
||||
expected = f(a[:i, :], b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim1(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
a = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, 2)
|
||||
symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = f(a[:, :vi], b).reshape(3, i+2).numpy()
|
||||
expected = f(a[:, :i], b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim0_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(10, 3)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(j, 3)
|
||||
symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = f(a[:vi, :], b[:vj, :]).reshape(i+j, 3).numpy()
|
||||
expected = f(a[:i, :], b[:j, :]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim1_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
a = Tensor.rand(3, 10)
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = f(a[:, :vi], b[:, :vj]).reshape(3, i+j).numpy()
|
||||
expected = f(a[:, :i], b[:, :j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_two_vars_plus1_ij(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = f(a[:vi, :], b[:, :vj]).reshape(i, j).numpy()
|
||||
expected = f(a[:i, :], b[:, :j]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_two_vars_plus1_ji(self):
|
||||
# reverse the order of variables
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
a = Tensor.rand(10, 3)
|
||||
b = Tensor.rand(3, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(j, 3)
|
||||
b = Tensor.rand(3, i)
|
||||
symbolic = f(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
symbolic = f(a[:vj, :], b[:, :vi]).reshape(j, i).numpy()
|
||||
expected = f(a[:j, :], b[:, :i]).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_reshape_from_symbolic(self):
|
||||
a = Tensor.rand(30)
|
||||
for i in range(3, 5):
|
||||
vi = Variable("i", 3, 10).bind(i)
|
||||
symbolic = a[:vi*3].reshape((3, 3)).numpy()
|
||||
# To match symbolic reshape (potential implicit shrink), we need a shrink
|
||||
expected = a[:i*3].shrink(((0, 9),)).reshape((3, 3)).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_invalid_symbolic_reshape(self):
|
||||
a = Tensor.rand(30)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
# Cannot reshape into symbolic from non-symbolic
|
||||
with self.assertRaises(AssertionError): a.reshape((3, vi))
|
||||
|
||||
def test_shrink(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
@@ -176,11 +184,10 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_slice_no_start(self):
|
||||
a = Tensor.rand(7, 11)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a[3:5, :vi:1].reshape(2,i)
|
||||
symbolic = symbolic.numpy()
|
||||
symbolic = a[3:5, :vi:1].reshape(2, i).numpy()
|
||||
expected = a[3:5, :i:1].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -201,31 +208,31 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_ones_sum(self):
|
||||
t = Tensor.ones(10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
t = Tensor.ones(i)
|
||||
symbolic = t.reshape(vi).sum().item()
|
||||
expected = t.sum().item()
|
||||
symbolic = t[:vi].sum().item()
|
||||
expected = t[:i].sum().item()
|
||||
np.testing.assert_equal(symbolic, expected)
|
||||
|
||||
def test_mean(self):
|
||||
a = Tensor.rand(10, 3)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
for axis in [None, 0, 1]:
|
||||
a = Tensor.rand(i, 3)
|
||||
expected = a.mean(axis).numpy()
|
||||
symbolic = a.reshape(vi, 3).mean(axis).reshape(expected.shape).numpy()
|
||||
expected = a[:i].mean(axis).numpy()
|
||||
symbolic = a[:vi].mean(axis).reshape(expected.shape).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_mean_2d(self):
|
||||
a = Tensor.rand(10, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
for axis in [None, 0, 1]:
|
||||
a = Tensor.rand(i, j)
|
||||
expected = a.mean(axis).numpy()
|
||||
symbolic = a.reshape(vi, vj).mean(axis).reshape(expected.shape).numpy()
|
||||
expected = a[:i, :j].mean(axis).numpy()
|
||||
symbolic = a[:vi, :vj].mean(axis).reshape(expected.shape).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_var(self):
|
||||
@@ -233,43 +240,43 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
for axis in [None, 0, 1]:
|
||||
expected = a[:i, :].var(axis).numpy()
|
||||
symbolic = a[:vi, :].var(axis).reshape(expected.shape).numpy()
|
||||
expected = a[:i].var(axis).numpy()
|
||||
symbolic = a[:vi].var(axis).reshape(expected.shape).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_var_2d(self):
|
||||
a = Tensor.rand(10, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
for axis in [None, 0, 1]:
|
||||
a = Tensor.rand(i, j)
|
||||
expected = a.var(axis).numpy()
|
||||
symbolic = a.reshape(vi, vj).var(axis).reshape(expected.shape).numpy()
|
||||
expected = a[:i, :j].var(axis).numpy()
|
||||
symbolic = a[:vi, :vj].var(axis).reshape(expected.shape).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_bitcast_down(self):
|
||||
a = Tensor.rand(10, 3)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(i, 3)
|
||||
expected = a.bitcast(dtypes.uint8).numpy()
|
||||
symbolic = a.reshape(vi, 3).bitcast(dtypes.uint8).reshape(expected.shape).numpy()
|
||||
expected = a[:i].bitcast(dtypes.uint8).numpy()
|
||||
symbolic = a[:vi].bitcast(dtypes.uint8).reshape(expected.shape).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no uint64")
|
||||
def test_bitcast_up(self):
|
||||
a = Tensor.rand(10, 4)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(i, 4)
|
||||
expected = a.bitcast(dtypes.uint64).numpy()
|
||||
symbolic = a.reshape(vi, 4).bitcast(dtypes.uint64).reshape(expected.shape).numpy()
|
||||
expected = a[:i].bitcast(dtypes.uint64).numpy()
|
||||
symbolic = a[:vi].bitcast(dtypes.uint64).reshape(expected.shape).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_conv2d_ceildiv_edge_case(self):
|
||||
v = Variable('v', 11, 50_000)
|
||||
val = 39601
|
||||
x = Tensor.randn(1, 22, 39601).reshape(1, 22, v.bind(val))
|
||||
x = Tensor.randn(1, 22, 50_000)[:, :, :v.bind(val)]
|
||||
weight = Tensor.randn(256, 22, 12)
|
||||
|
||||
result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Variable
|
||||
from tinygrad.helpers import Context
|
||||
|
||||
class TestTensorVariable(unittest.TestCase):
|
||||
def test_add_tvar(self):
|
||||
@@ -23,43 +22,38 @@ class TestTensorVariable(unittest.TestCase):
|
||||
assert (Tensor(3) * (vv * 4)).item() == 24
|
||||
|
||||
def test_symbolic_mean(self):
|
||||
with Context(IGNORE_OOB=1):
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
|
||||
ret = t.mean().item()
|
||||
assert ret == 1
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
t = Tensor.ones(2, 10).contiguous()[:, :vv]
|
||||
ret = t.mean().item()
|
||||
assert ret == 1
|
||||
|
||||
def test_symbolic_mean_2d(self):
|
||||
with Context(IGNORE_OOB=1):
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
vv2 = Variable("b", 1, 10).bind(2)
|
||||
t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv)
|
||||
ret = t.mean().item()
|
||||
assert ret == 1
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
vv2 = Variable("b", 1, 10).bind(2)
|
||||
t = Tensor.ones(10, 10).contiguous()[:vv2, :vv]
|
||||
ret = t.mean().item()
|
||||
assert ret == 1
|
||||
|
||||
def test_symbolic_mean_2d_axis_1(self):
|
||||
with Context(IGNORE_OOB=1):
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
vv2 = Variable("b", 1, 10).bind(2)
|
||||
t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv)
|
||||
ret = t.mean(axis=1).reshape(2, 1).numpy()
|
||||
assert np.all(ret == 1)
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
vv2 = Variable("b", 1, 10).bind(2)
|
||||
t = Tensor.ones(10, 10).contiguous()[:vv2, :vv]
|
||||
ret = t.mean(axis=1).reshape(2, 1).numpy()
|
||||
assert np.all(ret == 1)
|
||||
|
||||
def test_symbolic_mean_2d_add(self):
|
||||
with Context(IGNORE_OOB=1):
|
||||
add_term = Variable("c", 0, 10).bind(1)
|
||||
vv = Variable("a", 1, 10).bind(1)
|
||||
vv2 = Variable("b", 1, 10).bind(1)
|
||||
t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term)
|
||||
ret = t.mean().item()
|
||||
assert ret == 1
|
||||
add_term = Variable("c", 0, 10).bind(1)
|
||||
vv = Variable("a", 1, 10).bind(1)
|
||||
vv2 = Variable("b", 1, 10).bind(1)
|
||||
t = Tensor.ones(20, 20).contiguous()[:vv2+add_term, :vv+add_term]
|
||||
ret = t.mean().item()
|
||||
assert ret == 1
|
||||
|
||||
def test_symbolic_var(self):
|
||||
with Context(IGNORE_OOB=1):
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
|
||||
ret = t.var().item()
|
||||
assert ret == 0
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
t = Tensor.ones(2, 10).contiguous()[:, :vv]
|
||||
ret = t.var().item()
|
||||
assert ret == 0
|
||||
|
||||
def test_symbolic_pad(self):
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
|
||||
@@ -447,7 +447,7 @@ class TestUOpMethod(unittest.TestCase):
|
||||
def test_uop_variables(self):
|
||||
a = UOp.variable("a", 1, 10)
|
||||
uop_var = Tensor(a.bind(1))
|
||||
st_var = Tensor.empty((2, 1)).reshape((2, a.bind(1)))
|
||||
st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
|
||||
_, var_vals = (uop_var+st_var).schedule_with_vars()
|
||||
self.assertEqual(len(var_vals), 1)
|
||||
self.assertEqual(list(var_vals)[0], a)
|
||||
|
||||
@@ -839,25 +839,22 @@ class TestRender(unittest.TestCase):
|
||||
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
|
||||
self.assertEqual(valid.render(), "(ridx0<2)")
|
||||
|
||||
class TestVariableReshape(unittest.TestCase):
|
||||
def test_reshape(self):
|
||||
st = ShapeTracker.from_shape((3,))
|
||||
st = st.reshape((Variable("i", 1, 10),))
|
||||
class TestVariableShrink(unittest.TestCase):
|
||||
def test_shrink(self):
|
||||
st = ShapeTracker.from_shape((10,))
|
||||
st = st.shrink(((0, Variable("i", 1, 10)),))
|
||||
assert len(st.views) == 1
|
||||
|
||||
def test_reshape_stride_0(self):
|
||||
st = ShapeTracker.from_shape((3,), (0,))
|
||||
st = st.reshape((Variable("i", 1, 10).bind(3),))
|
||||
assert len(st.views) == 1, f"multiview {st}"
|
||||
|
||||
def test_reshape_bound(self):
|
||||
st = ShapeTracker.from_shape((3,))
|
||||
st = st.reshape((Variable("i", 1, 10).bind(3),))
|
||||
def test_shrink_bound(self):
|
||||
st = ShapeTracker.from_shape((10,))
|
||||
st = st.shrink(((0, Variable("i", 1, 10).bind(3)),))
|
||||
assert len(st.views) == 1
|
||||
|
||||
def test_add(self):
|
||||
st1 = ShapeTracker.from_shape((3,))
|
||||
st2 = ShapeTracker.from_shape((Variable("i", 1, 10),))
|
||||
class TestVariableMerge(unittest.TestCase):
|
||||
def test_add_reshape(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
st1 = ShapeTracker.from_shape((vi,))
|
||||
st2 = ShapeTracker.from_shape((1, vi,))
|
||||
st = st1+st2
|
||||
assert len(st.views) == 1
|
||||
|
||||
@@ -867,15 +864,17 @@ class TestVariableReshape(unittest.TestCase):
|
||||
st = st1+st2
|
||||
assert len(st.views) == 1, f"multiview {st}"
|
||||
|
||||
def test_add_bound(self):
|
||||
st1 = ShapeTracker.from_shape((3,))
|
||||
st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),))
|
||||
def test_add_reshape_bound(self):
|
||||
vi = Variable("i", 1, 10).bind(3)
|
||||
st1 = ShapeTracker.from_shape((vi,))
|
||||
st2 = ShapeTracker.from_shape((1, vi,))
|
||||
st = st1+st2
|
||||
assert len(st.views) == 1
|
||||
|
||||
def test_simplify(self):
|
||||
st1 = ShapeTracker.from_shape((3,))
|
||||
st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),))
|
||||
vi = Variable("i", 1, 10).bind(3)
|
||||
st1 = ShapeTracker.from_shape((vi,))
|
||||
st2 = ShapeTracker.from_shape((1, vi,))
|
||||
st = ShapeTracker((st1.views[0], st2.views[0]))
|
||||
st = st.simplify()
|
||||
assert len(st.views) == 1
|
||||
|
||||
@@ -87,20 +87,6 @@ class TestShapeTrackerAdd(unittest.TestCase):
|
||||
assert not (st_equal(st1, st2))
|
||||
|
||||
class TestShapeTrackerAddVariable(unittest.TestCase):
|
||||
def test_self_add(self):
|
||||
j = Variable("j", 0, 20).bind(10)
|
||||
a = ShapeTracker.from_shape((10,10))
|
||||
x = a.reshape((10, j))
|
||||
out = x + x
|
||||
assert out == x
|
||||
|
||||
def test_self_add_reshape(self):
|
||||
j = Variable("j", 0, 20).bind(10)
|
||||
a = ShapeTracker.from_shape((10,10))
|
||||
x = a.reshape((10, j))
|
||||
out = x.reshape((5, 2, j)) + x
|
||||
assert out == x
|
||||
|
||||
def test_merge_symbolic_views(self):
|
||||
var_i = Variable('i', 1, 10)
|
||||
var_j = Variable('i', 1, 10)
|
||||
|
||||
@@ -48,11 +48,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
i = Variable("i", 1, 5).bind(3)
|
||||
j = Variable("j", 1, 5).bind(3)
|
||||
k = Variable("k", 1, 5).bind(3)
|
||||
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
|
||||
t = Tensor.rand(5, 4)[:i].cat(Tensor.rand(5, 4)[:j], dim=0).cat(Tensor.rand(5, 4)[:k], dim=0)
|
||||
st = t.uop.st
|
||||
self.assert_tuple_equal(st.shape, (i+j+k, 4))
|
||||
assert st.real_strides() == (4, 1)
|
||||
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
|
||||
t = Tensor.rand(5, 3)[:i].cat(Tensor.rand(5, 3)[:i], dim=0).cat(Tensor.rand(3, 3), dim=0)
|
||||
st = t.uop.st
|
||||
self.assert_tuple_equal(st.shape, (2*i+3, 3))
|
||||
assert st.real_strides() == (3, 1)
|
||||
@@ -61,7 +61,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
i = Variable("i", 1, 5).bind(4)
|
||||
j = Variable("j", 1, 5).bind(4)
|
||||
k = Variable("k", 1, 5).bind(4)
|
||||
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
|
||||
t = Tensor.rand(3, 5)[:, :i].cat(Tensor.rand(3, 5)[:, :j], dim=1).cat(Tensor.rand(3, 5)[:, :k], dim=1)
|
||||
st = t.uop.st
|
||||
self.assert_tuple_equal(st.shape, (3, i+j+k))
|
||||
self.assert_tuple_equal(st.real_strides(), (i+j+k, 1))
|
||||
@@ -109,60 +109,44 @@ class TestShapeTrackerUnbind(unittest.TestCase):
|
||||
assert unbound_view == View.create(shape=(v, 4))
|
||||
assert var_val == {v: 3}
|
||||
|
||||
def test_reshape_unbind(self):
|
||||
v = Variable("v", 1, 100)
|
||||
bv = Variable("v", 1, 100).bind(3)
|
||||
t = Tensor.rand(3, 4).reshape(bv, 4)
|
||||
unbound_st, var_val = t.uop.st.unbind()
|
||||
assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),))
|
||||
assert var_val == {v: 3}
|
||||
|
||||
def test_shrink_unbind(self):
|
||||
v = Variable("v", 1, 100)
|
||||
bv = Variable("v", 1, 100).bind(2)
|
||||
t = Tensor.rand(3, 4).shrink(((0,bv),(0,4)))
|
||||
unbound_st, var_val = t.uop.st.unbind()
|
||||
assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),))
|
||||
assert var_val == {v: 2}
|
||||
t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4)))
|
||||
unbound_st, var_val = t.uop.st.unbind()
|
||||
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
|
||||
assert var_val == {v: 2}
|
||||
|
||||
class TestSymbolicReshapeFromContiguous(unittest.TestCase):
|
||||
def test_reshape_into_symbols_simple(self):
|
||||
class TestSymbolicReshape(unittest.TestCase):
|
||||
def test_reshape(self):
|
||||
a = Tensor.rand(5, 4)
|
||||
b = Tensor.rand(5, 6)
|
||||
for i in range(1, 6):
|
||||
vi = Variable("i", 1, 5).bind(i)
|
||||
t = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
assert t.shape == (vi, 4)
|
||||
t = Tensor.rand(i, 6).reshape(vi, 2, 3)
|
||||
assert t.shape == (vi, 2, 3)
|
||||
|
||||
def test_reshape_symbols_reshape_ints(self):
|
||||
for i in range(1, 6):
|
||||
vi = Variable("i", 1, 5).bind(i)
|
||||
t = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
assert t.shape == (vi, 4)
|
||||
t = t.reshape(i, 4)
|
||||
assert t.shape == (i, 4)
|
||||
|
||||
@unittest.skip("works now")
|
||||
def test_reshape_into_symbols_bad_shape(self):
|
||||
vi = Variable("i", 1, 10).bind(4)
|
||||
# TODO: this never actually worked, it relied on lazy
|
||||
#with self.assertRaises(ValueError):
|
||||
# Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
|
||||
with self.assertRaises(AssertionError):
|
||||
Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
|
||||
ret = a[:vi]
|
||||
ret = ret.reshape((vi, 4))
|
||||
assert ret.shape == (vi, 4)
|
||||
ret = b[:vi]
|
||||
ret = ret.reshape((vi, 2, 3))
|
||||
assert ret.shape == (vi, 2, 3)
|
||||
|
||||
def test_two_symbol_reshape(self):
|
||||
t = Tensor.rand(5, 5)
|
||||
for i in range(1, 6):
|
||||
for j in range(1, 6):
|
||||
vi = Variable("i", 1, 5).bind(i)
|
||||
vj = Variable("j", 1, 5).bind(j)
|
||||
t = Tensor.rand(i, j).reshape(vi, vj)
|
||||
assert t.shape == (vi, vj)
|
||||
# NOTE: this is currently not allowed
|
||||
# t = t.reshape(1, vi*vj)
|
||||
# assert t.shape == (1, vi*vj)
|
||||
t = t.reshape(vj, vi)
|
||||
assert t.shape == (vj, vi)
|
||||
ret = t[:vi, :vj]
|
||||
ret = ret.reshape(vj, vi)
|
||||
assert ret.shape == (vj, vi)
|
||||
ret = ret.reshape(vi, vj)
|
||||
assert ret.shape == (vi, vj)
|
||||
ret = ret.reshape(1, vi*vj)
|
||||
assert ret.shape == (1, vi*vj)
|
||||
|
||||
def test_symbolic_mask(self):
|
||||
# taken from gpt2 single kvcache
|
||||
@@ -175,41 +159,6 @@ class TestSymbolicReshapeFromContiguous(unittest.TestCase):
|
||||
new_shape = (2, (Variable('start_pos', 1, 128)+1), 16, 64)
|
||||
assert view.reshape(new_shape) is None
|
||||
|
||||
class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
|
||||
def test_reshape_from_const(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
t = Tensor.ones(3, 4).reshape(3, vi)
|
||||
assert t.shape == (3, vi)
|
||||
assert not t.uop.st.contiguous
|
||||
assert len(t.uop.st.views) == 1
|
||||
|
||||
def test_reshape_not_allowed(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
with self.assertRaises(ValueError):
|
||||
# different shape length # TODO: cases where contractions matched might be fine
|
||||
Tensor.ones(3, 4, 1).reshape(3, vi)
|
||||
with self.assertRaises(ValueError):
|
||||
# size matched, but dimensions do not match
|
||||
Tensor.ones(4, 3).reshape(3, vi)
|
||||
|
||||
def test_reshape_from_padded(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
|
||||
st = t.uop.st
|
||||
assert len(st.views) == 1
|
||||
view = st.views[0]
|
||||
assert view.shape == (4, 3, 2)
|
||||
t = t.reshape(vi, 3, 2)
|
||||
st2 = t.uop.st
|
||||
assert len(st2.views) == 1
|
||||
view2 = st2.views[0]
|
||||
# check only shape changed. strides, offset, mask, contiguous remained the same
|
||||
assert view2.shape == (vi, 3, 2)
|
||||
assert view.strides == view2.strides == (0, 4, 1)
|
||||
assert view.offset == view2.offset == 1
|
||||
assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2))
|
||||
assert not view.contiguous and not view2.contiguous
|
||||
|
||||
class TestSymbolicExpand(unittest.TestCase):
|
||||
def test_expand_into_symbols(self):
|
||||
vi = Variable("i", 1, 5).bind(3)
|
||||
@@ -220,11 +169,12 @@ class TestSymbolicExpand(unittest.TestCase):
|
||||
assert a.shape == (3, vi, vj)
|
||||
|
||||
def test_plus_expands_constant(self):
|
||||
a = Tensor.rand(3, 5)
|
||||
for i in range(1, 6):
|
||||
vi = Variable("i", 1, 5).bind(i)
|
||||
a = Tensor.rand(3, i).reshape(3, vi)
|
||||
a = a + 1
|
||||
self.assertTupleEqual(a.shape, (3, vi))
|
||||
ret = a[:, :vi]
|
||||
ret = ret + 1
|
||||
self.assertTupleEqual(ret.shape, (3, vi))
|
||||
|
||||
def test_pad_then_expand_into_symbols(self):
|
||||
vi = Variable("i", 1, 10).bind(3)
|
||||
@@ -234,6 +184,11 @@ class TestSymbolicExpand(unittest.TestCase):
|
||||
self.assertEqual(a.reshape(vi*25).shape, (vi*25,))
|
||||
|
||||
class TestSymbolicShrink(unittest.TestCase):
|
||||
def test_shrink_symbols_simple(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
t = Tensor.rand(5, 5).shrink(((0, 5),(0,vi)))
|
||||
assert t.shape == (5, vi)
|
||||
|
||||
def test_shrink_symbols(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1)))
|
||||
@@ -242,10 +197,10 @@ class TestSymbolicShrink(unittest.TestCase):
|
||||
class TestSymbolicPad(unittest.TestCase):
|
||||
def test_pad(self):
|
||||
v = Variable("v", 1, 100).bind(5)
|
||||
t = Tensor.ones(5).reshape(v).pad(((4, 0),)).reshape(9)
|
||||
assert t.shape == (9,)
|
||||
st = t.uop.st
|
||||
print(st)
|
||||
t = Tensor.ones(100)[:v].pad(((4, 0),))
|
||||
t = t.reshape(9)
|
||||
assert t.tolist() == [0,0,0,0,1,1,1,1,1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -97,7 +97,7 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
|
||||
vi = UOp.variable("i", 1, 3).bind(1)
|
||||
a = Tensor.empty(3, vi)
|
||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
|
||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.SHRINK, src=(UPat(Ops.BUFFER),))),))
|
||||
self.assertEqual(a.uop.base.buffer.size, 9)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -119,7 +119,7 @@ def create_kernel(x:UOp, b:UOp|None=None):
|
||||
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
|
||||
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
||||
return buffer.assign(kernel).reshape(x.shape)
|
||||
return buffer.assign(kernel).shrink(((0, prod(x.shape)),)).reshape(x.shape)
|
||||
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND}
|
||||
def append_to_kernel(x:UOp):
|
||||
|
||||
@@ -196,7 +196,8 @@ def map_contiguous(ctx:RangeifyContext, x:UOp):
|
||||
ranges = []
|
||||
for s in x.shape[len(x.src)-1:]:
|
||||
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
|
||||
return x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device).forced_reshape(x.shape)
|
||||
ret = x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device)
|
||||
return ret.shrink(((0, prod(x.shape)),)).forced_reshape(x.shape)
|
||||
|
||||
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
||||
rngs = list(idx.src[1:])
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools, operator, itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, Sequence
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify
|
||||
from tinygrad.uop.ops import resolve, UOp, Variable, sint, smax, smin, sint_to_uop, Ops, ssimplify
|
||||
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
@@ -311,9 +311,10 @@ class View:
|
||||
|
||||
if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")
|
||||
# check for the same size
|
||||
if (self_all_int := all_int(self.shape)):
|
||||
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
if all_int(self.shape):
|
||||
# reshapes cannot introduce symbolic shape
|
||||
assert all_int(new_shape), f"{self.shape=} -> {new_shape=} contains non int dims"
|
||||
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
|
||||
if 0 in self.shape: return View.create(new_shape)
|
||||
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
||||
@@ -321,15 +322,6 @@ class View:
|
||||
# after the asserts, it's okay to check contiguous
|
||||
if self.contiguous: return View.create(new_shape)
|
||||
|
||||
# if it's not contiguous and new shape is symbolic, check if it's directly replaceable
|
||||
if self_all_int and not all_int(new_shape):
|
||||
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||
for si, so in zip(self.shape, new_shape):
|
||||
if not isinstance(so, int): so = sym_infer(so, dict([v.unbind() for v in so.vars()]))
|
||||
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||
# all dimensions matched, return the new view directly
|
||||
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
|
||||
|
||||
r_strides, r_new_shape = [], reversed(new_shape)
|
||||
for merged_size, new_stride, real_size in reversed(merge_dims(self.shape, self.strides, self.mask)):
|
||||
# TODO: write with get_contraction
|
||||
|
||||
@@ -442,7 +442,7 @@ class Tensor(MathTrait):
|
||||
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
|
||||
# TODO: add test for multidevice tensor
|
||||
device = tuple(Device.canonicalize(d) for d in device) if isinstance(device, tuple) else Device.canonicalize(device)
|
||||
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).reshape(shape)
|
||||
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
|
||||
|
||||
@staticmethod
|
||||
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user