examples of new GPT2 and JIT change (#2261)

* var_vals are global

* working with global ish

* better

* fix export model

* fix tests

* better kv cache

* does it run?

* use where for kvmask

* fix excessive var_vals

* fix import

* how does multigpu use this?

* llama kinda work

* faster and simpler

* cleanup

* fix conversation mode

* test cleanups

* fix one more test

* test cleanup

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
chenyu
2023-11-10 15:07:02 -05:00
committed by GitHub
parent b6aaf12df7
commit a753c8e071
15 changed files with 189 additions and 249 deletions

View File

@@ -19,19 +19,6 @@ class TestSymbolicJit(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_reshape_inside_plus1(self):
def f(a, jit=False, jit_ctx=None):
if jit: a = a.reshape(3, Variable("i", 1, 10).bind(a.shape[1]))
return (a+1).realize()
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10)
a = Tensor.rand(3, i)
symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy()
expected = f(a).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_add(self):
def f(a, b): return (a+b).realize()
jf = TinyJit(f)