better device prompting with area support + fix circular import

This commit is contained in:
Alex O'Connell
2024-06-07 17:42:03 -04:00
parent 5ddf0d09d5
commit 249298bb99
4 changed files with 66 additions and 24 deletions

View File

@@ -32,7 +32,7 @@ from .const import (
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS,
ALLOWED_SERVICE_CALL_ARGUMENTS,
DOMAIN,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
@@ -133,9 +133,6 @@ class HassServiceTool(llm.Tool):
ALLOWED_DOMAINS: Final[list[str]] = [
"light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script",
]
ALLOWED_SERVICE_CALL_ARGUMENTS: Final[list[str]] = [
"rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration",
]
async def async_call(
self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
@@ -151,7 +148,7 @@ class HassServiceTool(llm.Tool):
return { "result": "unknown service" }
service_data = {ATTR_ENTITY_ID: target_device}
for attr in self.ALLOWED_SERVICE_CALL_ARGUMENTS:
for attr in ALLOWED_SERVICE_CALL_ARGUMENTS:
if attr in tool_input.tool_args.keys():
service_data[attr] = tool_input.tool_args[attr]
try:

View File

@@ -23,13 +23,12 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm, area_registry as ar
from homeassistant.helpers.event import async_track_state_change, async_call_later
from homeassistant.util import ulid, color
import voluptuous_serialize
from . import HassServiceTool
from .utils import closest_color, flatten_vol_schema, custom_custom_serializer, install_llama_cpp_python, \
validate_llama_cpp_python_installation, format_url
from .const import (
@@ -114,6 +113,7 @@ from .const import (
TOOL_FORMAT_FULL,
TOOL_FORMAT_REDUCED,
TOOL_FORMAT_MINIMAL,
ALLOWED_SERVICE_CALL_ARGUMENTS,
)
# make type checking work for llama-cpp-python without importing it directly at runtime
@@ -445,6 +445,7 @@ class LocalLLMAgent(AbstractConversationAgent):
entity_states = {}
domains = set()
entity_registry = er.async_get(self.hass)
area_registry = ar.async_get(self.hass)
for state in self.hass.states.async_all():
if not async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
@@ -456,11 +457,15 @@ class LocalLLMAgent(AbstractConversationAgent):
attributes["state"] = state.state
if entity and entity.aliases:
attributes["aliases"] = entity.aliases
if entity and entity.area_id:
area = area_registry.async_get_area(entity.area_id)
attributes["area_id"] = area.id
attributes["area_name"] = area.name
entity_states[state.entity_id] = attributes
domains.add(state.domain)
# _LOGGER.debug(f"Exposed entities: {entity_states}")
return entity_states, list(domains)
def _format_prompt(
@@ -556,6 +561,9 @@ class LocalLLMAgent(AbstractConversationAgent):
entity_names = entity_names[:]
entity_domains = set([x.split(".")[0] for x in entity_names])
area_registry = ar.async_get(self.hass)
all_areas = list(area_registry.async_list_areas())
in_context_examples = [
x for x in self.in_context_examples
if x["type"] in entity_domains
@@ -575,7 +583,7 @@ class LocalLLMAgent(AbstractConversationAgent):
response = chosen_example["response"]
random_device = [ x for x in entity_names if x.split(".")[0] == chosen_example["type"] ][0]
random_area = "bedroom" # todo, pick a random area
random_area = random.choice(all_areas).name
random_brightness = round(random.random(), 2)
random_color = random.choice(list(color.COLORS.keys()))
@@ -619,8 +627,8 @@ class LocalLLMAgent(AbstractConversationAgent):
extra_attributes_to_expose = self.entry.options \
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
def expose_attributes(attributes):
result = attributes["state"]
def expose_attributes(attributes) -> list[str]:
result = []
for attribute_name in extra_attributes_to_expose:
if attribute_name not in attributes:
continue
@@ -644,19 +652,38 @@ class LocalLLMAgent(AbstractConversationAgent):
elif attribute_name == "humidity":
value = f"{value}%"
result = result + ";" + str(value)
result.append(str(value))
return result
device_states = []
devices = []
formatted_devices = ""
# expose devices and their alias as well
for name, attributes in entities_to_expose.items():
device_states.append(f"{name} '{attributes.get('friendly_name')}' = {expose_attributes(attributes)}")
state = attributes["state"]
exposed_attributes = expose_attributes(attributes)
str_attributes = ";".join([state] + exposed_attributes)
formatted_devices = formatted_devices + f"{name} '{attributes.get('friendly_name')}' = {str_attributes}\n"
devices.append({
"entity_id": name,
"name": attributes.get('friendly_name'),
"state": state,
"attributes": exposed_attributes,
"area_name": attributes.get("area_name"),
"area_id": attributes.get("area_id")
})
if "aliases" in attributes:
for alias in attributes["aliases"]:
device_states.append(f"{name} '{alias}' = {expose_attributes(attributes)}")
formatted_states = "\n".join(device_states) + "\n"
formatted_devices = formatted_devices + f"{name} '{alias}' = {str_attributes}\n"
devices.append({
"entity_id": name,
"name": alias,
"state": state,
"attributes": exposed_attributes,
"area_name": attributes.get("area_name"),
"area_id": attributes.get("area_id")
})
if llm_api:
if llm_api.api.id == HOME_LLM_API_ID:
@@ -670,7 +697,7 @@ class LocalLLMAgent(AbstractConversationAgent):
for name, service in service_dict.get(domain, {}).items():
args = flatten_vol_schema(service.schema)
args_to_expose = set(args).intersection(HassServiceTool.ALLOWED_SERVICE_CALL_ARGUMENTS)
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
service_schema = vol.Schema({
vol.Optional(arg): str for arg in args_to_expose
})
@@ -681,17 +708,21 @@ class LocalLLMAgent(AbstractConversationAgent):
self._format_tool(*tool)
for tool in all_services
]
formatted_tools = ", ".join(tools)
else:
tools = [
self._format_tool(tool.name, tool.parameters, tool.description)
for tool in llm_api.tools
]
formatted_tools = json.dumps(tools)
else:
tools = "No tools were provided. If the user requests you interact with a device, tell them you are unable to do so."
render_variables = {
"devices": formatted_states,
"devices": devices,
"formatted_devices": formatted_devices,
"tools": tools,
"formatted_tools": formatted_tools,
"response_examples": []
}

View File

@@ -15,12 +15,25 @@ DEFAULT_PROMPT_BASE = """<persona>
The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}
Tools: {{ tools | to_json }}
Devices:
{{ devices }}"""
{% for device in devices | selectattr('area_id', 'none'): %}
{{ device.entity_id }} '{{ device.name }}' = {{ device.state }}{{ ([""] + device.attributes) | join(";") }}
{% endfor %}
{% for area in devices | rejectattr('area_id', 'none') | groupby('area_name') %}
## Area: {{ area.grouper }}
{% for device in area.list %}
{{ device.entity_id }} '{{ device.name }}' = {{ device.state }};{{ device.attributes | join(";") }}
{% endfor %}
{% endfor %}
{% for item in response_examples %}
{{ item.request }}
{{ item.response }}
<functioncall> {{ item.tool | to_json }}
{% endfor %}"""
DEFAULT_PROMPT_BASE_LEGACY = """<persona>
The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}
Services: {{ tools | join(", ") }}
Services: {{ formatted_tools }}
Devices:
{{ devices }}"""
{{ formatted_devices }}"""
ICL_EXTRAS = """
{% for item in response_examples %}
{{ item.request }}
@@ -76,6 +89,7 @@ DEFAULT_PORT = "5000"
DEFAULT_SSL = False
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose"
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "item", "wind_speed"]
ALLOWED_SERVICE_CALL_ARGUMENTS = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration" ]
CONF_PROMPT_TEMPLATE = "prompt_template"
PROMPT_TEMPLATE_CHATML = "chatml"
PROMPT_TEMPLATE_COMMAND_R = "command-r"

View File

@@ -1,7 +1,7 @@
{
"domain": "llama_conversation",
"name": "Local LLM Conversation",
"version": "0.2.17",
"version": "0.3",
"codeowners": ["@acon96"],
"config_flow": true,
"dependencies": ["conversation"],