SLM on Sharkstudio (#1454)

* localize import, fix file reading, device cpu

* extract out model args
This commit is contained in:
Phaneesh Barwaria
2023-05-19 23:51:08 +05:30
committed by GitHub
parent 991f144598
commit f5ce121988
2 changed files with 32 additions and 30 deletions

View File

@@ -170,11 +170,16 @@ def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
# ADD Device Arg
from shark.shark_inference import SharkInference
vmfb_path = Path(model_vmfb_name + ".vmfb")
device = "cuda" # 'cpu'
vmfb_path = (
Path(model_name + f"_{device}.vmfb")
if model_vmfb_name is None
else Path(model_vmfb_name)
)
if vmfb_path.exists():
print("Loading ", vmfb_path)
print("Loading vmfb from: ", vmfb_path)
shark_module = SharkInference(
None, device="cuda", mlir_dialect="tm_tensor"
None, device=device, mlir_dialect="tm_tensor"
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
@@ -185,8 +190,8 @@ def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
f"[DEBUG] mlir path { mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path) as f:
bytecode = f.read("rb")
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
module = torch_mlir.compile(
@@ -205,7 +210,7 @@ def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
)
shark_module.compile()

View File

@@ -1,16 +1,9 @@
import gradio as gr
import torch
import os
from apps.language_models.scripts.stablelm import (
compile_stableLM,
StopOnTokens,
generate,
get_tokenizer,
StableLMModel,
)
from pathlib import Path
from transformers import (
AutoModelForCausalLM,
TextIteratorStreamer,
StoppingCriteriaList,
)
from apps.stable_diffusion.web.ui.utils import available_devices
@@ -28,10 +21,6 @@ def user(message, history):
return "", history + [[message, ""]]
input_ids = torch.randint(3, (1, 256))
attention_mask = torch.randint(3, (1, 256))
sharkModel = 0
sharded_model = 0
@@ -41,6 +30,7 @@ past_key_values = None
def chat(curr_system_message, history, model):
print(f"In chat for {model}")
global sharded_model
global past_key_values
if "vicuna" in model:
@@ -97,22 +87,34 @@ def chat(curr_system_message, history, model):
print(new_sentence)
return history
# else Model is StableLM
global sharkModel
print("In chat")
from apps.language_models.scripts.stablelm import (
compile_stableLM,
StopOnTokens,
generate,
StableLMModel,
)
if sharkModel == 0:
tok = get_tokenizer()
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
max_sequence_len = 256
precision = "fp32"
model_name_template = (
f"stableLM_linalg_{precision}_seqLen{max_sequence_len}"
)
m = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
)
stableLMModel = StableLMModel(m)
input_ids = torch.randint(3, (1, 256))
attention_mask = torch.randint(3, (1, 256))
input_ids = torch.randint(3, (1, max_sequence_len))
attention_mask = torch.randint(3, (1, max_sequence_len))
sharkModel = compile_stableLM(
stableLMModel,
tuple([input_ids, attention_mask]),
"stableLM_linalg_f32_seqLen256",
os.getcwd(),
model_name_template,
None, # provide a fully qualified path to vmfb file if already exists
)
# Initialize a StopOnTokens object
stop = StopOnTokens()
@@ -126,14 +128,9 @@ def chat(curr_system_message, history, model):
for item in history
]
)
# print(messages)
# Tokenize the messages string
# streamer = TextIteratorStreamer(
# tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
# )
generate_kwargs = dict(
new_text=messages,
# streamer=streamer,
max_new_tokens=512,
do_sample=True,
top_p=0.95,