fix #2 SLM in SharkStudio

This commit is contained in:
PhaneeshB
2023-05-18 00:32:45 +05:30
committed by Phaneesh Barwaria
parent aefcf80b48
commit 09bea17e59
2 changed files with 6 additions and 6 deletions

View File

@@ -251,7 +251,7 @@ def get_tokenizer():
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
def generate(
new_text,
streamer,
# streamer,
max_new_tokens,
do_sample,
top_p,
@@ -265,7 +265,7 @@ def generate(
attention_mask=torch.randint(3, (1, 256)),
):
if tok == None:
tok = get_tokenizer
tok = get_tokenizer()
# Construct the input message string for the model by concatenating the current system message and conversation history
# Tokenize the messages string
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")

View File

@@ -128,12 +128,12 @@ def chat(curr_system_message, history, model):
)
# print(messages)
# Tokenize the messages string
streamer = TextIteratorStreamer(
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
# streamer = TextIteratorStreamer(
# tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
# )
generate_kwargs = dict(
new_text=messages,
streamer=streamer,
# streamer=streamer,
max_new_tokens=512,
do_sample=True,
top_p=0.95,