Merge branch 'develop' into feature/dataset-customization

This commit is contained in:
Alex O'Connell
2024-02-13 20:23:25 -05:00
17 changed files with 67 additions and 30 deletions

View File

@@ -208,6 +208,7 @@ It is highly recommend to set up text-generation-webui on a separate machine tha
## Version History
| Version | Description |
| ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| v0.2.6 | Bug fixes, add options for limiting chat history, HTTPS endpoint support, added zephyr prompt format. |
| v0.2.5 | Fix Ollama max tokens parameter, fix GGUF download from Hugging Face, update included llama-cpp-python to 0.2.32, and add parameters to function calling for dataset + component, & model update |
| v0.2.4 | Fix API key auth on model load for text-generation-webui, and add support for Ollama API backend |
| v0.2.3 | Fix API key auth, Support chat completion endpoint, and refactor to make it easier to add more remote backends |

View File

@@ -22,6 +22,7 @@
- [ ] figure out DPO for refusals + fixing incorrect entity id
- [ ] mixtral + prompting (no fine tuning)
- [ ] use varied system prompts to add behaviors
- [ ] setup github actions to build wheels that are optimized for RPIs
## more complicated ideas
- [ ] "context requests"
@@ -32,4 +33,4 @@
- [ ] RAG for getting info for setting up new devices
- set up vectordb
- ingest home assistant docs
- "context request" from above to initiate a RAG search
- "context request" from above to initiate a RAG search

View File

@@ -23,10 +23,9 @@ RUN \
python3-venv \
python3-pip \
\
&& git clone https://github.com/oobabooga/text-generation-webui.git ${APP_DIR} --branch snapshot-2024-01-28 \
&& git clone https://github.com/oobabooga/text-generation-webui.git ${APP_DIR} --branch snapshot-2024-02-11 \
&& python3 -m pip install torch torchvision torchaudio py-cpuinfo==9.0.0 \
&& python3 -m pip install -r ${APP_DIR}/requirements_cpu_only_noavx2.txt -r ${APP_DIR}/extensions/openai/requirements.txt llama-cpp-python \
&& python3 -m pip install llama-cpp-python==0.2.32 \
&& apt-get purge -y --auto-remove \
git \
build-essential \

View File

@@ -1,6 +1,6 @@
---
name: oobabooga-text-generation-webui
version: 2024.01.28
version: 2024.02.11
slug: text-generation-webui
description: "A tool for running Large Language Models"
url: "https://github.com/oobabooga/text-generation-webui"

View File

