mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move gpt2/llama sampling inside the model call (#3013)
* move gpt2/llama sampling inside the model call * argmax uses one more kernel
This commit is contained in:
@@ -77,7 +77,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
# TODO: test first token vs rest properly
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 190 if CI else 718, all_jitted=True)
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 191 if CI else 719, all_jitted=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||
def test_gpt2(self):
|
||||
|
||||
Reference in New Issue
Block a user