mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
[vicuna] Integrate sharded vicuna in web (#1717)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from transformers import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -106,7 +107,15 @@ def set_vicuna_model(model):
|
||||
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
def chat(
|
||||
curr_system_message,
|
||||
history,
|
||||
model,
|
||||
devices,
|
||||
precision,
|
||||
config_file,
|
||||
cli=True,
|
||||
):
|
||||
global past_key_values
|
||||
|
||||
global vicuna_model
|
||||
@@ -121,10 +130,12 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
]:
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
ShardedVicuna,
|
||||
)
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
device = devices[0]
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
@@ -137,14 +148,28 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
print("unrecognized device")
|
||||
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
)
|
||||
if len(devices) == 1 and config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
)
|
||||
else:
|
||||
if config_file is not None:
|
||||
config_file = open(config_file)
|
||||
config_json = json.load(config_file)
|
||||
config_file.close()
|
||||
else:
|
||||
config_json = None
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
device=device,
|
||||
precision=precision,
|
||||
config_json=config_json,
|
||||
)
|
||||
prompt = create_prompt(model_name, history)
|
||||
|
||||
for partial_text in vicuna_model.generate(prompt, cli=cli):
|
||||
@@ -307,13 +332,14 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
print(supported_devices)
|
||||
device = gr.Dropdown(
|
||||
devices = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
@@ -357,7 +383,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
@@ -365,7 +391,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user