mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
update naming and start implementing new LLM API support
This commit is contained in:
9
TODO.md
9
TODO.md
@@ -1,6 +1,13 @@
|
||||
# TODO
|
||||
- [x] detection/mitigation of too many entities being exposed & blowing out the context length
|
||||
- [ ] support new LLM APIs
|
||||
- rewrite how services are called
|
||||
- handle no API selected
|
||||
- rewrite prompts + service block formats
|
||||
- update dataset so new models will work with the API
|
||||
- [ ] make ICL examples into conversation turns
|
||||
- [ ] translate ICL examples + make better ones
|
||||
- [ ] areas/room support
|
||||
- [x] detection/mitigation of too many entities being exposed & blowing out the context length
|
||||
- [ ] figure out DPO to improve response quality
|
||||
- [ ] train the model to respond to house events
|
||||
- present the model with an event + a "prompt" from the user of what you want it to do (i.e. turn on the lights when I get home = the model turns on lights when your entity presence triggers as being home)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""The Local LLaMA Conversation integration."""
|
||||
"""The Local LLM Conversation integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@@ -9,8 +9,8 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
from .agent import (
|
||||
LLaMAAgent,
|
||||
LocalLLaMAAgent,
|
||||
LocalLLMAgent,
|
||||
LlamaCppAgent,
|
||||
GenericOpenAIAPIAgent,
|
||||
TextGenerationWebuiAgent,
|
||||
LlamaCppPythonAPIAgent,
|
||||
@@ -38,19 +38,19 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
|
||||
# call update handler
|
||||
agent: LLaMAAgent = await ha_conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
agent: LocalLLMAgent = await ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||
agent._update_options()
|
||||
|
||||
return True
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up Local LLaMA Conversation from a config entry."""
|
||||
"""Set up Local LLM Conversation from a config entry."""
|
||||
|
||||
def create_agent(backend_type):
|
||||
agent_cls = None
|
||||
|
||||
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
|
||||
agent_cls = LocalLLaMAAgent
|
||||
agent_cls = LlamaCppAgent
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
agent_cls = GenericOpenAIAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
@@ -78,7 +78,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload Local LLaMA."""
|
||||
"""Unload Local LLM."""
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
ha_conversation.async_unset_agent(hass, entry)
|
||||
return True
|
||||
|
||||
@@ -18,10 +18,10 @@ from homeassistant.components.conversation import ConversationInput, Conversatio
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL, CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError
|
||||
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm
|
||||
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
||||
from homeassistant.util import ulid
|
||||
|
||||
@@ -114,8 +114,8 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
class LLaMAAgent(AbstractConversationAgent):
|
||||
"""Base LLaMA conversation agent."""
|
||||
class LocalLLMAgent(AbstractConversationAgent):
|
||||
"""Base Local LLM conversation agent."""
|
||||
|
||||
hass: HomeAssistant
|
||||
entry_id: str
|
||||
@@ -225,6 +225,22 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
llm_api: llm.API | None = None
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = llm.async_get_api(
|
||||
self.hass, self.entry.options[CONF_LLM_HASS_API]
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error getting LLM API: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Error preparing LLM API: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
if user_input.conversation_id in self.history:
|
||||
conversation_id = user_input.conversation_id
|
||||
@@ -235,7 +251,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
if len(conversation) == 0 or refresh_system_prompt:
|
||||
try:
|
||||
message = self._generate_system_prompt(raw_prompt)
|
||||
message = self._generate_system_prompt(raw_prompt, llm_api)
|
||||
except TemplateError as err:
|
||||
_LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
@@ -407,7 +423,7 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
_LOGGER.debug(formatted_prompt)
|
||||
return formatted_prompt
|
||||
|
||||
def _generate_system_prompt(self, prompt_template: str) -> str:
|
||||
def _generate_system_prompt(self, prompt_template: str, llm_api: llm.API) -> str:
|
||||
"""Generate the system prompt with current entity states"""
|
||||
entities_to_expose, domains = self._async_get_exposed_entities()
|
||||
|
||||
@@ -487,21 +503,14 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
|
||||
formatted_states = "\n".join(device_states) + "\n"
|
||||
|
||||
service_dict = self.hass.services.async_services()
|
||||
all_services = []
|
||||
all_service_names = []
|
||||
for domain in domains:
|
||||
# scripts show up as individual services
|
||||
if domain == "script":
|
||||
all_services.extend(["script.reload()", "script.turn_on()", "script.turn_off()", "script.toggle()"])
|
||||
continue
|
||||
|
||||
for name, service in service_dict.get(domain, {}).items():
|
||||
args = flatten_vol_schema(service.schema)
|
||||
args_to_expose = set(args).intersection(allowed_service_call_arguments)
|
||||
all_services.append(f"{domain}.{name}({','.join(args_to_expose)})")
|
||||
all_service_names.append(f"{domain}.{name}")
|
||||
formatted_services = ", ".join(all_services)
|
||||
if llm_api:
|
||||
tools = [
|
||||
f"{tool.name}({flatten_vol_schema(tool.parameters)}) - {tool.description}"
|
||||
for tool in llm_api.async_get_tools()
|
||||
]
|
||||
formatted_services = llm_api.prompt_template + "\n" + "\n".join(tools)
|
||||
else:
|
||||
formatted_services = "No tools exposed."
|
||||
|
||||
render_variables = {
|
||||
"devices": formatted_states,
|
||||
@@ -509,15 +518,16 @@ class LLaMAAgent(AbstractConversationAgent):
|
||||
}
|
||||
|
||||
if self.in_context_examples:
|
||||
num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES))
|
||||
render_variables["response_examples"] = "\n".join(icl_example_generator(num_examples, list(entities_to_expose.keys()), all_service_names))
|
||||
# num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES))
|
||||
# render_variables["response_examples"] = "\n".join(icl_example_generator(num_examples, list(entities_to_expose.keys()), all_service_names))
|
||||
render_variables["response_examples"] = ""
|
||||
|
||||
return template.Template(prompt_template, self.hass).async_render(
|
||||
render_variables,
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
class LocalLLaMAAgent(LLaMAAgent):
|
||||
class LlamaCppAgent(LocalLLMAgent):
|
||||
model_path: str
|
||||
llm: LlamaType
|
||||
grammar: Any
|
||||
@@ -612,7 +622,7 @@ class LocalLLaMAAgent(LLaMAAgent):
|
||||
self.grammar = None
|
||||
|
||||
def _update_options(self):
|
||||
LLaMAAgent._update_options(self)
|
||||
LocalLLMAgent._update_options(self)
|
||||
|
||||
model_reloaded = False
|
||||
if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \
|
||||
@@ -662,7 +672,7 @@ class LocalLLaMAAgent(LLaMAAgent):
|
||||
|
||||
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
||||
"""Takes the super class function results and sorts the entities with the recently updated at the end"""
|
||||
entities, domains = LLaMAAgent._async_get_exposed_entities(self)
|
||||
entities, domains = LocalLLMAgent._async_get_exposed_entities(self)
|
||||
|
||||
# ignore sorting if prompt caching is disabled
|
||||
if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
|
||||
@@ -739,10 +749,20 @@ class LocalLLaMAAgent(LLaMAAgent):
|
||||
self.cache_refresh_after_cooldown = True
|
||||
return
|
||||
|
||||
llm_api: llm.API | None = None
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = llm.async_get_api(
|
||||
self.hass, self.entry.options[CONF_LLM_HASS_API]
|
||||
)
|
||||
except HomeAssistantError:
|
||||
_LOGGER.exception("Failed to get LLM API when caching prompt!")
|
||||
return
|
||||
|
||||
try:
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
prompt = self._format_prompt([
|
||||
{ "role": "system", "message": self._generate_system_prompt(raw_prompt)},
|
||||
{ "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)},
|
||||
{ "role": "user", "message": "" }
|
||||
], include_generation_prompt=False)
|
||||
|
||||
@@ -839,7 +859,7 @@ class LocalLLaMAAgent(LLaMAAgent):
|
||||
|
||||
return result
|
||||
|
||||
class GenericOpenAIAPIAgent(LLaMAAgent):
|
||||
class GenericOpenAIAPIAgent(LocalLLMAgent):
|
||||
api_host: str
|
||||
api_key: str
|
||||
model_name: str
|
||||
@@ -1046,7 +1066,7 @@ class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
class OllamaAPIAgent(LLaMAAgent):
|
||||
class OllamaAPIAgent(LocalLLMAgent):
|
||||
api_host: str
|
||||
api_key: str
|
||||
model_name: str
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Config flow for Local LLaMA Conversation integration."""
|
||||
"""Config flow for Local LLM Conversation integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@@ -13,18 +13,20 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, UnitOfTime
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime
|
||||
from homeassistant.data_entry_flow import (
|
||||
AbortFlow,
|
||||
FlowHandler,
|
||||
FlowManager,
|
||||
FlowResult,
|
||||
)
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.selector import (
|
||||
NumberSelector,
|
||||
NumberSelectorConfig,
|
||||
NumberSelectorMode,
|
||||
TemplateSelector,
|
||||
SelectOptionDict,
|
||||
SelectSelector,
|
||||
SelectSelectorConfig,
|
||||
SelectSelectorMode,
|
||||
@@ -279,7 +281,7 @@ class BaseLlamaConversationConfigFlow(FlowHandler, ABC):
|
||||
""" Finish configuration """
|
||||
|
||||
class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Local LLaMA Conversation."""
|
||||
"""Handle a config flow for Local LLM Conversation."""
|
||||
|
||||
VERSION = 1
|
||||
install_wheel_task = None
|
||||
@@ -584,7 +586,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en"))
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<persona>", persona)
|
||||
|
||||
schema = vol.Schema(local_llama_config_option_schema(selected_default_options, backend_type))
|
||||
schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type))
|
||||
|
||||
if user_input:
|
||||
self.options = user_input
|
||||
@@ -626,7 +628,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
|
||||
|
||||
class OptionsFlow(config_entries.OptionsFlow):
|
||||
"""Local LLaMA config flow options handler."""
|
||||
"""Local LLM config flow options handler."""
|
||||
|
||||
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
|
||||
"""Initialize options flow."""
|
||||
@@ -656,9 +658,10 @@ class OptionsFlow(config_entries.OptionsFlow):
|
||||
description_placeholders["filename"] = filename
|
||||
|
||||
if len(errors) == 0:
|
||||
return self.async_create_entry(title="LLaMA Conversation", data=user_input)
|
||||
return self.async_create_entry(title="Local LLM Conversation", data=user_input)
|
||||
|
||||
schema = local_llama_config_option_schema(
|
||||
self.hass,
|
||||
self.config_entry.options,
|
||||
self.config_entry.data[CONF_BACKEND_TYPE],
|
||||
)
|
||||
@@ -682,12 +685,31 @@ def insert_after_key(input_dict: dict, key_name: str, other_dict: dict):
|
||||
|
||||
return result
|
||||
|
||||
def local_llama_config_option_schema(options: MappingProxyType[str, Any], backend_type: str) -> dict:
|
||||
"""Return a schema for Local LLaMA completion options."""
|
||||
def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyType[str, Any], backend_type: str) -> dict:
|
||||
"""Return a schema for Local LLM completion options."""
|
||||
if not options:
|
||||
options = DEFAULT_OPTIONS
|
||||
|
||||
apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label="No control",
|
||||
value="none",
|
||||
)
|
||||
]
|
||||
apis.extend(
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
)
|
||||
for api in llm.async_get_apis(hass)
|
||||
)
|
||||
|
||||
result = {
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default="none",
|
||||
): SelectSelector(SelectSelectorConfig(options=apis)),
|
||||
vol.Required(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"domain": "llama_conversation",
|
||||
"name": "LLaMA Conversation",
|
||||
"name": "Local LLM Conversation",
|
||||
"version": "0.2.17",
|
||||
"codeowners": ["@acon96"],
|
||||
"config_flow": true,
|
||||
|
||||
@@ -35,7 +35,7 @@ The following link will open your Home Assistant installation and download the i
|
||||
|
||||
[](https://my.home-assistant.io/redirect/hacs_repository/?category=Integration&repository=home-llm&owner=acon96)
|
||||
|
||||
After installation, A "LLaMA Conversation" device should show up in the `Settings > Devices and Services > [Devices]` tab now.
|
||||
After installation, A "Local LLM Conversation" device should show up in the `Settings > Devices and Services > [Devices]` tab now.
|
||||
|
||||
## Path 1: Using the Home Model with the Llama.cpp Backend
|
||||
### Overview
|
||||
@@ -44,7 +44,7 @@ This setup path involves downloading a fine-tuned model from HuggingFace and int
|
||||
### Step 1: Wheel Installation for llama-cpp-python
|
||||
1. In Home Assistant: navigate to `Settings > Devices and Services`
|
||||
2. Select the `+ Add Integration` button in the bottom right corner
|
||||
3. Search for, and select `LLaMA Conversation`
|
||||
3. Search for, and select `Local LLM Conversation`
|
||||
4. With the `Llama.cpp (HuggingFace)` backend selected, click `Submit`
|
||||
|
||||
This should download and install `llama-cpp-python` from GitHub. If the installation fails for any reason, follow the manual installation instructions [here](./Backend%20Configuration.md#wheels).
|
||||
@@ -82,7 +82,7 @@ In order to access the model from another machine, we need to run the Ollama API
|
||||
|
||||
1. In Home Assistant: navigate to `Settings > Devices and Services`
|
||||
2. Select the `+ Add Integration` button in the bottom right corner
|
||||
3. Search for, and select `LLaMA Conversation`
|
||||
3. Search for, and select `Local LLM Conversation`
|
||||
4. Select `Ollama API` from the dropdown and click `Submit`
|
||||
5. Set up the connection to the API:
|
||||
- **IP Address**: Fill out IP Address for the machine hosting Ollama
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "LLaMA Conversation",
|
||||
"homeassistant": "2023.10.0",
|
||||
"name": "Local LLM Conversation",
|
||||
"homeassistant": "2024.5.5",
|
||||
"content_in_root": false,
|
||||
"render_readme": true
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import pytest
|
||||
import jinja2
|
||||
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
|
||||
|
||||
from custom_components.llama_conversation.agent import LocalLLaMAAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.agent import LlamaCppAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
@@ -140,10 +140,10 @@ def home_assistant_mock():
|
||||
|
||||
@pytest.fixture
|
||||
def local_llama_agent_fixture(config_entry, home_assistant_mock):
|
||||
with patch.object(LocalLLaMAAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(LocalLLaMAAgent, '_load_grammar') as load_grammar_mock, \
|
||||
patch.object(LocalLLaMAAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(LocalLLaMAAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
with patch.object(LlamaCppAgent, '_load_icl_examples') as load_icl_examples_mock, \
|
||||
patch.object(LlamaCppAgent, '_load_grammar') as load_grammar_mock, \
|
||||
patch.object(LlamaCppAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
|
||||
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
|
||||
patch('homeassistant.helpers.template.Template') as template_mock, \
|
||||
patch('custom_components.llama_conversation.agent.importlib.import_module') as import_module_mock, \
|
||||
patch('custom_components.llama_conversation.agent.install_llama_cpp_python') as install_llama_cpp_python_mock:
|
||||
@@ -174,7 +174,7 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock):
|
||||
"target_device": "light.kitchen_light",
|
||||
}).encode()
|
||||
|
||||
agent_obj = LocalLLaMAAgent(
|
||||
agent_obj = LlamaCppAgent(
|
||||
home_assistant_mock,
|
||||
config_entry
|
||||
)
|
||||
@@ -191,7 +191,7 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock):
|
||||
|
||||
async def test_local_llama_agent(local_llama_agent_fixture):
|
||||
|
||||
local_llama_agent: LocalLLaMAAgent
|
||||
local_llama_agent: LlamaCppAgent
|
||||
all_mocks: dict[str, MagicMock]
|
||||
local_llama_agent, all_mocks = local_llama_agent_fixture
|
||||
|
||||
|
||||
Reference in New Issue
Block a user