Files
home-llm/custom_components/llama_conversation/utils.py
2025-12-20 20:29:09 -05:00

539 lines
21 KiB
Python

import time
import os
import re
import ipaddress
import sys
import platform
import logging
import multiprocessing
import voluptuous as vol
import webcolors
import json
import base64
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple, cast
from webcolors import CSS3
from importlib.metadata import version
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.components import conversation
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers import intent, llm, aiohttp_client
from homeassistant.requirements import pip_kwargs
from homeassistant.util import color
from homeassistant.util.package import install_package, is_installed
from voluptuous_openapi import convert as convert_to_openapi
from .const import (
DOMAIN,
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
ALLOWED_SERVICE_CALL_ARGUMENTS,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
SERVICE_TOOL_NAME,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from llama_cpp.llama_types import ChatCompletionRequestMessage, ChatCompletionTool
else:
ChatCompletionRequestMessage = Any
ChatCompletionTool = Any
_LOGGER = logging.getLogger(__name__)
CSS3_NAME_TO_RGB = {
name: webcolors.name_to_rgb(name, CSS3)
for name
in webcolors.names(CSS3)
}
class MissingQuantizationException(Exception):
def __init__(self, missing_quant: str, available_quants: list[str]):
super().__init__(missing_quant, available_quants)
self.missing_quant = missing_quant
self.available_quants = available_quants
class MalformedToolCallException(Exception):
def __init__(self, agent_id: str, tool_call_id: str, tool_name: str, tool_args: str, error_msg: str):
super().__init__(agent_id, tool_call_id, tool_name, tool_args, error_msg)
self.agent_id = agent_id
self.tool_call_id = tool_call_id
self.tool_name = tool_name
self.tool_args = tool_args
self.error_msg = error_msg
def as_tool_messages(self) -> Sequence[conversation.Content]:
return [
conversation.AssistantContent(
self.agent_id, tool_calls=[llm.ToolInput(self.tool_name, {})]
),
conversation.ToolResultContent(
self.agent_id, self.tool_call_id, self.tool_name,
{"error": f"Error occurred calling tool with args='{self.tool_args}': {self.error_msg}" }
)]
def closest_color(requested_color):
min_colors = {}
for name, rgb in CSS3_NAME_TO_RGB.items():
r_c, g_c, b_c = rgb
rd = (r_c - requested_color[0]) ** 2
gd = (g_c - requested_color[1]) ** 2
bd = (b_c - requested_color[2]) ** 2
min_colors[(rd + gd + bd)] = name
return min_colors[min(min_colors.keys())]
def flatten_vol_schema(schema):
flattened = []
def _flatten(current_schema, prefix=''):
if isinstance(current_schema, vol.Schema):
if isinstance(current_schema.schema, vol.validators._WithSubValidators):
for subval in current_schema.schema.validators:
_flatten(subval, prefix)
elif isinstance(current_schema.schema, dict):
for key, val in current_schema.schema.items():
_flatten(val, prefix + str(key) + '/')
elif isinstance(current_schema, vol.validators._WithSubValidators):
for subval in current_schema.validators:
_flatten(subval, prefix)
elif callable(current_schema):
flattened.append(prefix[:-1] if prefix else prefix)
_flatten(schema)
return flattened
def custom_custom_serializer(value):
"""a vol schema is really not straightforward to convert back into a dictionary"""
if value is cv.ensure_list:
return { "type": "list" }
if value is color.color_name_to_rgb:
return { "type": "string" }
if value is intent.non_empty_string:
return { "type": "string" }
# media player registers an intent using a lambda...
# there's literally no way to detect that properly. with that in mind, we have this
try:
if value(100) == 1:
return { "type": "integer" }
except Exception:
pass
# this is throwing exceptions. I thought vol should handle this already
if isinstance(value, vol.In):
if isinstance(value.container, dict):
return { "enum": list(value.container.keys()) }
else:
return { "enum": list(value.container) }
if isinstance(value, list):
result = {}
for x in value:
result.update(custom_custom_serializer(x))
return result
return cv.custom_serializer(value)
def download_model_from_hf(model_name: str, quantization_type: str, storage_folder: str, file_lookup_only: bool = False):
try:
from huggingface_hub import hf_hub_download, HfFileSystem
except Exception as ex:
raise Exception(f"Failed to import huggingface-hub library. Please re-install the integration.") from ex
fs = HfFileSystem()
potential_files = [ f for f in fs.glob(f"{model_name}/*.gguf") ]
wanted_file = [f for f in potential_files if (f"{quantization_type.lower()}.gguf" in f or f"{quantization_type.upper()}.gguf" in f)]
if len(wanted_file) != 1:
available_quants = [
re.split(r"\.|-", file.removesuffix(".gguf"))[-1].upper()
for file in potential_files
]
raise MissingQuantizationException(quantization_type, available_quants)
try:
os.makedirs(storage_folder, exist_ok=True)
except Exception as ex:
raise Exception(f"Failed to create the required folder for storing models! You may need to create the path '{storage_folder}' manually.") from ex
return hf_hub_download(
repo_id=model_name,
repo_type="model",
filename=wanted_file[0].removeprefix(model_name + "/"),
cache_dir=storage_folder,
local_files_only=file_lookup_only
)
def _load_extension():
"""
Makes sure it is possible to load llama-cpp-python without crashing Home Assistant.
This needs to be at the root file level because we are using the 'spawn' start method.
Also ignore ModuleNotFoundError because that just means it's not installed. Not that it will crash HA
"""
import importlib
try:
importlib.import_module("llama_cpp")
except ModuleNotFoundError:
pass
def validate_llama_cpp_python_installation():
"""
Spawns another process and tries to import llama.cpp to avoid crashing the main process
"""
mp_ctx = multiprocessing.get_context('spawn') # required because of aio
process = mp_ctx.Process(target=_load_extension)
process.start()
process.join()
if process.exitcode != 0:
raise Exception(f"Failed to properly initialize llama-cpp-python. (Exit code {process.exitcode}.)")
def get_llama_cpp_python_version():
if not is_installed("llama-cpp-python"):
return None
return version("llama-cpp-python")
def get_runtime_and_platform_suffix() -> Tuple[str, str]:
runtime_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_suffix = platform.machine()
# remap other names for architectures to the names we use
if platform_suffix == "arm64":
platform_suffix = "aarch64"
if platform_suffix == "i386" or platform_suffix == "amd64":
platform_suffix = "x86_64"
return runtime_version, platform_suffix
async def get_available_llama_cpp_versions(hass: HomeAssistant) -> List[Tuple[str, bool]]:
github_index_url = "https://acon96.github.io/llama-cpp-python/whl/ha/llama-cpp-python/"
session = aiohttp_client.async_get_clientsession(hass)
try:
async with session.get(github_index_url) as resp:
if resp.status != 200:
raise Exception(f"Failed to fetch available versions from GitHub (HTTP {resp.status})")
text = await resp.text()
# pull version numbers out of h2 tags
versions = re.findall(r"<h2.*>(.+)</h2>", text)
remote = sorted([(v, False) for v in versions], reverse=True)
except Exception as ex:
_LOGGER.warning(f"Error fetching available versions from GitHub: {repr(ex)}")
remote = []
runtime_version, platform_suffix = get_runtime_and_platform_suffix()
folder = os.path.dirname(__file__)
potential_wheels = sorted([ path for path in os.listdir(folder) if path.endswith(f"{platform_suffix}.whl") ], reverse=True)
local = [ (wheel, True) for wheel in potential_wheels if runtime_version in wheel and "llama_cpp_python" in wheel]
return remote + local
def install_llama_cpp_python(config_dir: str, force_reinstall: bool = False, specific_version: str | None = None) -> bool:
installed_wrong_version = False
if is_installed("llama-cpp-python") and not force_reinstall:
if version("llama-cpp-python") != EMBEDDED_LLAMA_CPP_PYTHON_VERSION:
installed_wrong_version = True
else:
time.sleep(0.5) # I still don't know why this is required
return True
runtime_version, platform_suffix = get_runtime_and_platform_suffix()
if not specific_version:
specific_version = EMBEDDED_LLAMA_CPP_PYTHON_VERSION
if ".whl" in specific_version:
wheel_location = os.path.join(os.path.dirname(__file__), specific_version)
else:
wheel_location = f"https://github.com/acon96/llama-cpp-python/releases/download/{specific_version}/llama_cpp_python-{specific_version}-{runtime_version}-{runtime_version}-linux_{platform_suffix}.whl"
if install_package(wheel_location, **pip_kwargs(config_dir)):
_LOGGER.info("llama-cpp-python successfully installed")
return True
# if it is just the wrong version installed then ignore the installation error
if not installed_wrong_version:
_LOGGER.error(
"Error installing llama-cpp-python. Could not install the binary wheels from GitHub." + \
"Please manually build or download the wheels and place them in the `/config/custom_components/llama_conversation` directory." + \
"Make sure that you download the correct .whl file for your platform and python version from the GitHub releases page."
)
return False
else:
_LOGGER.info(
"Error installing llama-cpp-python. Could not install the binary wheels from GitHub." + \
f"You already have a version of llama-cpp-python ({version('llama-cpp-python')}) installed, however it may not be compatible!"
)
time.sleep(0.5) # I still don't know why this is required
return True
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
result: List[ChatCompletionTool] = []
for tool in llm_api.tools:
# when combining with home assistant llm APIs, it adds a prefix to differentiate tools; compare against the suffix here
if tool.name.endswith(SERVICE_TOOL_NAME):
result.extend([{
"type": "function",
"function": {
"name": tool["name"],
"description": f"Call the Home Assistant service '{tool['name']}'",
"parameters": convert_to_openapi(tool["arguments"], custom_serializer=llm_api.custom_serializer)
}
} for tool in get_home_llm_tools(llm_api, domains) ])
else:
result.append({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": convert_to_openapi(tool.parameters, custom_serializer=llm_api.custom_serializer)
}
})
return result
def get_oai_formatted_messages(conversation: Sequence[conversation.Content], user_content_as_list: bool = False, tool_args_to_str: bool = True) -> List[ChatCompletionRequestMessage]:
messages: List[ChatCompletionRequestMessage] = []
for message in conversation:
if message.role == "system":
messages.append({
"role": "system",
"content": message.content
})
elif message.role == "user":
images: list[str] = []
for attachment in message.attachments or ():
if not attachment.mime_type.startswith("image/"):
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="unsupported_attachment_type",
)
images.append(get_file_contents_base64(attachment.path))
if user_content_as_list:
content = [{ "type": "text", "text": message.content }]
for image in images:
content.append({ "type": "image_url", "image_url": {"url": image } })
messages.append({
"role": "user",
"content": content
})
else:
message = {
"role": "user",
"content": message.content
}
if images:
message["images"] = images
messages.append(message)
elif message.role == "assistant":
if message.tool_calls:
messages.append({
"role": "assistant",
"content": str(message.content),
"tool_calls": [
{
"type" : "function",
"id": t.id,
"function": {
"arguments": cast(str, json.dumps(t.tool_args) if tool_args_to_str else t.tool_args),
"name": t.tool_name,
}
} for t in message.tool_calls
]
})
elif message.role == "tool_result":
messages.append({
"role": "tool",
# FIXME: what is the correct format for content here? gemma expects name and result
# "content": json.dumps(message.tool_result),
"content": {
"name": message.tool_name,
"response": { "result": message.tool_result },
},
"tool_call_id": message.tool_call_id
})
return messages
def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dict[str, Any]]:
service_dict = llm_api.api.hass.services.async_services()
all_services = []
scripts_added = False
for domain in domains:
if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
continue
# scripts show up as individual services
if domain == "script" and not scripts_added:
all_services.extend([
("script.reload", vol.Schema({vol.Required("target_device"): str})),
("script.turn_on", vol.Schema({vol.Required("target_device"): str})),
("script.turn_off", vol.Schema({vol.Required("target_device"): str})),
("script.toggle", vol.Schema({vol.Required("target_device"): str})),
])
scripts_added = True
continue
for name, service in service_dict.get(domain, {}).items():
if name not in SERVICE_TOOL_ALLOWED_SERVICES:
continue
args = flatten_vol_schema(service.schema)
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
service_schema = vol.Schema({
vol.Required("target_device"): str,
**{vol.Optional(arg): str for arg in args_to_expose}
})
all_services.append((f"{domain}.{name}", service_schema))
tools: List[Dict[str, Any]] = [
{ "name": service[0], "arguments": service[1] } for service in all_services
]
return tools
def parse_raw_tool_call(raw_block: str | dict, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
if isinstance(raw_block, dict):
parsed_tool_call = raw_block
else:
try:
parsed_tool_call: dict = json.loads(raw_block)
except json.JSONDecodeError:
# handle the "gemma" tool calling format
# call:HassTurnOn{name:<escape>light.living_room_rgbww<escape>}
gemma_match = re.finditer(r"call:(?P<name>\w+){(?P<args>.+)}", raw_block)
for match in gemma_match:
tool_name = match.group("name")
raw_args = match.group("args")
args_dict = {}
for arg_match in re.finditer(r"(?P<key>\w+):<escape>(?P<value>.+?)<escape>", raw_args):
args_dict[arg_match.group("key")] = arg_match.group("value")
parsed_tool_call = {
"name": tool_name,
"arguments": args_dict
}
break # TODO: how do we properly handle multiple tool calls in one response?
else:
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted JSON")
# try to validate either format
is_services_tool_call = False
try:
base_schema_to_validate = vol.Schema({
vol.Required("name"): str,
vol.Required("arguments"): vol.Union(str, dict),
})
base_schema_to_validate(parsed_tool_call)
except vol.Error as ex:
try:
home_llm_schema_to_validate = vol.Schema({
vol.Required('service'): str,
vol.Required('target_device'): str,
vol.Optional('rgb_color'): str,
vol.Optional('brightness'): vol.Coerce(float),
vol.Optional('temperature'): vol.Coerce(float),
vol.Optional('humidity'): vol.Coerce(float),
vol.Optional('fan_mode'): str,
vol.Optional('hvac_mode'): str,
vol.Optional('preset_mode'): str,
vol.Optional('duration'): str,
vol.Optional('item'): str,
})
home_llm_schema_to_validate(parsed_tool_call)
is_services_tool_call = True
except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
# try to fix certain arguments
args_dict = parsed_tool_call if is_services_tool_call else parsed_tool_call["arguments"]
tool_name = SERVICE_TOOL_NAME if is_services_tool_call else parsed_tool_call["name"]
if isinstance(args_dict, str):
if not args_dict.strip():
args_dict = {} # don't attempt to parse empty arguments
else:
try:
args_dict = json.loads(args_dict)
except json.JSONDecodeError:
raise MalformedToolCallException(agent_id, "", tool_name, str(args_dict), "Tool arguments were not properly formatted JSON")
# make sure brightness is 0-255 and not a percentage
if "brightness" in args_dict and 0.0 < args_dict["brightness"] <= 1.0:
args_dict["brightness"] = int(args_dict["brightness"] * 255)
# convert string "tuple" to a list for RGB colors
if "rgb_color" in args_dict and isinstance(args_dict["rgb_color"], str):
args_dict["rgb_color"] = [ int(x) for x in args_dict["rgb_color"][1:-1].split(",") ]
to_say = args_dict.pop("to_say", "")
tool_input = llm.ToolInput(
tool_name=tool_name,
tool_args=args_dict,
)
return tool_input, to_say
def is_valid_hostname(host: str) -> bool:
"""
Validates whether a string is a valid hostname or IP address,
rejecting URLs, paths, ports, query strings, etc.
"""
if not host or not isinstance(host, str):
return False
# Normalize: strip whitespace
host = host.strip().lower()
# Special case: localhost
if host == "localhost":
return True
# Try to parse as IPv4
try:
ipaddress.IPv4Address(host)
return True
except ipaddress.AddressValueError:
pass
# Try to parse as IPv6
try:
ipaddress.IPv6Address(host)
return True
except ipaddress.AddressValueError:
pass
# Validate as domain name (RFC 1034/1123)
# Rules:
# - Only a-z, 0-9, hyphens
# - No leading/trailing hyphens
# - Max 63 chars per label
# - At least 2 chars in TLD
# - No consecutive dots
domain_pattern = re.compile(r"^[a-z0-9]([a-z0-9\-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9\-]{0,61}[a-z0-9])?)*\.[a-z]{2,}$")
return bool(domain_pattern.match(host))
def get_file_contents_base64(file_path: Path) -> str:
"""Reads a file and returns its contents encoded in base64."""
with open(file_path, "rb") as f:
encoded_bytes = base64.b64encode(f.read())
encoded_str = encoded_bytes.decode('utf-8')
return encoded_str