mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
fix #2 SLM in SharkStudio
This commit is contained in:
committed by
Phaneesh Barwaria
parent
aefcf80b48
commit
09bea17e59
@@ -251,7 +251,7 @@ def get_tokenizer():
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
def generate(
|
||||
new_text,
|
||||
streamer,
|
||||
# streamer,
|
||||
max_new_tokens,
|
||||
do_sample,
|
||||
top_p,
|
||||
@@ -265,7 +265,7 @@ def generate(
|
||||
attention_mask=torch.randint(3, (1, 256)),
|
||||
):
|
||||
if tok == None:
|
||||
tok = get_tokenizer
|
||||
tok = get_tokenizer()
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
# Tokenize the messages string
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
|
||||
@@ -128,12 +128,12 @@ def chat(curr_system_message, history, model):
|
||||
)
|
||||
# print(messages)
|
||||
# Tokenize the messages string
|
||||
streamer = TextIteratorStreamer(
|
||||
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
# streamer = TextIteratorStreamer(
|
||||
# tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
# )
|
||||
generate_kwargs = dict(
|
||||
new_text=messages,
|
||||
streamer=streamer,
|
||||
# streamer=streamer,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
top_p=0.95,
|
||||
|
||||
Reference in New Issue
Block a user