mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
Add support for different compilation paths for DocuChat (#1665)
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -189,3 +189,7 @@ apps/stable_diffusion/web/models/
|
||||
|
||||
# Stencil annotators.
|
||||
stencil_annotator/
|
||||
|
||||
# For DocuChat
|
||||
apps/language_models/langchain/user_path/
|
||||
db_dir_UserData
|
||||
|
||||
@@ -6,10 +6,12 @@
|
||||
```shell
|
||||
pip install -r apps/language_models/langchain/langchain_requirements.txt
|
||||
```
|
||||
2.) Create a folder named `user_path` and all your docs into that folder.
|
||||
|
||||
2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory.
|
||||
|
||||
Now, you are ready to use the model.
|
||||
|
||||
3.) To run the model, run the following command:
|
||||
```shell
|
||||
python apps/language_models/langchain/gen.py --user_path=<path_to_user_path_directory> --cli=True
|
||||
python apps/language_models/langchain/gen.py --cli=True
|
||||
```
|
||||
|
||||
@@ -177,7 +177,7 @@ def main(
|
||||
LangChainAction.SUMMARIZE_MAP.value,
|
||||
],
|
||||
document_choice: list = [DocumentChoices.All_Relevant.name],
|
||||
user_path: str = None,
|
||||
user_path: str = "apps/language_models/langchain/user_path/",
|
||||
detect_user_path_changes_every_query: bool = False,
|
||||
load_db_if_exists: bool = True,
|
||||
keep_sources_in_context: bool = False,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from apps.stable_diffusion.src.utils.utils import _compile_module
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
@@ -19,34 +20,79 @@ import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
global_device = "cuda"
|
||||
global_precision = "fp16"
|
||||
|
||||
if not args.run_docuchat_web:
|
||||
args.device = global_device
|
||||
args.precision = global_precision
|
||||
|
||||
|
||||
class H2OGPTSHARKModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_name = "h2ogpt_falcon_7b"
|
||||
path_str = (
|
||||
model_name + "_" + global_precision + "_" + global_device + ".vmfb"
|
||||
model_name + "_" + args.precision + "_" + args.device + ".vmfb"
|
||||
)
|
||||
vmfb_path = Path(path_str)
|
||||
path_str = model_name + "_" + args.precision + ".mlir"
|
||||
mlir_path = Path(path_str)
|
||||
shark_module = None
|
||||
|
||||
if not vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Downloading vmfb from shark tank.")
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + path_str,
|
||||
vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
print("Compiled vmfb found. Loading it from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=global_device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Compiled vmfb loaded successfully.")
|
||||
if args.device == "cuda" and args.precision in ["fp16", "fp32"]:
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Downloading vmfb from shark tank.")
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + path_str,
|
||||
vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
else:
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/"
|
||||
+ model_name
|
||||
+ "_"
|
||||
+ args.precision
|
||||
+ ".mlir",
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(shark_module, vmfb_path, [])
|
||||
print("Saved newly generated vmfb.")
|
||||
|
||||
if shark_module is None:
|
||||
if vmfb_path.exists():
|
||||
print("Compiled vmfb found. Loading it from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=global_device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Compiled vmfb loaded successfully.")
|
||||
else:
|
||||
raise ValueError("Unable to download/generate a vmfb.")
|
||||
|
||||
self.model = shark_module
|
||||
|
||||
|
||||
@@ -648,6 +648,16 @@ p.add_argument(
|
||||
help="Op to be optimized, options are matmul, bmm, conv and all.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# DocuChat Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--run_docuchat_web",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Specifies whether the docuchat's web version is running or not.",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
if args.import_debug:
|
||||
|
||||
@@ -12,6 +12,7 @@ from apps.language_models.langchain.enums import (
|
||||
LangChainAction,
|
||||
)
|
||||
import apps.language_models.langchain.gen as gen
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -80,6 +81,7 @@ def create_prompt(model_name, history):
|
||||
|
||||
|
||||
def chat(curr_system_message, history, model, device, precision):
|
||||
args.run_docuchat_web = True
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
global h2ogpt_model
|
||||
|
||||
Reference in New Issue
Block a user