restrict services that can be called and add format url function to make behavior standard

This commit is contained in:
Alex O'Connell
2024-06-07 08:48:50 -04:00
parent 301c73a6f7
commit 5ddf0d09d5
5 changed files with 53 additions and 12 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import logging
from typing import Final
import homeassistant.components.conversation as ha_conversation
from homeassistant.config_entries import ConfigEntry
@@ -107,8 +108,8 @@ async def async_migrate_entry(hass, config_entry: ConfigEntry):
class HassServiceTool(llm.Tool):
"""Tool to get the current time."""
name = SERVICE_TOOL_NAME
description = "Executes a Home Assistant service"
name: Final[str] = SERVICE_TOOL_NAME
description: Final[str] = "Executes a Home Assistant service"
# Optional. A voluptuous schema of the input parameters.
parameters = vol.Schema({
@@ -125,6 +126,17 @@ class HassServiceTool(llm.Tool):
vol.Optional('item'): str,
})
ALLOWED_SERVICES: Final[list[str]] = [
"turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover",
"lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item"
]
ALLOWED_DOMAINS: Final[list[str]] = [
"light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script",
]
ALLOWED_SERVICE_CALL_ARGUMENTS: Final[list[str]] = [
"rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration",
]
async def async_call(
self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
) -> JsonObjectType:
@@ -132,8 +144,14 @@ class HassServiceTool(llm.Tool):
domain, service = tuple(tool_input.tool_args["service"].split("."))
target_device = tool_input.tool_args["target_device"]
if domain not in self.ALLOWED_DOMAINS or service not in self.ALLOWED_SERVICES:
return { "result": "unknown service" }
if domain == "script" and service not in ["reload", "turn_on", "turn_off", "toggle"]:
return { "result": "unknown service" }
service_data = {ATTR_ENTITY_ID: target_device}
for attr in ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS:
for attr in self.ALLOWED_SERVICE_CALL_ARGUMENTS:
if attr in tool_input.tool_args.keys():
service_data[attr] = tool_input.tool_args[attr]
try:

View File

@@ -29,8 +29,9 @@ from homeassistant.util import ulid, color
import voluptuous_serialize
from . import HassServiceTool
from .utils import closest_color, flatten_vol_schema, custom_custom_serializer, install_llama_cpp_python, \
validate_llama_cpp_python_installation
validate_llama_cpp_python_installation, format_url
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@@ -106,7 +107,6 @@ from .const import (
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS,
DOMAIN,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
@@ -670,7 +670,7 @@ class LocalLLMAgent(AbstractConversationAgent):
for name, service in service_dict.get(domain, {}).items():
args = flatten_vol_schema(service.schema)
args_to_expose = set(args).intersection(ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS)
args_to_expose = set(args).intersection(HassServiceTool.ALLOWED_SERVICE_CALL_ARGUMENTS)
service_schema = vol.Schema({
vol.Optional(arg): str for arg in args_to_expose
})
@@ -1042,7 +1042,13 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
model_name: str
def _load_model(self, entry: ConfigEntry) -> None:
self.api_host = f"{'https' if entry.data[CONF_SSL] else 'http'}://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
self.api_host = format_url(
hostname=entry.data[CONF_HOST],
port=entry.data[CONF_PORT],
ssl=entry.data[CONF_SSL],
path=""
)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
@@ -1249,7 +1255,12 @@ class OllamaAPIAgent(LocalLLMAgent):
model_name: str
def _load_model(self, entry: ConfigEntry) -> None:
self.api_host = f"{'https' if entry.data[CONF_SSL] else 'http'}://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
self.api_host = format_url(
hostname=entry.data[CONF_HOST],
port=entry.data[CONF_PORT],
ssl=entry.data[CONF_SSL],
path=""
)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)

View File

@@ -38,7 +38,7 @@ from homeassistant.helpers.selector import (
from homeassistant.util.package import is_installed
from importlib.metadata import version
from .utils import download_model_from_hf, install_llama_cpp_python, MissingQuantizationException
from .utils import download_model_from_hf, install_llama_cpp_python, format_url, MissingQuantizationException
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@@ -503,7 +503,12 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
headers["Authorization"] = f"Bearer {api_key}"
models_result = requests.get(
f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/v1/internal/model/list",
format_url(
hostname=self.model_config[CONF_HOST],
port=self.model_config[CONF_PORT],
ssl=self.model_config[CONF_SSL],
path="/v1/internal/model/list"
),
timeout=5, # quick timeout
headers=headers
)
@@ -535,7 +540,12 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
headers["Authorization"] = f"Bearer {api_key}"
models_result = requests.get(
f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/api/tags",
format_url(
hostname=self.model_config[CONF_HOST],
port=self.model_config[CONF_PORT],
ssl=self.model_config[CONF_SSL],
path="/api/tags"
),
timeout=5, # quick timeout
headers=headers
)

View File

@@ -76,7 +76,6 @@ DEFAULT_PORT = "5000"
DEFAULT_SSL = False
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose"
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "item", "wind_speed"]
ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration"]
CONF_PROMPT_TEMPLATE = "prompt_template"
PROMPT_TEMPLATE_CHATML = "chatml"
PROMPT_TEMPLATE_COMMAND_R = "command-r"

View File

@@ -203,3 +203,6 @@ def install_llama_cpp_python(config_dir: str):
time.sleep(0.5) # I still don't know why this is required
return True
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"