Add support for different compilation paths for DocuChat (#1665)

This commit is contained in:
Vivek Khandelwal
2023-07-18 22:19:44 +05:30
committed by GitHub
parent 11f62d7fac
commit b0136593df
6 changed files with 81 additions and 17 deletions

4
.gitignore vendored
View File

@@ -189,3 +189,7 @@ apps/stable_diffusion/web/models/
# Stencil annotators.
stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData

View File

@@ -6,10 +6,12 @@
```shell
pip install -r apps/language_models/langchain/langchain_requirements.txt
```
2.) Create a folder named `user_path` and all your docs into that folder.
2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory.
Now, you are ready to use the model.
3.) To run the model, run the following command:
```shell
python apps/language_models/langchain/gen.py --user_path=<path_to_user_path_directory> --cli=True
python apps/language_models/langchain/gen.py --cli=True
```

View File

@@ -177,7 +177,7 @@ def main(
LangChainAction.SUMMARIZE_MAP.value,
],
document_choice: list = [DocumentChoices.All_Relevant.name],
user_path: str = None,
user_path: str = "apps/language_models/langchain/user_path/",
detect_user_path_changes_every_query: bool = False,
load_db_if_exists: bool = True,
keep_sources_in_context: bool = False,

View File

@@ -1,4 +1,5 @@
import os
from apps.stable_diffusion.src.utils.utils import _compile_module
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
@@ -19,34 +20,79 @@ import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from apps.stable_diffusion.src import args
global_device = "cuda"
global_precision = "fp16"
if not args.run_docuchat_web:
args.device = global_device
args.precision = global_precision
class H2OGPTSHARKModel(torch.nn.Module):
def __init__(self):
super().__init__()
model_name = "h2ogpt_falcon_7b"
path_str = (
model_name + "_" + global_precision + "_" + global_device + ".vmfb"
model_name + "_" + args.precision + "_" + args.device + ".vmfb"
)
vmfb_path = Path(path_str)
path_str = model_name + "_" + args.precision + ".mlir"
mlir_path = Path(path_str)
shark_module = None
if not vmfb_path.exists():
# Downloading VMFB from shark_tank
print("Downloading vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + path_str,
vmfb_path.absolute(),
single_file=True,
)
print("Compiled vmfb found. Loading it from: ", vmfb_path)
shark_module = SharkInference(
None, device=global_device, mlir_dialect="linalg"
)
shark_module.load_module(vmfb_path)
print("Compiled vmfb loaded successfully.")
if args.device == "cuda" and args.precision in ["fp16", "fp32"]:
# Downloading VMFB from shark_tank
print("Downloading vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + path_str,
vmfb_path.absolute(),
single_file=True,
)
else:
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/"
+ model_name
+ "_"
+ args.precision
+ ".mlir",
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
raise ValueError(
f"MLIR not found at {mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(shark_module, vmfb_path, [])
print("Saved newly generated vmfb.")
if shark_module is None:
if vmfb_path.exists():
print("Compiled vmfb found. Loading it from: ", vmfb_path)
shark_module = SharkInference(
None, device=global_device, mlir_dialect="linalg"
)
shark_module.load_module(vmfb_path)
print("Compiled vmfb loaded successfully.")
else:
raise ValueError("Unable to download/generate a vmfb.")
self.model = shark_module

View File

@@ -648,6 +648,16 @@ p.add_argument(
help="Op to be optimized, options are matmul, bmm, conv and all.",
)
##############################################################################
# DocuChat Flags
##############################################################################
p.add_argument(
"--run_docuchat_web",
default=False,
action=argparse.BooleanOptionalAction,
help="Specifies whether the docuchat's web version is running or not.",
)
args, unknown = p.parse_known_args()
if args.import_debug:

View File

@@ -12,6 +12,7 @@ from apps.language_models.langchain.enums import (
LangChainAction,
)
import apps.language_models.langchain.gen as gen
from apps.stable_diffusion.src import args
def user(message, history):
@@ -80,6 +81,7 @@ def create_prompt(model_name, history):
def chat(curr_system_message, history, model, device, precision):
args.run_docuchat_web = True
global sharded_model
global past_key_values
global h2ogpt_model