remove intermediate dict format and pass around home assistant model object

This commit is contained in:
Alex O'Connell
2025-09-15 22:10:09 -04:00
committed by Alex O'Connell
parent 53052af641
commit da0a0e4dbc
6 changed files with 173 additions and 196 deletions

View File

@@ -1,18 +1,20 @@
"""Defines the OpenAI API compatible agents"""
from __future__ import annotations
import json
import aiohttp
import asyncio
import datetime
import logging
from typing import List, Dict, Tuple, Optional, Any
from typing import List, Dict, Tuple, AsyncGenerator, Any
from homeassistant.components import conversation as conversation
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers import llm
from custom_components.llama_conversation.utils import format_url
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@@ -30,7 +32,6 @@ from custom_components.llama_conversation.const import (
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_GENERIC_OPENAI_PATH,
)
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
@@ -50,10 +51,10 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
path=""
)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY, "")
self.model_name = entry.data.get(CONF_CHAT_MODEL, "")
async def _async_generate_with_parameters(self, endpoint: str, additional_params: dict) -> TextGenerationResult:
async def _async_generate_with_parameters(self, endpoint: str, stream: bool, additional_params: dict):
"""Generate a response using the OpenAI-compatible API"""
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
@@ -63,6 +64,7 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
request_params = {
"model": self.model_name,
"stream": stream,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
@@ -84,18 +86,20 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
headers=headers
) as response:
response.raise_for_status()
result = await response.json()
if stream:
async for line_bytes in response.content:
chunk = line_bytes.decode("utf-8").strip()
yield self._extract_response(json.loads(chunk))
else:
response_json = await response.json()
yield self._extract_response(response_json)
except asyncio.TimeoutError:
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
yield TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
except aiohttp.ClientError as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {response}")
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
_LOGGER.debug(result)
return self._extract_response(result)
yield TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
def _extract_response(self, response_json: dict) -> TextGenerationResult:
raise NotImplementedError("Subclasses must implement _extract_response()")
@@ -103,24 +107,6 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
endpoint = f"/{api_base_path}/chat/completions"
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
return endpoint, request_params
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
endpoint = f"/{api_base_path}/completions"
request_params["prompt"] = self._format_prompt(conversation)
return endpoint, request_params
def _extract_response(self, response_json: dict) -> TextGenerationResult:
choice = response_json["choices"][0]
if response_json["object"] == "chat.completion":
@@ -140,26 +126,26 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
return TextGenerationResult(
response=response_text,
stop_reason=choice["finish_reason"],
response_streamed=streamed
response_streamed=streamed,
)
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> AsyncGenerator[TextGenerationResult, None]:
request_params = {}
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
if use_chat_api:
endpoint, additional_params = self._chat_completion_params(conversation)
else:
endpoint, additional_params = self._completion_params(conversation)
endpoint = f"/{api_base_path}/chat/completions"
request_params["messages"] = get_oai_formatted_messages(conversation)
result = await self._async_generate_with_parameters(endpoint, additional_params)
if llm_api:
request_params["tools"] = get_oai_formatted_tools(llm_api)
return result
return self._async_generate_with_parameters(endpoint, True, request_params)
class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
"""Implements the OpenAPI-compatible Responses API backend."""
_last_response_id: str | None = None
_last_response_id_time: datetime.datetime = None
_last_response_id_time: datetime.datetime | None = None
def _responses_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
request_params = {}
@@ -254,8 +240,8 @@ class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
return to_return
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
async def _async_generate(self, conv: List[Dict[str, str]], user_input: conversation.ConversationInput, chat_log: conversation.chat_log.ChatLog):
"""Generate a response using the OpenAI-compatible Responses API"""
endpoint, additional_params = self._responses_params(conversation)
return await self._async_generate_with_parameters(endpoint, additional_params)
endpoint, additional_params = self._responses_params(conv)
return self._async_generate_with_parameters(endpoint, additional_params)

View File

@@ -6,7 +6,7 @@ import logging
import os
import threading
import time
from typing import Any, Callable, Generator, Optional, List, Dict, AsyncIterable
from typing import Any, Callable, Generator, Optional, List, Dict, AsyncIterable, AsyncGenerator, Sequence
from homeassistant.components import conversation as conversation
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
@@ -18,7 +18,7 @@ from homeassistant.exceptions import ConfigEntryError, HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.event import async_track_state_change, async_call_later
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation, get_oai_formatted_messages, get_oai_formatted_tools
from custom_components.llama_conversation.const import (
CONF_MAX_TOKENS,
CONF_PROMPT,
@@ -65,6 +65,7 @@ else:
_LOGGER = logging.getLogger(__name__)
class LlamaCppAgent(LocalLLMAgent):
model_path: str
llm: LlamaType
@@ -351,10 +352,20 @@ class LlamaCppAgent(LocalLLMAgent):
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
async def _async_generate_completion(self, chat_completion) -> AsyncGenerator[TextGenerationResult, None]:
for token in chat_completion:
if isinstance(token, str):
yield TextGenerationResult(
response=token,
response_streamed=True
)
else:
token["choices"][0]["delta"].get("tool_calls")
def _generate_stream(self, conversation: List[Dict[str, str]], user_input: conversation.ConversationInput, chat_log: conversation.ChatLog) -> TextGenerationResult:
prompt = self._format_prompt(conversation)
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> AsyncGenerator[TextGenerationResult, None]:
"""Async generator that yields TextGenerationResult as tokens are produced."""
# prompt = self._format_prompt(conversation)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
@@ -364,49 +375,37 @@ class LlamaCppAgent(LocalLLMAgent):
_LOGGER.debug(f"Options: {self.entry.options}")
with self.model_lock:
input_tokens = self.llm.tokenize(
prompt.encode(), add_bos=False
)
# with self.model_lock:
# # FIXME: use the high level API so we can use the built-in prompt formatting
# input_tokens = self.llm.tokenize(
# prompt.encode(), add_bos=False
# )
context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
if len(input_tokens) >= context_len:
num_entities = len(self._async_get_exposed_entities()[0])
context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
self._warn_context_size()
raise Exception(f"The model failed to produce a result because too many devices are exposed ({num_entities} devices) for the context size ({context_size} tokens)!")
if len(input_tokens) + max_tokens >= context_len:
self._warn_context_size()
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
# if len(input_tokens) >= context_len:
# num_entities = len(self._async_get_exposed_entities()[0])
# context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
# self._warn_context_size()
# raise Exception(f"The model failed to produce a result because too many devices are exposed ({num_entities} devices) for the context size ({context_size} tokens)!")
# if len(input_tokens) + max_tokens >= context_len:
# self._warn_context_size()
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
output_tokens = self.llm.generate(
input_tokens,
temp=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
grammar=self.grammar
)
# _LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
result_tokens = []
async def async_iterator():
num_tokens = 0
for token in output_tokens:
result_tokens.append(token)
yield TextGenerationResult(response=self.llm.detokenize([token]).decode(), response_streamed=True)
messages = get_oai_formatted_messages(conversation)
tools = None
if llm_api:
tools = get_oai_formatted_tools(llm_api)
if token == self.llm.token_eos():
break
return self._async_generate_completion(self.llm.create_chat_completion(
messages,
tools=tools,
temperature=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
max_tokens=max_tokens,
grammar=self.grammar
))
if len(result_tokens) >= max_tokens:
break
num_tokens += 1
self._transform_result_stream(async_iterator(), user_input=user_input, chat_log=chat_log)
response = TextGenerationResult(
response=self.llm.detokenize(result_tokens).decode(),
response_streamed=True,
)

View File

@@ -1,5 +1,6 @@
"""Defines the ollama compatible agent"""
from __future__ import annotations
from warnings import deprecated
import aiohttp
import asyncio
@@ -43,6 +44,7 @@ from custom_components.llama_conversation.conversation import LocalLLMAgent, Tex
_LOGGER = logging.getLogger(__name__)
@deprecated("Use the built-in Ollama integration instead")
class OllamaAPIAgent(LocalLLMAgent):
api_host: str
api_key: Optional[str]

View File

@@ -107,37 +107,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_e
async_add_entities([entry.runtime_data])
return True
def _convert_content(
chat_content: conversation.Content
) -> dict[str, str]:
"""Create tool response content."""
role_name = None
if isinstance(chat_content, conversation.ToolResultContent):
role_name = "tool"
elif isinstance(chat_content, conversation.AssistantContent):
role_name = "assistant"
elif isinstance(chat_content, conversation.UserContent):
role_name = "user"
elif isinstance(chat_content, conversation.SystemContent):
role_name = "system"
else:
raise ValueError(f"Unexpected content type: {type(chat_content)}")
return { "role": role_name, "message": chat_content.content }
def _convert_content_back(
agent_id: str,
message_history_entry: dict[str, str]
) -> Optional[conversation.Content]:
if message_history_entry["role"] == "tool":
return conversation.ToolResultContent(agent_id=agent_id, content=message_history_entry["message"])
if message_history_entry["role"] == "assistant":
return conversation.AssistantContent(agent_id=agent_id, content=message_history_entry["message"])
if message_history_entry["role"] == "user":
return conversation.UserContent(content=message_history_entry["message"])
if message_history_entry["role"] == "system":
return conversation.SystemContent(content=message_history_entry["message"])
def _parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance, user_input: ConversationInput) -> tuple[bool, ConversationResult | llm.ToolInput, str | None]:
parsed_tool_call: dict = json.loads(raw_block)
@@ -298,21 +267,29 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
self._load_model, entry
)
def _generate_stream(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> AsyncGenerator[TextGenerationResult, None]:
"""Async generator for streaming responses. Subclasses should implement."""
raise NotImplementedError()
def _generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
async def _generate(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> TextGenerationResult:
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
raise NotImplementedError()
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
"""Default implementation is to call _generate() which probably does blocking stuff"""
async def _async_generate(self, conv: List[conversation.Content], user_input: ConversationInput, chat_log: conversation.chat_log.ChatLog):
"""Default implementation: if streaming is supported, consume the async generator and return the full result."""
if hasattr(self, '_generate_stream'):
return await self.hass.async_add_executor_job(
self._generate_stream, conversation
# Try to stream and collect the full response
return await self._transform_result_stream(self._generate_stream(conv, chat_log.llm_api), user_input, chat_log)
# Fallback to "blocking" generate
blocking_result = await self._generate(conv, chat_log.llm_api)
return chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=blocking_result.response,
tool_calls=blocking_result.tool_calls
)
return await self.hass.async_add_executor_job(
self._generate, conversation
)
def _warn_context_size(self):
@@ -322,7 +299,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
f"{self.entry.data[CONF_CHAT_MODEL]} and it exceeded the context size for the model. " +
f"Please reduce the number of entities exposed ({num_entities}) or increase the model's context size ({int(context_size)})")
def _transform_result_stream(
async def _transform_result_stream(
self,
result: AsyncIterator[TextGenerationResult],
user_input: ConversationInput,
@@ -335,8 +312,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
tool_calls=input_chunk.tool_calls
)
chat_log.async_add_delta_content_stream(user_input.agent_id, stream=async_iterator())
return chat_log.async_add_delta_content_stream(user_input.agent_id, stream=async_iterator())
async def async_process(
self, user_input: ConversationInput
@@ -405,7 +381,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
)
if remember_conversation:
message_history = [ _convert_content(content) for content in chat_log.content ]
message_history = chat_log.content[:]
else:
message_history = []
@@ -436,16 +412,14 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
multi_turn_enabled = self.entry.options.get(CONF_TOOL_MULTI_TURN_CHAT, DEFAULT_TOOL_MULTI_TURN_CHAT)
MAX_TOOL_CALL_ITERATIONS = 3 if multi_turn_enabled else 1
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
for _ in range(MAX_TOOL_CALL_ITERATIONS):
# generate a response
try:
_LOGGER.debug(message_history)
generation_result = await self._async_generate(message_history)
generation_result = await self._async_generate(message_history, user_input, chat_log)
_LOGGER.debug(generation_result)
except Exception as err:
_LOGGER.exception("There was a problem talking to the backend")
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
@@ -455,67 +429,28 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
response=intent_response, conversation_id=user_input.conversation_id
)
response = generation_result.response or ""
last_message_had_tool_calls = False
async for message in generation_result:
message_history.append(message)
if message.role == "assistant":
if message.tool_calls and len(message.tool_calls) > 0:
last_message_had_tool_calls = True
else:
last_message_had_tool_calls = False
# remove think blocks
response = re.sub(rf"^.*?{template_desc["chain_of_thought"]["suffix"]}", "", response, flags=re.DOTALL)
message_history.append({"role": "assistant", "message": response})
if llm_api is None or (generation_result.tool_calls and len(generation_result.tool_calls) == 0):
if not generation_result.response_streamed:
chat_log.async_add_assistant_content_without_tools(conversation.AssistantContent(
agent_id=user_input.agent_id,
content=response,
))
# return the output without messing with it if there is no API exposed to the model
intent_response = intent.IntentResponse(language=user_input.language)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
# execute the tool calls
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
for tool_input in generation_result.tool_calls or []:
tool_response = None
try:
tool_response = llm_api.async_call_tool(tool_input)
_LOGGER.debug("Tool response: %s", tool_response)
tool_calls.append((tool_input, tool_response))
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
_LOGGER.debug("Tool response: %s", tool_response)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
f"I'm sorry! I encountered an error calling the tool. See the logs for more info.",
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
if tool_response and multi_turn_enabled:
async for tool_result in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=response,
tool_calls=generation_result.tool_calls,
),
tool_call_tasks={ x[0].tool_name: x[1] for x in tool_calls}
):
message_history.append({"role": "tool", "content": json.dumps(tool_result.tool_result) })
# If not multi-turn, break after first tool call
# also break if no tool calls were made
if not multi_turn_enabled or not last_message_had_tool_calls:
break
# generate intent response to Home Assistant
intent_response = intent.IntentResponse(language=user_input.language)
if len(tool_calls) > 0:
str_tools = [f"{input.tool_name}({', '.join(input.tool_args.values())})" for input, response in tool_calls]
str_tools = [f"{input.tool_name}({', '.join(str(x) for x in input.tool_args.values())})" for input, response in tool_calls]
tools_str = '\n'.join(str_tools)
intent_response.async_set_card(
title="Changes",
content=f"Ran the following tools:\n{'\n'.join(str_tools)}"
content=f"Ran the following tools:\n{tools_str}"
)
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id

