mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Variable.bind newer (#2017)
* Variable.bind attempt 2 * ShapeTracker.unbind * fix llama * fix types * test case * View.vars cleanup * include mask in symbolic source * mask can be sint * st.unbind in bufferops * assert ast contain free Variable only * cleanup * conservative unbinding reduce op arg * move reduceop unbind * fix llama JIT arg behavior
This commit is contained in:
@@ -11,8 +11,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
@@ -20,12 +20,12 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_reshape_inside_plus1(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
def f(a, jit=False, jit_ctx=None):
|
||||
if jit: a = a.reshape(3, vi)
|
||||
if jit: a = a.reshape(3, Variable("i", 1, 10).bind(a.shape[1]))
|
||||
return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
@@ -35,8 +35,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, i)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
|
||||
@@ -47,8 +47,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_matmul(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
@@ -63,7 +63,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
return s
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
@@ -74,8 +74,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_attention(self):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
k = Tensor.rand(2, i, 4, 8)
|
||||
v = Tensor.rand(2, i, 4, 8)
|
||||
@@ -87,8 +87,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(2, 3)
|
||||
symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
|
||||
@@ -99,8 +99,8 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_cat_dim1(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, 2)
|
||||
symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy()
|
||||
@@ -111,10 +111,10 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_cat_dim0_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(j, 3)
|
||||
symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
|
||||
@@ -125,10 +125,10 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_cat_dim1_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
|
||||
@@ -139,10 +139,10 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_two_vars_plus1(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
vj = Variable("j", 1, 10).bind(j)
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
|
||||
@@ -153,13 +153,14 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_jit_symbolic_shape_mismatch(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(3, i).reshape(3, vi)
|
||||
b = Tensor.rand(3, i).reshape(3, vi)
|
||||
c = add(a, b)
|
||||
a = Tensor.rand(3, 7).reshape(3, vi)
|
||||
bad = Tensor.rand(4, 7).reshape(4, vi)
|
||||
vi2 = Variable("i", 1, 10).bind(7)
|
||||
a = Tensor.rand(3, 7).reshape(3, vi2)
|
||||
bad = Tensor.rand(4, 7).reshape(4, vi2)
|
||||
with self.assertRaises(AssertionError):
|
||||
add(a, bad)
|
||||
|
||||
@@ -167,11 +168,10 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
# shrink is a movement, so we pair it with a simple function to test the JIT interaction
|
||||
def f(a): return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a.shrink(((3,5),(vi,vi+2)))
|
||||
symbolic.lazydata.var_vals[vi] = i
|
||||
symbolic = jf(symbolic).numpy()
|
||||
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
Reference in New Issue
Block a user