mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Fix Langchain Prompt issue and add web UI support (#1682)
This commit is contained in:
@@ -161,7 +161,7 @@ def main(
|
||||
extra_model_options: typing.List[str] = [],
|
||||
extra_lora_options: typing.List[str] = [],
|
||||
extra_server_options: typing.List[str] = [],
|
||||
score_model: str = None,
|
||||
score_model: str = "OpenAssistant/reward-model-deberta-v3-large-v2",
|
||||
eval_filename: str = None,
|
||||
eval_prompts_only_num: int = 0,
|
||||
eval_prompts_only_seed: int = 1234,
|
||||
@@ -1704,9 +1704,11 @@ def evaluate(
|
||||
assert tokenizer, "Tokenizer is missing"
|
||||
|
||||
# choose chat or non-chat mode
|
||||
print(instruction)
|
||||
if not chat:
|
||||
instruction = instruction_nochat
|
||||
iinput = iinput_nochat
|
||||
print(instruction)
|
||||
|
||||
# in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice
|
||||
model_lower = base_model.lower()
|
||||
@@ -1741,7 +1743,7 @@ def evaluate(
|
||||
max_new_tokens=max_new_tokens,
|
||||
max_max_new_tokens=max_max_new_tokens,
|
||||
)
|
||||
model_max_length = get_model_max_length(chosen_model_state)
|
||||
model_max_length = 2048 # get_model_max_length(chosen_model_state)
|
||||
max_new_tokens = min(max(1, int(max_new_tokens)), max_max_new_tokens)
|
||||
min_new_tokens = min(max(0, int(min_new_tokens)), max_new_tokens)
|
||||
max_time = min(max(0, max_time), max_max_time)
|
||||
@@ -1761,6 +1763,7 @@ def evaluate(
|
||||
# restrict instruction, typically what has large input
|
||||
from h2oai_pipeline import H2OTextGenerationPipeline
|
||||
|
||||
print(instruction)
|
||||
instruction, num_prompt_tokens1 = H2OTextGenerationPipeline.limit_prompt(
|
||||
instruction, tokenizer
|
||||
)
|
||||
@@ -2318,6 +2321,8 @@ def evaluate(
|
||||
model_max_length=tokenizer.model_max_length,
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
# exit(0)
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
if debug and len(inputs["input_ids"]) > 0:
|
||||
print("input_ids length", len(inputs["input_ids"][0]), flush=True)
|
||||
|
||||
@@ -83,7 +83,9 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(shark_module, vmfb_path, [])
|
||||
shark_module = _compile_module(
|
||||
shark_module, str(vmfb_path), []
|
||||
)
|
||||
print("Saved newly generated vmfb.")
|
||||
|
||||
if shark_module is None:
|
||||
@@ -92,7 +94,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
shark_module = SharkInference(
|
||||
None, device=global_device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
shark_module.load_module(str(vmfb_path))
|
||||
print("Compiled vmfb loaded successfully.")
|
||||
else:
|
||||
raise ValueError("Unable to download/generate a vmfb.")
|
||||
|
||||
@@ -26,54 +26,17 @@ h2ogpt_model = 0
|
||||
|
||||
past_key_values = None
|
||||
|
||||
model_map = {
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
}
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"StableLM": (
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)"
|
||||
"\n- StableLM is a helpful and harmless open-source AI language model "
|
||||
"developed by StabilityAI."
|
||||
"\n- StableLM is excited to be able to help the user, but will refuse "
|
||||
"to do anything that could be considered harmful to the user."
|
||||
"\n- StableLM is more than just an information source, StableLM is also "
|
||||
"able to write poetry, short stories, and make jokes."
|
||||
"\n- StableLM will refuse to participate in anything that "
|
||||
"could harm a human."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna1p3": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"codegen": "",
|
||||
}
|
||||
start_message = """
|
||||
SHARK DocuChat
|
||||
Chat with an AI, contextualized with provided files.
|
||||
"""
|
||||
|
||||
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
def create_prompt(history):
|
||||
system_message = start_message
|
||||
|
||||
if model_name in ["StableLM", "vicuna", "vicuna1p3"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
conversation = "".join(["".join([item[0], item[1]]) for item in history])
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
@@ -85,6 +48,7 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
global h2ogpt_model
|
||||
global userpath_selector
|
||||
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
print(f"In chat for {model_name}")
|
||||
@@ -109,36 +73,84 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
# precision=precision,
|
||||
# max_num_tokens=max_toks,
|
||||
# )
|
||||
# prompt = create_prompt(model_name, history)
|
||||
prompt = create_prompt(history)
|
||||
print(prompt)
|
||||
# print("prompt = ", prompt)
|
||||
|
||||
# for partial_text in h2ogpt_model.generate(prompt):
|
||||
# history[-1][1] = partial_text
|
||||
# yield history
|
||||
model, tokenizer, device = gen.get_model(
|
||||
load_half=True,
|
||||
load_gptq="",
|
||||
use_safetensors=False,
|
||||
infer_devices=True,
|
||||
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
inference_server="",
|
||||
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
lora_weights="",
|
||||
gpu_id=0,
|
||||
reward_type=None,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
use_auth_token=False,
|
||||
trust_remote_code=True,
|
||||
offload_folder=None,
|
||||
compile_model=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(tokenizer.model_max_length)
|
||||
model_state = dict(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
lora_weights="",
|
||||
inference_server="",
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
)
|
||||
output = gen.evaluate(
|
||||
None, # model_state
|
||||
model_state, # model_state
|
||||
None, # my_db_state
|
||||
None, # instruction
|
||||
None, # iinput
|
||||
history, # context
|
||||
False, # stream_output
|
||||
None, # prompt_type
|
||||
None, # prompt_dict
|
||||
None, # temperature
|
||||
None, # top_p
|
||||
None, # top_k
|
||||
None, # num_beams
|
||||
None, # max_new_tokens
|
||||
None, # min_new_tokens
|
||||
None, # early_stopping
|
||||
None, # max_time
|
||||
None, # repetition_penalty
|
||||
None, # num_return_sequences
|
||||
prompt, # instruction
|
||||
"", # iinput
|
||||
"", # context
|
||||
True, # stream_output
|
||||
"prompt_answer", # prompt_type
|
||||
{
|
||||
"promptA": "",
|
||||
"promptB": "",
|
||||
"PreInstruct": "<|prompt|>",
|
||||
"PreInput": None,
|
||||
"PreResponse": "<|answer|>",
|
||||
"terminate_response": [
|
||||
"<|prompt|>",
|
||||
"<|answer|>",
|
||||
"<|endoftext|>",
|
||||
],
|
||||
"chat_sep": "<|endoftext|>",
|
||||
"chat_turn_sep": "<|endoftext|>",
|
||||
"humanstr": "<|prompt|>",
|
||||
"botstr": "<|answer|>",
|
||||
"generates_leading_space": False,
|
||||
}, # prompt_dict
|
||||
0.1, # temperature
|
||||
0.75, # top_p
|
||||
40, # top_k
|
||||
1, # num_beams
|
||||
256, # max_new_tokens
|
||||
0, # min_new_tokens
|
||||
False, # early_stopping
|
||||
180, # max_time
|
||||
1.07, # repetition_penalty
|
||||
1, # num_return_sequences
|
||||
False, # do_sample
|
||||
False, # chat
|
||||
None, # instruction_nochat
|
||||
curr_system_message, # iinput_nochat
|
||||
"Disabled", # langchain_mode
|
||||
True, # chat
|
||||
prompt, # instruction_nochat
|
||||
"", # iinput_nochat
|
||||
"UserData", # langchain_mode
|
||||
LangChainAction.QUERY.value, # langchain_action
|
||||
3, # top_k_docs
|
||||
True, # chunk
|
||||
@@ -154,6 +166,10 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
db_type="chroma",
|
||||
n_jobs=-1,
|
||||
first_para=False,
|
||||
max_max_time=60 * 2,
|
||||
model_state0=model_state,
|
||||
model_lock=True,
|
||||
user_path=userpath_selector.value,
|
||||
)
|
||||
for partial_text in output:
|
||||
history[-1][1] = partial_text
|
||||
@@ -164,14 +180,6 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
|
||||
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
with gr.Row():
|
||||
model_choices = list(
|
||||
map(lambda x: f"{x[0]: <10} => {x[1]}", model_map.items())
|
||||
)
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
)
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
# show cpu-task device first in list for chatbot
|
||||
@@ -197,6 +205,14 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
userpath_selector = gr.Textbox(
|
||||
label="Document Directory",
|
||||
value=str(
|
||||
os.path.abspath("apps/language_models/langchain/user_path/")
|
||||
),
|
||||
interactive=True,
|
||||
container=True,
|
||||
)
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@@ -220,7 +236,7 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
inputs=[system_msg, chatbot, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
@@ -228,7 +244,7 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
inputs=[system_msg, chatbot, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user