[vicuna] add default config in case of sharded vicuna

Signed-Off-by: Gaurav Shukla<gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-08-07 21:59:39 +05:30
parent e8c1203be2
commit 8e90f1b81a
2 changed files with 28 additions and 3 deletions

View File

@@ -113,6 +113,31 @@ def set_vicuna_model(model):
vicuna_model = model
def get_default_config():
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()
# TODO: Make chat reusable for UI and API
def chat(
curr_system_message,
@@ -185,7 +210,7 @@ def chat(
config_json = json.load(config_file)
config_file.close()
else:
config_json = None
config_json = get_default_config()
vicuna_model = Vicuna(
model_name,
device=device,
@@ -379,7 +404,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Group():
config_file = gr.File(label="Upload sharding configuration")
json_view_button = gr.Button("View as JSON")
json_view = gr.JSON()
json_view = gr.JSON(interactive=True)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)

View File

@@ -144,4 +144,4 @@ if __name__ == "__main__":
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_dispatches("vulkan")
c.split_into_layers()