mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Fix multiple issue for Langchain
This commit fixes the following issue for the Langchain: 1.) Web UI not able to fetch results. 2.) For each query model getting reloaded. 3.) SHARK module not using user provided device and precision. 4.) Create a class for main Langchain code. 5.) Misc issues
This commit is contained in:
@@ -2,7 +2,7 @@ import copy
|
||||
import torch
|
||||
|
||||
from evaluate_params import eval_func_param_names
|
||||
from gen import get_score_model, get_model, evaluate, check_locals
|
||||
from gen import Langchain
|
||||
from prompter import non_hf_types
|
||||
from utils import clear_torch_cache, NullContext, get_kwargs
|
||||
|
||||
@@ -87,7 +87,7 @@ def run_cli( # for local function:
|
||||
# unique to this function:
|
||||
cli_loop=None,
|
||||
):
|
||||
check_locals(**locals())
|
||||
Langchain.check_locals(**locals())
|
||||
|
||||
score_model = "" # FIXME: For now, so user doesn't have to pass
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
||||
@@ -98,16 +98,20 @@ def run_cli( # for local function:
|
||||
from functools import partial
|
||||
|
||||
# get score model
|
||||
smodel, stokenizer, sdevice = get_score_model(
|
||||
smodel, stokenizer, sdevice = Langchain.get_score_model(
|
||||
reward_type=True,
|
||||
**get_kwargs(
|
||||
get_score_model, exclude_names=["reward_type"], **locals()
|
||||
Langchain.get_score_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, device = get_model(
|
||||
model, tokenizer, device = Langchain.get_model(
|
||||
reward_type=False,
|
||||
**get_kwargs(get_model, exclude_names=["reward_type"], **locals())
|
||||
**get_kwargs(
|
||||
Langchain.get_model, exclude_names=["reward_type"], **locals()
|
||||
)
|
||||
)
|
||||
model_dict = dict(
|
||||
base_model=base_model,
|
||||
@@ -121,11 +125,11 @@ def run_cli( # for local function:
|
||||
model_state.update(model_dict)
|
||||
my_db_state = [None]
|
||||
fun = partial(
|
||||
evaluate,
|
||||
Langchain.evaluate,
|
||||
model_state,
|
||||
my_db_state,
|
||||
**get_kwargs(
|
||||
evaluate,
|
||||
Langchain.evaluate,
|
||||
exclude_names=["model_state", "my_db_state"]
|
||||
+ eval_func_param_names,
|
||||
**locals()
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from evaluate_params import eval_func_param_names, eval_extra_columns
|
||||
from gen import get_context, get_score_model, get_model, evaluate, check_locals
|
||||
from gen import Langchain
|
||||
from prompter import Prompter
|
||||
from utils import clear_torch_cache, NullContext, get_kwargs
|
||||
|
||||
@@ -94,7 +94,7 @@ def run_eval( # for local function:
|
||||
force_langchain_evaluate=None,
|
||||
model_state_none=None,
|
||||
):
|
||||
check_locals(**locals())
|
||||
Langchain.check_locals(**locals())
|
||||
|
||||
if eval_prompts_only_num > 0:
|
||||
np.random.seed(eval_prompts_only_seed)
|
||||
@@ -144,7 +144,7 @@ def run_eval( # for local function:
|
||||
] = "" # no input
|
||||
examplenew[
|
||||
eval_func_param_names.index("context")
|
||||
] = get_context(chat_context, prompt_type)
|
||||
] = Langchain.get_context(chat_context, prompt_type)
|
||||
examples.append(examplenew)
|
||||
responses.append(output)
|
||||
else:
|
||||
@@ -170,7 +170,7 @@ def run_eval( # for local function:
|
||||
] = "" # no input
|
||||
examplenew[
|
||||
eval_func_param_names.index("context")
|
||||
] = get_context(chat_context, prompt_type)
|
||||
] = Langchain.get_context(chat_context, prompt_type)
|
||||
examples.append(examplenew)
|
||||
responses.append(output)
|
||||
|
||||
@@ -210,18 +210,22 @@ def run_eval( # for local function:
|
||||
from functools import partial
|
||||
|
||||
# get score model
|
||||
smodel, stokenizer, sdevice = get_score_model(
|
||||
smodel, stokenizer, sdevice = Langchain.get_score_model(
|
||||
reward_type=True,
|
||||
**get_kwargs(
|
||||
get_score_model, exclude_names=["reward_type"], **locals()
|
||||
Langchain.get_score_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
|
||||
if not eval_as_output:
|
||||
model, tokenizer, device = get_model(
|
||||
model, tokenizer, device = Langchain.get_model(
|
||||
reward_type=False,
|
||||
**get_kwargs(
|
||||
get_model, exclude_names=["reward_type"], **locals()
|
||||
Langchain.get_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
model_dict = dict(
|
||||
@@ -236,11 +240,11 @@ def run_eval( # for local function:
|
||||
model_state.update(model_dict)
|
||||
my_db_state = [None]
|
||||
fun = partial(
|
||||
evaluate,
|
||||
Langchain.evaluate,
|
||||
model_state,
|
||||
my_db_state,
|
||||
**get_kwargs(
|
||||
evaluate,
|
||||
Langchain.evaluate,
|
||||
exclude_names=["model_state", "my_db_state"]
|
||||
+ eval_func_param_names,
|
||||
**locals()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,7 +34,7 @@ from enums import (
|
||||
LangChainMode,
|
||||
)
|
||||
from evaluate_params import gen_hyper
|
||||
from gen import get_model, SEED
|
||||
from gen import Langchain, SEED
|
||||
from prompter import non_hf_types, PromptType, Prompter
|
||||
from utils import (
|
||||
wrapped_partial,
|
||||
@@ -907,7 +907,7 @@ def get_llm(
|
||||
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
|
||||
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
||||
inference_server = ""
|
||||
model, tokenizer, device = get_model(
|
||||
model, tokenizer, device = Langchain.get_model(
|
||||
load_8bit=True,
|
||||
base_model=model_name,
|
||||
inference_server=inference_server,
|
||||
|
||||
@@ -34,12 +34,11 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_name = "h2ogpt_falcon_7b"
|
||||
path_str = (
|
||||
model_name + "_" + args.precision + "_" + args.device + ".vmfb"
|
||||
extended_model_name = (
|
||||
model_name + "_" + args.precision + "_" + args.device
|
||||
)
|
||||
vmfb_path = Path(path_str)
|
||||
path_str = model_name + "_" + args.precision + ".mlir"
|
||||
mlir_path = Path(path_str)
|
||||
vmfb_path = Path(extended_model_name + ".vmfb")
|
||||
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
|
||||
shark_module = None
|
||||
|
||||
if not vmfb_path.exists():
|
||||
@@ -50,7 +49,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Downloading vmfb from shark tank.")
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + path_str,
|
||||
"gs://shark_tank/langchain/" + str(vmfb_path),
|
||||
vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
@@ -61,11 +60,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
else:
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/"
|
||||
+ model_name
|
||||
+ "_"
|
||||
+ args.precision
|
||||
+ ".mlir",
|
||||
"gs://shark_tank/langchain/" + str(mlir_path),
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
@@ -84,7 +79,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(
|
||||
shark_module, str(vmfb_path), []
|
||||
shark_module, extended_model_name, []
|
||||
)
|
||||
print("Saved newly generated vmfb.")
|
||||
|
||||
@@ -92,7 +87,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
if vmfb_path.exists():
|
||||
print("Compiled vmfb found. Loading it from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=global_device, mlir_dialect="linalg"
|
||||
None, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.load_module(str(vmfb_path))
|
||||
print("Compiled vmfb loaded successfully.")
|
||||
@@ -107,7 +102,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
"forward",
|
||||
(input_ids.to(device="cpu"), attention_mask.to(device="cpu")),
|
||||
)
|
||||
).to(device=global_device)
|
||||
).to(device=args.device)
|
||||
return result
|
||||
|
||||
|
||||
@@ -123,14 +118,14 @@ def pad_or_truncate_inputs(
|
||||
num_add_token = max_padding_length - inp_shape[1]
|
||||
padded_input_ids = torch.cat(
|
||||
[
|
||||
torch.tensor([[11] * num_add_token]).to(device=global_device),
|
||||
torch.tensor([[11] * num_add_token]).to(device=args.device),
|
||||
input_ids,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
padded_attention_mask = torch.cat(
|
||||
[
|
||||
torch.tensor([[0] * num_add_token]).to(device=global_device),
|
||||
torch.tensor([[0] * num_add_token]).to(device=args.device),
|
||||
attention_mask,
|
||||
],
|
||||
dim=1,
|
||||
@@ -333,7 +328,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
)
|
||||
|
||||
if global_precision == "fp16":
|
||||
if args.precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
@@ -460,7 +455,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(device=global_device)
|
||||
torch.tensor(eos_token_id).to(device=args.device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
@@ -538,7 +533,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
self.input_ids = torch.cat(
|
||||
[
|
||||
torch.tensor(self.truncated_input_ids)
|
||||
.to(device=global_device)
|
||||
.to(device=args.device)
|
||||
.unsqueeze(dim=0),
|
||||
self.input_ids,
|
||||
],
|
||||
@@ -617,22 +612,9 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
**generate_kwargs,
|
||||
)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt":
|
||||
generated_sequence = generated_sequence.reshape(
|
||||
in_b, out_b // in_b, *generated_sequence.shape[1:]
|
||||
)
|
||||
elif self.framework == "tf":
|
||||
from transformers import is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
generated_sequence = tf.reshape(
|
||||
generated_sequence,
|
||||
(in_b, out_b // in_b, *generated_sequence.shape[1:]),
|
||||
)
|
||||
else:
|
||||
raise ValueError("TF not avaialble.")
|
||||
generated_sequence = generated_sequence.reshape(
|
||||
in_b, out_b // in_b, *generated_sequence.shape[1:]
|
||||
)
|
||||
return {
|
||||
"generated_sequence": generated_sequence,
|
||||
"input_ids": input_ids,
|
||||
|
||||
@@ -21,10 +21,8 @@ def user(message, history):
|
||||
|
||||
|
||||
sharkModel = 0
|
||||
sharded_model = 0
|
||||
h2ogpt_model = 0
|
||||
|
||||
past_key_values = None
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = """
|
||||
@@ -43,83 +41,76 @@ def create_prompt(history):
|
||||
return msg
|
||||
|
||||
|
||||
def chat(curr_system_message, history, model, device, precision):
|
||||
def chat(curr_system_message, history, device, precision):
|
||||
args.run_docuchat_web = True
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
global h2ogpt_model
|
||||
global h2ogpt_tokenizer
|
||||
global model_state
|
||||
global langchain
|
||||
global userpath_selector
|
||||
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
print(f"In chat for {model_name}")
|
||||
if h2ogpt_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu"
|
||||
elif "task" in device:
|
||||
device = "cpu"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
# if h2ogpt_model == 0:
|
||||
# if "cuda" in device:
|
||||
# device = "cuda"
|
||||
# elif "sync" in device:
|
||||
# device = "cpu-sync"
|
||||
# elif "task" in device:
|
||||
# device = "cpu-task"
|
||||
# elif "vulkan" in device:
|
||||
# device = "vulkan"
|
||||
# else:
|
||||
# print("unrecognized device")
|
||||
args.device = device
|
||||
args.precision = precision
|
||||
|
||||
from apps.language_models.langchain.gen import Langchain
|
||||
|
||||
langchain = Langchain(device, precision)
|
||||
h2ogpt_model, h2ogpt_tokenizer, _ = langchain.get_model(
|
||||
load_8bit=True
|
||||
if device == "cuda"
|
||||
else False, # load model in 4bit if device is cuda to save memory
|
||||
load_gptq="",
|
||||
use_safetensors=False,
|
||||
infer_devices=True,
|
||||
device=device,
|
||||
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
inference_server="",
|
||||
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
lora_weights="",
|
||||
gpu_id=0,
|
||||
reward_type=None,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
use_auth_token=False,
|
||||
trust_remote_code=True,
|
||||
offload_folder=None,
|
||||
compile_model=False,
|
||||
verbose=False,
|
||||
)
|
||||
model_state = dict(
|
||||
model=h2ogpt_model,
|
||||
tokenizer=h2ogpt_tokenizer,
|
||||
device=device,
|
||||
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
lora_weights="",
|
||||
inference_server="",
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
)
|
||||
|
||||
# max_toks = 128 if model_name == "codegen" else 512
|
||||
# h2ogpt_model = UnshardedVicuna(
|
||||
# model_name,
|
||||
# hf_model_path=model_path,
|
||||
# device=device,
|
||||
# precision=precision,
|
||||
# max_num_tokens=max_toks,
|
||||
# )
|
||||
prompt = create_prompt(history)
|
||||
print(prompt)
|
||||
# print("prompt = ", prompt)
|
||||
|
||||
# for partial_text in h2ogpt_model.generate(prompt):
|
||||
# history[-1][1] = partial_text
|
||||
# yield history
|
||||
model, tokenizer, device = gen.get_model(
|
||||
load_half=True,
|
||||
load_gptq="",
|
||||
use_safetensors=False,
|
||||
infer_devices=True,
|
||||
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
inference_server="",
|
||||
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
lora_weights="",
|
||||
gpu_id=0,
|
||||
reward_type=None,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
use_auth_token=False,
|
||||
trust_remote_code=True,
|
||||
offload_folder=None,
|
||||
compile_model=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(tokenizer.model_max_length)
|
||||
model_state = dict(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
lora_weights="",
|
||||
inference_server="",
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
)
|
||||
output = gen.evaluate(
|
||||
model_state, # model_state
|
||||
None, # my_db_state
|
||||
prompt, # instruction
|
||||
"", # iinput
|
||||
"", # context
|
||||
True, # stream_output
|
||||
"prompt_answer", # prompt_type
|
||||
{
|
||||
output = langchain.evaluate(
|
||||
model_state=model_state,
|
||||
my_db_state=None,
|
||||
instruction=prompt,
|
||||
iinput="",
|
||||
context="",
|
||||
stream_output=True,
|
||||
prompt_type="prompt_answer",
|
||||
prompt_dict={
|
||||
"promptA": "",
|
||||
"promptB": "",
|
||||
"PreInstruct": "<|prompt|>",
|
||||
@@ -135,27 +126,27 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
"humanstr": "<|prompt|>",
|
||||
"botstr": "<|answer|>",
|
||||
"generates_leading_space": False,
|
||||
}, # prompt_dict
|
||||
0.1, # temperature
|
||||
0.75, # top_p
|
||||
40, # top_k
|
||||
1, # num_beams
|
||||
256, # max_new_tokens
|
||||
0, # min_new_tokens
|
||||
False, # early_stopping
|
||||
180, # max_time
|
||||
1.07, # repetition_penalty
|
||||
1, # num_return_sequences
|
||||
False, # do_sample
|
||||
True, # chat
|
||||
prompt, # instruction_nochat
|
||||
"", # iinput_nochat
|
||||
"UserData", # langchain_mode
|
||||
LangChainAction.QUERY.value, # langchain_action
|
||||
3, # top_k_docs
|
||||
True, # chunk
|
||||
512, # chunk_size
|
||||
[DocumentChoices.All_Relevant.name], # document_choice
|
||||
},
|
||||
temperature=0.1,
|
||||
top_p=0.75,
|
||||
top_k=40,
|
||||
num_beams=1,
|
||||
max_new_tokens=256,
|
||||
min_new_tokens=0,
|
||||
early_stopping=False,
|
||||
max_time=180,
|
||||
repetition_penalty=1.07,
|
||||
num_return_sequences=1,
|
||||
do_sample=False,
|
||||
chat=True,
|
||||
instruction_nochat=prompt,
|
||||
iinput_nochat="",
|
||||
langchain_mode="UserData",
|
||||
langchain_action=LangChainAction.QUERY.value,
|
||||
top_k_docs=3,
|
||||
chunk=True,
|
||||
chunk_size=512,
|
||||
document_choice=[DocumentChoices.All_Relevant.name],
|
||||
concurrency_count=1,
|
||||
memory_restriction_level=2,
|
||||
raise_generate_gpu_exceptions=False,
|
||||
@@ -172,7 +163,7 @@ def chat(curr_system_message, history, model, device, precision):
|
||||
user_path=userpath_selector.value,
|
||||
)
|
||||
for partial_text in output:
|
||||
history[-1][1] = partial_text
|
||||
history[-1][1] = partial_text["response"]
|
||||
yield history
|
||||
|
||||
return history
|
||||
|
||||
Reference in New Issue
Block a user