Add support for the Llama-2 model

This commit is contained in:
Vivek Khandelwal
2023-07-19 14:31:06 +00:00
parent 536aba1424
commit 4be80f7158
3 changed files with 84 additions and 23 deletions

View File

@@ -151,21 +151,32 @@ class ShardedVicuna(SharkLLMBase):
self.shark_model = self.compile(device=device)
def get_tokenizer(self):
kwargs = {}
if self.model_name == "llama2":
kwargs = {
"use_auth_token": "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"
}
if self.model_name == "codegen":
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True,
self.hf_model_path,
trust_remote_code=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
self.hf_model_path,
use_fast=False,
**kwargs,
)
return tokenizer
def get_src_model(self):
# Retrieve the torch model from Huggingface
kwargs = {"torch_dtype": torch.float}
if self.model_name == "llama2":
kwargs["use_auth_token"] = "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"
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
self.hf_model_path,
**kwargs,
)
return vicuna_model
@@ -879,6 +890,8 @@ class UnshardedVicuna(SharkLLMBase):
download_vmfb=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
if self.model_name == "llama2":
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
print(f"[DEBUG] hf model name: {self.hf_model_path}")
self.max_sequence_length = 256
self.device = device
@@ -909,26 +922,39 @@ class UnshardedVicuna(SharkLLMBase):
def get_model_path(self, model_number="first", suffix="mlir"):
safe_device = self.device.split("-")[0]
if suffix == "mlir":
return Path(f"{model_number}_{self.model_name}_{self.precision}.{suffix}")
return Path(
f"{model_number}_{self.model_name}_{self.precision}.{suffix}"
)
return Path(
f"{model_number}_{self.model_name}_{self.precision}_{safe_device}.{suffix}"
)
def get_tokenizer(self):
kwargs = {}
if self.model_name == "llama2":
kwargs = {
"use_auth_token": "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"
}
if self.model_name == "codegen":
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True,
self.hf_model_path,
trust_remote_code=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
self.hf_model_path,
use_fast=False,
**kwargs,
)
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
if self.model_name == "llama2":
kwargs["use_auth_token"] = "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"
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
self.hf_model_path,
**kwargs,
)
return vicuna_model
@@ -962,9 +988,9 @@ class UnshardedVicuna(SharkLLMBase):
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
print(
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
" after downloading! Generating mlir on device."
)
else:
print(
@@ -987,14 +1013,18 @@ class UnshardedVicuna(SharkLLMBase):
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
self.hf_model_path, self.precision, self.weight_group_size
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision == "fp16", # TODO: Remove from import_with_fx args and fix all calls
is_f16=self.precision
== "fp16", # TODO: Remove from import_with_fx args and fix all calls
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
@@ -1130,9 +1160,9 @@ class UnshardedVicuna(SharkLLMBase):
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
print(
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
" after downloading! Generating mlir on device."
)
else:
print(
@@ -1147,7 +1177,10 @@ class UnshardedVicuna(SharkLLMBase):
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(
self.hf_model_path, self.precision, self.weight_group_size
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
)
print(f"[DEBUG] generating torchscript graph")
@@ -1305,7 +1338,7 @@ class UnshardedVicuna(SharkLLMBase):
and self.precision == "int8" and self.download_vmfb
):
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
@@ -1327,7 +1360,7 @@ class UnshardedVicuna(SharkLLMBase):
and self.precision == "int8" and self.download_vmfb
):
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
@@ -1349,7 +1382,9 @@ class UnshardedVicuna(SharkLLMBase):
res_tokens[i] = int(res_tokens[i][0])
skip_sp_tok = True if self.model_name == "codegen" else False
res_str = self.tokenizer.decode(res_tokens, skip_special_tokens=skip_sp_tok)
res_str = self.tokenizer.decode(
res_tokens, skip_special_tokens=skip_sp_tok
)
return res_str
def generate(self, prompt, cli=True):
@@ -1421,7 +1456,7 @@ class UnshardedVicuna(SharkLLMBase):
yield part_str
if self.device == "cuda":
del sec_vic, pkv, logits
del params["sv"], pkv, logits
torch.cuda.empty_cache()
gc.collect()

View File

@@ -6,9 +6,17 @@ from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class FirstVicuna(torch.nn.Module):
def __init__(self, model_path, precision="fp32", weight_group_size=128):
def __init__(
self,
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if model_name == "llama2":
kwargs["use_auth_token"] = "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"
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
@@ -47,9 +55,17 @@ class FirstVicuna(torch.nn.Module):
class SecondVicuna(torch.nn.Module):
def __init__(self, model_path, precision="fp32", weight_group_size=128):
def __init__(
self,
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if model_name == "llama2":
kwargs["use_auth_token"] = "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"
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)

View File

@@ -21,6 +21,7 @@ vicuna_model = 0
past_key_values = None
model_map = {
"llama2": "meta-llama/Llama-2-7b-chat-hf",
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
@@ -29,6 +30,15 @@ model_map = {
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"StableLM": (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
@@ -57,7 +67,7 @@ start_message = {
def create_prompt(model_name, history):
system_message = start_message[model_name]
if model_name in ["StableLM", "vicuna", "vicuna1p3"]:
if model_name in ["StableLM", "vicuna", "vicuna1p3", "llama2"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
@@ -86,7 +96,7 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
if model_name in ["vicuna", "vicuna1p3", "codegen"]:
if model_name in ["vicuna", "vicuna1p3", "codegen", "llama2"]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)