Variable.bind newer (#2017)

* Variable.bind attempt 2

* ShapeTracker.unbind

* fix llama

* fix types

* test case

* View.vars cleanup

* include mask in symbolic source

* mask can be sint

* st.unbind in bufferops

* assert ast contain free Variable only

* cleanup

* conservative unbinding reduce op arg

* move reduceop unbind

* fix llama JIT arg behavior
This commit is contained in:
chenyu
2023-10-10 13:03:01 -04:00
committed by GitHub
parent 71d93ffd79
commit e2b83f1b42
15 changed files with 204 additions and 199 deletions

View File

@@ -11,8 +11,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy()
expected = f(a).numpy()
@@ -20,12 +20,12 @@ class TestSymbolicJit(unittest.TestCase):
assert len(jf.jit_cache) == 1
def test_reshape_inside_plus1(self):
vi = Variable("i", 1, 10)
def f(a, jit=False, jit_ctx=None):
if jit: a = a.reshape(3, vi)
if jit: a = a.reshape(3, Variable("i", 1, 10).bind(a.shape[1]))
return (a+1).realize()
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10)
a = Tensor.rand(3, i)
symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy()
expected = f(a).numpy()
@@ -35,8 +35,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_add(self):
def f(a, b): return (a+b).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(3, i)
symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy()
@@ -47,8 +47,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_matmul(self):
def f(a, b): return (a@b).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
@@ -63,7 +63,7 @@ class TestSymbolicJit(unittest.TestCase):
return s
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10)
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
@@ -74,8 +74,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, i, 4, 8)
v = Tensor.rand(2, i, 4, 8)
@@ -87,8 +87,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(i, 3)
b = Tensor.rand(2, 3)
symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy()
@@ -99,8 +99,8 @@ class TestSymbolicJit(unittest.TestCase):
def test_cat_dim1(self):
def f(a, b): return a.cat(b, dim=1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i)
b = Tensor.rand(3, 2)
symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy()
@@ -111,10 +111,10 @@ class TestSymbolicJit(unittest.TestCase):
def test_cat_dim0_two_vars(self):
def f(a, b): return a.cat(b, dim=0).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(i, 3)
b = Tensor.rand(j, 3)
symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy()
@@ -125,10 +125,10 @@ class TestSymbolicJit(unittest.TestCase):
def test_cat_dim1_two_vars(self):
def f(a, b): return a.cat(b, dim=1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(3, i)
b = Tensor.rand(3, j)
symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy()
@@ -139,10 +139,10 @@ class TestSymbolicJit(unittest.TestCase):
def test_two_vars_plus1(self):
def f(a, b): return (a@b+1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(i, 3)
b = Tensor.rand(3, j)
symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy()
@@ -153,13 +153,14 @@ class TestSymbolicJit(unittest.TestCase):
def test_jit_symbolic_shape_mismatch(self):
@TinyJit
def add(a, b): return (a+b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(3, i).reshape(3, vi)
b = Tensor.rand(3, i).reshape(3, vi)
c = add(a, b)
a = Tensor.rand(3, 7).reshape(3, vi)
bad = Tensor.rand(4, 7).reshape(4, vi)
vi2 = Variable("i", 1, 10).bind(7)
a = Tensor.rand(3, 7).reshape(3, vi2)
bad = Tensor.rand(4, 7).reshape(4, vi2)
with self.assertRaises(AssertionError):
add(a, bad)
@@ -167,11 +168,10 @@ class TestSymbolicJit(unittest.TestCase):
# shrink is a movement, so we pair it with a simple function to test the JIT interaction
def f(a): return (a+1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
a = Tensor.rand(7, 11)
symbolic = a.shrink(((3,5),(vi,vi+2)))
symbolic.lazydata.var_vals[vi] = i
symbolic = jf(symbolic).numpy()
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)