mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
Add support for Llama-2-70b for web and cli, and for hf_auth_token
This commit is contained in:
@@ -101,7 +101,25 @@ parser.add_argument(
|
||||
default=128,
|
||||
help="Group size for per_group weight quantization. Default: 128.",
|
||||
)
|
||||
parser.add_argument("--download_vmfb", default=False, action=argparse.BooleanOptionalAction, help="download vmfb from sharktank, system dependent, YMMV")
|
||||
parser.add_argument(
|
||||
"--download_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download vmfb from sharktank, system dependent, YMMV",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="vicuna",
|
||||
choices=["vicuna", "llama2_7b", "llama2_70b"],
|
||||
help="Specify which model to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
@@ -870,6 +888,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
hf_auth_token: str = None,
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
@@ -883,8 +902,15 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
download_vmfb=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
if self.model_name == "llama2":
|
||||
if "llama2" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
"HF auth token required. Pass it using --hf_auth_token flag."
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
if self.model_name == "llama2_7b":
|
||||
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||
elif self.model_name == "llama2_70b":
|
||||
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
|
||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
@@ -923,11 +949,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
kwargs = {}
|
||||
if self.model_name == "llama2":
|
||||
kwargs = {
|
||||
"use_auth_token": "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"
|
||||
}
|
||||
kwargs = {"use_auth_token": self.hf_auth_token}
|
||||
if self.model_name == "codegen":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
@@ -942,9 +964,10 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
if self.model_name == "llama2":
|
||||
kwargs["use_auth_token"] = "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"
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"use_auth_token": self.hf_auth_token,
|
||||
}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path,
|
||||
**kwargs,
|
||||
@@ -1010,6 +1033,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
@@ -1174,6 +1198,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
@@ -1328,7 +1353,8 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
):
|
||||
if (self.device == "cuda" and self.precision == "fp16") or (
|
||||
self.device in ["cpu-sync", "cpu-task"]
|
||||
and self.precision == "int8" and self.download_vmfb
|
||||
and self.precision == "int8"
|
||||
and self.download_vmfb
|
||||
):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
|
||||
@@ -1350,7 +1376,8 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
):
|
||||
if (self.device == "cuda" and self.precision == "fp16") or (
|
||||
self.device in ["cpu-sync", "cpu-task"]
|
||||
and self.precision == "int8" and self.download_vmfb
|
||||
and self.precision == "int8"
|
||||
and self.download_vmfb
|
||||
):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
|
||||
@@ -1571,7 +1598,8 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
vic = UnshardedVicuna(
|
||||
"vicuna",
|
||||
model_name=args.model_name,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
first_vicuna_mlir_path=first_vic_mlir_path,
|
||||
@@ -1590,21 +1618,45 @@ if __name__ == "__main__":
|
||||
else:
|
||||
config_json = None
|
||||
vic = ShardedVicuna(
|
||||
"vicuna",
|
||||
model_name=args.model_name,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
config_json=config_json,
|
||||
weight_group_size=args.weight_group_size,
|
||||
)
|
||||
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
if args.model_name == "vicuna":
|
||||
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
else:
|
||||
system_message = """System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
as helpfully as possible, while being safe. Your answers should not
|
||||
include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal
|
||||
content. Please ensure that your responses are socially unbiased and positive
|
||||
in nature. If a question does not make any sense, or is not factually coherent,
|
||||
explain why instead of answering something not correct. If you don't know the
|
||||
answer to a question, please don't share false information."""
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model
|
||||
|
||||
history = []
|
||||
set_vicuna_model(vic)
|
||||
|
||||
model_list = {
|
||||
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
|
||||
"llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
|
||||
}
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
history.append([user_prompt,""])
|
||||
history = list(chat(system_message, history, model="vicuna=>TheBloke/vicuna-7B-1.1-HF", device=args.device, precision=args.precision, cli=args.cli))[0]
|
||||
|
||||
history.append([user_prompt, ""])
|
||||
history = list(
|
||||
chat(
|
||||
system_message,
|
||||
history,
|
||||
model=model_list[args.model_name],
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
cli=args.cli,
|
||||
)
|
||||
)[0]
|
||||
|
||||
@@ -12,11 +12,12 @@ class FirstVicuna(torch.nn.Module):
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if model_name == "llama2":
|
||||
kwargs["use_auth_token"] = "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"
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
@@ -54,11 +55,12 @@ class SecondVicuna(torch.nn.Module):
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if model_name == "llama2":
|
||||
kwargs["use_auth_token"] = "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"
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
|
||||
@@ -400,6 +400,13 @@ p.add_argument(
|
||||
help="Load and unload models for low VRAM.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -21,7 +21,8 @@ vicuna_model = 0
|
||||
past_key_values = None
|
||||
|
||||
model_map = {
|
||||
"llama2": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
@@ -30,7 +31,16 @@ model_map = {
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2": (
|
||||
"llama2_7b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
@@ -67,7 +77,13 @@ start_message = {
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if model_name in ["StableLM", "vicuna", "vicuna1p3", "llama2"]:
|
||||
if model_name in [
|
||||
"StableLM",
|
||||
"vicuna",
|
||||
"vicuna1p3",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
@@ -96,10 +112,17 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
global vicuna_model
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
|
||||
if model_name in ["vicuna", "vicuna1p3", "codegen", "llama2"]:
|
||||
if model_name in [
|
||||
"vicuna",
|
||||
"vicuna1p3",
|
||||
"codegen",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
@@ -117,6 +140,7 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
|
||||
Reference in New Issue
Block a user