mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
HF-Reference LLM mode + Update test result to match latest Turbine. (#2080)
* HF-Reference LLM mode. * Fixup test to match current output from Turbine. * lint * Fix test error message + Only initialize HF torch model when used. * Remove redundant format_out change.
This commit is contained in:
@@ -9,7 +9,7 @@ from itertools import chain
|
||||
import gc
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
llm_model_map = {
|
||||
"llama2_7b": {
|
||||
@@ -109,6 +109,7 @@ class LanguageModel:
|
||||
self.global_iter = 0
|
||||
self.prev_token_len = 0
|
||||
self.first_input = True
|
||||
self.hf_auth_token = hf_auth_token
|
||||
if self.external_weight_file is not None:
|
||||
if not os.path.exists(self.external_weight_file):
|
||||
print(
|
||||
@@ -164,6 +165,8 @@ class LanguageModel:
|
||||
use_auth_token=hf_auth_token,
|
||||
)
|
||||
self.compile()
|
||||
# Reserved for running HF torch model as reference.
|
||||
self.hf_mod = None
|
||||
|
||||
def compile(self) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
@@ -267,6 +270,50 @@ class LanguageModel:
|
||||
self.global_iter += 1
|
||||
return result_output, total_time
|
||||
|
||||
# Reference HF model function for sanity checks.
|
||||
def chat_hf(self, prompt):
|
||||
if self.hf_mod is None:
|
||||
self.hf_mod = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_name,
|
||||
torch_dtype=torch.float,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
prompt = self.sanitize_prompt(prompt)
|
||||
|
||||
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
history = []
|
||||
for iter in range(self.max_tokens):
|
||||
token_len = input_tensor.shape[-1]
|
||||
if self.first_input:
|
||||
st_time = time.time()
|
||||
result = self.hf_mod(input_tensor)
|
||||
token = torch.argmax(result.logits[:, -1, :], dim=1)
|
||||
total_time = time.time() - st_time
|
||||
token_len += 1
|
||||
pkv = result.past_key_values
|
||||
self.first_input = False
|
||||
|
||||
history.append(int(token))
|
||||
while token != llm_model_map["llama2_7b"]["stop_token"]:
|
||||
dec_time = time.time()
|
||||
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
|
||||
history.append(int(token))
|
||||
total_time = time.time() - dec_time
|
||||
token = torch.argmax(result.logits[:, -1, :], dim=1)
|
||||
pkv = result.past_key_values
|
||||
yield self.tokenizer.decode(history), total_time
|
||||
|
||||
self.prev_token_len = token_len + len(history)
|
||||
|
||||
if token == llm_model_map["llama2_7b"]["stop_token"]:
|
||||
break
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
history[i] = int(history[i])
|
||||
result_output = self.tokenizer.decode(history)
|
||||
self.global_iter += 1
|
||||
return result_output, total_time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lm = LanguageModel(
|
||||
|
||||
@@ -20,14 +20,15 @@ class LLMAPITest(unittest.TestCase):
|
||||
quantization="None",
|
||||
)
|
||||
count = 0
|
||||
label = "Turkishoure Turkish"
|
||||
for msg, _ in lm.chat("hi, what are you?"):
|
||||
# skip first token output
|
||||
if count == 0:
|
||||
count += 1
|
||||
continue
|
||||
assert (
|
||||
msg.strip(" ") == "Turkish Turkish Turkish"
|
||||
), f"LLM API failed to return correct response, expected 'Turkish Turkish Turkish', received {msg}"
|
||||
msg.strip(" ") == label
|
||||
), f"LLM API failed to return correct response, expected '{label}', received {msg}"
|
||||
break
|
||||
del lm
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user