HF-Reference LLM mode + Update test result to match latest Turbine. (#2080)

* HF-Reference LLM mode.

* Fixup test to match current output from Turbine.

* lint

* Fix test error message + Only initialize HF torch model when used.

* Remove redundant format_out change.
This commit is contained in:
Stanley Winata
2024-02-01 09:46:22 -08:00
committed by GitHub
parent 05b498267e
commit 6bf51f1f1d
2 changed files with 51 additions and 3 deletions

View File

@@ -9,7 +9,7 @@ from itertools import chain
import gc
import os
import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
llm_model_map = {
"llama2_7b": {
@@ -109,6 +109,7 @@ class LanguageModel:
self.global_iter = 0
self.prev_token_len = 0
self.first_input = True
self.hf_auth_token = hf_auth_token
if self.external_weight_file is not None:
if not os.path.exists(self.external_weight_file):
print(
@@ -164,6 +165,8 @@ class LanguageModel:
use_auth_token=hf_auth_token,
)
self.compile()
# Reserved for running HF torch model as reference.
self.hf_mod = None
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
@@ -267,6 +270,50 @@ class LanguageModel:
self.global_iter += 1
return result_output, total_time
# Reference HF model function for sanity checks.
def chat_hf(self, prompt):
if self.hf_mod is None:
self.hf_mod = AutoModelForCausalLM.from_pretrained(
self.hf_model_name,
torch_dtype=torch.float,
token=self.hf_auth_token,
)
prompt = self.sanitize_prompt(prompt)
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
history = []
for iter in range(self.max_tokens):
token_len = input_tensor.shape[-1]
if self.first_input:
st_time = time.time()
result = self.hf_mod(input_tensor)
token = torch.argmax(result.logits[:, -1, :], dim=1)
total_time = time.time() - st_time
token_len += 1
pkv = result.past_key_values
self.first_input = False
history.append(int(token))
while token != llm_model_map["llama2_7b"]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
total_time = time.time() - dec_time
token = torch.argmax(result.logits[:, -1, :], dim=1)
pkv = result.past_key_values
yield self.tokenizer.decode(history), total_time
self.prev_token_len = token_len + len(history)
if token == llm_model_map["llama2_7b"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
self.global_iter += 1
return result_output, total_time
if __name__ == "__main__":
lm = LanguageModel(

View File

@@ -20,14 +20,15 @@ class LLMAPITest(unittest.TestCase):
quantization="None",
)
count = 0
label = "Turkishoure Turkish"
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == "Turkish Turkish Turkish"
), f"LLM API failed to return correct response, expected 'Turkish Turkish Turkish', received {msg}"
msg.strip(" ") == label
), f"LLM API failed to return correct response, expected '{label}', received {msg}"
break
del lm
gc.collect()