From b0e70ab04f1c06e48c48c93c6086cc7befe731e4 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 27 Sep 2024 15:25:59 +0800 Subject: [PATCH] llm.c updates --- examples/llm.c/train_gpt2.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/examples/llm.c/train_gpt2.py b/examples/llm.c/train_gpt2.py index 388061dd20..f378f9ca0d 100755 --- a/examples/llm.c/train_gpt2.py +++ b/examples/llm.c/train_gpt2.py @@ -120,6 +120,7 @@ if __name__ == "__main__": parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run") parser.add_argument("--batch_size", type=int, default=4, help="batch size") parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") + parser.add_argument("--skip_test", action="store_true", help="skip test") args = parser.parse_args() B, T = args.batch_size, args.sequence_length assert 1 <= T <= 1024 @@ -135,10 +136,7 @@ if __name__ == "__main__": # load the tokens # prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories # we're using val instead of train split just because it is smaller/faster - shake_tokens_bin = "data/tiny_shakespeare_val.bin" - story_tokens_bin = "data/TinyStories_val.bin" - assert os.path.isfile(shake_tokens_bin) or os.path.isfile(story_tokens_bin), "you must run prepro on some dataset" - tokens_bin = shake_tokens_bin if os.path.isfile(shake_tokens_bin) else story_tokens_bin + tokens_bin = fetch("https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin") assert os.path.isfile(tokens_bin) print(f"loading cached tokens in {tokens_bin}") with open(tokens_bin, "rb") as f: @@ -181,12 +179,13 @@ if __name__ == "__main__": t1 = time.time() print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms") - start = "<|endoftext|>" - start_ids = encode(start) - x = (Tensor(start_ids)[None, ...]) - max_new_tokens = 16 - temperature = 1.0 - top_k = 40 - y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) - print(decode(y[0].tolist())) + if not args.skip_test: + start = "<|endoftext|>" + start_ids = encode(start) + x = (Tensor(start_ids)[None, ...]) + max_new_tokens = 16 + temperature = 1.0 + top_k = 40 + y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) + print(decode(y[0].tolist()))