mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add support for Llama-2-70b for web and cli, and for hf_auth_token
This commit is contained in:
@@ -101,7 +101,25 @@ parser.add_argument(
|
|||||||
default=128,
|
default=128,
|
||||||
help="Group size for per_group weight quantization. Default: 128.",
|
help="Group size for per_group weight quantization. Default: 128.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--download_vmfb", default=False, action=argparse.BooleanOptionalAction, help="download vmfb from sharktank, system dependent, YMMV")
|
parser.add_argument(
|
||||||
|
"--download_vmfb",
|
||||||
|
default=False,
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
help="download vmfb from sharktank, system dependent, YMMV",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
type=str,
|
||||||
|
default="vicuna",
|
||||||
|
choices=["vicuna", "llama2_7b", "llama2_70b"],
|
||||||
|
help="Specify which model to run.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf_auth_token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||||
@@ -870,6 +888,7 @@ class UnshardedVicuna(SharkLLMBase):
|
|||||||
self,
|
self,
|
||||||
model_name,
|
model_name,
|
||||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||||
|
hf_auth_token: str = None,
|
||||||
max_num_tokens=512,
|
max_num_tokens=512,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
precision="fp32",
|
precision="fp32",
|
||||||
@@ -883,8 +902,15 @@ class UnshardedVicuna(SharkLLMBase):
|
|||||||
download_vmfb=False,
|
download_vmfb=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||||
if self.model_name == "llama2":
|
if "llama2" in self.model_name and hf_auth_token == None:
|
||||||
|
raise ValueError(
|
||||||
|
"HF auth token required. Pass it using --hf_auth_token flag."
|
||||||
|
)
|
||||||
|
self.hf_auth_token = hf_auth_token
|
||||||
|
if self.model_name == "llama2_7b":
|
||||||
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
|
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
|
elif self.model_name == "llama2_70b":
|
||||||
|
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
|
||||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||||
self.max_sequence_length = 256
|
self.max_sequence_length = 256
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -923,11 +949,7 @@ class UnshardedVicuna(SharkLLMBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_tokenizer(self):
|
def get_tokenizer(self):
|
||||||
kwargs = {}
|
kwargs = {"use_auth_token": self.hf_auth_token}
|
||||||
if self.model_name == "llama2":
|
|
||||||
kwargs = {
|
|
||||||
"use_auth_token": "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"
|
|
||||||
}
|
|
||||||
if self.model_name == "codegen":
|
if self.model_name == "codegen":
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.hf_model_path,
|
self.hf_model_path,
|
||||||
@@ -942,9 +964,10 @@ class UnshardedVicuna(SharkLLMBase):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def get_src_model(self):
|
def get_src_model(self):
|
||||||
kwargs = {"torch_dtype": torch.float}
|
kwargs = {
|
||||||
if self.model_name == "llama2":
|
"torch_dtype": torch.float,
|
||||||
kwargs["use_auth_token"] = "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"
|
"use_auth_token": self.hf_auth_token,
|
||||||
|
}
|
||||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.hf_model_path,
|
self.hf_model_path,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -1010,6 +1033,7 @@ class UnshardedVicuna(SharkLLMBase):
|
|||||||
self.precision,
|
self.precision,
|
||||||
self.weight_group_size,
|
self.weight_group_size,
|
||||||
self.model_name,
|
self.model_name,
|
||||||
|
self.hf_auth_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[DEBUG] generating torchscript graph")
|
print(f"[DEBUG] generating torchscript graph")
|
||||||
@@ -1174,6 +1198,7 @@ class UnshardedVicuna(SharkLLMBase):
|
|||||||
self.precision,
|
self.precision,
|
||||||
self.weight_group_size,
|
self.weight_group_size,
|
||||||
self.model_name,
|
self.model_name,
|
||||||
|
self.hf_auth_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[DEBUG] generating torchscript graph")
|
print(f"[DEBUG] generating torchscript graph")
|
||||||
@@ -1328,7 +1353,8 @@ class UnshardedVicuna(SharkLLMBase):
|
|||||||
):
|
):
|
||||||
if (self.device == "cuda" and self.precision == "fp16") or (
|
if (self.device == "cuda" and self.precision == "fp16") or (
|
||||||
self.device in ["cpu-sync", "cpu-task"]
|
self.device in ["cpu-sync", "cpu-task"]
|
||||||
and self.precision == "int8" and self.download_vmfb
|
and self.precision == "int8"
|
||||||
|
and self.download_vmfb
|
||||||
):
|
):
|
||||||
download_public_file(
|
download_public_file(
|
||||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
|
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
|
||||||
@@ -1350,7 +1376,8 @@ class UnshardedVicuna(SharkLLMBase):
|
|||||||
):
|
):
|
||||||
if (self.device == "cuda" and self.precision == "fp16") or (
|
if (self.device == "cuda" and self.precision == "fp16") or (
|
||||||
self.device in ["cpu-sync", "cpu-task"]
|
self.device in ["cpu-sync", "cpu-task"]
|
||||||
and self.precision == "int8" and self.download_vmfb
|
and self.precision == "int8"
|
||||||
|
and self.download_vmfb
|
||||||
):
|
):
|
||||||
download_public_file(
|
download_public_file(
|
||||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
|
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
|
||||||
@@ -1571,7 +1598,8 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
vic = UnshardedVicuna(
|
vic = UnshardedVicuna(
|
||||||
"vicuna",
|
model_name=args.model_name,
|
||||||
|
hf_auth_token=args.hf_auth_token,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
precision=args.precision,
|
precision=args.precision,
|
||||||
first_vicuna_mlir_path=first_vic_mlir_path,
|
first_vicuna_mlir_path=first_vic_mlir_path,
|
||||||
@@ -1590,21 +1618,45 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
config_json = None
|
config_json = None
|
||||||
vic = ShardedVicuna(
|
vic = ShardedVicuna(
|
||||||
"vicuna",
|
model_name=args.model_name,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
precision=args.precision,
|
precision=args.precision,
|
||||||
config_json=config_json,
|
config_json=config_json,
|
||||||
weight_group_size=args.weight_group_size,
|
weight_group_size=args.weight_group_size,
|
||||||
)
|
)
|
||||||
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
if args.model_name == "vicuna":
|
||||||
|
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||||
|
else:
|
||||||
|
system_message = """System: You are a helpful, respectful and honest assistant. Always answer "
|
||||||
|
as helpfully as possible, while being safe. Your answers should not
|
||||||
|
include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal
|
||||||
|
content. Please ensure that your responses are socially unbiased and positive
|
||||||
|
in nature. If a question does not make any sense, or is not factually coherent,
|
||||||
|
explain why instead of answering something not correct. If you don't know the
|
||||||
|
answer to a question, please don't share false information."""
|
||||||
prologue_prompt = "ASSISTANT:\n"
|
prologue_prompt = "ASSISTANT:\n"
|
||||||
|
|
||||||
from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model
|
from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
set_vicuna_model(vic)
|
set_vicuna_model(vic)
|
||||||
|
|
||||||
|
model_list = {
|
||||||
|
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
|
||||||
|
"llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
|
||||||
|
}
|
||||||
while True:
|
while True:
|
||||||
# TODO: Add break condition from user input
|
# TODO: Add break condition from user input
|
||||||
user_prompt = input("User: ")
|
user_prompt = input("User: ")
|
||||||
history.append([user_prompt,""])
|
history.append([user_prompt, ""])
|
||||||
history = list(chat(system_message, history, model="vicuna=>TheBloke/vicuna-7B-1.1-HF", device=args.device, precision=args.precision, cli=args.cli))[0]
|
history = list(
|
||||||
|
chat(
|
||||||
|
system_message,
|
||||||
|
history,
|
||||||
|
model=model_list[args.model_name],
|
||||||
|
device=args.device,
|
||||||
|
precision=args.precision,
|
||||||
|
cli=args.cli,
|
||||||
|
)
|
||||||
|
)[0]
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ class FirstVicuna(torch.nn.Module):
|
|||||||
precision="fp32",
|
precision="fp32",
|
||||||
weight_group_size=128,
|
weight_group_size=128,
|
||||||
model_name="vicuna",
|
model_name="vicuna",
|
||||||
|
hf_auth_token: str = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kwargs = {"torch_dtype": torch.float32}
|
kwargs = {"torch_dtype": torch.float32}
|
||||||
if model_name == "llama2":
|
if "llama2" in model_name:
|
||||||
kwargs["use_auth_token"] = "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"
|
kwargs["use_auth_token"] = hf_auth_token
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path, low_cpu_mem_usage=True, **kwargs
|
model_path, low_cpu_mem_usage=True, **kwargs
|
||||||
)
|
)
|
||||||
@@ -54,11 +55,12 @@ class SecondVicuna(torch.nn.Module):
|
|||||||
precision="fp32",
|
precision="fp32",
|
||||||
weight_group_size=128,
|
weight_group_size=128,
|
||||||
model_name="vicuna",
|
model_name="vicuna",
|
||||||
|
hf_auth_token: str = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kwargs = {"torch_dtype": torch.float32}
|
kwargs = {"torch_dtype": torch.float32}
|
||||||
if model_name == "llama2":
|
if "llama2" in model_name:
|
||||||
kwargs["use_auth_token"] = "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"
|
kwargs["use_auth_token"] = hf_auth_token
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path, low_cpu_mem_usage=True, **kwargs
|
model_path, low_cpu_mem_usage=True, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -400,6 +400,13 @@ p.add_argument(
|
|||||||
help="Load and unload models for low VRAM.",
|
help="Load and unload models for low VRAM.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
p.add_argument(
|
||||||
|
"--hf_auth_token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||||
|
)
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# IREE - Vulkan supported flags
|
# IREE - Vulkan supported flags
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ vicuna_model = 0
|
|||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
|
||||||
model_map = {
|
model_map = {
|
||||||
"llama2": "meta-llama/Llama-2-7b-chat-hf",
|
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
|
||||||
"codegen": "Salesforce/codegen25-7b-multi",
|
"codegen": "Salesforce/codegen25-7b-multi",
|
||||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||||
@@ -30,7 +31,16 @@ model_map = {
|
|||||||
|
|
||||||
# NOTE: Each `model_name` should have its own start message
|
# NOTE: Each `model_name` should have its own start message
|
||||||
start_message = {
|
start_message = {
|
||||||
"llama2": (
|
"llama2_7b": (
|
||||||
|
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||||
|
"as helpfully as possible, while being safe. Your answers should not "
|
||||||
|
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||||
|
"content. Please ensure that your responses are socially unbiased and positive "
|
||||||
|
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||||
|
"explain why instead of answering something not correct. If you don't know the "
|
||||||
|
"answer to a question, please don't share false information."
|
||||||
|
),
|
||||||
|
"llama2_70b": (
|
||||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||||
"as helpfully as possible, while being safe. Your answers should not "
|
"as helpfully as possible, while being safe. Your answers should not "
|
||||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||||
@@ -67,7 +77,13 @@ start_message = {
|
|||||||
def create_prompt(model_name, history):
|
def create_prompt(model_name, history):
|
||||||
system_message = start_message[model_name]
|
system_message = start_message[model_name]
|
||||||
|
|
||||||
if model_name in ["StableLM", "vicuna", "vicuna1p3", "llama2"]:
|
if model_name in [
|
||||||
|
"StableLM",
|
||||||
|
"vicuna",
|
||||||
|
"vicuna1p3",
|
||||||
|
"llama2_7b",
|
||||||
|
"llama2_70b",
|
||||||
|
]:
|
||||||
conversation = "".join(
|
conversation = "".join(
|
||||||
[
|
[
|
||||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||||
@@ -96,10 +112,17 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
|||||||
global vicuna_model
|
global vicuna_model
|
||||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||||
|
|
||||||
if model_name in ["vicuna", "vicuna1p3", "codegen", "llama2"]:
|
if model_name in [
|
||||||
|
"vicuna",
|
||||||
|
"vicuna1p3",
|
||||||
|
"codegen",
|
||||||
|
"llama2_7b",
|
||||||
|
"llama2_70b",
|
||||||
|
]:
|
||||||
from apps.language_models.scripts.vicuna import (
|
from apps.language_models.scripts.vicuna import (
|
||||||
UnshardedVicuna,
|
UnshardedVicuna,
|
||||||
)
|
)
|
||||||
|
from apps.stable_diffusion.src import args
|
||||||
|
|
||||||
if vicuna_model == 0:
|
if vicuna_model == 0:
|
||||||
if "cuda" in device:
|
if "cuda" in device:
|
||||||
@@ -117,6 +140,7 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
|||||||
vicuna_model = UnshardedVicuna(
|
vicuna_model = UnshardedVicuna(
|
||||||
model_name,
|
model_name,
|
||||||
hf_model_path=model_path,
|
hf_model_path=model_path,
|
||||||
|
hf_auth_token=args.hf_auth_token,
|
||||||
device=device,
|
device=device,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
max_num_tokens=max_toks,
|
max_num_tokens=max_toks,
|
||||||
|
|||||||
Reference in New Issue
Block a user