mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
fix compatibility with the HA chat log API
This commit is contained in:
@@ -16,7 +16,7 @@ import voluptuous as vol
|
||||
from typing import Literal, Any, Callable
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent, ConversationEntity
|
||||
from homeassistant.components import assist_pipeline, conversation as ha_conversation
|
||||
from homeassistant.components import assist_pipeline, conversation as conversation
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@@ -24,7 +24,7 @@ from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL,
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm, \
|
||||
area_registry as ar, device_registry as dr
|
||||
area_registry as ar, device_registry as dr, chat_session
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
||||
@@ -159,12 +159,42 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_e
|
||||
|
||||
return True
|
||||
|
||||
def _convert_content(
|
||||
chat_content: conversation.Content
|
||||
) -> dict[str, str]:
|
||||
"""Create tool response content."""
|
||||
role_name = None
|
||||
if isinstance(chat_content, conversation.ToolResultContent):
|
||||
role_name = "tool"
|
||||
elif isinstance(chat_content, conversation.AssistantContent):
|
||||
role_name = "assistant"
|
||||
elif isinstance(chat_content, conversation.UserContent):
|
||||
role_name = "user"
|
||||
elif isinstance(chat_content, conversation.SystemContent):
|
||||
role_name = "system"
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(chat_content)}")
|
||||
|
||||
return { "role": role_name, "message": chat_content.content }
|
||||
|
||||
def _convert_content_back(
|
||||
agent_id: str,
|
||||
message_history_entry: dict[str, str]
|
||||
) -> conversation.Content:
|
||||
if message_history_entry["role"] == "tool":
|
||||
return conversation.ToolResultContent(content=message_history_entry["message"])
|
||||
if message_history_entry["role"] == "assistant":
|
||||
return conversation.AssistantContent(agent_id=agent_id, content=message_history_entry["message"])
|
||||
if message_history_entry["role"] == "user":
|
||||
return conversation.UserContent(content=message_history_entry["message"])
|
||||
if message_history_entry["role"] == "system":
|
||||
return conversation.SystemContent(content=message_history_entry["message"])
|
||||
|
||||
class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
"""Base Local LLM conversation agent."""
|
||||
|
||||
hass: HomeAssistant
|
||||
entry_id: str
|
||||
history: dict[str, list[dict]]
|
||||
in_context_examples: list[dict]
|
||||
|
||||
_attr_has_entity_name = True
|
||||
@@ -176,7 +206,6 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
self.hass = hass
|
||||
self.entry_id = entry.entry_id
|
||||
self.history = {}
|
||||
|
||||
self.backend_type = entry.data.get(
|
||||
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
|
||||
@@ -184,7 +213,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
ha_conversation.ConversationEntityFeature.CONTROL
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
self.in_context_examples = None
|
||||
@@ -197,11 +226,11 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
assist_pipeline.async_migrate_engine(
|
||||
self.hass, "conversation", self.entry.entry_id, self.entity_id
|
||||
)
|
||||
ha_conversation.async_set_agent(self.hass, self.entry, self)
|
||||
conversation.async_set_agent(self.hass, self.entry, self)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""When entity will be removed from Home Assistant."""
|
||||
ha_conversation.async_unset_agent(self.hass, self.entry)
|
||||
conversation.async_unset_agent(self.hass, self.entry)
|
||||
await super().async_will_remove_from_hass()
|
||||
|
||||
def _load_icl_examples(self, filename: str):
|
||||
@@ -227,7 +256,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
def _update_options(self):
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
ha_conversation.ConversationEntityFeature.CONTROL
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
@@ -278,6 +307,19 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
self, user_input: ConversationInput
|
||||
) -> ConversationResult:
|
||||
"""Process a sentence."""
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||
):
|
||||
return await self._async_handle_message(user_input, chat_log)
|
||||
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> conversation.ConversationResult:
|
||||
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
|
||||
@@ -297,8 +339,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, there was a problem compiling the service call regex: {err}",
|
||||
)
|
||||
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
llm_api: llm.APIInstance | None = None
|
||||
@@ -312,7 +355,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=ha_conversation.DOMAIN,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
)
|
||||
@@ -327,14 +370,10 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
if user_input.conversation_id in self.history:
|
||||
conversation_id = user_input.conversation_id
|
||||
conversation = self.history[conversation_id] if remember_conversation else [self.history[conversation_id][0]]
|
||||
else:
|
||||
conversation_id = ulid.ulid()
|
||||
conversation = []
|
||||
message_history = [ _convert_content(content) for content in chat_log.content ]
|
||||
|
||||
if len(conversation) == 0 or refresh_system_prompt:
|
||||
# re-generate prompt if necessary
|
||||
if len(message_history) == 0 or refresh_system_prompt:
|
||||
try:
|
||||
message = self._generate_system_prompt(raw_prompt, llm_api)
|
||||
except TemplateError as err:
|
||||
@@ -345,24 +384,20 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"Sorry, I had a problem with my template: {err}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
system_prompt = { "role": "system", "message": message }
|
||||
|
||||
if len(conversation) == 0:
|
||||
conversation.append(system_prompt)
|
||||
if not remember_conversation:
|
||||
self.history[conversation_id] = conversation
|
||||
if len(message_history) == 0:
|
||||
message_history.append(system_prompt)
|
||||
else:
|
||||
conversation[0] = system_prompt
|
||||
|
||||
conversation.append({"role": "user", "message": user_input.text})
|
||||
message_history[0] = system_prompt
|
||||
|
||||
# generate a response
|
||||
try:
|
||||
_LOGGER.debug(conversation)
|
||||
response = await self._async_generate(conversation)
|
||||
_LOGGER.debug(message_history)
|
||||
response = await self._async_generate(message_history)
|
||||
_LOGGER.debug(response)
|
||||
|
||||
except Exception as err:
|
||||
@@ -374,7 +409,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"Sorry, there was a problem talking to the backend: {repr(err)}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# remove end of text token if it was returned
|
||||
@@ -383,19 +418,19 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
# remove think blocks
|
||||
response = re.sub(rf"^.*?{template_desc["chain_of_thought"]["suffix"]}", "", response, flags=re.DOTALL)
|
||||
|
||||
conversation.append({"role": "assistant", "message": response})
|
||||
message_history.append({"role": "assistant", "message": response})
|
||||
if remember_conversation:
|
||||
if remember_num_interactions and len(conversation) > (remember_num_interactions * 2) + 1:
|
||||
if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1:
|
||||
for i in range(0,2):
|
||||
conversation.pop(1)
|
||||
self.history[conversation_id] = conversation
|
||||
message_history.pop(1)
|
||||
# chat_log.content = [_convert_content_back(user_input.agent_id, message_history_entry) for message_history_entry in message_history ]
|
||||
|
||||
if llm_api is None:
|
||||
# return the output without messing with it if there is no API exposed to the model
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response.strip())
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
tool_response = None
|
||||
@@ -436,7 +471,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
_LOGGER.info(f"calling tool: {block}")
|
||||
@@ -480,20 +515,20 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"I'm sorry! I encountered an error calling the tool. See the logs for more info.",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# handle models that generate a function call and wait for the result before providing a response
|
||||
if self.entry.options.get(CONF_TOOL_MULTI_TURN_CHAT, DEFAULT_TOOL_MULTI_TURN_CHAT) and tool_response is not None:
|
||||
try:
|
||||
conversation.append({"role": "tool", "message": json.dumps(tool_response)})
|
||||
message_history.append({"role": "tool", "message": json.dumps(tool_response)})
|
||||
except:
|
||||
conversation.append({"role": "tool", "message": "No tools were used in this response."})
|
||||
message_history.append({"role": "tool", "message": "No tools were used in this response."})
|
||||
|
||||
# generate a response based on the tool result
|
||||
try:
|
||||
_LOGGER.debug(conversation)
|
||||
to_say = await self._async_generate(conversation)
|
||||
_LOGGER.debug(message_history)
|
||||
to_say = await self._async_generate(message_history)
|
||||
_LOGGER.debug(to_say)
|
||||
|
||||
except Exception as err:
|
||||
@@ -505,17 +540,17 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"Sorry, there was a problem talking to the backend: {repr(err)}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
conversation.append({"role": "assistant", "message": response})
|
||||
conversation.append({"role": "assistant", "message": to_say})
|
||||
message_history.append({"role": "assistant", "message": response})
|
||||
message_history.append({"role": "assistant", "message": to_say})
|
||||
|
||||
# generate intent response to Home Assistant
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(to_say.strip())
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
||||
|
||||
Reference in New Issue
Block a user