diff --git a/examples/gpt2.py b/examples/gpt2.py index 7b508c1b3a..5c4dd28f2b 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -232,7 +232,7 @@ if __name__ == "__main__": gpt2 = GPT2.build_gguf(args.model_size) if args.model_size.startswith("gpt2_gguf_") else GPT2.build(args.model_size) if args.benchmark != -1: - gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize() + gpt2.model(Tensor.randint(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize() else: texts = gpt2.generate(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size) if not args.noshow: