mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
Add a platform so agents show up in the UI correctly + bump llama-cpp-python
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Final
|
||||
|
||||
import homeassistant.components.conversation as ha_conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.const import ATTR_ENTITY_ID, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, llm
|
||||
@@ -14,24 +14,8 @@ from homeassistant.util.json import JsonObjectType
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from .agent import (
|
||||
LocalLLMAgent,
|
||||
LlamaCppAgent,
|
||||
GenericOpenAIAPIAgent,
|
||||
TextGenerationWebuiAgent,
|
||||
LlamaCppPythonAPIAgent,
|
||||
OllamaAPIAgent,
|
||||
)
|
||||
|
||||
from .const import (
|
||||
CONF_BACKEND_TYPE,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
DOMAIN,
|
||||
HOME_LLM_API_ID,
|
||||
@@ -42,61 +26,26 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Handle options update."""
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
|
||||
# call update handler
|
||||
agent: LocalLLMAgent = ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
await hass.async_add_executor_job(agent._update_options)
|
||||
PLATFORMS = (Platform.CONVERSATION,)
|
||||
|
||||
return True
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up Local LLM Conversation from a config entry."""
|
||||
|
||||
# make sure the API is registered
|
||||
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(hass)]):
|
||||
llm.async_register_api(hass, HomeLLMAPI(hass))
|
||||
|
||||
def create_agent(backend_type):
|
||||
agent_cls = None
|
||||
|
||||
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
|
||||
agent_cls = LlamaCppAgent
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
agent_cls = GenericOpenAIAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
agent_cls = TextGenerationWebuiAgent
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
agent_cls = LlamaCppPythonAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
agent_cls = OllamaAPIAgent
|
||||
|
||||
return agent_cls(hass, entry)
|
||||
|
||||
# create the agent in an executor job because the constructor calls `open()`
|
||||
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
agent = await hass.async_add_executor_job(create_agent, backend_type)
|
||||
|
||||
# call load model
|
||||
await agent._async_load_model(entry)
|
||||
|
||||
# handle updates to the options
|
||||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||
|
||||
ha_conversation.async_set_agent(hass, entry, agent)
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry
|
||||
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload Local LLM."""
|
||||
"""Unload Ollama."""
|
||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||
return False
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
ha_conversation.async_unset_agent(hass, entry)
|
||||
return True
|
||||
|
||||
async def async_migrate_entry(hass, config_entry: ConfigEntry):
|
||||
|
||||
@@ -384,4 +384,4 @@ OPTIONS_OVERRIDES = {
|
||||
}
|
||||
|
||||
INTEGRATION_VERSION = "0.3.4"
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION = "0.2.84"
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION = "0.2.87"
|
||||
@@ -15,8 +15,8 @@ import time
|
||||
import voluptuous as vol
|
||||
from typing import Literal, Any, Callable
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent
|
||||
import homeassistant.components.conversation as ha_conversation
|
||||
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent, ConversationEntity
|
||||
from homeassistant.components import assist_pipeline, conversation as ha_conversation
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@@ -25,11 +25,13 @@ from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm, \
|
||||
area_registry as ar, device_registry as dr
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
||||
from homeassistant.components.sensor import SensorEntity
|
||||
from homeassistant.util import ulid, color
|
||||
|
||||
|
||||
import voluptuous_serialize
|
||||
|
||||
from .utils import closest_color, flatten_vol_schema, custom_custom_serializer, install_llama_cpp_python, \
|
||||
@@ -119,6 +121,14 @@ from .const import (
|
||||
TOOL_FORMAT_REDUCED,
|
||||
TOOL_FORMAT_MINIMAL,
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
CONF_BACKEND_TYPE,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
BACKEND_TYPE_LLAMA_HF,
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
)
|
||||
|
||||
# make type checking work for llama-cpp-python without importing it directly at runtime
|
||||
@@ -132,7 +142,50 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
class LocalLLMAgent(AbstractConversationAgent):
|
||||
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Handle options update."""
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
|
||||
# call update handler
|
||||
agent: LocalLLMAgent = ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
await hass.async_add_executor_job(agent._update_options)
|
||||
|
||||
return True
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback) -> bool:
|
||||
"""Set up Local LLM Conversation from a config entry."""
|
||||
|
||||
def create_agent(backend_type):
|
||||
agent_cls = None
|
||||
|
||||
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
|
||||
agent_cls = LlamaCppAgent
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
agent_cls = GenericOpenAIAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
agent_cls = TextGenerationWebuiAgent
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
agent_cls = LlamaCppPythonAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
agent_cls = OllamaAPIAgent
|
||||
|
||||
return agent_cls(hass, entry)
|
||||
|
||||
# create the agent in an executor job because the constructor calls `open()`
|
||||
backend_type = entry.data.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
agent = await hass.async_add_executor_job(create_agent, backend_type)
|
||||
|
||||
# call load model
|
||||
await agent._async_load_model(entry)
|
||||
|
||||
# handle updates to the options
|
||||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||
|
||||
async_add_entities([agent])
|
||||
|
||||
return True
|
||||
|
||||
class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
"""Base Local LLM conversation agent."""
|
||||
|
||||
hass: HomeAssistant
|
||||
@@ -140,8 +193,13 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
history: dict[str, list[dict]]
|
||||
in_context_examples: list[dict]
|
||||
|
||||
_attr_has_entity_name = True
|
||||
|
||||
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self._attr_name = entry.title
|
||||
self._attr_unique_id = entry.entry_id
|
||||
|
||||
self.hass = hass
|
||||
self.entry_id = entry.entry_id
|
||||
self.history = {}
|
||||
@@ -150,10 +208,28 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
|
||||
)
|
||||
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
ha_conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
self.in_context_examples = None
|
||||
if entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
self._load_icl_examples(entry.options.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE))
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""When entity is added to Home Assistant."""
|
||||
await super().async_added_to_hass()
|
||||
assist_pipeline.async_migrate_engine(
|
||||
self.hass, "conversation", self.entry.entry_id, self.entity_id
|
||||
)
|
||||
ha_conversation.async_set_agent(self.hass, self.entry, self)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""When entity will be removed from Home Assistant."""
|
||||
ha_conversation.async_unset_agent(self.hass, self.entry)
|
||||
await super().async_will_remove_from_hass()
|
||||
|
||||
def _load_icl_examples(self, filename: str):
|
||||
"""Load info used for generating in context learning examples"""
|
||||
try:
|
||||
@@ -175,6 +251,11 @@ class LocalLLMAgent(AbstractConversationAgent):
|
||||
self.in_context_examples = None
|
||||
|
||||
def _update_options(self):
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
ha_conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
self._load_icl_examples(self.entry.options.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE))
|
||||
else:
|
||||
@@ -5,6 +5,7 @@
|
||||
"codeowners": ["@acon96"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
"after_dependencies": ["assist_pipeline"],
|
||||
"documentation": "https://github.com/acon96/home-llm",
|
||||
"integration_type": "service",
|
||||
"iot_class": "local_polling",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
VERSION_TO_BUILD="v0.2.77"
|
||||
VERSION_TO_BUILD="v0.2.87"
|
||||
|
||||
# make python11 wheels
|
||||
# docker run -it --rm \
|
||||
|
||||
@@ -4,7 +4,7 @@ import pytest
|
||||
import jinja2
|
||||
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
|
||||
|
||||
from custom_components.llama_conversation.agent import LlamaCppAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.conversation import LlamaCppAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
|
||||
Reference in New Issue
Block a user