mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-10 14:18:00 -05:00
restrict services that can be called and add format url function to make behavior standard
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
import homeassistant.components.conversation as ha_conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@@ -107,8 +108,8 @@ async def async_migrate_entry(hass, config_entry: ConfigEntry):
|
||||
class HassServiceTool(llm.Tool):
|
||||
"""Tool to get the current time."""
|
||||
|
||||
name = SERVICE_TOOL_NAME
|
||||
description = "Executes a Home Assistant service"
|
||||
name: Final[str] = SERVICE_TOOL_NAME
|
||||
description: Final[str] = "Executes a Home Assistant service"
|
||||
|
||||
# Optional. A voluptuous schema of the input parameters.
|
||||
parameters = vol.Schema({
|
||||
@@ -125,6 +126,17 @@ class HassServiceTool(llm.Tool):
|
||||
vol.Optional('item'): str,
|
||||
})
|
||||
|
||||
ALLOWED_SERVICES: Final[list[str]] = [
|
||||
"turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover",
|
||||
"lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item"
|
||||
]
|
||||
ALLOWED_DOMAINS: Final[list[str]] = [
|
||||
"light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script",
|
||||
]
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS: Final[list[str]] = [
|
||||
"rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration",
|
||||
]
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
|
||||
) -> JsonObjectType:
|
||||
@@ -132,8 +144,14 @@ class HassServiceTool(llm.Tool):
|
||||
domain, service = tuple(tool_input.tool_args["service"].split("."))
|
||||
target_device = tool_input.tool_args["target_device"]
|
||||
|
||||
if domain not in self.ALLOWED_DOMAINS or service not in self.ALLOWED_SERVICES:
|
||||
return { "result": "unknown service" }
|
||||
|
||||
if domain == "script" and service not in ["reload", "turn_on", "turn_off", "toggle"]:
|
||||
return { "result": "unknown service" }
|
||||
|
||||
service_data = {ATTR_ENTITY_ID: target_device}
|
||||
for attr in ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS:
|
||||
for attr in self.ALLOWED_SERVICE_CALL_ARGUMENTS:
|
||||
if attr in tool_input.tool_args.keys():
|
||||
service_data[attr] = tool_input.tool_args[attr]
|
||||
try:
|
||||
|
||||
@@ -29,8 +29,9 @@ from homeassistant.util import ulid, color
|
||||
|
||||
import voluptuous_serialize
|
||||
|
||||
from . import HassServiceTool
|
||||
from .utils import closest_color, flatten_vol_schema, custom_custom_serializer, install_llama_cpp_python, \
|
||||
validate_llama_cpp_python_installation
|
||||
validate_llama_cpp_python_installation, format_url
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
@@ -106,7 +107,6 @@ from .const import (
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
|
||||
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
|
||||
ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS,
|
||||
DOMAIN,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME,
|
||||
@@ -670,7 +670,7 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
|
||||
for name, service in service_dict.get(domain, {}).items():
|
||||
args = flatten_vol_schema(service.schema)
|
||||
args_to_expose = set(args).intersection(ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS)
|
||||
args_to_expose = set(args).intersection(HassServiceTool.ALLOWED_SERVICE_CALL_ARGUMENTS)
|
||||
service_schema = vol.Schema({
|
||||
vol.Optional(arg): str for arg in args_to_expose
|
||||
})
|
||||
@@ -1042,7 +1042,13 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
|
||||
model_name: str
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
self.api_host = f"{'https' if entry.data[CONF_SSL] else 'http'}://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
|
||||
self.api_host = format_url(
|
||||
hostname=entry.data[CONF_HOST],
|
||||
port=entry.data[CONF_PORT],
|
||||
ssl=entry.data[CONF_SSL],
|
||||
path=""
|
||||
)
|
||||
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
||||
|
||||
@@ -1249,7 +1255,12 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
model_name: str
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
self.api_host = f"{'https' if entry.data[CONF_SSL] else 'http'}://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
|
||||
self.api_host = format_url(
|
||||
hostname=entry.data[CONF_HOST],
|
||||
port=entry.data[CONF_PORT],
|
||||
ssl=entry.data[CONF_SSL],
|
||||
path=""
|
||||
)
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from homeassistant.helpers.selector import (
|
||||
from homeassistant.util.package import is_installed
|
||||
from importlib.metadata import version
|
||||
|
||||
from .utils import download_model_from_hf, install_llama_cpp_python, MissingQuantizationException
|
||||
from .utils import download_model_from_hf, install_llama_cpp_python, format_url, MissingQuantizationException
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
@@ -503,7 +503,12 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
models_result = requests.get(
|
||||
f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/v1/internal/model/list",
|
||||
format_url(
|
||||
hostname=self.model_config[CONF_HOST],
|
||||
port=self.model_config[CONF_PORT],
|
||||
ssl=self.model_config[CONF_SSL],
|
||||
path="/v1/internal/model/list"
|
||||
),
|
||||
timeout=5, # quick timeout
|
||||
headers=headers
|
||||
)
|
||||
@@ -535,7 +540,12 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
models_result = requests.get(
|
||||
f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/api/tags",
|
||||
format_url(
|
||||
hostname=self.model_config[CONF_HOST],
|
||||
port=self.model_config[CONF_PORT],
|
||||
ssl=self.model_config[CONF_SSL],
|
||||
path="/api/tags"
|
||||
),
|
||||
timeout=5, # quick timeout
|
||||
headers=headers
|
||||
)
|
||||
|
||||
@@ -76,7 +76,6 @@ DEFAULT_PORT = "5000"
|
||||
DEFAULT_SSL = False
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose"
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "item", "wind_speed"]
|
||||
ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration"]
|
||||
CONF_PROMPT_TEMPLATE = "prompt_template"
|
||||
PROMPT_TEMPLATE_CHATML = "chatml"
|
||||
PROMPT_TEMPLATE_COMMAND_R = "command-r"
|
||||
|
||||
@@ -203,3 +203,6 @@ def install_llama_cpp_python(config_dir: str):
|
||||
time.sleep(0.5) # I still don't know why this is required
|
||||
|
||||
return True
|
||||
|
||||
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
|
||||
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
|
||||
Reference in New Issue
Block a user