mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
LLM Pipeline Wrapper (#1477)
* [LLM] Add LLM pipeline Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com> * add base pipeline and stableLM * StableLM on UI - full block * add SLM default model name * add vicuna with pipeline * add one token gen api for vic * Fix stableLM bugs * debug vic memory * lint fix --------- Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com> Co-authored-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
committed by
GitHub
parent
1ddef26af5
commit
f0a4e59758
@@ -88,35 +88,16 @@ def chat(curr_system_message, history, model):
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
from apps.language_models.scripts.stablelm import (
|
||||
compile_stableLM,
|
||||
StopOnTokens,
|
||||
generate,
|
||||
StableLMModel,
|
||||
from apps.language_models.src.pipelines.stablelm_pipeline import (
|
||||
SharkStableLM,
|
||||
)
|
||||
|
||||
if sharkModel == 0:
|
||||
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
|
||||
max_sequence_len = 256
|
||||
precision = "fp32"
|
||||
model_name_template = (
|
||||
f"stableLM_linalg_{precision}_seqLen{max_sequence_len}"
|
||||
)
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(
|
||||
"StableLM"
|
||||
) # pass elements from UI as required
|
||||
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
|
||||
)
|
||||
stableLMModel = StableLMModel(m)
|
||||
input_ids = torch.randint(3, (1, max_sequence_len))
|
||||
attention_mask = torch.randint(3, (1, max_sequence_len))
|
||||
sharkModel = compile_stableLM(
|
||||
stableLMModel,
|
||||
tuple([input_ids, attention_mask]),
|
||||
model_name_template,
|
||||
None, # provide a fully qualified path to vmfb file if already exists
|
||||
)
|
||||
# Initialize a StopOnTokens object
|
||||
stop = StopOnTokens()
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
@@ -128,12 +109,10 @@ def chat(curr_system_message, history, model):
|
||||
]
|
||||
)
|
||||
|
||||
generate_kwargs = dict(
|
||||
new_text=messages,
|
||||
max_new_tokens=512,
|
||||
sharkStableLM=sharkModel,
|
||||
)
|
||||
words_list = generate(**generate_kwargs)
|
||||
generate_kwargs = dict(prompt=messages)
|
||||
|
||||
words_list = shark_slm.generate(**generate_kwargs)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
# print(new_text)
|
||||
|
||||
Reference in New Issue
Block a user