Fix multiple issue for Langchain

This commit fixes the following issue for the Langchain:
1.) Web UI not able to fetch results.
2.) For each query model getting reloaded.
3.) SHARK module not using user provided device and precision.
4.) Create a class for main Langchain code.
5.) Misc issues
This commit is contained in:
Vivek Khandelwal
2023-07-21 13:13:27 +00:00
parent a415f3f70e
commit d7092aafaa
6 changed files with 3268 additions and 3181 deletions

View File

@@ -2,7 +2,7 @@ import copy
import torch
from evaluate_params import eval_func_param_names
from gen import get_score_model, get_model, evaluate, check_locals
from gen import Langchain
from prompter import non_hf_types
from utils import clear_torch_cache, NullContext, get_kwargs
@@ -87,7 +87,7 @@ def run_cli( # for local function:
# unique to this function:
cli_loop=None,
):
check_locals(**locals())
Langchain.check_locals(**locals())
score_model = "" # FIXME: For now, so user doesn't have to pass
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
@@ -98,16 +98,20 @@ def run_cli( # for local function:
from functools import partial
# get score model
smodel, stokenizer, sdevice = get_score_model(
smodel, stokenizer, sdevice = Langchain.get_score_model(
reward_type=True,
**get_kwargs(
get_score_model, exclude_names=["reward_type"], **locals()
Langchain.get_score_model,
exclude_names=["reward_type"],
**locals()
)
)
model, tokenizer, device = get_model(
model, tokenizer, device = Langchain.get_model(
reward_type=False,
**get_kwargs(get_model, exclude_names=["reward_type"], **locals())
**get_kwargs(
Langchain.get_model, exclude_names=["reward_type"], **locals()
)
)
model_dict = dict(
base_model=base_model,
@@ -121,11 +125,11 @@ def run_cli( # for local function:
model_state.update(model_dict)
my_db_state = [None]
fun = partial(
evaluate,
Langchain.evaluate,
model_state,
my_db_state,
**get_kwargs(
evaluate,
Langchain.evaluate,
exclude_names=["model_state", "my_db_state"]
+ eval_func_param_names,
**locals()

View File

@@ -7,7 +7,7 @@ import torch
from matplotlib import pyplot as plt
from evaluate_params import eval_func_param_names, eval_extra_columns
from gen import get_context, get_score_model, get_model, evaluate, check_locals
from gen import Langchain
from prompter import Prompter
from utils import clear_torch_cache, NullContext, get_kwargs
@@ -94,7 +94,7 @@ def run_eval( # for local function:
force_langchain_evaluate=None,
model_state_none=None,
):
check_locals(**locals())
Langchain.check_locals(**locals())
if eval_prompts_only_num > 0:
np.random.seed(eval_prompts_only_seed)
@@ -144,7 +144,7 @@ def run_eval( # for local function:
] = "" # no input
examplenew[
eval_func_param_names.index("context")
] = get_context(chat_context, prompt_type)
] = Langchain.get_context(chat_context, prompt_type)
examples.append(examplenew)
responses.append(output)
else:
@@ -170,7 +170,7 @@ def run_eval( # for local function:
] = "" # no input
examplenew[
eval_func_param_names.index("context")
] = get_context(chat_context, prompt_type)
] = Langchain.get_context(chat_context, prompt_type)
examples.append(examplenew)
responses.append(output)
@@ -210,18 +210,22 @@ def run_eval( # for local function:
from functools import partial
# get score model
smodel, stokenizer, sdevice = get_score_model(
smodel, stokenizer, sdevice = Langchain.get_score_model(
reward_type=True,
**get_kwargs(
get_score_model, exclude_names=["reward_type"], **locals()
Langchain.get_score_model,
exclude_names=["reward_type"],
**locals()
)
)
if not eval_as_output:
model, tokenizer, device = get_model(
model, tokenizer, device = Langchain.get_model(
reward_type=False,
**get_kwargs(
get_model, exclude_names=["reward_type"], **locals()
Langchain.get_model,
exclude_names=["reward_type"],
**locals()
)
)
model_dict = dict(
@@ -236,11 +240,11 @@ def run_eval( # for local function:
model_state.update(model_dict)
my_db_state = [None]
fun = partial(
evaluate,
Langchain.evaluate,
model_state,
my_db_state,
**get_kwargs(
evaluate,
Langchain.evaluate,
exclude_names=["model_state", "my_db_state"]
+ eval_func_param_names,
**locals()

File diff suppressed because it is too large Load Diff

View File

@@ -34,7 +34,7 @@ from enums import (
LangChainMode,
)
from evaluate_params import gen_hyper
from gen import get_model, SEED
from gen import Langchain, SEED
from prompter import non_hf_types, PromptType, Prompter
from utils import (
wrapped_partial,
@@ -907,7 +907,7 @@ def get_llm(
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
inference_server = ""
model, tokenizer, device = get_model(
model, tokenizer, device = Langchain.get_model(
load_8bit=True,
base_model=model_name,
inference_server=inference_server,

View File

@@ -34,12 +34,11 @@ class H2OGPTSHARKModel(torch.nn.Module):
def __init__(self):
super().__init__()
model_name = "h2ogpt_falcon_7b"
path_str = (
model_name + "_" + args.precision + "_" + args.device + ".vmfb"
extended_model_name = (
model_name + "_" + args.precision + "_" + args.device
)
vmfb_path = Path(path_str)
path_str = model_name + "_" + args.precision + ".mlir"
mlir_path = Path(path_str)
vmfb_path = Path(extended_model_name + ".vmfb")
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
shark_module = None
if not vmfb_path.exists():
@@ -50,7 +49,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
# Downloading VMFB from shark_tank
print("Downloading vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + path_str,
"gs://shark_tank/langchain/" + str(vmfb_path),
vmfb_path.absolute(),
single_file=True,
)
@@ -61,11 +60,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
else:
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/"
+ model_name
+ "_"
+ args.precision
+ ".mlir",
"gs://shark_tank/langchain/" + str(mlir_path),
mlir_path.absolute(),
single_file=True,
)
@@ -84,7 +79,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(
shark_module, str(vmfb_path), []
shark_module, extended_model_name, []
)
print("Saved newly generated vmfb.")
@@ -92,7 +87,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
if vmfb_path.exists():
print("Compiled vmfb found. Loading it from: ", vmfb_path)
shark_module = SharkInference(
None, device=global_device, mlir_dialect="linalg"
None, device=args.device, mlir_dialect="linalg"
)
shark_module.load_module(str(vmfb_path))
print("Compiled vmfb loaded successfully.")
@@ -107,7 +102,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
"forward",
(input_ids.to(device="cpu"), attention_mask.to(device="cpu")),
)
).to(device=global_device)
).to(device=args.device)
return result
@@ -123,14 +118,14 @@ def pad_or_truncate_inputs(
num_add_token = max_padding_length - inp_shape[1]
padded_input_ids = torch.cat(
[
torch.tensor([[11] * num_add_token]).to(device=global_device),
torch.tensor([[11] * num_add_token]).to(device=args.device),
input_ids,
],
dim=1,
)
padded_attention_mask = torch.cat(
[
torch.tensor([[0] * num_add_token]).to(device=global_device),
torch.tensor([[0] * num_add_token]).to(device=args.device),
attention_mask,
],
dim=1,
@@ -333,7 +328,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if global_precision == "fp16":
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
@@ -460,7 +455,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=global_device)
torch.tensor(eos_token_id).to(device=args.device)
if eos_token_id is not None
else None
)
@@ -538,7 +533,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
self.input_ids = torch.cat(
[
torch.tensor(self.truncated_input_ids)
.to(device=global_device)
.to(device=args.device)
.unsqueeze(dim=0),
self.input_ids,
],
@@ -617,22 +612,9 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
**generate_kwargs,
)
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(
in_b, out_b // in_b, *generated_sequence.shape[1:]
)
elif self.framework == "tf":
from transformers import is_tf_available
if is_tf_available():
import tensorflow as tf
generated_sequence = tf.reshape(
generated_sequence,
(in_b, out_b // in_b, *generated_sequence.shape[1:]),
)
else:
raise ValueError("TF not avaialble.")
generated_sequence = generated_sequence.reshape(
in_b, out_b // in_b, *generated_sequence.shape[1:]
)
return {
"generated_sequence": generated_sequence,
"input_ids": input_ids,

View File

@@ -21,10 +21,8 @@ def user(message, history):
sharkModel = 0
sharded_model = 0
h2ogpt_model = 0
past_key_values = None
# NOTE: Each `model_name` should have its own start message
start_message = """
@@ -43,83 +41,76 @@ def create_prompt(history):
return msg
def chat(curr_system_message, history, model, device, precision):
def chat(curr_system_message, history, device, precision):
args.run_docuchat_web = True
global sharded_model
global past_key_values
global h2ogpt_model
global h2ogpt_tokenizer
global model_state
global langchain
global userpath_selector
model_name, model_path = list(map(str.strip, model.split("=>")))
print(f"In chat for {model_name}")
if h2ogpt_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu"
elif "task" in device:
device = "cpu"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
# if h2ogpt_model == 0:
# if "cuda" in device:
# device = "cuda"
# elif "sync" in device:
# device = "cpu-sync"
# elif "task" in device:
# device = "cpu-task"
# elif "vulkan" in device:
# device = "vulkan"
# else:
# print("unrecognized device")
args.device = device
args.precision = precision
from apps.language_models.langchain.gen import Langchain
langchain = Langchain(device, precision)
h2ogpt_model, h2ogpt_tokenizer, _ = langchain.get_model(
load_8bit=True
if device == "cuda"
else False, # load model in 4bit if device is cuda to save memory
load_gptq="",
use_safetensors=False,
infer_devices=True,
device=device,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
inference_server="",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
gpu_id=0,
reward_type=None,
local_files_only=False,
resume_download=True,
use_auth_token=False,
trust_remote_code=True,
offload_folder=None,
compile_model=False,
verbose=False,
)
model_state = dict(
model=h2ogpt_model,
tokenizer=h2ogpt_tokenizer,
device=device,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
inference_server="",
prompt_type=None,
prompt_dict=None,
)
# max_toks = 128 if model_name == "codegen" else 512
# h2ogpt_model = UnshardedVicuna(
# model_name,
# hf_model_path=model_path,
# device=device,
# precision=precision,
# max_num_tokens=max_toks,
# )
prompt = create_prompt(history)
print(prompt)
# print("prompt = ", prompt)
# for partial_text in h2ogpt_model.generate(prompt):
# history[-1][1] = partial_text
# yield history
model, tokenizer, device = gen.get_model(
load_half=True,
load_gptq="",
use_safetensors=False,
infer_devices=True,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
inference_server="",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
gpu_id=0,
reward_type=None,
local_files_only=False,
resume_download=True,
use_auth_token=False,
trust_remote_code=True,
offload_folder=None,
compile_model=False,
verbose=False,
)
print(tokenizer.model_max_length)
model_state = dict(
model=model,
tokenizer=tokenizer,
device=device,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
inference_server="",
prompt_type=None,
prompt_dict=None,
)
output = gen.evaluate(
model_state, # model_state
None, # my_db_state
prompt, # instruction
"", # iinput
"", # context
True, # stream_output
"prompt_answer", # prompt_type
{
output = langchain.evaluate(
model_state=model_state,
my_db_state=None,
instruction=prompt,
iinput="",
context="",
stream_output=True,
prompt_type="prompt_answer",
prompt_dict={
"promptA": "",
"promptB": "",
"PreInstruct": "<|prompt|>",
@@ -135,27 +126,27 @@ def chat(curr_system_message, history, model, device, precision):
"humanstr": "<|prompt|>",
"botstr": "<|answer|>",
"generates_leading_space": False,
}, # prompt_dict
0.1, # temperature
0.75, # top_p
40, # top_k
1, # num_beams
256, # max_new_tokens
0, # min_new_tokens
False, # early_stopping
180, # max_time
1.07, # repetition_penalty
1, # num_return_sequences
False, # do_sample
True, # chat
prompt, # instruction_nochat
"", # iinput_nochat
"UserData", # langchain_mode
LangChainAction.QUERY.value, # langchain_action
3, # top_k_docs
True, # chunk
512, # chunk_size
[DocumentChoices.All_Relevant.name], # document_choice
},
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=256,
min_new_tokens=0,
early_stopping=False,
max_time=180,
repetition_penalty=1.07,
num_return_sequences=1,
do_sample=False,
chat=True,
instruction_nochat=prompt,
iinput_nochat="",
langchain_mode="UserData",
langchain_action=LangChainAction.QUERY.value,
top_k_docs=3,
chunk=True,
chunk_size=512,
document_choice=[DocumentChoices.All_Relevant.name],
concurrency_count=1,
memory_restriction_level=2,
raise_generate_gpu_exceptions=False,
@@ -172,7 +163,7 @@ def chat(curr_system_message, history, model, device, precision):
user_path=userpath_selector.value,
)
for partial_text in output:
history[-1][1] = partial_text
history[-1][1] = partial_text["response"]
yield history
return history