move gpt2/llama sampling inside the model call (#3013)

* move gpt2/llama sampling inside the model call

* argmax uses one more kernel
This commit is contained in:
chenyu
2024-01-04 17:01:50 -05:00
committed by GitHub
parent c2a044ed83
commit f88506e630
6 changed files with 21 additions and 24 deletions

View File

@@ -77,7 +77,7 @@ class TestRealWorld(unittest.TestCase):
@TinyJit
def test(t): return model(t, 0).realize()
# TODO: test first token vs rest properly
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 190 if CI else 718, all_jitted=True)
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 191 if CI else 719, all_jitted=True)
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
def test_gpt2(self):