filter out disallowed services from the prompt when the home-llm API is selected

This commit is contained in:
Alex O'Connell
2025-02-08 08:40:02 -05:00
parent 280cb454b6
commit b1d0191310
3 changed files with 14 additions and 7 deletions

View File

@@ -20,6 +20,8 @@ from .const import (
DOMAIN,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
)
_LOGGER = logging.getLogger(__name__)
@@ -82,13 +84,8 @@ class HassServiceTool(llm.Tool):
vol.Optional('item'): str,
})
ALLOWED_SERVICES: Final[list[str]] = [
"turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover",
"lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item"
]
ALLOWED_DOMAINS: Final[list[str]] = [
"light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script",
]
ALLOWED_SERVICES: Final[list[str]] = SERVICE_TOOL_ALLOWED_SERVICES
ALLOWED_DOMAINS: Final[list[str]] = SERVICE_TOOL_ALLOWED_DOMAINS
async def async_call(
self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext

View File

@@ -4,6 +4,8 @@ import types, os
DOMAIN = "llama_conversation"
HOME_LLM_API_ID = "home-llm-service-api"
SERVICE_TOOL_NAME = "HassCallService"
SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"]
SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"]
CONF_PROMPT = "prompt"
PERSONA_PROMPTS = {
"en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.",

View File

@@ -121,6 +121,8 @@ from .const import (
TOOL_FORMAT_REDUCED,
TOOL_FORMAT_MINIMAL,
ALLOWED_SERVICE_CALL_ARGUMENTS,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
CONF_BACKEND_TYPE,
DEFAULT_BACKEND_TYPE,
BACKEND_TYPE_LLAMA_HF,
@@ -816,6 +818,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
all_services = []
scripts_added = False
for domain in domains:
if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
continue
# scripts show up as individual services
if domain == "script" and not scripts_added:
all_services.extend([
@@ -828,6 +833,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
continue
for name, service in service_dict.get(domain, {}).items():
if name not in SERVICE_TOOL_ALLOWED_SERVICES:
continue
args = flatten_vol_schema(service.schema)
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
service_schema = vol.Schema({