mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
SLM on Sharkstudio (#1454)
* localize import, fix file reading, device cpu * extract out model args
This commit is contained in:
committed by
GitHub
parent
991f144598
commit
f5ce121988
@@ -170,11 +170,16 @@ def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
|
||||
# ADD Device Arg
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
vmfb_path = Path(model_vmfb_name + ".vmfb")
|
||||
device = "cuda" # 'cpu'
|
||||
vmfb_path = (
|
||||
Path(model_name + f"_{device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
print("Loading ", vmfb_path)
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device="cuda", mlir_dialect="tm_tensor"
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
@@ -185,8 +190,8 @@ def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
|
||||
f"[DEBUG] mlir path { mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path) as f:
|
||||
bytecode = f.read("rb")
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
@@ -205,7 +210,7 @@ def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
import gradio as gr
|
||||
import torch
|
||||
import os
|
||||
from apps.language_models.scripts.stablelm import (
|
||||
compile_stableLM,
|
||||
StopOnTokens,
|
||||
generate,
|
||||
get_tokenizer,
|
||||
StableLMModel,
|
||||
)
|
||||
from pathlib import Path
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
TextIteratorStreamer,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
@@ -28,10 +21,6 @@ def user(message, history):
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
input_ids = torch.randint(3, (1, 256))
|
||||
attention_mask = torch.randint(3, (1, 256))
|
||||
|
||||
|
||||
sharkModel = 0
|
||||
sharded_model = 0
|
||||
|
||||
@@ -41,6 +30,7 @@ past_key_values = None
|
||||
|
||||
|
||||
def chat(curr_system_message, history, model):
|
||||
print(f"In chat for {model}")
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
if "vicuna" in model:
|
||||
@@ -97,22 +87,34 @@ def chat(curr_system_message, history, model):
|
||||
print(new_sentence)
|
||||
return history
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
print("In chat")
|
||||
from apps.language_models.scripts.stablelm import (
|
||||
compile_stableLM,
|
||||
StopOnTokens,
|
||||
generate,
|
||||
StableLMModel,
|
||||
)
|
||||
|
||||
if sharkModel == 0:
|
||||
tok = get_tokenizer()
|
||||
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
|
||||
max_sequence_len = 256
|
||||
precision = "fp32"
|
||||
model_name_template = (
|
||||
f"stableLM_linalg_{precision}_seqLen{max_sequence_len}"
|
||||
)
|
||||
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
|
||||
)
|
||||
stableLMModel = StableLMModel(m)
|
||||
input_ids = torch.randint(3, (1, 256))
|
||||
attention_mask = torch.randint(3, (1, 256))
|
||||
input_ids = torch.randint(3, (1, max_sequence_len))
|
||||
attention_mask = torch.randint(3, (1, max_sequence_len))
|
||||
sharkModel = compile_stableLM(
|
||||
stableLMModel,
|
||||
tuple([input_ids, attention_mask]),
|
||||
"stableLM_linalg_f32_seqLen256",
|
||||
os.getcwd(),
|
||||
model_name_template,
|
||||
None, # provide a fully qualified path to vmfb file if already exists
|
||||
)
|
||||
# Initialize a StopOnTokens object
|
||||
stop = StopOnTokens()
|
||||
@@ -126,14 +128,9 @@ def chat(curr_system_message, history, model):
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
# print(messages)
|
||||
# Tokenize the messages string
|
||||
# streamer = TextIteratorStreamer(
|
||||
# tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
# )
|
||||
|
||||
generate_kwargs = dict(
|
||||
new_text=messages,
|
||||
# streamer=streamer,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
top_p=0.95,
|
||||
|
||||
Reference in New Issue
Block a user