[Shard] Add sharding generation in shark studio

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-07-28 21:53:22 +05:30
parent c9de2729b2
commit bd30044c0b
5 changed files with 91 additions and 19 deletions

View File

@@ -301,12 +301,13 @@ class CombinedModel(torch.nn.Module):
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
def forward(self, input_ids):
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
logits = first_output[0]
pkv = first_output[1:]
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
second_output = self.second_vicuna(secondVicunaInput)
first_output = self.first_vicuna(input_ids=input_ids)
# generate second vicuna
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
second_output = self.second_vicuna(*secondVicunaCompileInput)
return second_output

View File

@@ -154,6 +154,7 @@ if __name__ == "__main__":
upscaler_sendto_outpaint,
lora_train_web,
model_web,
model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
@@ -221,6 +222,16 @@ if __name__ == "__main__":
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="Chat Bot(Experimental)", id=7):
stablelm_chat.render()
with gr.TabItem(label="Generate Sharding Config", id=8):
model_config_web.render()
with gr.TabItem(label="LoRA Training(Experimental)", id=9):
lora_train_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
if args.output_gallery:
with gr.TabItem(label="Output Gallery", id=5) as og_tab:
outputgallery_web.render()
@@ -236,15 +247,7 @@ if __name__ == "__main__":
upscaler_status,
]
)
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="LoRA Training (Experimental)", id=8):
lora_train_web.render()
with gr.TabItem(label="Chat Bot (Experimental)", id=7):
stablelm_chat.render()
with gr.TabItem(label="MultiModal (Experimental)", id=9):
minigpt4_web.render()
with gr.TabItem(label="DocuChat(Experimental)", id=10):
with gr.TabItem(label="DocuChat(Experimental)", id=11):
h2ogpt_web.render()
# send to buttons

View File

@@ -78,6 +78,7 @@ from apps.stable_diffusion.web.ui.stablelm_ui import (
stablelm_chat,
llm_chat_api,
)
from apps.stable_diffusion.web.ui.generate_config import model_config_web
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (

View File

@@ -0,0 +1,41 @@
import gradio as gr
import torch
from transformers import AutoTokenizer
from apps.language_models.src.model_wrappers.vicuna_model import CombinedModel
from shark.shark_generate_model_config import GenerateConfigFile
def get_model_config():
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
return c.split_into_layers()
with gr.Blocks() as model_config_web:
with gr.Row():
hf_models = gr.Dropdown(
label="Model List",
choices=["Vicuna"],
value="Vicuna",
visible=True,
)
get_model_config_btn = gr.Button(value="Get Model Config")
json_view = gr.JSON()
get_model_config_btn.click(
fn=get_model_config,
inputs=[],
outputs=[json_view],
)