fix compatibility with the HA chat log API

This commit is contained in:
Alex O'Connell
2025-04-13 18:00:26 -04:00
parent 811c1ea23b
commit 3b01f8ace2

View File

@@ -16,7 +16,7 @@ import voluptuous as vol
from typing import Literal, Any, Callable
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent, ConversationEntity
from homeassistant.components import assist_pipeline, conversation as ha_conversation
from homeassistant.components import assist_pipeline, conversation as conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
@@ -24,7 +24,7 @@ from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL,
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm, \
area_registry as ar, device_registry as dr
area_registry as ar, device_registry as dr, chat_session
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.event import async_track_state_change, async_call_later
@@ -159,12 +159,42 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_e
return True
def _convert_content(
chat_content: conversation.Content
) -> dict[str, str]:
"""Create tool response content."""
role_name = None
if isinstance(chat_content, conversation.ToolResultContent):
role_name = "tool"
elif isinstance(chat_content, conversation.AssistantContent):
role_name = "assistant"
elif isinstance(chat_content, conversation.UserContent):
role_name = "user"
elif isinstance(chat_content, conversation.SystemContent):
role_name = "system"
else:
raise ValueError(f"Unexpected content type: {type(chat_content)}")
return { "role": role_name, "message": chat_content.content }
def _convert_content_back(
agent_id: str,
message_history_entry: dict[str, str]
) -> conversation.Content:
if message_history_entry["role"] == "tool":
return conversation.ToolResultContent(content=message_history_entry["message"])
if message_history_entry["role"] == "assistant":
return conversation.AssistantContent(agent_id=agent_id, content=message_history_entry["message"])
if message_history_entry["role"] == "user":
return conversation.UserContent(content=message_history_entry["message"])
if message_history_entry["role"] == "system":
return conversation.SystemContent(content=message_history_entry["message"])
class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
"""Base Local LLM conversation agent."""
hass: HomeAssistant
entry_id: str
history: dict[str, list[dict]]
in_context_examples: list[dict]
_attr_has_entity_name = True
@@ -176,7 +206,6 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
self.hass = hass
self.entry_id = entry.entry_id
self.history = {}
self.backend_type = entry.data.get(
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
@@ -184,7 +213,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
if self.entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
ha_conversation.ConversationEntityFeature.CONTROL
conversation.ConversationEntityFeature.CONTROL
)
self.in_context_examples = None
@@ -197,11 +226,11 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
assist_pipeline.async_migrate_engine(
self.hass, "conversation", self.entry.entry_id, self.entity_id
)
ha_conversation.async_set_agent(self.hass, self.entry, self)
conversation.async_set_agent(self.hass, self.entry, self)
async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant."""
ha_conversation.async_unset_agent(self.hass, self.entry)
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
def _load_icl_examples(self, filename: str):
@@ -227,7 +256,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
def _update_options(self):
if self.entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
ha_conversation.ConversationEntityFeature.CONTROL
conversation.ConversationEntityFeature.CONTROL
)
if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
@@ -278,6 +307,19 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
self, user_input: ConversationInput
) -> ConversationResult:
"""Process a sentence."""
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
chat_log: conversation.ChatLog,
) -> conversation.ConversationResult:
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
@@ -297,8 +339,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, there was a problem compiling the service call regex: {err}",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
llm_api: llm.APIInstance | None = None
@@ -312,7 +355,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=ha_conversation.DOMAIN,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
)
@@ -327,14 +370,10 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
response=intent_response, conversation_id=user_input.conversation_id
)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
conversation = self.history[conversation_id] if remember_conversation else [self.history[conversation_id][0]]
else:
conversation_id = ulid.ulid()
conversation = []
message_history = [ _convert_content(content) for content in chat_log.content ]
if len(conversation) == 0 or refresh_system_prompt:
# re-generate prompt if necessary
if len(message_history) == 0 or refresh_system_prompt:
try:
message = self._generate_system_prompt(raw_prompt, llm_api)
except TemplateError as err:
@@ -345,24 +384,20 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"Sorry, I had a problem with my template: {err}",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
system_prompt = { "role": "system", "message": message }
if len(conversation) == 0:
conversation.append(system_prompt)
if not remember_conversation:
self.history[conversation_id] = conversation
if len(message_history) == 0:
message_history.append(system_prompt)
else:
conversation[0] = system_prompt
conversation.append({"role": "user", "message": user_input.text})
message_history[0] = system_prompt
# generate a response
try:
_LOGGER.debug(conversation)
response = await self._async_generate(conversation)
_LOGGER.debug(message_history)
response = await self._async_generate(message_history)
_LOGGER.debug(response)
except Exception as err:
@@ -374,7 +409,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"Sorry, there was a problem talking to the backend: {repr(err)}",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
# remove end of text token if it was returned
@@ -383,19 +418,19 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
# remove think blocks
response = re.sub(rf"^.*?{template_desc["chain_of_thought"]["suffix"]}", "", response, flags=re.DOTALL)
conversation.append({"role": "assistant", "message": response})
message_history.append({"role": "assistant", "message": response})
if remember_conversation:
if remember_num_interactions and len(conversation) > (remember_num_interactions * 2) + 1:
if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1:
for i in range(0,2):
conversation.pop(1)
self.history[conversation_id] = conversation
message_history.pop(1)
# chat_log.content = [_convert_content_back(user_input.agent_id, message_history_entry) for message_history_entry in message_history ]
if llm_api is None:
# return the output without messing with it if there is no API exposed to the model
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.strip())
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
tool_response = None
@@ -436,7 +471,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
_LOGGER.info(f"calling tool: {block}")
@@ -480,20 +515,20 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"I'm sorry! I encountered an error calling the tool. See the logs for more info.",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
# handle models that generate a function call and wait for the result before providing a response
if self.entry.options.get(CONF_TOOL_MULTI_TURN_CHAT, DEFAULT_TOOL_MULTI_TURN_CHAT) and tool_response is not None:
try:
conversation.append({"role": "tool", "message": json.dumps(tool_response)})
message_history.append({"role": "tool", "message": json.dumps(tool_response)})
except:
conversation.append({"role": "tool", "message": "No tools were used in this response."})
message_history.append({"role": "tool", "message": "No tools were used in this response."})
# generate a response based on the tool result
try:
_LOGGER.debug(conversation)
to_say = await self._async_generate(conversation)
_LOGGER.debug(message_history)
to_say = await self._async_generate(message_history)
_LOGGER.debug(to_say)
except Exception as err:
@@ -505,17 +540,17 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"Sorry, there was a problem talking to the backend: {repr(err)}",
)
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
conversation.append({"role": "assistant", "message": response})
conversation.append({"role": "assistant", "message": to_say})
message_history.append({"role": "assistant", "message": response})
message_history.append({"role": "assistant", "message": to_say})
# generate intent response to Home Assistant
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(to_say.strip())
return ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]: