mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add support for the Llama-2 model
This commit is contained in:
@@ -151,21 +151,32 @@ class ShardedVicuna(SharkLLMBase):
|
||||
self.shark_model = self.compile(device=device)
|
||||
|
||||
def get_tokenizer(self):
|
||||
kwargs = {}
|
||||
if self.model_name == "llama2":
|
||||
kwargs = {
|
||||
"use_auth_token": "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"
|
||||
}
|
||||
if self.model_name == "codegen":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, trust_remote_code=True,
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
self.hf_model_path,
|
||||
use_fast=False,
|
||||
**kwargs,
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
# Retrieve the torch model from Huggingface
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
if self.model_name == "llama2":
|
||||
kwargs["use_auth_token"] = "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"
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
self.hf_model_path,
|
||||
**kwargs,
|
||||
)
|
||||
return vicuna_model
|
||||
|
||||
@@ -879,6 +890,8 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
download_vmfb=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
if self.model_name == "llama2":
|
||||
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
@@ -909,26 +922,39 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
def get_model_path(self, model_number="first", suffix="mlir"):
|
||||
safe_device = self.device.split("-")[0]
|
||||
if suffix == "mlir":
|
||||
return Path(f"{model_number}_{self.model_name}_{self.precision}.{suffix}")
|
||||
return Path(
|
||||
f"{model_number}_{self.model_name}_{self.precision}.{suffix}"
|
||||
)
|
||||
return Path(
|
||||
f"{model_number}_{self.model_name}_{self.precision}_{safe_device}.{suffix}"
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
kwargs = {}
|
||||
if self.model_name == "llama2":
|
||||
kwargs = {
|
||||
"use_auth_token": "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"
|
||||
}
|
||||
if self.model_name == "codegen":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, trust_remote_code=True,
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
self.hf_model_path,
|
||||
use_fast=False,
|
||||
**kwargs,
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
if self.model_name == "llama2":
|
||||
kwargs["use_auth_token"] = "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"
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
self.hf_model_path,
|
||||
**kwargs,
|
||||
)
|
||||
return vicuna_model
|
||||
|
||||
@@ -962,9 +988,9 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
print(
|
||||
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
" after downloading! Generating mlir on device."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
@@ -987,14 +1013,18 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path, self.precision, self.weight_group_size
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
is_f16=self.precision == "fp16", # TODO: Remove from import_with_fx args and fix all calls
|
||||
is_f16=self.precision
|
||||
== "fp16", # TODO: Remove from import_with_fx args and fix all calls
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
@@ -1130,9 +1160,9 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
print(
|
||||
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
|
||||
" after downloading! Generating mlir on device."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
@@ -1147,7 +1177,10 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
model = SecondVicuna(
|
||||
self.hf_model_path, self.precision, self.weight_group_size
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
@@ -1305,7 +1338,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
and self.precision == "int8" and self.download_vmfb
|
||||
):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
|
||||
self.first_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
@@ -1327,7 +1360,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
and self.precision == "int8" and self.download_vmfb
|
||||
):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
|
||||
self.second_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
@@ -1349,7 +1382,9 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
skip_sp_tok = True if self.model_name == "codegen" else False
|
||||
res_str = self.tokenizer.decode(res_tokens, skip_special_tokens=skip_sp_tok)
|
||||
res_str = self.tokenizer.decode(
|
||||
res_tokens, skip_special_tokens=skip_sp_tok
|
||||
)
|
||||
return res_str
|
||||
|
||||
def generate(self, prompt, cli=True):
|
||||
@@ -1421,7 +1456,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
yield part_str
|
||||
|
||||
if self.device == "cuda":
|
||||
del sec_vic, pkv, logits
|
||||
del params["sv"], pkv, logits
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@@ -6,9 +6,17 @@ from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
class FirstVicuna(torch.nn.Module):
|
||||
def __init__(self, model_path, precision="fp32", weight_group_size=128):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if model_name == "llama2":
|
||||
kwargs["use_auth_token"] = "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"
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
@@ -47,9 +55,17 @@ class FirstVicuna(torch.nn.Module):
|
||||
|
||||
|
||||
class SecondVicuna(torch.nn.Module):
|
||||
def __init__(self, model_path, precision="fp32", weight_group_size=128):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if model_name == "llama2":
|
||||
kwargs["use_auth_token"] = "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"
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ vicuna_model = 0
|
||||
past_key_values = None
|
||||
|
||||
model_map = {
|
||||
"llama2": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
@@ -29,6 +30,15 @@ model_map = {
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"StableLM": (
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)"
|
||||
"\n- StableLM is a helpful and harmless open-source AI language model "
|
||||
@@ -57,7 +67,7 @@ start_message = {
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if model_name in ["StableLM", "vicuna", "vicuna1p3"]:
|
||||
if model_name in ["StableLM", "vicuna", "vicuna1p3", "llama2"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
@@ -86,7 +96,7 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
global vicuna_model
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
|
||||
if model_name in ["vicuna", "vicuna1p3", "codegen"]:
|
||||
if model_name in ["vicuna", "vicuna1p3", "codegen", "llama2"]:
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user