mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
improvements to apps.llm (#15631)
This commit is contained in:
@@ -394,6 +394,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model", "-m", default=list(models.keys())[0], help=f"Model choice ({', '.join(models.keys())}) or path to a local GGUF file")
|
||||
parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length")
|
||||
parser.add_argument("--serve", nargs='?', type=int, const=11434, metavar="PORT", help="Run OpenAI compatible API (optional port, default 11434)")
|
||||
parser.add_argument("--warmup", action="store_true", help="warmup the JIT")
|
||||
parser.add_argument("--benchmark", nargs='?', type=int, const=20, metavar="COUNT", help="Benchmark tok/s (optional count, default 20)")
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -412,8 +413,17 @@ if __name__ == "__main__":
|
||||
bos_id: int|None = kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None
|
||||
eos_id: int = kv['tokenizer.ggml.eos_token_id']
|
||||
|
||||
# warmup the JIT
|
||||
if args.warmup or args.serve:
|
||||
# run 2 tokens through the model twice to capture the JIT before serving
|
||||
with Context(DEBUG=max(DEBUG.value, 1)):
|
||||
for _ in range(2): list(zip(range(2), model.generate([0])))
|
||||
|
||||
# start server
|
||||
if args.serve: TCPServerWithReuse(('', args.serve), Handler).serve_forever()
|
||||
|
||||
# do benchmark
|
||||
if args.benchmark:
|
||||
if args.benchmark is not None:
|
||||
gen = model.generate(toks:=[bos_id or 0])
|
||||
for _ in range(args.benchmark):
|
||||
GlobalCounters.reset()
|
||||
@@ -422,13 +432,6 @@ if __name__ == "__main__":
|
||||
tok.decode(toks).replace("\n", "\\n")): next(gen)
|
||||
exit(0)
|
||||
|
||||
# start server
|
||||
if args.serve:
|
||||
# warmup: run 2 tokens through the model twice to capture the JIT before serving
|
||||
with Context(DEBUG=max(DEBUG.value, 1)):
|
||||
for _ in range(2): list(zip(range(2), model.generate([0])))
|
||||
TCPServerWithReuse(('', args.serve), Handler).serve_forever()
|
||||
|
||||
# interactive chat
|
||||
ids: list[int] = [bos_id] if bos_id is not None else []
|
||||
while 1:
|
||||
|
||||
Reference in New Issue
Block a user