update naming and start implementing new LLM API support

This commit is contained in:
Alex O'Connell
2024-05-25 17:12:58 -04:00
parent 9cacc4d78e
commit 8a28dd61ad
8 changed files with 108 additions and 59 deletions

View File

@@ -1,6 +1,13 @@
# TODO
- [x] detection/mitigation of too many entities being exposed & blowing out the context length
- [ ] support new LLM APIs
- rewrite how services are called
- handle no API selected
- rewrite prompts + service block formats
- update dataset so new models will work with the API
- [ ] make ICL examples into conversation turns
- [ ] translate ICL examples + make better ones
- [ ] areas/room support
- [x] detection/mitigation of too many entities being exposed & blowing out the context length
- [ ] figure out DPO to improve response quality
- [ ] train the model to respond to house events
- present the model with an event + a "prompt" from the user of what you want it to do (i.e. turn on the lights when I get home = the model turns on lights when your entity presence triggers as being home)

View File

@@ -1,4 +1,4 @@
"""The Local LLaMA Conversation integration."""
"""The Local LLM Conversation integration."""
from __future__ import annotations
import logging
@@ -9,8 +9,8 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from .agent import (
LLaMAAgent,
LocalLLaMAAgent,
LocalLLMAgent,
LlamaCppAgent,
GenericOpenAIAPIAgent,
TextGenerationWebuiAgent,
LlamaCppPythonAPIAgent,
@@ -38,19 +38,19 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
hass.data[DOMAIN][entry.entry_id] = entry
# call update handler
agent: LLaMAAgent = await ha_conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
agent: LocalLLMAgent = await ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id)
agent._update_options()
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Local LLaMA Conversation from a config entry."""
"""Set up Local LLM Conversation from a config entry."""
def create_agent(backend_type):
agent_cls = None
if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]:
agent_cls = LocalLLaMAAgent
agent_cls = LlamaCppAgent
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
agent_cls = GenericOpenAIAPIAgent
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
@@ -78,7 +78,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Local LLaMA."""
"""Unload Local LLM."""
hass.data[DOMAIN].pop(entry.entry_id)
ha_conversation.async_unset_agent(hass, entry)
return True

View File

@@ -18,10 +18,10 @@ from homeassistant.components.conversation import ConversationInput, Conversatio
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm
from homeassistant.helpers.event import async_track_state_change, async_call_later
from homeassistant.util import ulid
@@ -114,8 +114,8 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
class LLaMAAgent(AbstractConversationAgent):
"""Base LLaMA conversation agent."""
class LocalLLMAgent(AbstractConversationAgent):
"""Base Local LLM conversation agent."""
hass: HomeAssistant
entry_id: str
@@ -225,6 +225,22 @@ class LLaMAAgent(AbstractConversationAgent):
return ConversationResult(
response=intent_response, conversation_id=conversation_id
)
llm_api: llm.API | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
try:
llm_api = llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
except HomeAssistantError as err:
_LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
@@ -235,7 +251,7 @@ class LLaMAAgent(AbstractConversationAgent):
if len(conversation) == 0 or refresh_system_prompt:
try:
message = self._generate_system_prompt(raw_prompt)
message = self._generate_system_prompt(raw_prompt, llm_api)
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
@@ -407,7 +423,7 @@ class LLaMAAgent(AbstractConversationAgent):
_LOGGER.debug(formatted_prompt)
return formatted_prompt
def _generate_system_prompt(self, prompt_template: str) -> str:
def _generate_system_prompt(self, prompt_template: str, llm_api: llm.API) -> str:
"""Generate the system prompt with current entity states"""
entities_to_expose, domains = self._async_get_exposed_entities()
@@ -487,21 +503,14 @@ class LLaMAAgent(AbstractConversationAgent):
formatted_states = "\n".join(device_states) + "\n"
service_dict = self.hass.services.async_services()
all_services = []
all_service_names = []
for domain in domains:
# scripts show up as individual services
if domain == "script":
all_services.extend(["script.reload()", "script.turn_on()", "script.turn_off()", "script.toggle()"])
continue
for name, service in service_dict.get(domain, {}).items():
args = flatten_vol_schema(service.schema)
args_to_expose = set(args).intersection(allowed_service_call_arguments)
all_services.append(f"{domain}.{name}({','.join(args_to_expose)})")
all_service_names.append(f"{domain}.{name}")
formatted_services = ", ".join(all_services)
if llm_api:
tools = [
f"{tool.name}({flatten_vol_schema(tool.parameters)}) - {tool.description}"
for tool in llm_api.async_get_tools()
]
formatted_services = llm_api.prompt_template + "\n" + "\n".join(tools)
else:
formatted_services = "No tools exposed."
render_variables = {
"devices": formatted_states,
@@ -509,15 +518,16 @@ class LLaMAAgent(AbstractConversationAgent):
}
if self.in_context_examples:
num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES))
render_variables["response_examples"] = "\n".join(icl_example_generator(num_examples, list(entities_to_expose.keys()), all_service_names))
# num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES))
# render_variables["response_examples"] = "\n".join(icl_example_generator(num_examples, list(entities_to_expose.keys()), all_service_names))
render_variables["response_examples"] = ""
return template.Template(prompt_template, self.hass).async_render(
render_variables,
parse_result=False,
)
class LocalLLaMAAgent(LLaMAAgent):
class LlamaCppAgent(LocalLLMAgent):
model_path: str
llm: LlamaType
grammar: Any
@@ -612,7 +622,7 @@ class LocalLLaMAAgent(LLaMAAgent):
self.grammar = None
def _update_options(self):
LLaMAAgent._update_options(self)
LocalLLMAgent._update_options(self)
model_reloaded = False
if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \
@@ -662,7 +672,7 @@ class LocalLLaMAAgent(LLaMAAgent):
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
"""Takes the super class function results and sorts the entities with the recently updated at the end"""
entities, domains = LLaMAAgent._async_get_exposed_entities(self)
entities, domains = LocalLLMAgent._async_get_exposed_entities(self)
# ignore sorting if prompt caching is disabled
if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
@@ -739,10 +749,20 @@ class LocalLLaMAAgent(LLaMAAgent):
self.cache_refresh_after_cooldown = True
return
llm_api: llm.API | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
try:
llm_api = llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
except HomeAssistantError:
_LOGGER.exception("Failed to get LLM API when caching prompt!")
return
try:
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
prompt = self._format_prompt([
{ "role": "system", "message": self._generate_system_prompt(raw_prompt)},
{ "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)},
{ "role": "user", "message": "" }
], include_generation_prompt=False)
@@ -839,7 +859,7 @@ class LocalLLaMAAgent(LLaMAAgent):
return result
class GenericOpenAIAPIAgent(LLaMAAgent):
class GenericOpenAIAPIAgent(LocalLLMAgent):
api_host: str
api_key: str
model_name: str
@@ -1046,7 +1066,7 @@ class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
return endpoint, request_params
class OllamaAPIAgent(LLaMAAgent):
class OllamaAPIAgent(LocalLLMAgent):
api_host: str
api_key: str
model_name: str

View File

@@ -1,4 +1,4 @@
"""Config flow for Local LLaMA Conversation integration."""
"""Config flow for Local LLM Conversation integration."""
from __future__ import annotations
import os
@@ -13,18 +13,20 @@ import voluptuous as vol
from homeassistant import config_entries
from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, UnitOfTime
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime
from homeassistant.data_entry_flow import (
AbortFlow,
FlowHandler,
FlowManager,
FlowResult,
)
from homeassistant.helpers import llm
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
NumberSelectorMode,
TemplateSelector,
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
@@ -279,7 +281,7 @@ class BaseLlamaConversationConfigFlow(FlowHandler, ABC):
""" Finish configuration """
class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Local LLaMA Conversation."""
"""Handle a config flow for Local LLM Conversation."""
VERSION = 1
install_wheel_task = None
@@ -584,7 +586,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en"))
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<persona>", persona)
schema = vol.Schema(local_llama_config_option_schema(selected_default_options, backend_type))
schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type))
if user_input:
self.options = user_input
@@ -626,7 +628,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
class OptionsFlow(config_entries.OptionsFlow):
"""Local LLaMA config flow options handler."""
"""Local LLM config flow options handler."""
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
"""Initialize options flow."""
@@ -656,9 +658,10 @@ class OptionsFlow(config_entries.OptionsFlow):
description_placeholders["filename"] = filename
if len(errors) == 0:
return self.async_create_entry(title="LLaMA Conversation", data=user_input)
return self.async_create_entry(title="Local LLM Conversation", data=user_input)
schema = local_llama_config_option_schema(
self.hass,
self.config_entry.options,
self.config_entry.data[CONF_BACKEND_TYPE],
)
@@ -682,12 +685,31 @@ def insert_after_key(input_dict: dict, key_name: str, other_dict: dict):
return result
def local_llama_config_option_schema(options: MappingProxyType[str, Any], backend_type: str) -> dict:
"""Return a schema for Local LLaMA completion options."""
def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyType[str, Any], backend_type: str) -> dict:
"""Return a schema for Local LLM completion options."""
if not options:
options = DEFAULT_OPTIONS
apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)
result = {
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Required(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},

View File

@@ -1,6 +1,6 @@
{
"domain": "llama_conversation",
"name": "LLaMA Conversation",
"name": "Local LLM Conversation",
"version": "0.2.17",
"codeowners": ["@acon96"],
"config_flow": true,

View File

@@ -35,7 +35,7 @@ The following link will open your Home Assistant installation and download the i
[![Open your Home Assistant instance and open a repository inside the Home Assistant Community Store.](https://my.home-assistant.io/badges/hacs_repository.svg)](https://my.home-assistant.io/redirect/hacs_repository/?category=Integration&repository=home-llm&owner=acon96)
After installation, A "LLaMA Conversation" device should show up in the `Settings > Devices and Services > [Devices]` tab now.
After installation, A "Local LLM Conversation" device should show up in the `Settings > Devices and Services > [Devices]` tab now.
## Path 1: Using the Home Model with the Llama.cpp Backend
### Overview
@@ -44,7 +44,7 @@ This setup path involves downloading a fine-tuned model from HuggingFace and int
### Step 1: Wheel Installation for llama-cpp-python
1. In Home Assistant: navigate to `Settings > Devices and Services`
2. Select the `+ Add Integration` button in the bottom right corner
3. Search for, and select `LLaMA Conversation`
3. Search for, and select `Local LLM Conversation`
4. With the `Llama.cpp (HuggingFace)` backend selected, click `Submit`
This should download and install `llama-cpp-python` from GitHub. If the installation fails for any reason, follow the manual installation instructions [here](./Backend%20Configuration.md#wheels).
@@ -82,7 +82,7 @@ In order to access the model from another machine, we need to run the Ollama API
1. In Home Assistant: navigate to `Settings > Devices and Services`
2. Select the `+ Add Integration` button in the bottom right corner
3. Search for, and select `LLaMA Conversation`
3. Search for, and select `Local LLM Conversation`
4. Select `Ollama API` from the dropdown and click `Submit`
5. Set up the connection to the API:
- **IP Address**: Fill out IP Address for the machine hosting Ollama

View File

@@ -1,6 +1,6 @@
{
"name": "LLaMA Conversation",
"homeassistant": "2023.10.0",
"name": "Local LLM Conversation",
"homeassistant": "2024.5.5",
"content_in_root": false,
"render_readme": true
}

View File

@@ -4,7 +4,7 @@ import pytest
import jinja2
from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY
from custom_components.llama_conversation.agent import LocalLLaMAAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent
from custom_components.llama_conversation.agent import LlamaCppAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@@ -140,10 +140,10 @@ def home_assistant_mock():
@pytest.fixture
def local_llama_agent_fixture(config_entry, home_assistant_mock):
with patch.object(LocalLLaMAAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(LocalLLaMAAgent, '_load_grammar') as load_grammar_mock, \
patch.object(LocalLLaMAAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(LocalLLaMAAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
with patch.object(LlamaCppAgent, '_load_icl_examples') as load_icl_examples_mock, \
patch.object(LlamaCppAgent, '_load_grammar') as load_grammar_mock, \
patch.object(LlamaCppAgent, 'entry', new_callable=PropertyMock) as entry_mock, \
patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \
patch('homeassistant.helpers.template.Template') as template_mock, \
patch('custom_components.llama_conversation.agent.importlib.import_module') as import_module_mock, \
patch('custom_components.llama_conversation.agent.install_llama_cpp_python') as install_llama_cpp_python_mock:
@@ -174,7 +174,7 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock):
"target_device": "light.kitchen_light",
}).encode()
agent_obj = LocalLLaMAAgent(
agent_obj = LlamaCppAgent(
home_assistant_mock,
config_entry
)
@@ -191,7 +191,7 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock):
async def test_local_llama_agent(local_llama_agent_fixture):
local_llama_agent: LocalLLaMAAgent
local_llama_agent: LlamaCppAgent
all_mocks: dict[str, MagicMock]
local_llama_agent, all_mocks = local_llama_agent_fixture