mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-10 14:18:00 -05:00
better device prompting with area support + fix circular import
This commit is contained in:
@@ -32,7 +32,7 @@ from .const import (
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS,
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
DOMAIN,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME,
|
||||
@@ -133,9 +133,6 @@ class HassServiceTool(llm.Tool):
|
||||
ALLOWED_DOMAINS: Final[list[str]] = [
|
||||
"light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script",
|
||||
]
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS: Final[list[str]] = [
|
||||
"rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration",
|
||||
]
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
|
||||
@@ -151,7 +148,7 @@ class HassServiceTool(llm.Tool):
|
||||
return { "result": "unknown service" }
|
||||
|
||||
service_data = {ATTR_ENTITY_ID: target_device}
|
||||
for attr in self.ALLOWED_SERVICE_CALL_ARGUMENTS:
|
||||
for attr in ALLOWED_SERVICE_CALL_ARGUMENTS:
|
||||
if attr in tool_input.tool_args.keys():
|
||||
service_data[attr] = tool_input.tool_args[attr]
|
||||
try:
|
||||
|
||||
@@ -23,13 +23,12 @@ from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL, CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm
|
||||
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm, area_registry as ar
|
||||
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
||||
from homeassistant.util import ulid, color
|
||||
|
||||
import voluptuous_serialize
|
||||
|
||||
from . import HassServiceTool
|
||||
from .utils import closest_color, flatten_vol_schema, custom_custom_serializer, install_llama_cpp_python, \
|
||||
validate_llama_cpp_python_installation, format_url
|
||||
from .const import (
|
||||
@@ -114,6 +113,7 @@ from .const import (
|
||||
TOOL_FORMAT_FULL,
|
||||
TOOL_FORMAT_REDUCED,
|
||||
TOOL_FORMAT_MINIMAL,
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
)
|
||||
|
||||
# make type checking work for llama-cpp-python without importing it directly at runtime
|
||||
@@ -445,6 +445,7 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
entity_states = {}
|
||||
domains = set()
|
||||
entity_registry = er.async_get(self.hass)
|
||||
area_registry = ar.async_get(self.hass)
|
||||
|
||||
for state in self.hass.states.async_all():
|
||||
if not async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
|
||||
@@ -456,11 +457,15 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
attributes["state"] = state.state
|
||||
if entity and entity.aliases:
|
||||
attributes["aliases"] = entity.aliases
|
||||
|
||||
if entity and entity.area_id:
|
||||
area = area_registry.async_get_area(entity.area_id)
|
||||
attributes["area_id"] = area.id
|
||||
attributes["area_name"] = area.name
|
||||
|
||||
entity_states[state.entity_id] = attributes
|
||||
domains.add(state.domain)
|
||||
|
||||
# _LOGGER.debug(f"Exposed entities: {entity_states}")
|
||||
|
||||
return entity_states, list(domains)
|
||||
|
||||
def _format_prompt(
|
||||
@@ -556,6 +561,9 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
entity_names = entity_names[:]
|
||||
entity_domains = set([x.split(".")[0] for x in entity_names])
|
||||
|
||||
area_registry = ar.async_get(self.hass)
|
||||
all_areas = list(area_registry.async_list_areas())
|
||||
|
||||
in_context_examples = [
|
||||
x for x in self.in_context_examples
|
||||
if x["type"] in entity_domains
|
||||
@@ -575,7 +583,7 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
response = chosen_example["response"]
|
||||
|
||||
random_device = [ x for x in entity_names if x.split(".")[0] == chosen_example["type"] ][0]
|
||||
random_area = "bedroom" # todo, pick a random area
|
||||
random_area = random.choice(all_areas).name
|
||||
random_brightness = round(random.random(), 2)
|
||||
random_color = random.choice(list(color.COLORS.keys()))
|
||||
|
||||
@@ -619,8 +627,8 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
extra_attributes_to_expose = self.entry.options \
|
||||
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
|
||||
|
||||
def expose_attributes(attributes):
|
||||
result = attributes["state"]
|
||||
def expose_attributes(attributes) -> list[str]:
|
||||
result = []
|
||||
for attribute_name in extra_attributes_to_expose:
|
||||
if attribute_name not in attributes:
|
||||
continue
|
||||
@@ -644,19 +652,38 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
elif attribute_name == "humidity":
|
||||
value = f"{value}%"
|
||||
|
||||
result = result + ";" + str(value)
|
||||
result.append(str(value))
|
||||
return result
|
||||
|
||||
device_states = []
|
||||
devices = []
|
||||
formatted_devices = ""
|
||||
|
||||
# expose devices and their alias as well
|
||||
for name, attributes in entities_to_expose.items():
|
||||
device_states.append(f"{name} '{attributes.get('friendly_name')}' = {expose_attributes(attributes)}")
|
||||
state = attributes["state"]
|
||||
exposed_attributes = expose_attributes(attributes)
|
||||
str_attributes = ";".join([state] + exposed_attributes)
|
||||
|
||||
formatted_devices = formatted_devices + f"{name} '{attributes.get('friendly_name')}' = {str_attributes}\n"
|
||||
devices.append({
|
||||
"entity_id": name,
|
||||
"name": attributes.get('friendly_name'),
|
||||
"state": state,
|
||||
"attributes": exposed_attributes,
|
||||
"area_name": attributes.get("area_name"),
|
||||
"area_id": attributes.get("area_id")
|
||||
})
|
||||
if "aliases" in attributes:
|
||||
for alias in attributes["aliases"]:
|
||||
device_states.append(f"{name} '{alias}' = {expose_attributes(attributes)}")
|
||||
|
||||
formatted_states = "\n".join(device_states) + "\n"
|
||||
formatted_devices = formatted_devices + f"{name} '{alias}' = {str_attributes}\n"
|
||||
devices.append({
|
||||
"entity_id": name,
|
||||
"name": alias,
|
||||
"state": state,
|
||||
"attributes": exposed_attributes,
|
||||
"area_name": attributes.get("area_name"),
|
||||
"area_id": attributes.get("area_id")
|
||||
})
|
||||
|
||||
if llm_api:
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
@@ -670,7 +697,7 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
|
||||
for name, service in service_dict.get(domain, {}).items():
|
||||
args = flatten_vol_schema(service.schema)
|
||||
args_to_expose = set(args).intersection(HassServiceTool.ALLOWED_SERVICE_CALL_ARGUMENTS)
|
||||
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
|
||||
service_schema = vol.Schema({
|
||||
vol.Optional(arg): str for arg in args_to_expose
|
||||
})
|
||||
@@ -681,17 +708,21 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
self._format_tool(*tool)
|
||||
for tool in all_services
|
||||
]
|
||||
formatted_tools = ", ".join(tools)
|
||||
else:
|
||||
tools = [
|
||||
self._format_tool(tool.name, tool.parameters, tool.description)
|
||||
for tool in llm_api.tools
|
||||
]
|
||||
formatted_tools = json.dumps(tools)
|
||||
else:
|
||||
tools = "No tools were provided. If the user requests you interact with a device, tell them you are unable to do so."
|
||||
|
||||
render_variables = {
|
||||
"devices": formatted_states,
|
||||
"devices": devices,
|
||||
"formatted_devices": formatted_devices,
|
||||
"tools": tools,
|
||||
"formatted_tools": formatted_tools,
|
||||
"response_examples": []
|
||||
}
|
||||
|
||||
|
||||
@@ -15,12 +15,25 @@ DEFAULT_PROMPT_BASE = """<persona>
|
||||
The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}
|
||||
Tools: {{ tools | to_json }}
|
||||
Devices:
|
||||
{{ devices }}"""
|
||||
{% for device in devices | selectattr('area_id', 'none'): %}
|
||||
{{ device.entity_id }} '{{ device.name }}' = {{ device.state }}{{ ([""] + device.attributes) | join(";") }}
|
||||
{% endfor %}
|
||||
{% for area in devices | rejectattr('area_id', 'none') | groupby('area_name') %}
|
||||
## Area: {{ area.grouper }}
|
||||
{% for device in area.list %}
|
||||
{{ device.entity_id }} '{{ device.name }}' = {{ device.state }};{{ device.attributes | join(";") }}
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
{% for item in response_examples %}
|
||||
{{ item.request }}
|
||||
{{ item.response }}
|
||||
<functioncall> {{ item.tool | to_json }}
|
||||
{% endfor %}"""
|
||||
DEFAULT_PROMPT_BASE_LEGACY = """<persona>
|
||||
The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }}
|
||||
Services: {{ tools | join(", ") }}
|
||||
Services: {{ formatted_tools }}
|
||||
Devices:
|
||||
{{ devices }}"""
|
||||
{{ formatted_devices }}"""
|
||||
ICL_EXTRAS = """
|
||||
{% for item in response_examples %}
|
||||
{{ item.request }}
|
||||
@@ -76,6 +89,7 @@ DEFAULT_PORT = "5000"
|
||||
DEFAULT_SSL = False
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose"
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "item", "wind_speed"]
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration" ]
|
||||
CONF_PROMPT_TEMPLATE = "prompt_template"
|
||||
PROMPT_TEMPLATE_CHATML = "chatml"
|
||||
PROMPT_TEMPLATE_COMMAND_R = "command-r"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"domain": "llama_conversation",
|
||||
"name": "Local LLM Conversation",
|
||||
"version": "0.2.17",
|
||||
"version": "0.3",
|
||||
"codeowners": ["@acon96"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
|
||||
Reference in New Issue
Block a user