mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
refactor test/external/external_llama_eval.py (#10567)
Co-authored-by: wozeparrot <wozeparrot@gmail.com>
This commit is contained in:
102
test/external/external_llama_eval.py
vendored
102
test/external/external_llama_eval.py
vendored
@@ -1,102 +0,0 @@
|
||||
from lm_eval.base import BaseLM
|
||||
from lm_eval import evaluator, tasks
|
||||
import torch, json, argparse
|
||||
|
||||
from examples.llama import LLaMa
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import Device
|
||||
|
||||
class LLaMaAdaptor(BaseLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_size="7B",
|
||||
model_gen=1,
|
||||
device="",
|
||||
quantize=False,
|
||||
batch_size=1,
|
||||
max_batch_size=1,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
checkpoint_path="",
|
||||
tokenizer_path="",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = 1
|
||||
self.do_sample = do_sample
|
||||
self.temperature = temperature
|
||||
self._device = device
|
||||
|
||||
assert isinstance(model_gen, int)
|
||||
assert isinstance(model_size, str)
|
||||
assert isinstance(batch_size, int)
|
||||
assert isinstance(checkpoint_path, str)
|
||||
assert isinstance(tokenizer_path, str)
|
||||
|
||||
self.llama = LLaMa.build(checkpoint_path, tokenizer_path, model_gen, model_size, quantize)
|
||||
|
||||
@classmethod
|
||||
def create_from_arg_string(cls, arg_string, additional_config=None):
|
||||
kwargs = {el.split("=")[0]: el.split("=")[1] for el in arg_string.split(",")}
|
||||
return cls(**kwargs, **additional_config)
|
||||
|
||||
@property
|
||||
def eot_token_id(self):
|
||||
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
||||
return self.llama.tokenizer.eos_id()
|
||||
|
||||
@property
|
||||
def max_length(self):
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def max_gen_toks(self):
|
||||
return 256
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
def tok_encode(self, string: str):
|
||||
return [self.llama.tokenizer.bos_id()] + self.llama.tokenizer.encode(string)
|
||||
|
||||
def tok_decode(self, tokens):
|
||||
return self.llama.tokenizer.decode(tokens)
|
||||
|
||||
def _model_call(self, inps):
|
||||
return torch.Tensor(self.llama.model(Tensor(inps.numpy()), 0).numpy())
|
||||
|
||||
def greedy_until(self, requests):
|
||||
continuations = []
|
||||
for request in requests:
|
||||
prompt, until = request[0], request[1]['until']
|
||||
output = self.llama.greedy_until(prompt, until, max_length=128, temperature=0.0)
|
||||
continuations.append(output[len(prompt):])
|
||||
return continuations
|
||||
|
||||
def _model_generate(self, context, max_length, eos_token_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run LLaMA evals in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B] for Gen 2")
|
||||
parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]")
|
||||
parser.add_argument('--quantize', action='store_true', help="Quantize the weights to int8 in memory")
|
||||
parser.add_argument('--eval', type=str, default="arc_easy", help="Run in evaluation mode")
|
||||
parser.add_argument('--limit', type=int, default=None, help="Limit tests in eval")
|
||||
parser.add_argument('--weights', type=str, default="./weights/LLaMa/", help="Location of the weights")
|
||||
parser.add_argument('--tokenizer', type=str, default="./weights/LLaMa/tokenizer.model", help="Location of the tokenizer")
|
||||
args = parser.parse_args()
|
||||
|
||||
# run eval and exit
|
||||
adaptor = LLaMaAdaptor(model_gen=args.gen, model_size=args.size, quantize=args.quantize,
|
||||
checkpoint_path=args.weights, tokenizer_path=args.tokenizer, device="cpu")
|
||||
results = evaluator.evaluate(adaptor, tasks.get_task_dict(args.eval.split(",")), False, 0, args.limit)
|
||||
print(json.dumps(results, indent=2))
|
||||
100
test/external/sglang_llama/external_llama_eval.py
vendored
Normal file
100
test/external/sglang_llama/external_llama_eval.py
vendored
Normal file
@@ -0,0 +1,100 @@
|
||||
from lm_eval import simple_evaluate
|
||||
from lm_eval.api.instance import Instance
|
||||
from lm_eval.api.model import LM
|
||||
from lm_eval.tasks import TaskManager
|
||||
from pathlib import Path
|
||||
import json, argparse
|
||||
|
||||
from examples.llama3 import build_transformer, Tokenizer, MODEL_PARAMS
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import tqdm
|
||||
|
||||
class LLaMaAdaptor(LM):
|
||||
def __init__(
|
||||
self,
|
||||
model_size: str,
|
||||
checkpoint_path: Path,
|
||||
max_length: int,
|
||||
quantize: str | None,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.tokenizer = Tokenizer(str((checkpoint_path if checkpoint_path.is_dir() else checkpoint_path.parent) / "tokenizer.model"))
|
||||
self.model = build_transformer(checkpoint_path, model_size=model_size, quantize=quantize, max_context=self.max_length)
|
||||
self.last_seen_toks = []
|
||||
def _prefill(self, toks, temperature) -> int:
|
||||
start_pos = 0
|
||||
# we can skip part of the prompt if it is the same as last
|
||||
for i, (a, b) in enumerate(zip(toks, self.last_seen_toks)):
|
||||
if a != b: break
|
||||
else: i = min(len(toks), len(self.last_seen_toks))
|
||||
start_pos += i
|
||||
self.last_seen_toks = toks
|
||||
toks = toks[i:]
|
||||
|
||||
# prefill the model
|
||||
for tok in toks:
|
||||
self.model(Tensor([[tok]]), start_pos, temperature).realize()
|
||||
start_pos += 1
|
||||
return start_pos
|
||||
|
||||
@property
|
||||
def tokenizer_name(self) -> str: pass
|
||||
def chat_template(self, chat_template: bool | str = False) -> str: pass
|
||||
def apply_chat_template(self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True) -> str:
|
||||
ret = ""
|
||||
for message in chat_history:
|
||||
ret += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content'].strip()}<|eot_id|>"
|
||||
if add_generation_prompt: ret += "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
return ret
|
||||
|
||||
def generate_until(self, requests: list[Instance]) -> list[str]:
|
||||
continuations = []
|
||||
for request in tqdm(requests):
|
||||
prompt, args = request.args
|
||||
until = [self.tokenizer.encode(tok) for tok in args.get("until", [])]
|
||||
toks = [self.tokenizer.bos_id] + self.tokenizer.encode(prompt,allow_special=True)
|
||||
prompt_len = len(toks)
|
||||
max_gen_toks = args.get("max_gen_toks") or args.get("max_length") or self.max_length-prompt_len
|
||||
assert self.max_length >= max_gen_toks, "This eval needs a longer context length"
|
||||
temperature = args.get("temperature", 0.0)
|
||||
start_pos = self._prefill(toks[:-1], temperature)
|
||||
|
||||
for _ in range(max_gen_toks):
|
||||
next_tok = self.model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
|
||||
if next_tok in self.tokenizer.stop_tokens or next_tok in until: break
|
||||
toks.append(next_tok)
|
||||
start_pos += 1
|
||||
|
||||
continuations.append(self.tokenizer.decode(toks[prompt_len:]))
|
||||
return continuations
|
||||
|
||||
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: raise NotImplementedError() # needs changes to extra/models/llama.py
|
||||
def loglikelihood_rolling(self, requests: list[Instance]) -> list[tuple[float, bool]]: raise NotImplementedError()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run LLaMA evals in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--size', type=str, default="8B", help=f"Size of model to use [{', '.join(list(MODEL_PARAMS.keys()))}]")
|
||||
parser.add_argument('--chat', action='store_true', help="Use chat model")
|
||||
parser.add_argument('--ctx', type=int, default=8192, help="Max context length")
|
||||
parser.add_argument('--quantize', type=str, default=None, help="Quantize the weights to int8 or int4 in memory")
|
||||
parser.add_argument('--eval', type=str, default="mgsm_en_cot_sglang", help="Run in evaluation mode")
|
||||
parser.add_argument('--limit', type=int, default=None, help="Limit tests in eval")
|
||||
parser.add_argument('--num_fewshot', type=int, default=None, help="Number of examples to add to context")
|
||||
parser.add_argument('--model', type=Path, default="./weights/LLaMa/", help="Location of the weights")
|
||||
parser.add_argument('--output_path', type=Path, default=None, help="Location of the log file")
|
||||
args = parser.parse_args()
|
||||
|
||||
# run eval and exit
|
||||
adaptor = LLaMaAdaptor(model_size=args.size, quantize=args.quantize,
|
||||
checkpoint_path=args.model, max_length=args.ctx)
|
||||
task_manager = TaskManager(include_path="./")
|
||||
results = simple_evaluate(model=adaptor, tasks=args.eval.split(","), task_manager=task_manager, apply_chat_template=args.chat,
|
||||
num_fewshot=args.num_fewshot, limit=args.limit)
|
||||
|
||||
if args.output_path: args.output_path.write_text(json.dumps(results, indent=2))
|
||||
for task_name, val in results["results"].items():
|
||||
print(f"{task_name}:")
|
||||
print("\n".join(f"\t{k}: {v}" for k, v in val.items() if k != "alias"))
|
||||
36
test/external/sglang_llama/mgsm.yaml
vendored
Normal file
36
test/external/sglang_llama/mgsm.yaml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
# https://github.com/sgl-project/sglang/blob/main/python/sglang/test/simple_eval_mgsm.py#L41
|
||||
task: mgsm_en_cot_sglang
|
||||
dataset_path: juletxara/mgsm
|
||||
dataset_name: en
|
||||
output_type: generate_until
|
||||
training_split: train
|
||||
test_split: test
|
||||
doc_to_target: '{{answer[21:] if answer is not none else answer_number|string}}'
|
||||
doc_to_text: >-
|
||||
{{'Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:".
|
||||
Do not add anything other than the integer answer after "Answer:".\n\n'
|
||||
+(question[10:] if answer is not none else question)}}
|
||||
generation_kwargs:
|
||||
do_sample: false
|
||||
temperature: 0.0
|
||||
until: []
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
ignore_case: true
|
||||
ignore_punctuation: true
|
||||
filter_list:
|
||||
- name: "strict-match"
|
||||
filter:
|
||||
- function: "regex"
|
||||
regex_pattern: 'Answer:\s*([\-]?[0-9\.\,]+)'
|
||||
- function: "take_first"
|
||||
- filter:
|
||||
- function: regex
|
||||
group_select: -1
|
||||
regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+)
|
||||
- function: take_first
|
||||
name: flexible-extract
|
||||
metadata:
|
||||
version: 3.0
|
||||
Reference in New Issue
Block a user