From fae53948453355fe65112749b0123409cb9bc50d Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 8 Dec 2023 16:42:01 -0500 Subject: [PATCH] validate llama output (#2681) * validate llama output * does not work with quantize --- examples/llama.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/examples/llama.py b/examples/llama.py index a9cdd0b702..6cf183ffdd 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -7,7 +7,7 @@ from pathlib import Path import sys, argparse, json import numpy as np np.set_printoptions(linewidth=200) -from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes +from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes, colored from tinygrad import Device from tinygrad.tensor import Tensor from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters @@ -193,7 +193,7 @@ class LLaMa: return output # **** main code **** -""" +r""" test: python3 examples/llama.py --temperature=0 --count=50 --prompt="Hello." output: @@ -430,3 +430,17 @@ After you are done speaking, output [EOS]. You are not Chad. # stop after you have your answer if chatbot and outputted.endswith(end_delim): break if not chatbot: break + + # validate output! + if args.temperature == 0 and args.count == 10 and args.prompt == "Hello." and not args.quantize: + text = llama.tokenizer.decode(toks) + key = (args.gen, args.size) + expected = { + ("1", "7B"): "Hello. I'm a 20 year old male", + ("2", "7B"): "Hello. I'm a 20 year old girl", + } + try: + assert text == expected[key], "invalid output: " + colored(text, "red") + print("\n" + colored("output validated", "green")) # NOTE: "\n" iside colored does not render the color in github action + except KeyError: + pass