mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
examples of new GPT2 and JIT change (#2261)
* var_vals are global * working with global ish * better * fix export model * fix tests * better kv cache * does it run? * use where for kvmask * fix excessive var_vals * fix import * how does multigpu use this? * llama kinda work * faster and simpler * cleanup * fix conversation mode * test cleanups * fix one more test * test cleanup --------- Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
@@ -19,19 +19,6 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_reshape_inside_plus1(self):
|
||||
def f(a, jit=False, jit_ctx=None):
|
||||
if jit: a = a.reshape(3, Variable("i", 1, 10).bind(a.shape[1]))
|
||||
return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = jf(a, jit=True, jit_ctx={vi: i}).reshape(3, i).numpy()
|
||||
expected = f(a).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
jf = TinyJit(f)
|
||||
|
||||
Reference in New Issue
Block a user