LLM Pipeline Wrapper (#1477)

* [LLM] Add LLM pipeline

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

* add base pipeline and stableLM

* StableLM on UI - full block

* add SLM default model name

* add vicuna with pipeline

* add one token gen api for vic

* Fix stableLM bugs

* debug vic memory

* lint fix

---------

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Phaneesh Barwaria
2023-05-31 22:47:20 +05:30
committed by GitHub
parent 1ddef26af5
commit f0a4e59758
7 changed files with 1085 additions and 31 deletions

View File

@@ -88,35 +88,16 @@ def chat(curr_system_message, history, model):
# else Model is StableLM
global sharkModel
from apps.language_models.scripts.stablelm import (
compile_stableLM,
StopOnTokens,
generate,
StableLMModel,
from apps.language_models.src.pipelines.stablelm_pipeline import (
SharkStableLM,
)
if sharkModel == 0:
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
max_sequence_len = 256
precision = "fp32"
model_name_template = (
f"stableLM_linalg_{precision}_seqLen{max_sequence_len}"
)
# max_new_tokens=512
shark_slm = SharkStableLM(
"StableLM"
) # pass elements from UI as required
m = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
)
stableLMModel = StableLMModel(m)
input_ids = torch.randint(3, (1, max_sequence_len))
attention_mask = torch.randint(3, (1, max_sequence_len))
sharkModel = compile_stableLM(
stableLMModel,
tuple([input_ids, attention_mask]),
model_name_template,
None, # provide a fully qualified path to vmfb file if already exists
)
# Initialize a StopOnTokens object
stop = StopOnTokens()
# Construct the input message string for the model by concatenating the current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
@@ -128,12 +109,10 @@ def chat(curr_system_message, history, model):
]
)
generate_kwargs = dict(
new_text=messages,
max_new_tokens=512,
sharkStableLM=sharkModel,
)
words_list = generate(**generate_kwargs)
generate_kwargs = dict(prompt=messages)
words_list = shark_slm.generate(**generate_kwargs)
partial_text = ""
for new_text in words_list:
# print(new_text)