@@ -21,7 +21,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
async_should_expose,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, MATCH_ALL
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL
from homeassistant.core import (
HomeAssistant,
ServiceCall,
@@ -54,6 +54,8 @@ from .const import (
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
@@ -68,6 +70,7 @@ from .const import (
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
@@ -179,7 +182,7 @@ def flatten_schema(schema):
if isinstance(current_schema.schema, vol.validators._WithSubValidators):
for subval in current_schema.schema.validators:
_flatten(subval, prefix)
else:
elif isinstance(current_schema.schema, dict):
for key, val in current_schema.schema.items():
_flatten(val, prefix + str(key) + '/')
elif isinstance(current_schema, vol.validators._WithSubValidators):
@@ -236,6 +239,8 @@ class LLaMAAgent(AbstractConversationAgent):
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
refresh_system_prompt = self.entry.options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
remember_conversation = self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, False)
service_call_regex = self.entry.options.get(CONF_SERVICE_CALL_REGEX, DEFAULT_SERVICE_CALL_REGEX)
extra_attributes_to_expose = self.entry.options \
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
@@ -256,7 +261,7 @@ class LLaMAAgent(AbstractConversationAgent):
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
conversation = self.history[conversation_id]
conversation = self.history[conversation_id] if remember_conversation else [self.history[conversation_id][0]]
else:
conversation_id = ulid.ulid()
conversation = []
@@ -279,6 +284,8 @@ class LLaMAAgent(AbstractConversationAgent):
if len(conversation) == 0:
conversation.append(system_prompt)
if not remember_conversation:
self.history[conversation_id] = conversation
else:
conversation[0] = system_prompt
@@ -302,7 +309,11 @@ class LLaMAAgent(AbstractConversationAgent):
)
conversation.append({"role": "assistant", "message": response})
self.history[conversation_id] = conversation
if remember_conversation:
if remember_num_interactions and len(conversation) > (remember_num_interactions * 2) + 1:
for i in range(0,2):
conversation.pop(1)
self.history[conversation_id] = conversation
exposed_entities = list(self._async_get_exposed_entities()[0].keys())
@@ -545,8 +556,7 @@ class GenericOpenAIAPIAgent(LLaMAAgent):
model_name: str
def _load_model(self, entry: ConfigEntry) -> None:
# TODO: https
self.api_host = f"http://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
self.api_host = f"{'https' if entry.data[CONF_SSL] else 'http'}://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
@@ -647,8 +657,6 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
else:
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
load_result = requests.post(
f"{self.api_host}/v1/internal/model/load",
json={
@@ -675,7 +683,6 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
if preset:
request_params["character"] = preset
elif chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT:
# TODO: handle uppercase properly?
request_params["instruction_template"] = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
return endpoint, request_params
@@ -739,8 +746,7 @@ class OllamaAPIAgent(LLaMAAgent):
model_name: str
def _load_model(self, entry: ConfigEntry) -> None:
# TODO: https
self.api_host = f"http://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
self.api_host = f"{'https' if entry.data[CONF_SSL] else 'http'}://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)

View File

@@ -19,7 +19,7 @@ from homeassistant import config_entries
from homeassistant.core import HomeAssistant
from homeassistant.requirements import pip_kwargs
from homeassistant.util.package import install_package, is_installed
from homeassistant.const import CONF_HOST, CONF_PORT
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.data_entry_flow import (
AbortFlow,
FlowHandler,
@@ -54,6 +54,8 @@ from .const import (
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
CONF_SERVICE_CALL_REGEX,
@@ -62,6 +64,7 @@ from .const import (
DEFAULT_CHAT_MODEL,
DEFAULT_HOST,
DEFAULT_PORT,
DEFAULT_SSL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
@@ -74,6 +77,7 @@ from .const import (
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_OPTIONS,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
@@ -84,11 +88,7 @@ from .const import (
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
PROMPT_TEMPLATE_CHATML,
PROMPT_TEMPLATE_ALPACA,
PROMPT_TEMPLATE_VICUNA,
PROMPT_TEMPLATE_MISTRAL,
PROMPT_TEMPLATE_NONE,
PROMPT_TEMPLATE_DESCRIPTIONS,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
@@ -136,7 +136,7 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q
}
)
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, chat_model=None, use_chat_endpoint=None, webui_preset="", webui_chat_mode=""):
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ssl=None, chat_model=None, use_chat_endpoint=None, webui_preset="", webui_chat_mode=""):
extra1, extra2 = ({}, {})
default_port = DEFAULT_PORT
@@ -158,6 +158,7 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ch
{
vol.Required(CONF_HOST, default=host if host else DEFAULT_HOST): str,
vol.Required(CONF_PORT, default=port if port else default_port): str,
vol.Required(CONF_SSL, default=ssl if ssl else DEFAULT_SSL): bool,
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): str,
vol.Required(CONF_REMOTE_USE_CHAT_ENDPOINT, default=use_chat_endpoint if use_chat_endpoint else DEFAULT_REMOTE_USE_CHAT_ENDPOINT): bool,
**extra1,
@@ -469,7 +470,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
headers["Authorization"] = f"Bearer {api_key}"
models_result = requests.get(
f"http://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/v1/internal/model/list",
f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/v1/internal/model/list",
headers=headers
)
models_result.raise_for_status()
@@ -508,6 +509,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
backend_type,
host=user_input[CONF_HOST],
port=user_input[CONF_PORT],
ssl=user_input[CONF_SSL],
chat_model=user_input[CONF_CHAT_MODEL],
use_chat_endpoint=user_input[CONF_REMOTE_USE_CHAT_ENDPOINT],
webui_preset=user_input.get(CONF_TEXT_GEN_WEBUI_PRESET),
@@ -601,7 +603,7 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
description={"suggested_value": options.get(CONF_PROMPT_TEMPLATE)},
default=DEFAULT_PROMPT_TEMPLATE,
): SelectSelector(SelectSelectorConfig(
options=[PROMPT_TEMPLATE_CHATML, PROMPT_TEMPLATE_ALPACA, PROMPT_TEMPLATE_VICUNA, PROMPT_TEMPLATE_MISTRAL, PROMPT_TEMPLATE_NONE],
options=list(PROMPT_TEMPLATE_DESCRIPTIONS.keys()),
translation_key=CONF_PROMPT_TEMPLATE,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
@@ -626,6 +628,15 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen
description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT)},
default=DEFAULT_REFRESH_SYSTEM_PROMPT,
): bool,
vol.Required(
CONF_REMEMBER_CONVERSATION,
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION)},
default=DEFAULT_REMEMBER_CONVERSATION,
): bool,
vol.Optional(
CONF_REMEMBER_NUM_INTERACTIONS,
description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS)},
): int,
}
if is_local_backend(backend_type):

