From d3de63d9984e775c9fceb2dec9811ceccc0b3f3c Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 7 Apr 2026 20:34:05 +0800 Subject: [PATCH] improvements to apps.llm (#15631) --- tinygrad/apps/llm.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index e1480b2e8c..6fe4a19201 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -394,6 +394,7 @@ if __name__ == "__main__": parser.add_argument("--model", "-m", default=list(models.keys())[0], help=f"Model choice ({', '.join(models.keys())}) or path to a local GGUF file") parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length") parser.add_argument("--serve", nargs='?', type=int, const=11434, metavar="PORT", help="Run OpenAI compatible API (optional port, default 11434)") + parser.add_argument("--warmup", action="store_true", help="warmup the JIT") parser.add_argument("--benchmark", nargs='?', type=int, const=20, metavar="COUNT", help="Benchmark tok/s (optional count, default 20)") args = parser.parse_args() @@ -412,8 +413,17 @@ if __name__ == "__main__": bos_id: int|None = kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None eos_id: int = kv['tokenizer.ggml.eos_token_id'] + # warmup the JIT + if args.warmup or args.serve: + # run 2 tokens through the model twice to capture the JIT before serving + with Context(DEBUG=max(DEBUG.value, 1)): + for _ in range(2): list(zip(range(2), model.generate([0]))) + + # start server + if args.serve: TCPServerWithReuse(('', args.serve), Handler).serve_forever() + # do benchmark - if args.benchmark: + if args.benchmark is not None: gen = model.generate(toks:=[bos_id or 0]) for _ in range(args.benchmark): GlobalCounters.reset() @@ -422,13 +432,6 @@ if __name__ == "__main__": tok.decode(toks).replace("\n", "\\n")): next(gen) exit(0) - # start server - if args.serve: - # warmup: run 2 tokens through the model twice to capture the JIT before serving - with Context(DEBUG=max(DEBUG.value, 1)): - for _ in range(2): list(zip(range(2), model.generate([0]))) - TCPServerWithReuse(('', args.serve), Handler).serve_forever() - # interactive chat ids: list[int] = [bos_id] if bos_id is not None else [] while 1: