[Shard] Add sharding generation in shark studio

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-07-28 21:53:22 +05:30
parent c9de2729b2
commit bd30044c0b
5 changed files with 91 additions and 19 deletions

View File

@@ -154,6 +154,7 @@ if __name__ == "__main__":
upscaler_sendto_outpaint,
lora_train_web,
model_web,
model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
@@ -221,6 +222,16 @@ if __name__ == "__main__":
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="Chat Bot(Experimental)", id=7):
stablelm_chat.render()
with gr.TabItem(label="Generate Sharding Config", id=8):
model_config_web.render()
with gr.TabItem(label="LoRA Training(Experimental)", id=9):
lora_train_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
if args.output_gallery:
with gr.TabItem(label="Output Gallery", id=5) as og_tab:
outputgallery_web.render()
@@ -236,15 +247,7 @@ if __name__ == "__main__":
upscaler_status,
]
)
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="LoRA Training (Experimental)", id=8):
lora_train_web.render()
with gr.TabItem(label="Chat Bot (Experimental)", id=7):
stablelm_chat.render()
with gr.TabItem(label="MultiModal (Experimental)", id=9):
minigpt4_web.render()
with gr.TabItem(label="DocuChat(Experimental)", id=10):
with gr.TabItem(label="DocuChat(Experimental)", id=11):
h2ogpt_web.render()
# send to buttons

View File

@@ -78,6 +78,7 @@ from apps.stable_diffusion.web.ui.stablelm_ui import (
stablelm_chat,
llm_chat_api,
)
from apps.stable_diffusion.web.ui.generate_config import model_config_web
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (

View File

@@ -0,0 +1,41 @@
import gradio as gr
import torch
from transformers import AutoTokenizer
from apps.language_models.src.model_wrappers.vicuna_model import CombinedModel
from shark.shark_generate_model_config import GenerateConfigFile
def get_model_config():
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
return c.split_into_layers()
with gr.Blocks() as model_config_web:
with gr.Row():
hf_models = gr.Dropdown(
label="Model List",
choices=["Vicuna"],
value="Vicuna",
visible=True,
)
get_model_config_btn = gr.Button(value="Get Model Config")
json_view = gr.JSON()
get_model_config_btn.click(
fn=get_model_config,
inputs=[],
outputs=[json_view],
)