mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add support for Llama-2-70b for web and cli, and for hf_auth_token
This commit is contained in:
@@ -21,7 +21,8 @@ vicuna_model = 0
|
||||
past_key_values = None
|
||||
|
||||
model_map = {
|
||||
"llama2": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
@@ -30,7 +31,16 @@ model_map = {
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2": (
|
||||
"llama2_7b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
@@ -67,7 +77,13 @@ start_message = {
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if model_name in ["StableLM", "vicuna", "vicuna1p3", "llama2"]:
|
||||
if model_name in [
|
||||
"StableLM",
|
||||
"vicuna",
|
||||
"vicuna1p3",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
@@ -96,10 +112,17 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
global vicuna_model
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
|
||||
if model_name in ["vicuna", "vicuna1p3", "codegen", "llama2"]:
|
||||
if model_name in [
|
||||
"vicuna",
|
||||
"vicuna1p3",
|
||||
"codegen",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
@@ -117,6 +140,7 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
|
||||
Reference in New Issue
Block a user