View File

@@ -34,6 +34,7 @@ CONF_DOWNLOADED_MODEL_FILE = "downloaded_model_file"
DEFAULT_DOWNLOADED_MODEL_FILE = ""
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = "5000"
DEFAULT_SSL = False
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose"
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level"]
GBNF_GRAMMAR_FILE = "output.gbnf"
@@ -44,6 +45,7 @@ PROMPT_TEMPLATE_VICUNA = "vicuna"
PROMPT_TEMPLATE_MISTRAL = "mistral"
PROMPT_TEMPLATE_LLAMA2 = "llama2"
PROMPT_TEMPLATE_NONE = "no_prompt_template"
PROMPT_TEMPLATE_ZEPHYR = "zephyr"
DEFAULT_PROMPT_TEMPLATE = PROMPT_TEMPLATE_CHATML
PROMPT_TEMPLATE_DESCRIPTIONS = {
PROMPT_TEMPLATE_CHATML: {
@@ -75,6 +77,12 @@ PROMPT_TEMPLATE_DESCRIPTIONS = {
"user": { "prefix": "[INST]", "suffix": "[/INST]" },
"assistant": { "prefix": "", "suffix": "</s>" },
"generation_prompt": ""
},
PROMPT_TEMPLATE_ZEPHYR: {
"system": { "prefix": "<|system|>\n", "suffix": "<|endoftext|>" },
"user": { "prefix": "<|user|>\n", "suffix": "<|endoftext|>" },
"assistant": { "prefix": "<|assistant|>\n", "suffix": "<|endoftext|>" },
"generation_prompt": "<|assistant|>\n"
}
}
CONF_USE_GBNF_GRAMMAR = "gbnf_grammar"
@@ -83,7 +91,10 @@ CONF_TEXT_GEN_WEBUI_PRESET = "text_generation_webui_preset"
CONF_OPENAI_API_KEY = "openai_api_key"
CONF_TEXT_GEN_WEBUI_ADMIN_KEY = "text_generation_webui_admin_key"
CONF_REFRESH_SYSTEM_PROMPT = "refresh_prompt_per_tern"
CONF_REMEMBER_CONVERSATION = "remember_conversation"
CONF_REMEMBER_NUM_INTERACTIONS = "remember_num_interactions"
DEFAULT_REFRESH_SYSTEM_PROMPT = True
DEFAULT_REMEMBER_CONVERSATION = True
CONF_SERVICE_CALL_REGEX = "service_call_regex"
DEFAULT_SERVICE_CALL_REGEX = r"```homeassistant\n([\S \t\n]*?)```"
CONF_REMOTE_USE_CHAT_ENDPOINT = "remote_use_chat_endpoint"
@@ -106,6 +117,7 @@ DEFAULT_OPTIONS = types.MappingProxyType(
CONF_USE_GBNF_GRAMMAR: DEFAULT_USE_GBNF_GRAMMAR,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_SERVICE_CALL_REGEX: DEFAULT_SERVICE_CALL_REGEX,
CONF_REMOTE_USE_CHAT_ENDPOINT: DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
CONF_TEXT_GEN_WEBUI_CHAT_MODE: DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,

View File

@@ -1,7 +1,7 @@
{
"domain": "llama_conversation",
"name": "LLaMA Conversation",
"version": "0.2.5",
"version": "0.2.6",
"codeowners": ["@acon96"],
"config_flow": true,
"dependencies": ["conversation"],

View File

@@ -29,6 +29,7 @@
"host": "API Hostname",
"huggingface_model": "Model Name",
"port": "API Port",
"ssl": "Use HTTPS",
"openai_api_key": "API Key",
"text_generation_webui_admin_key": "Admin Key",
"text_generation_webui_preset": "Generation Preset/Character Name",
@@ -65,6 +66,8 @@
"text_generation_webui_admin_key": "Admin Key",
"service_call_regex": "Service Call Regex",
"refresh_prompt_per_tern": "Refresh System Prompt Every Turn",
"remember_conversation": "Remember conversation",
"remember_num_interactions": "Number of past interactions to remember",
"text_generation_webui_preset": "Generation Preset/Character Name",
"remote_use_chat_endpoint": "Use chat completions endpoint",
"text_generation_webui_chat_mode": "Chat Mode"
@@ -79,6 +82,7 @@
"vicuna": "Vicuna",
"alpaca": "Alpaca",
"mistral": "Mistral",
"zephyr": "Zephyr",
"no_prompt_template": "None"
}
},

View File

@@ -38,6 +38,7 @@ Supported datasets right now are:
Please note that the supported datasets all have different licenses. Be aware that the license of the resulting data mixture might be different that the license of this dataset alone.
## Adding new Home Assistant functionality
Adding new functionality to the model is done by providing examples of a user asking the assistant for the
TODO:
## Adding a new personality
## Adding a new personality
TODO:

Binary file not shown.

Binary file not shown.

2
dist/run_docker.sh vendored
View File

@@ -3,4 +3,4 @@
docker run -it --rm \
--entrypoint bash \
-v $(pwd):/tmp/dist \
homeassistant/home-assistant /tmp/dist/make_wheel.sh v0.2.32
homeassistant/home-assistant /tmp/dist/make_wheel.sh v0.2.38

View File

@@ -11,6 +11,8 @@ There are multiple backends to choose for running the model that the Home Assist
| Additional attribute to expose in the context | Extra attributes that will be exposed to the model via the `{{ devices }}` template variable | |
| Service Call Regex | The regular expression used to extract service calls from the model response; should contain 1 repeated capture group | |
| Refresh System Prompt Every Turn | Flag to update the system prompt with updated device states on every chat turn. Disabling can significantly improve agent response times when using a backend that supports prefix caching (Llama.cpp) | Enabled |
| Remember conversation | Flag to remember the conversation history (excluding system prompt) in the model context. | Enabled |
| Number of past interactions to remember | If `Remember conversation` is enabled, number of user-assistant interaction pairs to keep in history. | |
# Llama.cpp
For details about the sampling parameters, see here: https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#parameters-description

View File

@@ -5,7 +5,7 @@ This integration allows for full customization of the system prompt using Home A
## System Prompt Template
The default system prompt is:
```
You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task ask instructed with the information provided only.
You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.
Services: {{ services }}
Devices:
{{ devices }}
@@ -23,4 +23,4 @@ Currently supported prompt formats are:
2. Vicuna
3. Alpaca
4. Mistral
5. None (useful for foundation models)
5. None (useful for foundation models)