mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[Langchain] Expand pipelines to fix token streaming issue
This commit is contained in:
572
apps/language_models/langchain/exp_hf_pipelines.py
Normal file
572
apps/language_models/langchain/exp_hf_pipelines.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""Wrapper around HuggingFace Pipeline APIs."""
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
import enum
|
||||
import warnings
|
||||
from transformers.pipelines.base import PIPELINE_INIT_ARGS, Pipeline
|
||||
from transformers.utils import add_end_docstrings
|
||||
from transformers import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFacePipeline(LLM):
|
||||
"""Wrapper around HuggingFace Pipeline API.
|
||||
|
||||
To use, you should have the ``transformers`` python package installed.
|
||||
|
||||
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
|
||||
|
||||
Example using from_model_id:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import HuggingFacePipeline
|
||||
hf = HuggingFacePipeline.from_model_id(
|
||||
model_id="gpt2",
|
||||
task="text-generation",
|
||||
pipeline_kwargs={"max_new_tokens": 10},
|
||||
)
|
||||
Example passing pipeline in directly:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import HuggingFacePipeline
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
model_id = "gpt2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
pipe = pipeline(
|
||||
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
|
||||
)
|
||||
hf = HuggingFacePipeline(pipeline=pipe)
|
||||
"""
|
||||
|
||||
pipeline: Any #: :meta private:
|
||||
model_id: str = DEFAULT_MODEL_ID
|
||||
"""Model name to use."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments passed to the model."""
|
||||
pipeline_kwargs: Optional[dict] = None
|
||||
"""Key word arguments passed to the pipeline."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@classmethod
|
||||
def from_model_id(
|
||||
cls,
|
||||
model_id: str,
|
||||
task: str,
|
||||
device: int = -1,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
pipeline_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLM:
|
||||
"""Construct the pipeline object from model_id and task."""
|
||||
try:
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from transformers import pipeline as hf_pipeline
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
|
||||
_model_kwargs = model_kwargs or {}
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||
|
||||
try:
|
||||
if task == "text-generation":
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, **_model_kwargs
|
||||
)
|
||||
elif task in ("text2text-generation", "summarization"):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_id, **_model_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"Could not load the {task} model due to missing dependencies."
|
||||
) from e
|
||||
|
||||
if importlib.util.find_spec("torch") is not None:
|
||||
import torch
|
||||
|
||||
cuda_device_count = torch.cuda.device_count()
|
||||
if device < -1 or (device >= cuda_device_count):
|
||||
raise ValueError(
|
||||
f"Got device=={device}, "
|
||||
f"device is required to be within [-1, {cuda_device_count})"
|
||||
)
|
||||
if device < 0 and cuda_device_count > 0:
|
||||
logger.warning(
|
||||
"Device has %d GPUs available. "
|
||||
"Provide device={deviceId} to `from_model_id` to use available"
|
||||
"GPUs for execution. deviceId is -1 (default) for CPU and "
|
||||
"can be a positive integer associated with CUDA device id.",
|
||||
cuda_device_count,
|
||||
)
|
||||
if "trust_remote_code" in _model_kwargs:
|
||||
_model_kwargs = {
|
||||
k: v
|
||||
for k, v in _model_kwargs.items()
|
||||
if k != "trust_remote_code"
|
||||
}
|
||||
_pipeline_kwargs = pipeline_kwargs or {}
|
||||
pipeline = hf_pipeline(
|
||||
task=task,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
model_kwargs=_model_kwargs,
|
||||
**_pipeline_kwargs,
|
||||
)
|
||||
if pipeline.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
f"Got invalid task {pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
return cls(
|
||||
pipeline=pipeline,
|
||||
model_id=model_id,
|
||||
model_kwargs=_model_kwargs,
|
||||
pipeline_kwargs=_pipeline_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model_id": self.model_id,
|
||||
"model_kwargs": self.model_kwargs,
|
||||
"pipeline_kwargs": self.pipeline_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "huggingface_pipeline"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
response = self.pipeline(prompt)
|
||||
if self.pipeline.task == "text-generation":
|
||||
# Text generation return includes the starter text.
|
||||
text = response[0]["generated_text"][len(prompt) :]
|
||||
elif self.pipeline.task == "text2text-generation":
|
||||
text = response[0]["generated_text"]
|
||||
elif self.pipeline.task == "summarization":
|
||||
text = response[0]["summary_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
|
||||
##### TextGenerationPipeline
|
||||
|
||||
|
||||
class ReturnType(enum.Enum):
|
||||
TENSORS = 0
|
||||
NEW_TEXT = 1
|
||||
FULL_TEXT = 2
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class TextGenerationPipeline(Pipeline):
|
||||
"""
|
||||
Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a
|
||||
specified text prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> generator = pipeline(model="gpt2")
|
||||
>>> generator("I can't believe you did such a ", do_sample=False)
|
||||
[{'generated_text': "I can't believe you did such a icky thing to me. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I"}]
|
||||
|
||||
>>> # These parameters will return suggestions, and only the newly created text making it easier for prompting suggestions.
|
||||
>>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False)
|
||||
```
|
||||
|
||||
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||
|
||||
This language generation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"text-generation"`.
|
||||
|
||||
The models that this pipeline can use are models that have been trained with an autoregressive language modeling
|
||||
objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available models
|
||||
on [huggingface.co/models](https://huggingface.co/models?filter=text-generation).
|
||||
"""
|
||||
|
||||
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
||||
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
||||
|
||||
XL_PREFIX = """
|
||||
In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The
|
||||
voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western
|
||||
Siberia, a young Grigori Rasputin is asked by his father and a group of men to perform magic. Rasputin has a vision
|
||||
and denounces one of the men as a horse thief. Although his father initially slaps him for making such an
|
||||
accusation, Rasputin watches as the man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
||||
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop,
|
||||
begging for his blessing. <eod> </s> <eos>
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
if self.framework == "tf"
|
||||
else MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
)
|
||||
if "prefix" not in self._preprocess_params:
|
||||
# This is very specific. The logic is quite complex and needs to be done
|
||||
# as a "default".
|
||||
# It also defines both some preprocess_kwargs and generate_kwargs
|
||||
# which is why we cannot put them in their respective methods.
|
||||
prefix = None
|
||||
if self.model.config.prefix is not None:
|
||||
prefix = self.model.config.prefix
|
||||
if prefix is None and self.model.__class__.__name__ in [
|
||||
"XLNetLMHeadModel",
|
||||
"TransfoXLLMHeadModel",
|
||||
"TFXLNetLMHeadModel",
|
||||
"TFTransfoXLLMHeadModel",
|
||||
]:
|
||||
# For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
|
||||
prefix = self.XL_PREFIX
|
||||
if prefix is not None:
|
||||
# Recalculate some generate_kwargs linked to prefix.
|
||||
(
|
||||
preprocess_params,
|
||||
forward_params,
|
||||
_,
|
||||
) = self._sanitize_parameters(
|
||||
prefix=prefix, **self._forward_params
|
||||
)
|
||||
self._preprocess_params = {
|
||||
**self._preprocess_params,
|
||||
**preprocess_params,
|
||||
}
|
||||
self._forward_params = {
|
||||
**self._forward_params,
|
||||
**forward_params,
|
||||
}
|
||||
|
||||
def _sanitize_parameters(
|
||||
self,
|
||||
return_full_text=None,
|
||||
return_tensors=None,
|
||||
return_text=None,
|
||||
return_type=None,
|
||||
clean_up_tokenization_spaces=None,
|
||||
prefix=None,
|
||||
handle_long_generation=None,
|
||||
stop_sequence=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
preprocess_params = {}
|
||||
if prefix is not None:
|
||||
preprocess_params["prefix"] = prefix
|
||||
if prefix:
|
||||
prefix_inputs = self.tokenizer(
|
||||
prefix,
|
||||
padding=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors=self.framework,
|
||||
)
|
||||
generate_kwargs["prefix_length"] = prefix_inputs[
|
||||
"input_ids"
|
||||
].shape[-1]
|
||||
|
||||
if handle_long_generation is not None:
|
||||
if handle_long_generation not in {"hole"}:
|
||||
raise ValueError(
|
||||
f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected"
|
||||
" [None, 'hole']"
|
||||
)
|
||||
preprocess_params[
|
||||
"handle_long_generation"
|
||||
] = handle_long_generation
|
||||
|
||||
preprocess_params.update(generate_kwargs)
|
||||
forward_params = generate_kwargs
|
||||
|
||||
postprocess_params = {}
|
||||
if return_full_text is not None and return_type is None:
|
||||
if return_text is not None:
|
||||
raise ValueError(
|
||||
"`return_text` is mutually exclusive with `return_full_text`"
|
||||
)
|
||||
if return_tensors is not None:
|
||||
raise ValueError(
|
||||
"`return_full_text` is mutually exclusive with `return_tensors`"
|
||||
)
|
||||
return_type = (
|
||||
ReturnType.FULL_TEXT
|
||||
if return_full_text
|
||||
else ReturnType.NEW_TEXT
|
||||
)
|
||||
if return_tensors is not None and return_type is None:
|
||||
if return_text is not None:
|
||||
raise ValueError(
|
||||
"`return_text` is mutually exclusive with `return_tensors`"
|
||||
)
|
||||
return_type = ReturnType.TENSORS
|
||||
if return_type is not None:
|
||||
postprocess_params["return_type"] = return_type
|
||||
if clean_up_tokenization_spaces is not None:
|
||||
postprocess_params[
|
||||
"clean_up_tokenization_spaces"
|
||||
] = clean_up_tokenization_spaces
|
||||
|
||||
if stop_sequence is not None:
|
||||
stop_sequence_ids = self.tokenizer.encode(
|
||||
stop_sequence, add_special_tokens=False
|
||||
)
|
||||
if len(stop_sequence_ids) > 1:
|
||||
warnings.warn(
|
||||
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
|
||||
" the stop sequence will be used as the stop sequence string in the interim."
|
||||
)
|
||||
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
||||
def _parse_and_tokenize(self, *args, **kwargs):
|
||||
"""
|
||||
Parse arguments and tokenize
|
||||
"""
|
||||
# Parse arguments
|
||||
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
|
||||
kwargs.update({"add_space_before_punct_symbol": True})
|
||||
|
||||
return super()._parse_and_tokenize(*args, **kwargs)
|
||||
|
||||
def __call__(self, text_inputs, **kwargs):
|
||||
"""
|
||||
Complete the prompt(s) given as inputs.
|
||||
|
||||
Args:
|
||||
args (`str` or `List[str]`):
|
||||
One or several prompts (or one list of prompts) to complete.
|
||||
return_tensors (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to
|
||||
`True`, the decoded text is not returned.
|
||||
return_text (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return the decoded texts in the outputs.
|
||||
return_full_text (`bool`, *optional*, defaults to `True`):
|
||||
If set to `False` only added text is returned, otherwise the full text is returned. Only meaningful if
|
||||
*return_text* is set to True.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
prefix (`str`, *optional*):
|
||||
Prefix added to prompt.
|
||||
handle_long_generation (`str`, *optional*):
|
||||
By default, this pipelines does not handle long generation (ones that exceed in one form or the other
|
||||
the model maximum length). There is no perfect way to adress this (more info
|
||||
:https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
|
||||
strategies to work around that problem depending on your use case.
|
||||
|
||||
- `None` : default strategy where nothing in particular happens
|
||||
- `"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
|
||||
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
|
||||
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||
corresponding to your framework [here](./model#generative-models)).
|
||||
|
||||
Return:
|
||||
A list or a list of list of `dict`: Returns one of the following dictionaries (cannot return a combination
|
||||
of both `generated_text` and `generated_token_ids`):
|
||||
|
||||
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.
|
||||
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
|
||||
ids of the generated text.
|
||||
"""
|
||||
return super().__call__(text_inputs, **kwargs)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
prompt_text,
|
||||
prefix="",
|
||||
handle_long_generation=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
inputs = self.tokenizer(
|
||||
prefix + prompt_text,
|
||||
padding=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors=self.framework,
|
||||
)
|
||||
inputs["prompt_text"] = prompt_text
|
||||
|
||||
if handle_long_generation == "hole":
|
||||
cur_len = inputs["input_ids"].shape[-1]
|
||||
if "max_new_tokens" in generate_kwargs:
|
||||
new_tokens = generate_kwargs["max_new_tokens"]
|
||||
else:
|
||||
new_tokens = (
|
||||
generate_kwargs.get(
|
||||
"max_length", self.model.config.max_length
|
||||
)
|
||||
- cur_len
|
||||
)
|
||||
if new_tokens < 0:
|
||||
raise ValueError(
|
||||
"We cannot infer how many new tokens are expected"
|
||||
)
|
||||
if cur_len + new_tokens > self.tokenizer.model_max_length:
|
||||
keep_length = self.tokenizer.model_max_length - new_tokens
|
||||
if keep_length <= 0:
|
||||
raise ValueError(
|
||||
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
|
||||
" models max length"
|
||||
)
|
||||
|
||||
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = inputs["attention_mask"][
|
||||
:, -keep_length:
|
||||
]
|
||||
|
||||
return inputs
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
prompt_text = model_inputs.pop("prompt_text")
|
||||
|
||||
# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
||||
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
||||
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
||||
if prefix_length > 0:
|
||||
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].max_new_tokens
|
||||
is not None
|
||||
)
|
||||
if not has_max_new_tokens:
|
||||
generate_kwargs["max_length"] = (
|
||||
generate_kwargs.get("max_length")
|
||||
or self.model.config.max_length
|
||||
)
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].min_new_tokens
|
||||
is not None
|
||||
)
|
||||
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
|
||||
# BS x SL
|
||||
generated_sequence = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt":
|
||||
generated_sequence = generated_sequence.reshape(
|
||||
in_b, out_b // in_b, *generated_sequence.shape[1:]
|
||||
)
|
||||
return {
|
||||
"generated_sequence": generated_sequence,
|
||||
"input_ids": input_ids,
|
||||
"prompt_text": prompt_text,
|
||||
}
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
model_outputs,
|
||||
return_type=ReturnType.FULL_TEXT,
|
||||
clean_up_tokenization_spaces=True,
|
||||
):
|
||||
generated_sequence = model_outputs["generated_sequence"][0]
|
||||
input_ids = model_outputs["input_ids"]
|
||||
prompt_text = model_outputs["prompt_text"]
|
||||
generated_sequence = generated_sequence.numpy().tolist()
|
||||
records = []
|
||||
for sequence in generated_sequence:
|
||||
if return_type == ReturnType.TENSORS:
|
||||
record = {"generated_token_ids": sequence}
|
||||
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
||||
# Decode text
|
||||
text = self.tokenizer.decode(
|
||||
sequence,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
|
||||
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
|
||||
if input_ids is None:
|
||||
prompt_length = 0
|
||||
else:
|
||||
prompt_length = len(
|
||||
self.tokenizer.decode(
|
||||
input_ids[0],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
)
|
||||
|
||||
if return_type == ReturnType.FULL_TEXT:
|
||||
all_text = prompt_text + text[prompt_length:]
|
||||
else:
|
||||
all_text = text[prompt_length:]
|
||||
|
||||
record = {"generated_text": all_text}
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
@@ -1,4 +1,3 @@
|
||||
"""Load question answering chains."""
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -11,23 +10,34 @@ from typing import (
|
||||
Union,
|
||||
Protocol,
|
||||
)
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from abc import ABC, abstractmethod
|
||||
import langchain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.question_answering import stuff_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
from abc import ABC, abstractmethod
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from pydantic import Extra, Field, root_validator
|
||||
from pydantic import Extra, Field, root_validator, validator
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
@@ -48,6 +58,257 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
return prompt.format(**document_info)
|
||||
|
||||
|
||||
class Chain(Serializable, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain expects."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys this chain expects."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
missing_keys = set(self.output_keys).difference(outputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(
|
||||
self, inputs: Union[Dict[str, Any], Any]
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(
|
||||
self.memory.memory_variables
|
||||
)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError(
|
||||
"`run` supports only one positional argument."
|
||||
)
|
||||
return self(args[0], callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
"`run` supported with either positional arguments or keyword arguments,"
|
||||
" but none were provided."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of chain."""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
_dict = super().dict()
|
||||
_dict["_type"] = self._chain_type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the chain.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the chain to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.save(file_path="path/chain.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(chain_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
@@ -79,12 +340,6 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def combine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
@@ -96,9 +351,27 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
output, extra_return_dict = self.combine_docs(
|
||||
docs, callbacks=_run_manager.get_child(), **other_keys
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
|
||||
# Call predict on the LLM.
|
||||
output, extra_return_dict = (
|
||||
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
|
||||
self.llm_chain.output_key
|
||||
],
|
||||
{},
|
||||
)
|
||||
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
@@ -153,21 +426,13 @@ class LLMChain(Chain):
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = self.generate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
return self.llm.generate_prompt(
|
||||
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
|
||||
response = self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
@@ -223,23 +488,6 @@ class LLMChain(Chain):
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
Completion from LLM.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs, callbacks=callbacks)[self.output_key]
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
@@ -350,14 +598,6 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "stuff_documents_chain"
|
||||
|
||||
@@ -968,7 +968,7 @@ def get_llm(
|
||||
# not built in prompt removal that is less general and not specific for our model
|
||||
pipe.task = "text2text-generation"
|
||||
|
||||
from langchain.llms import HuggingFacePipeline
|
||||
from exp_hf_pipelines import HuggingFacePipeline
|
||||
|
||||
llm = HuggingFacePipeline(pipeline=pipe)
|
||||
return llm, model_name, streamer, prompt_type
|
||||
|
||||
@@ -3,13 +3,10 @@ from apps.stable_diffusion.src.utils.utils import _compile_module
|
||||
from io import BytesIO
|
||||
import torch_mlir
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
|
||||
from stopping import get_stopping
|
||||
from prompter import Prompter, PromptType
|
||||
|
||||
|
||||
from exp_hf_pipelines import TextGenerationPipeline, ReturnType
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
|
||||
Reference in New Issue
Block a user