Fix Llama2 on CPU (#2133)

This commit is contained in:
gpetters-amd
2024-04-29 13:18:16 -04:00
committed by GitHub
parent e003d0abe8
commit 81d6e059ac
3 changed files with 17 additions and 6 deletions

View File

@@ -13,7 +13,7 @@ import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
llm_model_map = {
"llama2_7b": {
"meta-llama/Llama-2-7b-chat-hf": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
@@ -258,7 +258,8 @@ class LanguageModel:
history.append(format_out(token))
while (
format_out(token) != llm_model_map["llama2_7b"]["stop_token"]
format_out(token)
!= llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
and len(history) < self.max_tokens
):
dec_time = time.time()
@@ -272,7 +273,10 @@ class LanguageModel:
self.prev_token_len = token_len + len(history)
if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
if (
format_out(token)
== llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
):
break
for i in range(len(history)):
@@ -306,7 +310,7 @@ class LanguageModel:
self.first_input = False
history.append(int(token))
while token != llm_model_map["llama2_7b"]["stop_token"]:
while token != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
@@ -317,7 +321,7 @@ class LanguageModel:
self.prev_token_len = token_len + len(history)
if token == llm_model_map["llama2_7b"]["stop_token"]:
if token == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
@@ -347,7 +351,11 @@ def llm_chat_api(InputData: dict):
else:
print(f"prompt : {InputData['prompt']}")
model_name = InputData["model"] if "model" in InputData.keys() else "llama2_7b"
model_name = (
InputData["model"]
if "model" in InputData.keys()
else "meta-llama/Llama-2-7b-chat-hf"
)
model_path = llm_model_map[model_name]
device = InputData["device"] if "device" in InputData.keys() else "cpu"
precision = "fp16"

View File

@@ -9,6 +9,7 @@ from apps.shark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
B_SYS, E_SYS = "<s>", "</s>"
@@ -64,6 +65,7 @@ def chat_fn(
external_weights="safetensors",
use_system_prompt=prompt_prefix,
streaming_llm=streaming_llm,
hf_auth_token=cmd_opts.hf_auth_token,
)
history[-1][-1] = "Getting the model ready... Done"
yield history, ""

View File

@@ -35,6 +35,7 @@ safetensors==0.3.1
py-cpuinfo
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
mpmath==1.3.0
optimum
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile