mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 22:07:55 -05:00
Fix Llama2 on CPU (#2133)
This commit is contained in:
@@ -13,7 +13,7 @@ import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
llm_model_map = {
|
||||
"llama2_7b": {
|
||||
"meta-llama/Llama-2-7b-chat-hf": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
|
||||
@@ -258,7 +258,8 @@ class LanguageModel:
|
||||
|
||||
history.append(format_out(token))
|
||||
while (
|
||||
format_out(token) != llm_model_map["llama2_7b"]["stop_token"]
|
||||
format_out(token)
|
||||
!= llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
|
||||
and len(history) < self.max_tokens
|
||||
):
|
||||
dec_time = time.time()
|
||||
@@ -272,7 +273,10 @@ class LanguageModel:
|
||||
|
||||
self.prev_token_len = token_len + len(history)
|
||||
|
||||
if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
|
||||
if (
|
||||
format_out(token)
|
||||
== llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
|
||||
):
|
||||
break
|
||||
|
||||
for i in range(len(history)):
|
||||
@@ -306,7 +310,7 @@ class LanguageModel:
|
||||
self.first_input = False
|
||||
|
||||
history.append(int(token))
|
||||
while token != llm_model_map["llama2_7b"]["stop_token"]:
|
||||
while token != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
|
||||
dec_time = time.time()
|
||||
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
|
||||
history.append(int(token))
|
||||
@@ -317,7 +321,7 @@ class LanguageModel:
|
||||
|
||||
self.prev_token_len = token_len + len(history)
|
||||
|
||||
if token == llm_model_map["llama2_7b"]["stop_token"]:
|
||||
if token == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
|
||||
break
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
@@ -347,7 +351,11 @@ def llm_chat_api(InputData: dict):
|
||||
else:
|
||||
print(f"prompt : {InputData['prompt']}")
|
||||
|
||||
model_name = InputData["model"] if "model" in InputData.keys() else "llama2_7b"
|
||||
model_name = (
|
||||
InputData["model"]
|
||||
if "model" in InputData.keys()
|
||||
else "meta-llama/Llama-2-7b-chat-hf"
|
||||
)
|
||||
model_path = llm_model_map[model_name]
|
||||
device = InputData["device"] if "device" in InputData.keys() else "cpu"
|
||||
precision = "fp16"
|
||||
|
||||
@@ -9,6 +9,7 @@ from apps.shark_studio.api.llm import (
|
||||
llm_model_map,
|
||||
LanguageModel,
|
||||
)
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
@@ -64,6 +65,7 @@ def chat_fn(
|
||||
external_weights="safetensors",
|
||||
use_system_prompt=prompt_prefix,
|
||||
streaming_llm=streaming_llm,
|
||||
hf_auth_token=cmd_opts.hf_auth_token,
|
||||
)
|
||||
history[-1][-1] = "Getting the model ready... Done"
|
||||
yield history, ""
|
||||
|
||||
@@ -35,6 +35,7 @@ safetensors==0.3.1
|
||||
py-cpuinfo
|
||||
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
|
||||
mpmath==1.3.0
|
||||
optimum
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
Reference in New Issue
Block a user