diff --git a/TODO.md b/TODO.md index 9ce830b..1d6c9fd 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,13 @@ # TODO -- [x] detection/mitigation of too many entities being exposed & blowing out the context length +- [ ] support new LLM APIs + - rewrite how services are called + - handle no API selected + - rewrite prompts + service block formats + - update dataset so new models will work with the API +- [ ] make ICL examples into conversation turns +- [ ] translate ICL examples + make better ones - [ ] areas/room support +- [x] detection/mitigation of too many entities being exposed & blowing out the context length - [ ] figure out DPO to improve response quality - [ ] train the model to respond to house events - present the model with an event + a "prompt" from the user of what you want it to do (i.e. turn on the lights when I get home = the model turns on lights when your entity presence triggers as being home) diff --git a/custom_components/llama_conversation/__init__.py b/custom_components/llama_conversation/__init__.py index d00c9b1..0d35ff1 100644 --- a/custom_components/llama_conversation/__init__.py +++ b/custom_components/llama_conversation/__init__.py @@ -1,4 +1,4 @@ -"""The Local LLaMA Conversation integration.""" +"""The Local LLM Conversation integration.""" from __future__ import annotations import logging @@ -9,8 +9,8 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv from .agent import ( - LLaMAAgent, - LocalLLaMAAgent, + LocalLLMAgent, + LlamaCppAgent, GenericOpenAIAPIAgent, TextGenerationWebuiAgent, LlamaCppPythonAPIAgent, @@ -38,19 +38,19 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry): hass.data[DOMAIN][entry.entry_id] = entry # call update handler - agent: LLaMAAgent = await ha_conversation._get_agent_manager(hass).async_get_agent(entry.entry_id) + agent: LocalLLMAgent = await ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id) agent._update_options() return True async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: - """Set up Local LLaMA Conversation from a config entry.""" + """Set up Local LLM Conversation from a config entry.""" def create_agent(backend_type): agent_cls = None if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]: - agent_cls = LocalLLaMAAgent + agent_cls = LlamaCppAgent elif backend_type == BACKEND_TYPE_GENERIC_OPENAI: agent_cls = GenericOpenAIAPIAgent elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI: @@ -78,7 +78,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: - """Unload Local LLaMA.""" + """Unload Local LLM.""" hass.data[DOMAIN].pop(entry.entry_id) ha_conversation.async_unset_agent(hass, entry) return True diff --git a/custom_components/llama_conversation/agent.py b/custom_components/llama_conversation/agent.py index 7df88e8..6d9d156 100644 --- a/custom_components/llama_conversation/agent.py +++ b/custom_components/llama_conversation/agent.py @@ -18,10 +18,10 @@ from homeassistant.components.conversation import ConversationInput, Conversatio from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.config_entries import ConfigEntry -from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL +from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL, CONF_LLM_HASS_API from homeassistant.core import HomeAssistant, callback -from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError -from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er +from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError +from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm from homeassistant.helpers.event import async_track_state_change, async_call_later from homeassistant.util import ulid @@ -114,8 +114,8 @@ _LOGGER = logging.getLogger(__name__) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) -class LLaMAAgent(AbstractConversationAgent): - """Base LLaMA conversation agent.""" +class LocalLLMAgent(AbstractConversationAgent): + """Base Local LLM conversation agent.""" hass: HomeAssistant entry_id: str @@ -225,6 +225,22 @@ class LLaMAAgent(AbstractConversationAgent): return ConversationResult( response=intent_response, conversation_id=conversation_id ) + + llm_api: llm.API | None = None + if self.entry.options.get(CONF_LLM_HASS_API): + try: + llm_api = llm.async_get_api( + self.hass, self.entry.options[CONF_LLM_HASS_API] + ) + except HomeAssistantError as err: + _LOGGER.error("Error getting LLM API: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Error preparing LLM API: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=user_input.conversation_id + ) if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id @@ -235,7 +251,7 @@ class LLaMAAgent(AbstractConversationAgent): if len(conversation) == 0 or refresh_system_prompt: try: - message = self._generate_system_prompt(raw_prompt) + message = self._generate_system_prompt(raw_prompt, llm_api) except TemplateError as err: _LOGGER.error("Error rendering prompt: %s", err) intent_response = intent.IntentResponse(language=user_input.language) @@ -407,7 +423,7 @@ class LLaMAAgent(AbstractConversationAgent): _LOGGER.debug(formatted_prompt) return formatted_prompt - def _generate_system_prompt(self, prompt_template: str) -> str: + def _generate_system_prompt(self, prompt_template: str, llm_api: llm.API) -> str: """Generate the system prompt with current entity states""" entities_to_expose, domains = self._async_get_exposed_entities() @@ -487,21 +503,14 @@ class LLaMAAgent(AbstractConversationAgent): formatted_states = "\n".join(device_states) + "\n" - service_dict = self.hass.services.async_services() - all_services = [] - all_service_names = [] - for domain in domains: - # scripts show up as individual services - if domain == "script": - all_services.extend(["script.reload()", "script.turn_on()", "script.turn_off()", "script.toggle()"]) - continue - - for name, service in service_dict.get(domain, {}).items(): - args = flatten_vol_schema(service.schema) - args_to_expose = set(args).intersection(allowed_service_call_arguments) - all_services.append(f"{domain}.{name}({','.join(args_to_expose)})") - all_service_names.append(f"{domain}.{name}") - formatted_services = ", ".join(all_services) + if llm_api: + tools = [ + f"{tool.name}({flatten_vol_schema(tool.parameters)}) - {tool.description}" + for tool in llm_api.async_get_tools() + ] + formatted_services = llm_api.prompt_template + "\n" + "\n".join(tools) + else: + formatted_services = "No tools exposed." render_variables = { "devices": formatted_states, @@ -509,15 +518,16 @@ class LLaMAAgent(AbstractConversationAgent): } if self.in_context_examples: - num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES)) - render_variables["response_examples"] = "\n".join(icl_example_generator(num_examples, list(entities_to_expose.keys()), all_service_names)) + # num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES)) + # render_variables["response_examples"] = "\n".join(icl_example_generator(num_examples, list(entities_to_expose.keys()), all_service_names)) + render_variables["response_examples"] = "" return template.Template(prompt_template, self.hass).async_render( render_variables, parse_result=False, ) -class LocalLLaMAAgent(LLaMAAgent): +class LlamaCppAgent(LocalLLMAgent): model_path: str llm: LlamaType grammar: Any @@ -612,7 +622,7 @@ class LocalLLaMAAgent(LLaMAAgent): self.grammar = None def _update_options(self): - LLaMAAgent._update_options(self) + LocalLLMAgent._update_options(self) model_reloaded = False if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \ @@ -662,7 +672,7 @@ class LocalLLaMAAgent(LLaMAAgent): def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]: """Takes the super class function results and sorts the entities with the recently updated at the end""" - entities, domains = LLaMAAgent._async_get_exposed_entities(self) + entities, domains = LocalLLMAgent._async_get_exposed_entities(self) # ignore sorting if prompt caching is disabled if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED): @@ -739,10 +749,20 @@ class LocalLLaMAAgent(LLaMAAgent): self.cache_refresh_after_cooldown = True return + llm_api: llm.API | None = None + if self.entry.options.get(CONF_LLM_HASS_API): + try: + llm_api = llm.async_get_api( + self.hass, self.entry.options[CONF_LLM_HASS_API] + ) + except HomeAssistantError: + _LOGGER.exception("Failed to get LLM API when caching prompt!") + return + try: raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) prompt = self._format_prompt([ - { "role": "system", "message": self._generate_system_prompt(raw_prompt)}, + { "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)}, { "role": "user", "message": "" } ], include_generation_prompt=False) @@ -839,7 +859,7 @@ class LocalLLaMAAgent(LLaMAAgent): return result -class GenericOpenAIAPIAgent(LLaMAAgent): +class GenericOpenAIAPIAgent(LocalLLMAgent): api_host: str api_key: str model_name: str @@ -1046,7 +1066,7 @@ class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent): return endpoint, request_params -class OllamaAPIAgent(LLaMAAgent): +class OllamaAPIAgent(LocalLLMAgent): api_host: str api_key: str model_name: str diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index ed7e17f..5cca95c 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -1,4 +1,4 @@ -"""Config flow for Local LLaMA Conversation integration.""" +"""Config flow for Local LLM Conversation integration.""" from __future__ import annotations import os @@ -13,18 +13,20 @@ import voluptuous as vol from homeassistant import config_entries from homeassistant.core import HomeAssistant -from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, UnitOfTime +from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime from homeassistant.data_entry_flow import ( AbortFlow, FlowHandler, FlowManager, FlowResult, ) +from homeassistant.helpers import llm from homeassistant.helpers.selector import ( NumberSelector, NumberSelectorConfig, NumberSelectorMode, TemplateSelector, + SelectOptionDict, SelectSelector, SelectSelectorConfig, SelectSelectorMode, @@ -279,7 +281,7 @@ class BaseLlamaConversationConfigFlow(FlowHandler, ABC): """ Finish configuration """ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, domain=DOMAIN): - """Handle a config flow for Local LLaMA Conversation.""" + """Handle a config flow for Local LLM Conversation.""" VERSION = 1 install_wheel_task = None @@ -584,7 +586,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en")) selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", persona) - schema = vol.Schema(local_llama_config_option_schema(selected_default_options, backend_type)) + schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type)) if user_input: self.options = user_input @@ -626,7 +628,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom class OptionsFlow(config_entries.OptionsFlow): - """Local LLaMA config flow options handler.""" + """Local LLM config flow options handler.""" def __init__(self, config_entry: config_entries.ConfigEntry) -> None: """Initialize options flow.""" @@ -656,9 +658,10 @@ class OptionsFlow(config_entries.OptionsFlow): description_placeholders["filename"] = filename if len(errors) == 0: - return self.async_create_entry(title="LLaMA Conversation", data=user_input) + return self.async_create_entry(title="Local LLM Conversation", data=user_input) schema = local_llama_config_option_schema( + self.hass, self.config_entry.options, self.config_entry.data[CONF_BACKEND_TYPE], ) @@ -682,12 +685,31 @@ def insert_after_key(input_dict: dict, key_name: str, other_dict: dict): return result -def local_llama_config_option_schema(options: MappingProxyType[str, Any], backend_type: str) -> dict: - """Return a schema for Local LLaMA completion options.""" +def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyType[str, Any], backend_type: str) -> dict: + """Return a schema for Local LLM completion options.""" if not options: options = DEFAULT_OPTIONS + apis: list[SelectOptionDict] = [ + SelectOptionDict( + label="No control", + value="none", + ) + ] + apis.extend( + SelectOptionDict( + label=api.name, + value=api.id, + ) + for api in llm.async_get_apis(hass) + ) + result = { + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + default="none", + ): SelectSelector(SelectSelectorConfig(options=apis)), vol.Required( CONF_PROMPT, description={"suggested_value": options.get(CONF_PROMPT)}, diff --git a/custom_components/llama_conversation/manifest.json b/custom_components/llama_conversation/manifest.json index 9f09f63..e90a943 100644 --- a/custom_components/llama_conversation/manifest.json +++ b/custom_components/llama_conversation/manifest.json @@ -1,6 +1,6 @@ { "domain": "llama_conversation", - "name": "LLaMA Conversation", + "name": "Local LLM Conversation", "version": "0.2.17", "codeowners": ["@acon96"], "config_flow": true, diff --git a/docs/Setup.md b/docs/Setup.md index 95b449c..f888212 100644 --- a/docs/Setup.md +++ b/docs/Setup.md @@ -35,7 +35,7 @@ The following link will open your Home Assistant installation and download the i [![Open your Home Assistant instance and open a repository inside the Home Assistant Community Store.](https://my.home-assistant.io/badges/hacs_repository.svg)](https://my.home-assistant.io/redirect/hacs_repository/?category=Integration&repository=home-llm&owner=acon96) -After installation, A "LLaMA Conversation" device should show up in the `Settings > Devices and Services > [Devices]` tab now. +After installation, A "Local LLM Conversation" device should show up in the `Settings > Devices and Services > [Devices]` tab now. ## Path 1: Using the Home Model with the Llama.cpp Backend ### Overview @@ -44,7 +44,7 @@ This setup path involves downloading a fine-tuned model from HuggingFace and int ### Step 1: Wheel Installation for llama-cpp-python 1. In Home Assistant: navigate to `Settings > Devices and Services` 2. Select the `+ Add Integration` button in the bottom right corner -3. Search for, and select `LLaMA Conversation` +3. Search for, and select `Local LLM Conversation` 4. With the `Llama.cpp (HuggingFace)` backend selected, click `Submit` This should download and install `llama-cpp-python` from GitHub. If the installation fails for any reason, follow the manual installation instructions [here](./Backend%20Configuration.md#wheels). @@ -82,7 +82,7 @@ In order to access the model from another machine, we need to run the Ollama API 1. In Home Assistant: navigate to `Settings > Devices and Services` 2. Select the `+ Add Integration` button in the bottom right corner -3. Search for, and select `LLaMA Conversation` +3. Search for, and select `Local LLM Conversation` 4. Select `Ollama API` from the dropdown and click `Submit` 5. Set up the connection to the API: - **IP Address**: Fill out IP Address for the machine hosting Ollama diff --git a/hacs.json b/hacs.json index b4513aa..d0b9dbc 100644 --- a/hacs.json +++ b/hacs.json @@ -1,6 +1,6 @@ { - "name": "LLaMA Conversation", - "homeassistant": "2023.10.0", + "name": "Local LLM Conversation", + "homeassistant": "2024.5.5", "content_in_root": false, "render_readme": true } diff --git a/tests/llama_conversation/test_agent.py b/tests/llama_conversation/test_agent.py index 49df0ce..05b6722 100644 --- a/tests/llama_conversation/test_agent.py +++ b/tests/llama_conversation/test_agent.py @@ -4,7 +4,7 @@ import pytest import jinja2 from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY -from custom_components.llama_conversation.agent import LocalLLaMAAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent +from custom_components.llama_conversation.agent import LlamaCppAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent from custom_components.llama_conversation.const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, @@ -140,10 +140,10 @@ def home_assistant_mock(): @pytest.fixture def local_llama_agent_fixture(config_entry, home_assistant_mock): - with patch.object(LocalLLaMAAgent, '_load_icl_examples') as load_icl_examples_mock, \ - patch.object(LocalLLaMAAgent, '_load_grammar') as load_grammar_mock, \ - patch.object(LocalLLaMAAgent, 'entry', new_callable=PropertyMock) as entry_mock, \ - patch.object(LocalLLaMAAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \ + with patch.object(LlamaCppAgent, '_load_icl_examples') as load_icl_examples_mock, \ + patch.object(LlamaCppAgent, '_load_grammar') as load_grammar_mock, \ + patch.object(LlamaCppAgent, 'entry', new_callable=PropertyMock) as entry_mock, \ + patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \ patch('homeassistant.helpers.template.Template') as template_mock, \ patch('custom_components.llama_conversation.agent.importlib.import_module') as import_module_mock, \ patch('custom_components.llama_conversation.agent.install_llama_cpp_python') as install_llama_cpp_python_mock: @@ -174,7 +174,7 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock): "target_device": "light.kitchen_light", }).encode() - agent_obj = LocalLLaMAAgent( + agent_obj = LlamaCppAgent( home_assistant_mock, config_entry ) @@ -191,7 +191,7 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock): async def test_local_llama_agent(local_llama_agent_fixture): - local_llama_agent: LocalLLaMAAgent + local_llama_agent: LlamaCppAgent all_mocks: dict[str, MagicMock] local_llama_agent, all_mocks = local_llama_agent_fixture