Fix Langchain Prompt issue and add web UI support (#1682)

This commit is contained in:
Vivek Khandelwal
2023-07-21 19:06:55 +05:30
committed by GitHub
parent c292e5c9d7
commit a415f3f70e
3 changed files with 103 additions and 80 deletions

View File

@@ -161,7 +161,7 @@ def main(
extra_model_options: typing.List[str] = [],
extra_lora_options: typing.List[str] = [],
extra_server_options: typing.List[str] = [],
score_model: str = None,
score_model: str = "OpenAssistant/reward-model-deberta-v3-large-v2",
eval_filename: str = None,
eval_prompts_only_num: int = 0,
eval_prompts_only_seed: int = 1234,
@@ -1704,9 +1704,11 @@ def evaluate(
assert tokenizer, "Tokenizer is missing"
# choose chat or non-chat mode
print(instruction)
if not chat:
instruction = instruction_nochat
iinput = iinput_nochat
print(instruction)
# in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice
model_lower = base_model.lower()
@@ -1741,7 +1743,7 @@ def evaluate(
max_new_tokens=max_new_tokens,
max_max_new_tokens=max_max_new_tokens,
)
model_max_length = get_model_max_length(chosen_model_state)
model_max_length = 2048 # get_model_max_length(chosen_model_state)
max_new_tokens = min(max(1, int(max_new_tokens)), max_max_new_tokens)
min_new_tokens = min(max(0, int(min_new_tokens)), max_new_tokens)
max_time = min(max(0, max_time), max_max_time)
@@ -1761,6 +1763,7 @@ def evaluate(
# restrict instruction, typically what has large input
from h2oai_pipeline import H2OTextGenerationPipeline
print(instruction)
instruction, num_prompt_tokens1 = H2OTextGenerationPipeline.limit_prompt(
instruction, tokenizer
)
@@ -2318,6 +2321,8 @@ def evaluate(
model_max_length=tokenizer.model_max_length,
)
print(prompt)
# exit(0)
inputs = tokenizer(prompt, return_tensors="pt")
if debug and len(inputs["input_ids"]) > 0:
print("input_ids length", len(inputs["input_ids"][0]), flush=True)

View File

@@ -83,7 +83,9 @@ class H2OGPTSHARKModel(torch.nn.Module):
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(shark_module, vmfb_path, [])
shark_module = _compile_module(
shark_module, str(vmfb_path), []
)
print("Saved newly generated vmfb.")
if shark_module is None:
@@ -92,7 +94,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
shark_module = SharkInference(
None, device=global_device, mlir_dialect="linalg"
)
shark_module.load_module(vmfb_path)
shark_module.load_module(str(vmfb_path))
print("Compiled vmfb loaded successfully.")
else:
raise ValueError("Unable to download/generate a vmfb.")

View File

@@ -26,54 +26,17 @@ h2ogpt_model = 0
past_key_values = None
model_map = {
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
}
# NOTE: Each `model_name` should have its own start message
start_message = {
"StableLM": (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"vicuna1p3": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"codegen": "",
}
start_message = """
SHARK DocuChat
Chat with an AI, contextualized with provided files.
"""
def create_prompt(model_name, history):
system_message = start_message[model_name]
def create_prompt(history):
system_message = start_message
if model_name in ["StableLM", "vicuna", "vicuna1p3"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
conversation = "".join(["".join([item[0], item[1]]) for item in history])
msg = system_message + conversation
msg = msg.strip()
@@ -85,6 +48,7 @@ def chat(curr_system_message, history, model, device, precision):
global sharded_model
global past_key_values
global h2ogpt_model
global userpath_selector
model_name, model_path = list(map(str.strip, model.split("=>")))
print(f"In chat for {model_name}")
@@ -109,36 +73,84 @@ def chat(curr_system_message, history, model, device, precision):
# precision=precision,
# max_num_tokens=max_toks,
# )
# prompt = create_prompt(model_name, history)
prompt = create_prompt(history)
print(prompt)
# print("prompt = ", prompt)
# for partial_text in h2ogpt_model.generate(prompt):
# history[-1][1] = partial_text
# yield history
model, tokenizer, device = gen.get_model(
load_half=True,
load_gptq="",
use_safetensors=False,
infer_devices=True,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
inference_server="",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
gpu_id=0,
reward_type=None,
local_files_only=False,
resume_download=True,
use_auth_token=False,
trust_remote_code=True,
offload_folder=None,
compile_model=False,
verbose=False,
)
print(tokenizer.model_max_length)
model_state = dict(
model=model,
tokenizer=tokenizer,
device=device,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
inference_server="",
prompt_type=None,
prompt_dict=None,
)
output = gen.evaluate(
None, # model_state
model_state, # model_state
None, # my_db_state
None, # instruction
None, # iinput
history, # context
False, # stream_output
None, # prompt_type
None, # prompt_dict
None, # temperature
None, # top_p
None, # top_k
None, # num_beams
None, # max_new_tokens
None, # min_new_tokens
None, # early_stopping
None, # max_time
None, # repetition_penalty
None, # num_return_sequences
prompt, # instruction
"", # iinput
"", # context
True, # stream_output
"prompt_answer", # prompt_type
{
"promptA": "",
"promptB": "",
"PreInstruct": "<|prompt|>",
"PreInput": None,
"PreResponse": "<|answer|>",
"terminate_response": [
"<|prompt|>",
"<|answer|>",
"<|endoftext|>",
],
"chat_sep": "<|endoftext|>",
"chat_turn_sep": "<|endoftext|>",
"humanstr": "<|prompt|>",
"botstr": "<|answer|>",
"generates_leading_space": False,
}, # prompt_dict
0.1, # temperature
0.75, # top_p
40, # top_k
1, # num_beams
256, # max_new_tokens
0, # min_new_tokens
False, # early_stopping
180, # max_time
1.07, # repetition_penalty
1, # num_return_sequences
False, # do_sample
False, # chat
None, # instruction_nochat
curr_system_message, # iinput_nochat
"Disabled", # langchain_mode
True, # chat
prompt, # instruction_nochat
"", # iinput_nochat
"UserData", # langchain_mode
LangChainAction.QUERY.value, # langchain_action
3, # top_k_docs
True, # chunk
@@ -154,6 +166,10 @@ def chat(curr_system_message, history, model, device, precision):
db_type="chroma",
n_jobs=-1,
first_para=False,
max_max_time=60 * 2,
model_state0=model_state,
model_lock=True,
user_path=userpath_selector.value,
)
for partial_text in output:
history[-1][1] = partial_text
@@ -164,14 +180,6 @@ def chat(curr_system_message, history, model, device, precision):
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
with gr.Row():
model_choices = list(
map(lambda x: f"{x[0]: <10} => {x[1]}", model_map.items())
)
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
choices=model_choices,
)
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
@@ -197,6 +205,14 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
],
visible=True,
)
userpath_selector = gr.Textbox(
label="Document Directory",
value=str(
os.path.abspath("apps/language_models/langchain/user_path/")
),
interactive=True,
container=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
@@ -220,7 +236,7 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
inputs=[system_msg, chatbot, device, precision],
outputs=[chatbot],
queue=True,
)
@@ -228,7 +244,7 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
inputs=[system_msg, chatbot, device, precision],
outputs=[chatbot],
queue=True,
)