mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[Langchain] Patch for fixing streaming of tokens (#1709)
This commit is contained in:
@@ -1129,7 +1129,7 @@ class Langchain:
|
||||
max_time=max_time,
|
||||
num_return_sequences=num_return_sequences,
|
||||
)
|
||||
for r in run_qa_db(
|
||||
outr, extra = run_qa_db(
|
||||
query=instruction,
|
||||
iinput=iinput,
|
||||
context=context,
|
||||
@@ -1170,689 +1170,15 @@ class Langchain:
|
||||
auto_reduce_chunks=auto_reduce_chunks,
|
||||
max_chunks=max_chunks,
|
||||
device=self.device,
|
||||
):
|
||||
(
|
||||
outr,
|
||||
extra,
|
||||
) = r # doesn't accumulate, new answer every yield, so only save that full answer
|
||||
yield dict(response=outr, sources=extra)
|
||||
if save_dir:
|
||||
extra_dict = gen_hyper_langchain.copy()
|
||||
extra_dict.update(
|
||||
prompt_type=prompt_type,
|
||||
inference_server=inference_server,
|
||||
langchain_mode=langchain_mode,
|
||||
langchain_action=langchain_action,
|
||||
document_choice=document_choice,
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
instruction=instruction,
|
||||
iinput=iinput,
|
||||
context=context,
|
||||
)
|
||||
save_generate_output(
|
||||
prompt=prompt,
|
||||
output=outr,
|
||||
base_model=base_model,
|
||||
save_dir=save_dir,
|
||||
where_from="run_qa_db",
|
||||
extra_dict=extra_dict,
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
"Post-Generate Langchain: %s decoded_output: %s"
|
||||
% (str(datetime.now()), len(outr) if outr else -1),
|
||||
flush=True,
|
||||
)
|
||||
)
|
||||
response = dict(response=outr, sources=extra)
|
||||
if outr or base_model in non_hf_types:
|
||||
# if got no response (e.g. not showing sources and got no sources,
|
||||
# so nothing to give to LLM), then slip through and ask LLM
|
||||
# Or if llama/gptj, then just return since they had no response and can't go down below code path
|
||||
# clear before return, since .then() never done if from API
|
||||
clear_torch_cache()
|
||||
return
|
||||
|
||||
if inference_server.startswith(
|
||||
"openai"
|
||||
) or inference_server.startswith("http"):
|
||||
if inference_server.startswith("openai"):
|
||||
import openai
|
||||
|
||||
where_from = "openai_client"
|
||||
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
stop_sequences = list(
|
||||
set(prompter.terminate_response + [prompter.PreResponse])
|
||||
)
|
||||
stop_sequences = [x for x in stop_sequences if x]
|
||||
# OpenAI will complain if ask for too many new tokens, takes it as min in some sense, wrongly so.
|
||||
max_new_tokens_openai = min(
|
||||
max_new_tokens, model_max_length - num_prompt_tokens
|
||||
)
|
||||
gen_server_kwargs = dict(
|
||||
temperature=temperature if do_sample else 0,
|
||||
max_tokens=max_new_tokens_openai,
|
||||
top_p=top_p if do_sample else 1,
|
||||
frequency_penalty=0,
|
||||
n=num_return_sequences,
|
||||
presence_penalty=1.07
|
||||
- repetition_penalty
|
||||
+ 0.6, # so good default
|
||||
)
|
||||
if inference_server == "openai":
|
||||
response = openai.Completion.create(
|
||||
model=base_model,
|
||||
prompt=prompt,
|
||||
**gen_server_kwargs,
|
||||
stop=stop_sequences,
|
||||
stream=stream_output,
|
||||
)
|
||||
if not stream_output:
|
||||
text = response["choices"][0]["text"]
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
else:
|
||||
collected_events = []
|
||||
text = ""
|
||||
for event in response:
|
||||
collected_events.append(
|
||||
event
|
||||
) # save the event response
|
||||
event_text = event["choices"][0][
|
||||
"text"
|
||||
] # extract the text
|
||||
text += event_text # append the text
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
elif inference_server == "openai_chat":
|
||||
response = openai.ChatCompletion.create(
|
||||
model=base_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
],
|
||||
stream=stream_output,
|
||||
**gen_server_kwargs,
|
||||
)
|
||||
if not stream_output:
|
||||
text = response["choices"][0]["message"]["content"]
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
for chunk in response:
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if "content" in delta:
|
||||
text += delta["content"]
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No such OpenAI mode: %s" % inference_server
|
||||
)
|
||||
elif inference_server.startswith("http"):
|
||||
inference_server, headers = get_hf_server(inference_server)
|
||||
from gradio_utils.grclient import GradioClient
|
||||
from text_generation import Client as HFClient
|
||||
|
||||
if isinstance(model, GradioClient):
|
||||
gr_client = model
|
||||
hf_client = None
|
||||
elif isinstance(model, HFClient):
|
||||
gr_client = None
|
||||
hf_client = model
|
||||
else:
|
||||
(
|
||||
inference_server,
|
||||
gr_client,
|
||||
hf_client,
|
||||
) = self.get_client_from_inference_server(
|
||||
inference_server, base_model=base_model
|
||||
)
|
||||
|
||||
# quick sanity check to avoid long timeouts, just see if can reach server
|
||||
requests.get(
|
||||
inference_server,
|
||||
timeout=int(os.getenv("REQUEST_TIMEOUT_FAST", "10")),
|
||||
)
|
||||
|
||||
if gr_client is not None:
|
||||
# Note: h2oGPT gradio server could handle input token size issues for prompt,
|
||||
# but best to handle here so send less data to server
|
||||
|
||||
chat_client = False
|
||||
where_from = "gr_client"
|
||||
client_langchain_mode = "Disabled"
|
||||
client_langchain_action = LangChainAction.QUERY.value
|
||||
gen_server_kwargs = dict(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
num_beams=num_beams,
|
||||
max_new_tokens=max_new_tokens,
|
||||
min_new_tokens=min_new_tokens,
|
||||
early_stopping=early_stopping,
|
||||
max_time=max_time,
|
||||
repetition_penalty=repetition_penalty,
|
||||
num_return_sequences=num_return_sequences,
|
||||
do_sample=do_sample,
|
||||
chat=chat_client,
|
||||
)
|
||||
# account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection
|
||||
if prompt_type in [
|
||||
None,
|
||||
"",
|
||||
PromptType.plain.name,
|
||||
PromptType.plain.value,
|
||||
str(PromptType.plain.value),
|
||||
]:
|
||||
# if our prompt is plain, assume either correct or gradio server knows different prompt type,
|
||||
# so pass empty prompt_Type
|
||||
gr_prompt_type = ""
|
||||
gr_prompt_dict = ""
|
||||
gr_prompt = prompt # already prepared prompt
|
||||
gr_context = ""
|
||||
gr_iinput = ""
|
||||
else:
|
||||
# if already have prompt_type that is not plain, None, or '', then already applied some prompting
|
||||
# But assume server can handle prompting, and need to avoid double-up.
|
||||
# Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle
|
||||
# So avoid "prompt" and let gradio server reconstruct from prompt_type we passed
|
||||
# Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed,
|
||||
# because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter
|
||||
# since those won't appear
|
||||
gr_context = context
|
||||
gr_prompt = instruction
|
||||
gr_iinput = iinput
|
||||
gr_prompt_type = prompt_type
|
||||
gr_prompt_dict = prompt_dict
|
||||
client_kwargs = dict(
|
||||
instruction=gr_prompt
|
||||
if chat_client
|
||||
else "", # only for chat=True
|
||||
iinput=gr_iinput, # only for chat=True
|
||||
context=gr_context,
|
||||
# streaming output is supported, loops over and outputs each generation in streaming mode
|
||||
# but leave stream_output=False for simple input/output mode
|
||||
stream_output=stream_output,
|
||||
**gen_server_kwargs,
|
||||
prompt_type=gr_prompt_type,
|
||||
prompt_dict=gr_prompt_dict,
|
||||
instruction_nochat=gr_prompt
|
||||
if not chat_client
|
||||
else "",
|
||||
iinput_nochat=gr_iinput, # only for chat=False
|
||||
langchain_mode=client_langchain_mode,
|
||||
langchain_action=client_langchain_action,
|
||||
top_k_docs=top_k_docs,
|
||||
chunk=chunk,
|
||||
chunk_size=chunk_size,
|
||||
document_choice=[DocumentChoices.All_Relevant.name],
|
||||
)
|
||||
api_name = "/submit_nochat_api" # NOTE: like submit_nochat but stable API for string dict passing
|
||||
if not stream_output:
|
||||
res = gr_client.predict(
|
||||
str(dict(client_kwargs)), api_name=api_name
|
||||
)
|
||||
res_dict = ast.literal_eval(res)
|
||||
text = res_dict["response"]
|
||||
sources = res_dict["sources"]
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources=sources,
|
||||
)
|
||||
else:
|
||||
job = gr_client.submit(
|
||||
str(dict(client_kwargs)), api_name=api_name
|
||||
)
|
||||
text = ""
|
||||
sources = ""
|
||||
res_dict = dict(response=text, sources=sources)
|
||||
while not job.done():
|
||||
outputs_list = job.communicator.job.outputs
|
||||
if outputs_list:
|
||||
res = job.communicator.job.outputs[-1]
|
||||
res_dict = ast.literal_eval(res)
|
||||
text = res_dict["response"]
|
||||
sources = res_dict["sources"]
|
||||
if gr_prompt_type == "plain":
|
||||
# then gradio server passes back full prompt + text
|
||||
prompt_and_text = text
|
||||
else:
|
||||
prompt_and_text = prompt + text
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt_and_text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources=sources,
|
||||
)
|
||||
time.sleep(0.01)
|
||||
# ensure get last output to avoid race
|
||||
res_all = job.outputs()
|
||||
if len(res_all) > 0:
|
||||
res = res_all[-1]
|
||||
res_dict = ast.literal_eval(res)
|
||||
text = res_dict["response"]
|
||||
sources = res_dict["sources"]
|
||||
else:
|
||||
# go with old text if last call didn't work
|
||||
e = job.future._exception
|
||||
if e is not None:
|
||||
stre = str(e)
|
||||
strex = "".join(
|
||||
traceback.format_tb(e.__traceback__)
|
||||
)
|
||||
else:
|
||||
stre = ""
|
||||
strex = ""
|
||||
|
||||
print(
|
||||
"Bad final response: %s %s %s %s %s: %s %s"
|
||||
% (
|
||||
base_model,
|
||||
inference_server,
|
||||
res_all,
|
||||
prompt,
|
||||
text,
|
||||
stre,
|
||||
strex,
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
if gr_prompt_type == "plain":
|
||||
# then gradio server passes back full prompt + text
|
||||
prompt_and_text = text
|
||||
else:
|
||||
prompt_and_text = prompt + text
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt_and_text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources=sources,
|
||||
)
|
||||
elif hf_client:
|
||||
# HF inference server needs control over input tokens
|
||||
where_from = "hf_client"
|
||||
|
||||
# prompt must include all human-bot like tokens, already added by prompt
|
||||
# https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types
|
||||
stop_sequences = list(
|
||||
set(
|
||||
prompter.terminate_response
|
||||
+ [prompter.PreResponse]
|
||||
)
|
||||
)
|
||||
stop_sequences = [x for x in stop_sequences if x]
|
||||
gen_server_kwargs = dict(
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
# best_of=None,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=True,
|
||||
seed=SEED,
|
||||
stop_sequences=stop_sequences,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
# truncate=False, # behaves oddly
|
||||
# typical_p=top_p,
|
||||
# watermark=False,
|
||||
# decoder_input_details=False,
|
||||
)
|
||||
# work-around for timeout at constructor time, will be issue if multi-threading,
|
||||
# so just do something reasonable or max_time if larger
|
||||
# lower bound because client is re-used if multi-threading
|
||||
hf_client.timeout = max(300, max_time)
|
||||
if not stream_output:
|
||||
text = hf_client.generate(
|
||||
prompt, **gen_server_kwargs
|
||||
).generated_text
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
for response in hf_client.generate_stream(
|
||||
prompt, **gen_server_kwargs
|
||||
):
|
||||
if not response.token.special:
|
||||
# stop_sequences
|
||||
text_chunk = response.token.text
|
||||
text += text_chunk
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Failed to get client: %s" % inference_server
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No such inference_server %s" % inference_server
|
||||
)
|
||||
|
||||
if save_dir and text:
|
||||
# save prompt + new text
|
||||
extra_dict = gen_server_kwargs.copy()
|
||||
extra_dict.update(
|
||||
dict(
|
||||
inference_server=inference_server,
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
)
|
||||
)
|
||||
save_generate_output(
|
||||
prompt=prompt,
|
||||
output=text,
|
||||
base_model=base_model,
|
||||
save_dir=save_dir,
|
||||
where_from=where_from,
|
||||
extra_dict=extra_dict,
|
||||
)
|
||||
return
|
||||
else:
|
||||
assert not inference_server, (
|
||||
"inferene_server=%s not supported" % inference_server
|
||||
)
|
||||
|
||||
if isinstance(tokenizer, str):
|
||||
# pipeline
|
||||
if tokenizer == "summarization":
|
||||
key = "summary_text"
|
||||
else:
|
||||
raise RuntimeError("No such task type %s" % tokenizer)
|
||||
# NOTE: uses max_length only
|
||||
yield dict(
|
||||
response=model(prompt, max_length=max_new_tokens)[0][key],
|
||||
sources="",
|
||||
)
|
||||
|
||||
if "mbart-" in base_model.lower():
|
||||
assert src_lang is not None
|
||||
tokenizer.src_lang = self.languages_covered()[src_lang]
|
||||
|
||||
stopping_criteria = get_stopping(
|
||||
prompt_type,
|
||||
prompt_dict,
|
||||
tokenizer,
|
||||
self.device,
|
||||
model_max_length=tokenizer.model_max_length,
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
# exit(0)
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
if debug and len(inputs["input_ids"]) > 0:
|
||||
print("input_ids length", len(inputs["input_ids"][0]), flush=True)
|
||||
input_ids = inputs["input_ids"].to(self.device)
|
||||
# CRITICAL LIMIT else will fail
|
||||
max_max_tokens = tokenizer.model_max_length
|
||||
max_input_tokens = max_max_tokens - min_new_tokens
|
||||
# NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py
|
||||
input_ids = input_ids[:, -max_input_tokens:]
|
||||
# required for falcon if multiple threads or asyncio accesses to model during generation
|
||||
if use_cache is None:
|
||||
use_cache = False if "falcon" in base_model else True
|
||||
gen_config_kwargs = dict(
|
||||
temperature=float(temperature),
|
||||
top_p=float(top_p),
|
||||
top_k=top_k,
|
||||
num_beams=num_beams,
|
||||
do_sample=do_sample,
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
num_return_sequences=num_return_sequences,
|
||||
renormalize_logits=True,
|
||||
remove_invalid_values=True,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
token_ids = [
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
"bos_token_id",
|
||||
"cls_token_id",
|
||||
"sep_token_id",
|
||||
]
|
||||
for token_id in token_ids:
|
||||
if (
|
||||
hasattr(tokenizer, token_id)
|
||||
and getattr(tokenizer, token_id) is not None
|
||||
):
|
||||
gen_config_kwargs.update(
|
||||
{token_id: getattr(tokenizer, token_id)}
|
||||
)
|
||||
generation_config = GenerationConfig(**gen_config_kwargs)
|
||||
|
||||
gen_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
generation_config=generation_config,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
max_new_tokens=max_new_tokens, # prompt + new
|
||||
min_new_tokens=min_new_tokens, # prompt + new
|
||||
early_stopping=early_stopping, # False, True, "never"
|
||||
max_time=max_time,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
if "gpt2" in base_model.lower():
|
||||
gen_kwargs.update(
|
||||
dict(
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
)
|
||||
elif "mbart-" in base_model.lower():
|
||||
assert tgt_lang is not None
|
||||
tgt_lang = self.languages_covered()[tgt_lang]
|
||||
gen_kwargs.update(
|
||||
dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])
|
||||
)
|
||||
else:
|
||||
token_ids = ["eos_token_id", "bos_token_id", "pad_token_id"]
|
||||
for token_id in token_ids:
|
||||
if (
|
||||
hasattr(tokenizer, token_id)
|
||||
and getattr(tokenizer, token_id) is not None
|
||||
):
|
||||
gen_kwargs.update({token_id: getattr(tokenizer, token_id)})
|
||||
|
||||
decoder_kwargs = dict(
|
||||
skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
decoder = functools.partial(tokenizer.decode, **decoder_kwargs)
|
||||
decoder_raw_kwargs = dict(
|
||||
skip_special_tokens=False, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
decoder_raw = functools.partial(tokenizer.decode, **decoder_raw_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
have_lora_weights = lora_weights not in [no_lora_str, "", None]
|
||||
context_class_cast = (
|
||||
NullContext
|
||||
if self.device == "cpu"
|
||||
or have_lora_weights
|
||||
or self.device == "mps"
|
||||
else torch.autocast
|
||||
)
|
||||
with context_class_cast(self.device):
|
||||
# protection for gradio not keeping track of closed users,
|
||||
# else hit bitsandbytes lack of thread safety:
|
||||
# https://github.com/h2oai/h2ogpt/issues/104
|
||||
# but only makes sense if concurrency_count == 1
|
||||
context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
|
||||
if verbose:
|
||||
print("Pre-Generate: %s" % str(datetime.now()), flush=True)
|
||||
decoded_output = None
|
||||
with context_class("generate.lock"):
|
||||
if verbose:
|
||||
print("Generate: %s" % str(datetime.now()), flush=True)
|
||||
# decoded tokenized prompt can deviate from prompt due to special characters
|
||||
inputs_decoded = decoder(input_ids[0])
|
||||
inputs_decoded_raw = decoder_raw(input_ids[0])
|
||||
if inputs_decoded == prompt:
|
||||
# normal
|
||||
pass
|
||||
elif inputs_decoded.lstrip() == prompt.lstrip():
|
||||
# sometimes extra space in front, make prompt same for prompt removal
|
||||
prompt = inputs_decoded
|
||||
elif inputs_decoded_raw == prompt:
|
||||
# some models specify special tokens that are part of normal prompt, so can't skip them
|
||||
inputs_decoded = prompt = inputs_decoded_raw
|
||||
decoder = decoder_raw
|
||||
decoder_kwargs = decoder_raw_kwargs
|
||||
elif inputs_decoded_raw.replace("<unk> ", "").replace(
|
||||
"<unk>", ""
|
||||
).replace("\n", " ").replace(" ", "") == prompt.replace(
|
||||
"\n", " "
|
||||
).replace(
|
||||
" ", ""
|
||||
):
|
||||
inputs_decoded = prompt = inputs_decoded_raw
|
||||
decoder = decoder_raw
|
||||
decoder_kwargs = decoder_raw_kwargs
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"WARNING: Special characters in prompt",
|
||||
flush=True,
|
||||
)
|
||||
if stream_output:
|
||||
skip_prompt = False
|
||||
streamer = H2OTextIteratorStreamer(
|
||||
tokenizer,
|
||||
skip_prompt=skip_prompt,
|
||||
block=False,
|
||||
**decoder_kwargs,
|
||||
)
|
||||
gen_kwargs.update(dict(streamer=streamer))
|
||||
target = wrapped_partial(
|
||||
self.generate_with_exceptions,
|
||||
model.generate,
|
||||
prompt=prompt,
|
||||
inputs_decoded=inputs_decoded,
|
||||
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
|
||||
**gen_kwargs,
|
||||
)
|
||||
bucket = queue.Queue()
|
||||
thread = EThread(
|
||||
target=target, streamer=streamer, bucket=bucket
|
||||
)
|
||||
thread.start()
|
||||
outputs = ""
|
||||
try:
|
||||
for new_text in streamer:
|
||||
if bucket.qsize() > 0 or thread.exc:
|
||||
thread.join()
|
||||
outputs += new_text
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
outputs,
|
||||
prompt=inputs_decoded,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
except BaseException:
|
||||
# if any exception, raise that exception if was from thread, first
|
||||
if thread.exc:
|
||||
raise thread.exc
|
||||
raise
|
||||
finally:
|
||||
# clear before return, since .then() never done if from API
|
||||
clear_torch_cache()
|
||||
# in case no exception and didn't join with thread yet, then join
|
||||
if not thread.exc:
|
||||
thread.join()
|
||||
# in case raise StopIteration or broke queue loop in streamer, but still have exception
|
||||
if thread.exc:
|
||||
raise thread.exc
|
||||
decoded_output = outputs
|
||||
else:
|
||||
try:
|
||||
outputs = model.generate(**gen_kwargs)
|
||||
finally:
|
||||
clear_torch_cache() # has to be here for API submit_nochat_api since.then() not called
|
||||
outputs = [decoder(s) for s in outputs.sequences]
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
outputs,
|
||||
prompt=inputs_decoded,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
if outputs and len(outputs) >= 1:
|
||||
decoded_output = prompt + outputs[0]
|
||||
if save_dir and decoded_output:
|
||||
extra_dict = gen_config_kwargs.copy()
|
||||
extra_dict.update(
|
||||
dict(num_prompt_tokens=num_prompt_tokens)
|
||||
)
|
||||
save_generate_output(
|
||||
prompt=prompt,
|
||||
output=decoded_output,
|
||||
base_model=base_model,
|
||||
save_dir=save_dir,
|
||||
where_from="evaluate_%s" % str(stream_output),
|
||||
extra_dict=gen_config_kwargs,
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
"Post-Generate: %s decoded_output: %s"
|
||||
% (
|
||||
str(datetime.now()),
|
||||
len(decoded_output) if decoded_output else -1,
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
return outputs[0]
|
||||
return response
|
||||
|
||||
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
||||
global inputs_kwargs_list
|
||||
|
||||
@@ -2510,8 +2510,7 @@ def _run_qa_db(
|
||||
formatted_doc_chunks = "\n\n".join(
|
||||
[get_url(x) + "\n\n" + x.page_content for x in docs]
|
||||
)
|
||||
yield formatted_doc_chunks, ""
|
||||
return
|
||||
return formatted_doc_chunks, ""
|
||||
if not docs and langchain_action in [
|
||||
LangChainAction.SUMMARIZE_MAP.value,
|
||||
LangChainAction.SUMMARIZE_ALL.value,
|
||||
@@ -2523,8 +2522,7 @@ def _run_qa_db(
|
||||
else "No documents to summarize."
|
||||
)
|
||||
extra = ""
|
||||
yield ret, extra
|
||||
return
|
||||
return ret, extra
|
||||
if not docs and langchain_mode not in [
|
||||
LangChainMode.DISABLED.value,
|
||||
LangChainMode.CHAT_LLM.value,
|
||||
@@ -2536,8 +2534,7 @@ def _run_qa_db(
|
||||
else "No documents to query."
|
||||
)
|
||||
extra = ""
|
||||
yield ret, extra
|
||||
return
|
||||
return ret, extra
|
||||
|
||||
if chain is None and model_name not in non_hf_types:
|
||||
# here if no docs at all and not HF type
|
||||
@@ -2561,7 +2558,7 @@ def _run_qa_db(
|
||||
if not use_context:
|
||||
ret = answer["output_text"]
|
||||
extra = ""
|
||||
yield ret, extra
|
||||
return ret, extra
|
||||
elif answer is not None:
|
||||
ret, extra = get_sources_answer(
|
||||
query,
|
||||
@@ -2571,7 +2568,7 @@ def _run_qa_db(
|
||||
answer_with_sources,
|
||||
verbose=verbose,
|
||||
)
|
||||
yield ret, extra
|
||||
return ret, extra
|
||||
return
|
||||
|
||||
|
||||
|
||||
@@ -164,10 +164,7 @@ def chat(curr_system_message, history, device, precision):
|
||||
model_lock=True,
|
||||
user_path=userpath_selector.value,
|
||||
)
|
||||
for partial_text in output:
|
||||
history[-1][1] = partial_text["response"]
|
||||
yield history
|
||||
|
||||
history[-1][1] = output["response"]
|
||||
return history
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user