mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
146 lines
4.9 KiB
Python
146 lines
4.9 KiB
Python
"""The Local LLM Conversation integration."""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Final
|
|
|
|
import homeassistant.components.conversation as ha_conversation
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.const import ATTR_ENTITY_ID, Platform
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers import config_validation as cv, llm
|
|
from homeassistant.util.json import JsonObjectType
|
|
|
|
import voluptuous as vol
|
|
|
|
|
|
from .const import (
|
|
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
|
DOMAIN,
|
|
HOME_LLM_API_ID,
|
|
SERVICE_TOOL_NAME,
|
|
SERVICE_TOOL_ALLOWED_SERVICES,
|
|
SERVICE_TOOL_ALLOWED_DOMAINS,
|
|
)
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
|
|
|
PLATFORMS = (Platform.CONVERSATION,)
|
|
|
|
|
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|
|
|
# make sure the API is registered
|
|
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(hass)]):
|
|
llm.async_register_api(hass, HomeLLMAPI(hass))
|
|
|
|
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
|
|
|
|
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
|
return True
|
|
|
|
|
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|
"""Unload Ollama."""
|
|
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
|
return False
|
|
hass.data[DOMAIN].pop(entry.entry_id)
|
|
return True
|
|
|
|
async def async_migrate_entry(hass, config_entry: ConfigEntry):
|
|
"""Migrate old entry."""
|
|
_LOGGER.debug("Migrating from version %s", config_entry.version)
|
|
|
|
# 1 -> 2: This was a breaking change so force users to re-create entries
|
|
if config_entry.version == 1:
|
|
_LOGGER.error("Cannot upgrade models that were created prior to v0.3. Please delete and re-create them.")
|
|
return False
|
|
|
|
_LOGGER.debug("Migration to version %s successful", config_entry.version)
|
|
|
|
return True
|
|
|
|
class HassServiceTool(llm.Tool):
|
|
"""Tool to get the current time."""
|
|
|
|
name: Final[str] = SERVICE_TOOL_NAME
|
|
description: Final[str] = "Executes a Home Assistant service"
|
|
|
|
# Optional. A voluptuous schema of the input parameters.
|
|
parameters = vol.Schema({
|
|
vol.Required('service'): str,
|
|
vol.Required('target_device'): str,
|
|
vol.Optional('rgb_color'): str,
|
|
vol.Optional('brightness'): float,
|
|
vol.Optional('temperature'): float,
|
|
vol.Optional('humidity'): float,
|
|
vol.Optional('fan_mode'): str,
|
|
vol.Optional('hvac_mode'): str,
|
|
vol.Optional('preset_mode'): str,
|
|
vol.Optional('duration'): str,
|
|
vol.Optional('item'): str,
|
|
})
|
|
|
|
ALLOWED_SERVICES: Final[list[str]] = SERVICE_TOOL_ALLOWED_SERVICES
|
|
ALLOWED_DOMAINS: Final[list[str]] = SERVICE_TOOL_ALLOWED_DOMAINS
|
|
|
|
async def async_call(
|
|
self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
|
|
) -> JsonObjectType:
|
|
"""Call the tool."""
|
|
try:
|
|
domain, service = tuple(tool_input.tool_args["service"].split("."))
|
|
except ValueError:
|
|
return { "result": "unknown service" }
|
|
|
|
target_device = tool_input.tool_args["target_device"]
|
|
|
|
if domain not in self.ALLOWED_DOMAINS or service not in self.ALLOWED_SERVICES:
|
|
return { "result": "unknown service" }
|
|
|
|
if domain == "script" and service not in ["reload", "turn_on", "turn_off", "toggle"]:
|
|
return { "result": "unknown service" }
|
|
|
|
service_data = {ATTR_ENTITY_ID: target_device}
|
|
for attr in ALLOWED_SERVICE_CALL_ARGUMENTS:
|
|
if attr in tool_input.tool_args.keys():
|
|
service_data[attr] = tool_input.tool_args[attr]
|
|
try:
|
|
await hass.services.async_call(
|
|
domain,
|
|
service,
|
|
service_data=service_data,
|
|
blocking=True,
|
|
)
|
|
except Exception:
|
|
_LOGGER.exception("Failed to execute service for model")
|
|
return { "result": "failed" }
|
|
|
|
return { "result": "success" }
|
|
|
|
class HomeLLMAPI(llm.API):
|
|
"""
|
|
An API that allows calling Home Assistant services to maintain compatibility
|
|
with the older (v3 and older) Home LLM models
|
|
"""
|
|
|
|
def __init__(self, hass: HomeAssistant) -> None:
|
|
"""Init the class."""
|
|
super().__init__(
|
|
hass=hass,
|
|
id=HOME_LLM_API_ID,
|
|
name="Home-LLM (v1-v3)",
|
|
)
|
|
|
|
async def async_get_api_instance(self, llm_context: llm.LLMContext) -> llm.APIInstance:
|
|
"""Return the instance of the API."""
|
|
return llm.APIInstance(
|
|
api=self,
|
|
api_prompt="Call services in Home Assistant by passing the service name and the device to control.",
|
|
llm_context=llm_context,
|
|
tools=[HassServiceTool()],
|
|
)
|