mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Add vector database and add support on the web UI (#1699)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import fire
|
||||
|
||||
from gpt_langchain import (
|
||||
path_to_docs,
|
||||
@@ -202,7 +201,3 @@ def make_db_main(
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(make_db_main)
|
||||
|
||||
@@ -115,6 +115,7 @@ if __name__ == "__main__":
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
h2ogpt_upload,
|
||||
h2ogpt_web,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
@@ -247,7 +248,9 @@ if __name__ == "__main__":
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
with gr.TabItem(label="DocuChat(Experimental)", id=11):
|
||||
with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
h2ogpt_upload.render()
|
||||
with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
h2ogpt_web.render()
|
||||
|
||||
# send to buttons
|
||||
|
||||
@@ -79,7 +79,7 @@ from apps.stable_diffusion.web.ui.stablelm_ui import (
|
||||
llm_chat_api,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.generate_config import model_config_web
|
||||
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
|
||||
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_upload, h2ogpt_web
|
||||
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
|
||||
from apps.stable_diffusion.web.ui.outputgallery_ui import (
|
||||
outputgallery_web,
|
||||
|
||||
@@ -12,6 +12,10 @@ from apps.language_models.langchain.enums import (
|
||||
LangChainAction,
|
||||
)
|
||||
import apps.language_models.langchain.gen as gen
|
||||
from gpt_langchain import (
|
||||
path_to_docs,
|
||||
create_or_update_db,
|
||||
)
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
|
||||
@@ -168,7 +172,14 @@ def chat(curr_system_message, history, device, precision):
|
||||
return history
|
||||
|
||||
|
||||
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
userpath_selector = gr.Textbox(
|
||||
label="Document Directory",
|
||||
value=str(os.path.abspath("apps/language_models/langchain/user_path/")),
|
||||
interactive=True,
|
||||
container=True,
|
||||
)
|
||||
|
||||
with gr.Blocks(title="DocuChat") as h2ogpt_web:
|
||||
with gr.Row():
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
@@ -195,14 +206,6 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
userpath_selector = gr.Textbox(
|
||||
label="Document Directory",
|
||||
value=str(
|
||||
os.path.abspath("apps/language_models/langchain/user_path/")
|
||||
),
|
||||
interactive=True,
|
||||
container=True,
|
||||
)
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@@ -246,3 +249,100 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
|
||||
|
||||
with gr.Blocks(title="DocuChat Upload") as h2ogpt_upload:
|
||||
import pathlib
|
||||
|
||||
upload_path = None
|
||||
database = None
|
||||
database_directory = os.path.abspath(
|
||||
"apps/language_models/langchain/db_path/"
|
||||
)
|
||||
|
||||
def read_path():
|
||||
global upload_path
|
||||
filenames = [
|
||||
[f]
|
||||
for f in os.listdir(upload_path)
|
||||
if os.path.isfile(os.path.join(upload_path, f))
|
||||
]
|
||||
filenames.sort()
|
||||
return filenames
|
||||
|
||||
def upload_file(f):
|
||||
names = []
|
||||
for tmpfile in f:
|
||||
name = tmpfile.name.split("/")[-1]
|
||||
basename = os.path.join(upload_path, name)
|
||||
with open(basename, "wb") as w:
|
||||
with open(tmpfile.name, "rb") as r:
|
||||
w.write(r.read())
|
||||
update_or_create_db()
|
||||
return read_path()
|
||||
|
||||
def update_userpath(newpath):
|
||||
global upload_path
|
||||
upload_path = newpath
|
||||
pathlib.Path(upload_path).mkdir(parents=True, exist_ok=True)
|
||||
return read_path()
|
||||
|
||||
def update_or_create_db():
|
||||
global database
|
||||
global upload_path
|
||||
|
||||
sources = path_to_docs(
|
||||
upload_path,
|
||||
verbose=True,
|
||||
fail_any_exception=False,
|
||||
n_jobs=-1,
|
||||
chunk=True,
|
||||
chunk_size=512,
|
||||
url=None,
|
||||
enable_captions=False,
|
||||
captions_model=None,
|
||||
caption_loader=None,
|
||||
enable_ocr=False,
|
||||
)
|
||||
|
||||
pathlib.Path(database_directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
database = create_or_update_db(
|
||||
"chroma",
|
||||
database_directory,
|
||||
"UserData",
|
||||
sources,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
|
||||
def first_run():
|
||||
global database
|
||||
if database is None:
|
||||
update_or_create_db()
|
||||
|
||||
update_userpath(
|
||||
os.path.abspath("apps/language_models/langchain/user_path/")
|
||||
)
|
||||
h2ogpt_upload.load(fn=first_run)
|
||||
h2ogpt_web.load(fn=first_run)
|
||||
|
||||
with gr.Column():
|
||||
text = gr.DataFrame(
|
||||
col_count=(1, "fixed"),
|
||||
type="array",
|
||||
label="Documents",
|
||||
value=read_path(),
|
||||
)
|
||||
with gr.Row():
|
||||
upload = gr.UploadButton(
|
||||
label="Upload documents",
|
||||
file_count="multiple",
|
||||
)
|
||||
upload.upload(fn=upload_file, inputs=upload, outputs=text)
|
||||
userpath_selector.render()
|
||||
userpath_selector.input(
|
||||
fn=update_userpath, inputs=userpath_selector, outputs=text
|
||||
).then(fn=update_or_create_db)
|
||||
|
||||
Reference in New Issue
Block a user