View File

@@ -7,11 +7,13 @@ import logging
import multiprocessing
import voluptuous as vol
import webcolors
from typing import Any, Dict, List, Sequence
from webcolors import CSS3
from importlib.metadata import version
from homeassistant.components import conversation
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers import intent
from homeassistant.helpers import intent, llm
from homeassistant.requirements import pip_kwargs
from homeassistant.util import color
from homeassistant.util.package import install_package, is_installed
@@ -21,6 +23,13 @@ from .const import (
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from llama_cpp.llama_types import ChatCompletionRequestMessage, ChatCompletionTool
else:
ChatCompletionRequestMessage = Any
ChatCompletionTool = Any
_LOGGER = logging.getLogger(__name__)
CSS3_NAME_TO_RGB = {
@@ -238,4 +247,52 @@ def install_llama_cpp_python(config_dir: str):
return True
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
def get_oai_formatted_tools(llm_api: llm.APIInstance) -> List[ChatCompletionTool]:
return [ {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": tool.parameters.schema
}
} for tool in llm_api.tools ]
def get_oai_formatted_messages(conversation: Sequence[conversation.Content]) -> List[ChatCompletionRequestMessage]:
messages = []
for message in conversation:
if message.role == "system":
messages.append({
"role": "system",
"content": message.content
})
elif message.role == "user":
messages.append({
"role": "user",
"content": [{ "type": "text", "text": message.content }]
})
elif message.role == "assistant":
if message.tool_calls:
messages.append({
"role": "assistant",
"content": message.content,
"tool_calls": [
{
"type" : "function",
"id": t.id,
"function": {
"arguments": t.tool_args,
"name": t.tool_name,
}
} for t in message.tool_calls
]
})
elif message.role == "tool_result":
messages.append({
"role": "tool",
"content": message.tool_result,
"tool_call_id": message.tool_call_id
})
return messages

View File

@@ -675,8 +675,6 @@ def do_training_run(training_run_args: TrainingRunArguments):
if input("Something bad happened! Try and save it? (Y/n) ").lower().startswith("y"):
trainer._save_checkpoint(model, None)
print("Saved Checkpoint!")
exit(-1)
if __name__ == "__main__":
parser = HfArgumentParser([TrainingRunArguments])