fix gpt2 with benchmark (#12736)

`CPU=1 python3 examples/gpt2.py --benchmark 128` works now
This commit is contained in:
chenyu
2025-10-16 09:55:20 -04:00
committed by GitHub
parent 55db1b0e0e
commit f34f26bca0

View File

@@ -232,7 +232,7 @@ if __name__ == "__main__":
gpt2 = GPT2.build_gguf(args.model_size) if args.model_size.startswith("gpt2_gguf_") else GPT2.build(args.model_size)
if args.benchmark != -1:
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
gpt2.model(Tensor.randint(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
else:
texts = gpt2.generate(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
if not args.noshow: