Add support for Llama-2-70b for web and cli, and for hf_auth_token

This commit is contained in:
Vivek Khandelwal
2023-07-20 08:05:20 +00:00
parent 3662224c04
commit 03c4d9e171
4 changed files with 111 additions and 26 deletions

View File

@@ -101,7 +101,25 @@ parser.add_argument(
default=128,
help="Group size for per_group weight quantization. Default: 128.",
)
parser.add_argument("--download_vmfb", default=False, action=argparse.BooleanOptionalAction, help="download vmfb from sharktank, system dependent, YMMV")
parser.add_argument(
"--download_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="download vmfb from sharktank, system dependent, YMMV",
)
parser.add_argument(
"--model_name",
type=str,
default="vicuna",
choices=["vicuna", "llama2_7b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
@@ -870,6 +888,7 @@ class UnshardedVicuna(SharkLLMBase):
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_auth_token: str = None,
max_num_tokens=512,
device="cuda",
precision="fp32",
@@ -883,8 +902,15 @@ class UnshardedVicuna(SharkLLMBase):
download_vmfb=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
if self.model_name == "llama2":
if "llama2" in self.model_name and hf_auth_token == None:
raise ValueError(
"HF auth token required. Pass it using --hf_auth_token flag."
)
self.hf_auth_token = hf_auth_token
if self.model_name == "llama2_7b":
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
elif self.model_name == "llama2_70b":
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
print(f"[DEBUG] hf model name: {self.hf_model_path}")
self.max_sequence_length = 256
self.device = device
@@ -923,11 +949,7 @@ class UnshardedVicuna(SharkLLMBase):
)
def get_tokenizer(self):
kwargs = {}
if self.model_name == "llama2":
kwargs = {
"use_auth_token": "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"
}
kwargs = {"use_auth_token": self.hf_auth_token}
if self.model_name == "codegen":
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
@@ -942,9 +964,10 @@ class UnshardedVicuna(SharkLLMBase):
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
if self.model_name == "llama2":
kwargs["use_auth_token"] = "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"
kwargs = {
"torch_dtype": torch.float,
"use_auth_token": self.hf_auth_token,
}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path,
**kwargs,
@@ -1010,6 +1033,7 @@ class UnshardedVicuna(SharkLLMBase):
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
@@ -1174,6 +1198,7 @@ class UnshardedVicuna(SharkLLMBase):
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
@@ -1328,7 +1353,8 @@ class UnshardedVicuna(SharkLLMBase):
):
if (self.device == "cuda" and self.precision == "fp16") or (
self.device in ["cpu-sync", "cpu-task"]
and self.precision == "int8" and self.download_vmfb
and self.precision == "int8"
and self.download_vmfb
):
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
@@ -1350,7 +1376,8 @@ class UnshardedVicuna(SharkLLMBase):
):
if (self.device == "cuda" and self.precision == "fp16") or (
self.device in ["cpu-sync", "cpu-task"]
and self.precision == "int8" and self.download_vmfb
and self.precision == "int8"
and self.download_vmfb
):
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
@@ -1571,7 +1598,8 @@ if __name__ == "__main__":
)
vic = UnshardedVicuna(
"vicuna",
model_name=args.model_name,
hf_auth_token=args.hf_auth_token,
device=args.device,
precision=args.precision,
first_vicuna_mlir_path=first_vic_mlir_path,
@@ -1590,21 +1618,45 @@ if __name__ == "__main__":
else:
config_json = None
vic = ShardedVicuna(
"vicuna",
model_name=args.model_name,
device=args.device,
precision=args.precision,
config_json=config_json,
weight_group_size=args.weight_group_size,
)
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
if args.model_name == "vicuna":
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
else:
system_message = """System: You are a helpful, respectful and honest assistant. Always answer "
as helpfully as possible, while being safe. Your answers should not
include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal
content. Please ensure that your responses are socially unbiased and positive
in nature. If a question does not make any sense, or is not factually coherent,
explain why instead of answering something not correct. If you don't know the
answer to a question, please don't share false information."""
prologue_prompt = "ASSISTANT:\n"
from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model
history = []
set_vicuna_model(vic)
model_list = {
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
"llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf",
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
}
while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
history.append([user_prompt,""])
history = list(chat(system_message, history, model="vicuna=>TheBloke/vicuna-7B-1.1-HF", device=args.device, precision=args.precision, cli=args.cli))[0]
history.append([user_prompt, ""])
history = list(
chat(
system_message,
history,
model=model_list[args.model_name],
device=args.device,
precision=args.precision,
cli=args.cli,
)
)[0]

View File

@@ -12,11 +12,12 @@ class FirstVicuna(torch.nn.Module):
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if model_name == "llama2":
kwargs["use_auth_token"] = "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"
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
@@ -54,11 +55,12 @@ class SecondVicuna(torch.nn.Module):
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if model_name == "llama2":
kwargs["use_auth_token"] = "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"
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)

View File

@@ -400,6 +400,13 @@ p.add_argument(
help="Load and unload models for low VRAM.",
)
p.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################

View File

@@ -21,7 +21,8 @@ vicuna_model = 0
past_key_values = None
model_map = {
"llama2": "meta-llama/Llama-2-7b-chat-hf",
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
@@ -30,7 +31,16 @@ model_map = {
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2": (
"llama2_7b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
@@ -67,7 +77,13 @@ start_message = {
def create_prompt(model_name, history):
system_message = start_message[model_name]
if model_name in ["StableLM", "vicuna", "vicuna1p3", "llama2"]:
if model_name in [
"StableLM",
"vicuna",
"vicuna1p3",
"llama2_7b",
"llama2_70b",
]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
@@ -96,10 +112,17 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
if model_name in ["vicuna", "vicuna1p3", "codegen", "llama2"]:
if model_name in [
"vicuna",
"vicuna1p3",
"codegen",
"llama2_7b",
"llama2_70b",
]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
from apps.stable_diffusion.src import args
if vicuna_model == 0:
if "cuda" in device:
@@ -117,6 +140,7 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,