mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
filter out disallowed services from the prompt when the home-llm API is selected
This commit is contained in:
@@ -20,6 +20,8 @@ from .const import (
|
||||
DOMAIN,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME,
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -82,13 +84,8 @@ class HassServiceTool(llm.Tool):
|
||||
vol.Optional('item'): str,
|
||||
})
|
||||
|
||||
ALLOWED_SERVICES: Final[list[str]] = [
|
||||
"turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover",
|
||||
"lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item"
|
||||
]
|
||||
ALLOWED_DOMAINS: Final[list[str]] = [
|
||||
"light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script",
|
||||
]
|
||||
ALLOWED_SERVICES: Final[list[str]] = SERVICE_TOOL_ALLOWED_SERVICES
|
||||
ALLOWED_DOMAINS: Final[list[str]] = SERVICE_TOOL_ALLOWED_DOMAINS
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
|
||||
|
||||
@@ -4,6 +4,8 @@ import types, os
|
||||
DOMAIN = "llama_conversation"
|
||||
HOME_LLM_API_ID = "home-llm-service-api"
|
||||
SERVICE_TOOL_NAME = "HassCallService"
|
||||
SERVICE_TOOL_ALLOWED_SERVICES = ["turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover", "lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item", "set_temperature", "set_humidity", "set_fan_mode", "set_hvac_mode", "set_preset_mode"]
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS = ["light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script"]
|
||||
CONF_PROMPT = "prompt"
|
||||
PERSONA_PROMPTS = {
|
||||
"en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.",
|
||||
|
||||
@@ -121,6 +121,8 @@ from .const import (
|
||||
TOOL_FORMAT_REDUCED,
|
||||
TOOL_FORMAT_MINIMAL,
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
CONF_BACKEND_TYPE,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
@@ -816,6 +818,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
all_services = []
|
||||
scripts_added = False
|
||||
for domain in domains:
|
||||
if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
|
||||
continue
|
||||
|
||||
# scripts show up as individual services
|
||||
if domain == "script" and not scripts_added:
|
||||
all_services.extend([
|
||||
@@ -828,6 +833,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
continue
|
||||
|
||||
for name, service in service_dict.get(domain, {}).items():
|
||||
if name not in SERVICE_TOOL_ALLOWED_SERVICES:
|
||||
continue
|
||||
|
||||
args = flatten_vol_schema(service.schema)
|
||||
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
|
||||
service_schema = vol.Schema({
|
||||
|
||||
Reference in New Issue
Block a user