more rewrite work for new LLM API

This commit is contained in:
Alex O'Connell
2024-05-25 21:24:45 -04:00
parent 8a28dd61ad
commit 367607b14f
8 changed files with 65 additions and 87 deletions

View File

@@ -5,6 +5,7 @@ import logging
import threading
import importlib
from typing import Literal, Any, Callable
import voluptuous as vol
import requests
import re
@@ -15,6 +16,7 @@ import random
import time
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent
import homeassistant.components.conversation as ha_conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
@@ -39,7 +41,6 @@ from .const import (
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
@@ -74,7 +75,6 @@ from .const import (
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
@@ -209,8 +209,6 @@ class LocalLLMAgent(AbstractConversationAgent):
remember_conversation = self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)
service_call_regex = self.entry.options.get(CONF_SERVICE_CALL_REGEX, DEFAULT_SERVICE_CALL_REGEX)
allowed_service_call_arguments = self.entry.options \
.get(CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS)
try:
service_call_pattern = re.compile(service_call_regex)
@@ -302,72 +300,49 @@ class LocalLLMAgent(AbstractConversationAgent):
# parse response
exposed_entities = list(self._async_get_exposed_entities()[0].keys())
to_say = service_call_pattern.sub("", response).strip()
to_say = ""
for block in service_call_pattern.findall(response.strip()):
services = block.split("\n")
_LOGGER.info(f"running services: {' '.join(services)}")
_LOGGER.info(f"calling tool: {block}")
for line in services:
if len(line) == 0:
break
parsed_tool_call = json.loads(block)
# parse old format or JSON format
try:
json_output = json.loads(line)
service = json_output["service"]
entity = json_output["target_device"]
domain, service = tuple(service.split("."))
if "to_say" in json_output:
to_say = to_say + json_output.pop("to_say")
to_say = to_say + parsed_tool_call.get("to_say", "")
extra_arguments = { k: v for k, v in json_output.items() if k not in [ "service", "target_device" ] }
except Exception:
try:
service = line.split("(")[0]
entity = line.split("(")[1][:-1]
domain, service = tuple(service.split("."))
extra_arguments = {}
except Exception:
to_say += f" Failed to parse call from '{line}'!"
continue
# try to fix certain arguments
# make sure brightness is 0-255 and not a percentage
if "brightness" in parsed_tool_call["arguments"] and 0.0 < parsed_tool_call["arguments"]["brightness"] <= 1.0:
parsed_tool_call["arguments"]["brightness"] = int(parsed_tool_call["arguments"]["brightness"] * 255)
# fix certain arguments
# make sure brightness is 0-255 and not a percentage
if "brightness" in extra_arguments and 0.0 < extra_arguments["brightness"] <= 1.0:
extra_arguments["brightness"] = int(extra_arguments["brightness"] * 255)
# convert string "tuple" to a list for RGB colors
if "rgb_color" in parsed_tool_call["arguments"] and isinstance(parsed_tool_call["arguments"]["rgb_color"], str):
parsed_tool_call["arguments"]["rgb_color"] = [ int(x) for x in parsed_tool_call["arguments"]["rgb_color"][1:-1].split(",") ]
tool_input = llm.ToolInput(
tool_name=parsed_tool_call["tool"],
tool_args=parsed_tool_call["arguments"],
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=ha_conversation.DOMAIN,
)
# convert string "tuple" to a list for RGB colors
if "rgb_color" in extra_arguments and isinstance(extra_arguments["rgb_color"], str):
extra_arguments["rgb_color"] = [ int(x) for x in extra_arguments["rgb_color"][1:-1].split(",") ]
# TODO: multi-turn with the model where it acts on the response from the tool?
try:
tool_response = await llm_api.async_call_tool(
self.hass, tool_input
)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
# only acknowledge requests to exposed entities
if entity not in exposed_entities:
to_say += f" Can't find device '{entity}'!"
else:
# copy arguments to service call
service_data = {ATTR_ENTITY_ID: entity}
for attr in allowed_service_call_arguments:
if attr in extra_arguments.keys():
service_data[attr] = extra_arguments[attr]
try:
_LOGGER.debug(f"service data: {service_data}")
await self.hass.services.async_call(
domain,
service,
service_data=service_data,
blocking=True,
)
except Exception as err:
to_say += f"\nFailed to run: {line}"
_LOGGER.exception(f"Failed to run: {line}")
if template_desc["assistant"]["suffix"]:
to_say = to_say.replace(template_desc["assistant"]["suffix"], "") # remove the eos token if it is returned (some backends + the old model does this)
_LOGGER.debug("Tool response: %s", tool_response)
# generate intent response to Home Assistant
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(to_say)
intent_response.set
return ConversationResult(
response=intent_response, conversation_id=conversation_id
)
@@ -429,8 +404,6 @@ class LocalLLMAgent(AbstractConversationAgent):
extra_attributes_to_expose = self.entry.options \
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
allowed_service_call_arguments = self.entry.options \
.get(CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS)
def icl_example_generator(num_examples, entity_names, service_names):
entity_domains = set([x.split(".")[0] for x in entity_names])
@@ -459,8 +432,8 @@ class LocalLLMAgent(AbstractConversationAgent):
device = [ x for x in entity_names if x.split(".")[0] == chosen_service.split(".")[0] ][0]
example = {
"to_say": chosen_example["response"],
"service": chosen_service,
"target_device": device,
"tool": chosen_service,
"arguments": { "name": device },
}
yield json.dumps(example) + "\n"
@@ -505,7 +478,7 @@ class LocalLLMAgent(AbstractConversationAgent):
if llm_api:
tools = [
f"{tool.name}({flatten_vol_schema(tool.parameters)}) - {tool.description}"
f"{tool.name}({', '.join(flatten_vol_schema(tool.parameters))}) - {tool.description}"
for tool in llm_api.async_get_tools()
]
formatted_services = llm_api.prompt_template + "\n" + "\n".join(tools)

View File

@@ -60,7 +60,6 @@ from .const import (
CONF_USE_GBNF_GRAMMAR,
CONF_GBNF_GRAMMAR_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
@@ -100,7 +99,6 @@ from .const import (
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_GBNF_GRAMMAR_FILE,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
@@ -589,10 +587,14 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type))
if user_input:
self.options = user_input
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
try:
# validate input
schema(user_input)
self.options = user_input
return await self.async_step_finish()
except Exception as ex:
_LOGGER.exception("An unknown error has occurred!")
@@ -657,6 +659,9 @@ class OptionsFlow(config_entries.OptionsFlow):
errors["base"] = "missing_icl_file"
description_placeholders["filename"] = filename
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
if len(errors) == 0:
return self.async_create_entry(title="Local LLM Conversation", data=user_input)
@@ -750,11 +755,6 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT
description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
): TextSelector(TextSelectorConfig(multiple=True)),
vol.Required(
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
description={"suggested_value": options.get(CONF_ALLOWED_SERVICE_CALL_ARGUMENTS)},
default=DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
): TextSelector(TextSelectorConfig(multiple=True)),
vol.Required(
CONF_SERVICE_CALL_REGEX,
description={"suggested_value": options.get(CONF_SERVICE_CALL_REGEX)},

View File

@@ -59,8 +59,6 @@ DEFAULT_PORT = "5000"
DEFAULT_SSL = False
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose"
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "item", "wind_speed"]
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS = "allowed_service_call_arguments"
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration"]
CONF_PROMPT_TEMPLATE = "prompt_template"
PROMPT_TEMPLATE_CHATML = "chatml"
PROMPT_TEMPLATE_COMMAND_R = "command-r"
@@ -197,7 +195,6 @@ DEFAULT_OPTIONS = types.MappingProxyType(
CONF_ENABLE_FLASH_ATTENTION: DEFAULT_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR: DEFAULT_USE_GBNF_GRAMMAR,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,

View File

@@ -51,6 +51,7 @@
"model_parameters": {
"data": {
"max_new_tokens": "Maximum tokens to return in response",
"llm_hass_api": "Selected LLM API",
"prompt": "System Prompt",
"prompt_template": "Prompt Format",
"temperature": "Temperature",
@@ -62,7 +63,6 @@
"ollama_keep_alive": "Keep Alive/Inactivity Timeout (minutes)",
"ollama_json_mode": "JSON Output Mode",
"extra_attributes_to_expose": "Additional attribute to expose in the context",
"allowed_service_call_arguments": "Arguments allowed to be pass to service calls",
"enable_flash_attention": "Enable Flash Attention",
"gbnf_grammar": "Enable GBNF Grammar",
"gbnf_grammar_file": "GBNF Grammar Filename",
@@ -86,11 +86,11 @@
"n_batch_threads": "Batch Thread Count"
},
"data_description": {
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices.",
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
"remote_use_chat_endpoint": "If this is enabled, then the integration will use the chat completion HTTP endpoint instead of the text completion one.",
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
"allowed_service_call_arguments": "This is the list of parameters that are allowed to be passed to Home Assistant service calls.",
"gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.",
"prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below."
},
@@ -103,6 +103,7 @@
"step": {
"init": {
"data": {
"llm_hass_api": "Selected LLM API",
"max_new_tokens": "Maximum tokens to return in response",
"prompt": "System Prompt",
"prompt_template": "Prompt Format",
@@ -115,7 +116,6 @@
"ollama_keep_alive": "Keep Alive/Inactivity Timeout (minutes)",
"ollama_json_mode": "JSON Output Mode",
"extra_attributes_to_expose": "Additional attribute to expose in the context",
"allowed_service_call_arguments": "Arguments allowed to be pass to service calls",
"enable_flash_attention": "Enable Flash Attention",
"gbnf_grammar": "Enable GBNF Grammar",
"gbnf_grammar_file": "GBNF Grammar Filename",
@@ -139,11 +139,11 @@
"n_batch_threads": "Batch Thread Count"
},
"data_description": {
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices.",
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
"remote_use_chat_endpoint": "If this is enabled, then the integration will use the chat completion HTTP endpoint instead of the text completion one.",
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
"allowed_service_call_arguments": "This is the list of parameters that are allowed to be passed to Home Assistant service calls.",
"gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.",
"prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below."
}

View File

@@ -37,10 +37,15 @@ def flatten_vol_schema(schema):
_flatten(subval, prefix)
elif isinstance(current_schema.schema, dict):
for key, val in current_schema.schema.items():
if isinstance(key, vol.Any):
key = "|".join(key.validators)
if isinstance(key, vol.Optional):
key = "?" + str(key)
_flatten(val, prefix + str(key) + '/')
elif isinstance(current_schema, vol.validators._WithSubValidators):
for subval in current_schema.validators:
_flatten(subval, prefix)
_flatten(subval, prefix)
elif callable(current_schema):
flattened.append(prefix[:-1] if prefix else prefix)
_flatten(schema)

View File

@@ -1,3 +1,4 @@
# training + dataset requirements
transformers
tensorboard
datasets
@@ -11,9 +12,17 @@ sentencepiece
deep-translator
langcodes
# integration requirements
requests==2.31.0
huggingface-hub==0.23.0
webcolors==1.13
# types from Home Assistant
homeassistant
hassil
home-assistant-intents
# testing requirements
pytest
pytest-asyncio
pytest-homeassistant-custom-component

View File

@@ -18,7 +18,6 @@ from custom_components.llama_conversation.const import (
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
@@ -54,7 +53,6 @@ from custom_components.llama_conversation.const import (
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,

View File

@@ -24,7 +24,6 @@ from custom_components.llama_conversation.const import (
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_PROMPT_TEMPLATE,
CONF_ENABLE_FLASH_ATTENTION,
CONF_USE_GBNF_GRAMMAR,
@@ -66,7 +65,6 @@ from custom_components.llama_conversation.const import (
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_ENABLE_FLASH_ATTENTION,
DEFAULT_USE_GBNF_GRAMMAR,
@@ -171,7 +169,6 @@ async def test_validate_config_flow_generic_openai(mock_setup_entry, hass: HomeA
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
@@ -261,7 +258,6 @@ async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant
CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT,
CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS,
@@ -299,7 +295,7 @@ def test_validate_options_schema():
universal_options = [
CONF_PROMPT, CONF_PROMPT_TEMPLATE,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES,
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS,
]