diff --git a/custom_components/llama_conversation/__init__.py b/custom_components/llama_conversation/__init__.py index 3c3fbb2..e5c9690 100644 --- a/custom_components/llama_conversation/__init__.py +++ b/custom_components/llama_conversation/__init__.py @@ -20,6 +20,8 @@ from .const import ( DOMAIN, HOME_LLM_API_ID, SERVICE_TOOL_NAME, + SERVICE_TOOL_ALLOWED_SERVICES, + SERVICE_TOOL_ALLOWED_DOMAINS, ) _LOGGER = logging.getLogger(__name__) @@ -82,13 +84,8 @@ class HassServiceTool(llm.Tool): vol.Optional('item'): str, }) - ALLOWED_SERVICES: Final[list[str]] = [ - "turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", - "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item" - ] - ALLOWED_DOMAINS: Final[list[str]] = [ - "light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script", - ] + ALLOWED_SERVICES: Final[list[str]] = SERVICE_TOOL_ALLOWED_SERVICES + ALLOWED_DOMAINS: Final[list[str]] = SERVICE_TOOL_ALLOWED_DOMAINS async def async_call( self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index 9c662e6..b7e4c03 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -4,6 +4,8 @@ import types, os DOMAIN = "llama_conversation" HOME_LLM_API_ID = "home-llm-service-api" SERVICE_TOOL_NAME = "HassCallService" +SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"] +SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"] CONF_PROMPT = "prompt" PERSONA_PROMPTS = { "en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.", diff --git a/custom_components/llama_conversation/conversation.py b/custom_components/llama_conversation/conversation.py index ea3cb55..91d5dd5 100644 --- a/custom_components/llama_conversation/conversation.py +++ b/custom_components/llama_conversation/conversation.py @@ -121,6 +121,8 @@ from .const import ( TOOL_FORMAT_REDUCED, TOOL_FORMAT_MINIMAL, ALLOWED_SERVICE_CALL_ARGUMENTS, + SERVICE_TOOL_ALLOWED_SERVICES, + SERVICE_TOOL_ALLOWED_DOMAINS, CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE, BACKEND_TYPE_LLAMA_HF, @@ -816,6 +818,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): all_services = [] scripts_added = False for domain in domains: + if domain not in SERVICE_TOOL_ALLOWED_DOMAINS: + continue + # scripts show up as individual services if domain == "script" and not scripts_added: all_services.extend([ @@ -828,6 +833,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent): continue for name, service in service_dict.get(domain, {}).items(): + if name not in SERVICE_TOOL_ALLOWED_SERVICES: + continue + args = flatten_vol_schema(service.schema) args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS) service_schema = vol.Schema({