[vicuna] Integrate sharded vicuna in web (#1717)

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-08-04 22:16:53 +05:30
committed by GitHub
parent bd30044c0b
commit 51ec1a1360

View File

@@ -7,6 +7,7 @@ from transformers import (
)
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
import json
def user(message, history):
@@ -106,7 +107,15 @@ def set_vicuna_model(model):
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision, cli=True):
def chat(
curr_system_message,
history,
model,
devices,
precision,
config_file,
cli=True,
):
global past_key_values
global vicuna_model
@@ -121,10 +130,12 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
ShardedVicuna,
)
from apps.stable_diffusion.src import args
if vicuna_model == 0:
device = devices[0]
if "cuda" in device:
device = "cuda"
elif "sync" in device:
@@ -137,14 +148,28 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
print("unrecognized device")
max_toks = 128 if model_name == "codegen" else 512
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
if len(devices) == 1 and config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
else:
if config_file is not None:
config_file = open(config_file)
config_json = json.load(config_file)
config_file.close()
else:
config_json = None
vicuna_model = ShardedVicuna(
model_name,
device=device,
precision=precision,
config_json=config_json,
)
prompt = create_prompt(model_name, history)
for partial_text in vicuna_model.generate(prompt, cli=cli):
@@ -307,13 +332,14 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
devices = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
multiselect=True,
)
precision = gr.Radio(
label="Precision",
@@ -357,7 +383,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
queue=True,
)
@@ -365,7 +391,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
queue=True,
)