mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
@@ -158,6 +158,7 @@ python3 train.py \
|
||||
## Version History
|
||||
| Version | Description |
|
||||
|---------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| v0.4.4 | Fix issue with OpenAI backends appending `/v1` to all URLs, and fix an issue with tools being serialized into the system prompt. |
|
||||
| v0.4.3 | Fix an issue with the integration not creating model configs properly during setup |
|
||||
| v0.4.2 | Fix the following issues: not correctly setting default model settings during initial setup, non-integers being allowed in numeric config fields, being too strict with finish_reason requirements, and not letting the user clear the active LLM API |
|
||||
| v0.4.1 | Fix an issue with using Llama.cpp models downloaded from HuggingFace |
|
||||
|
||||
@@ -57,7 +57,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
PLATFORMS = (Platform.CONVERSATION,)
|
||||
PLATFORMS = (Platform.CONVERSATION, ) # Platform.AI_TASK)
|
||||
|
||||
BACKEND_TO_CLS: dict[str, type[LocalLLMClient]] = {
|
||||
BACKEND_TYPE_LLAMA_CPP: LlamaCppClient,
|
||||
|
||||
126
custom_components/llama_conversation/ai_task.py
Normal file
126
custom_components/llama_conversation/ai_task.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""AI Task integration for Local LLMs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from json import JSONDecodeError
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, Context
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from .entity import LocalLLMEntity, LocalLLMClient
|
||||
from .const import (
|
||||
CONF_PROMPT,
|
||||
DEFAULT_PROMPT,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry[LocalLLMClient],
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up AI Task entities."""
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.subentry_type != "ai_task_data":
|
||||
continue
|
||||
|
||||
async_add_entities(
|
||||
[LocalLLMTaskEntity(hass, config_entry, subentry, config_entry.runtime_data)],
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
|
||||
class ResultExtractionMethod(StrEnum):
|
||||
NONE = "none"
|
||||
STRUCTURED_OUTPUT = "structure"
|
||||
TOOL = "tool"
|
||||
|
||||
class LocalLLMTaskEntity(
|
||||
ai_task.AITaskEntity,
|
||||
LocalLLMEntity,
|
||||
):
|
||||
"""Ollama AI Task entity."""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""Initialize Ollama AI Task entity."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.client._supports_vision(self.runtime_options):
|
||||
self._attr_supported_features = (
|
||||
ai_task.AITaskEntityFeature.GENERATE_DATA |
|
||||
ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
)
|
||||
else:
|
||||
self._attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
task: ai_task.GenDataTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Handle a generate data task."""
|
||||
|
||||
extraction_method = ResultExtractionMethod.NONE
|
||||
|
||||
try:
|
||||
raw_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
|
||||
message_history = chat_log.content[:]
|
||||
|
||||
if not isinstance(message_history[0], conversation.SystemContent):
|
||||
system_prompt = conversation.SystemContent(content=self.client._generate_system_prompt(raw_prompt, None, self.runtime_options))
|
||||
message_history.insert(0, system_prompt)
|
||||
|
||||
_LOGGER.debug(f"Generating response for {task.name=}...")
|
||||
generation_result = await self.client._async_generate(message_history, self.entity_id, chat_log, self.runtime_options)
|
||||
|
||||
assistant_message = await anext(generation_result)
|
||||
if not isinstance(assistant_message, conversation.AssistantContent):
|
||||
raise HomeAssistantError("Last content in chat log is not an AssistantContent!")
|
||||
text = assistant_message.content
|
||||
|
||||
if not task.structure:
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=text,
|
||||
)
|
||||
|
||||
if extraction_method == ResultExtractionMethod.NONE:
|
||||
raise HomeAssistantError("Task structure provided but no extraction method was specified!")
|
||||
elif extraction_method == ResultExtractionMethod.STRUCTURED_OUTPUT:
|
||||
try:
|
||||
data = json_loads(text)
|
||||
except JSONDecodeError as err:
|
||||
_LOGGER.error(
|
||||
"Failed to parse JSON response: %s. Response: %s",
|
||||
err,
|
||||
text,
|
||||
)
|
||||
raise HomeAssistantError("Error with Local LLM structured response") from err
|
||||
elif extraction_method == ResultExtractionMethod.TOOL:
|
||||
try:
|
||||
data = assistant_message.tool_calls[0].tool_args
|
||||
except (IndexError, AttributeError) as err:
|
||||
_LOGGER.error(
|
||||
"Failed to extract tool arguments from response: %s. Response: %s",
|
||||
err,
|
||||
text,
|
||||
)
|
||||
raise HomeAssistantError("Error with Local LLM tool response") from err
|
||||
else:
|
||||
raise ValueError() # should not happen
|
||||
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=data,
|
||||
)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unhandled exception while running AI Task '%s'", task.name)
|
||||
raise HomeAssistantError(f"Unhandled error while running AI Task '{task.name}'") from err
|
||||
@@ -119,7 +119,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
def _generate_stream(self,
|
||||
conversation: List[conversation.Content],
|
||||
llm_api: llm.APIInstance | None,
|
||||
user_input: conversation.ConversationInput,
|
||||
agent_id: str,
|
||||
entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
model_name = entity_options[CONF_CHAT_MODEL]
|
||||
temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
@@ -128,7 +128,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
enable_legacy_tool_calling = entity_options.get(CONF_ENABLE_LEGACY_TOOL_CALLING, DEFAULT_ENABLE_LEGACY_TOOL_CALLING)
|
||||
|
||||
endpoint, additional_params = self._chat_completion_params(entity_options)
|
||||
messages = get_oai_formatted_messages(conversation)
|
||||
messages = get_oai_formatted_messages(conversation, user_content_as_list=True)
|
||||
|
||||
request_params = {
|
||||
"model": model_name,
|
||||
@@ -175,7 +175,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
break
|
||||
|
||||
if chunk and chunk.strip():
|
||||
to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, user_input)
|
||||
to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, agent_id)
|
||||
if to_say or tool_calls:
|
||||
yield to_say, tool_calls
|
||||
except asyncio.TimeoutError as err:
|
||||
@@ -183,14 +183,14 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
except aiohttp.ClientError as err:
|
||||
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
|
||||
|
||||
return self._async_parse_completion(llm_api, user_input, entity_options, anext_token=anext_token())
|
||||
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
|
||||
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
endpoint = "/chat/completions"
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
|
||||
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, agent_id: str) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
|
||||
if "choices" not in response_json or len(response_json["choices"]) == 0: # finished
|
||||
_LOGGER.warning("Response missing or empty 'choices'. Keys present: %s. Full response: %s",
|
||||
list(response_json.keys()), response_json)
|
||||
@@ -203,11 +203,11 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
streamed = False
|
||||
elif response_json["object"] == "chat.completion.chunk":
|
||||
response_text = choice["delta"].get("content", "")
|
||||
if "tool_calls" in choice["delta"]:
|
||||
if "tool_calls" in choice["delta"] and choice["delta"]["tool_calls"] is not None:
|
||||
tool_calls = []
|
||||
for call in choice["delta"]["tool_calls"]:
|
||||
tool_call, to_say = parse_raw_tool_call(
|
||||
call["function"], llm_api, user_input)
|
||||
call["function"], llm_api, agent_id)
|
||||
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
@@ -366,7 +366,7 @@ class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
||||
async def _generate(self,
|
||||
conversation: List[conversation.Content],
|
||||
llm_api: llm.APIInstance | None,
|
||||
user_input: conversation.ConversationInput,
|
||||
agent_id: str,
|
||||
entity_options: dict[str, Any]) -> TextGenerationResult:
|
||||
"""Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream)."""
|
||||
|
||||
|
||||
@@ -419,7 +419,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
def _generate_stream(self,
|
||||
conversation: List[conversation.Content],
|
||||
llm_api: llm.APIInstance | None,
|
||||
user_input: conversation.ConversationInput,
|
||||
agent_id: str,
|
||||
entity_options: dict[str, Any],
|
||||
) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
"""Async generator that yields TextGenerationResult as tokens are produced."""
|
||||
@@ -434,7 +434,7 @@ class LlamaCppClient(LocalLLMClient):
|
||||
|
||||
_LOGGER.debug(f"Options: {entity_options}")
|
||||
|
||||
messages = get_oai_formatted_messages(conversation)
|
||||
messages = get_oai_formatted_messages(conversation, user_content_as_list=True)
|
||||
tools = None
|
||||
if llm_api:
|
||||
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
@@ -464,5 +464,5 @@ class LlamaCppClient(LocalLLMClient):
|
||||
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
|
||||
yield content, tool_calls
|
||||
|
||||
return self._async_parse_completion(llm_api, user_input, entity_options, next_token=next_token())
|
||||
return self._async_parse_completion(llm_api, agent_id, entity_options, next_token=next_token())
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
|
||||
return response, tool_calls
|
||||
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput, entity_options: Dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: Dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL, "")
|
||||
context_length = entity_options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
max_tokens = entity_options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
@@ -199,4 +199,4 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
except aiohttp.ClientError as err:
|
||||
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
|
||||
|
||||
return self._async_parse_completion(llm_api, user_input, entity_options, anext_token=anext_token())
|
||||
return self._async_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
||||
|
||||
@@ -66,7 +66,7 @@ class TextGenerationWebuiClient(GenericOpenAIAPIClient):
|
||||
headers["Authorization"] = f"Bearer {self.admin_key}"
|
||||
|
||||
async with session.get(
|
||||
f"{self.api_host}/v1/internal/model/info",
|
||||
f"{self.api_host}/internal/model/info",
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
@@ -80,7 +80,7 @@ class TextGenerationWebuiClient(GenericOpenAIAPIClient):
|
||||
_LOGGER.info(f"Model is not {model_name} loaded on the remote backend. Loading it now...")
|
||||
|
||||
async with session.post(
|
||||
f"{self.api_host}/v1/internal/model/load",
|
||||
f"{self.api_host}/internal/model/load",
|
||||
json={
|
||||
"model_name": model_name,
|
||||
},
|
||||
|
||||
@@ -400,6 +400,7 @@ class ConfigFlow(BaseConfigFlow, domain=DOMAIN):
|
||||
"""Return subentries supported by this integration."""
|
||||
return {
|
||||
"conversation": LocalLLMSubentryFlowHandler,
|
||||
# "ai_task_data": LocalLLMSubentryFlowHandler,
|
||||
}
|
||||
|
||||
|
||||
@@ -836,7 +837,7 @@ def local_llama_config_option_schema(
|
||||
})
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
result.update({
|
||||
vol.Required(
|
||||
vol.Required(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
@@ -924,6 +925,8 @@ def local_llama_config_option_schema(
|
||||
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
||||
): int,
|
||||
})
|
||||
elif subentry_type == "ai_task_data":
|
||||
pass # no additional options for ai_task_data for now
|
||||
|
||||
# sort the options
|
||||
global_order = [
|
||||
|
||||
@@ -337,5 +337,5 @@ def option_overrides(backend_type: str) -> dict[str, Any]:
|
||||
},
|
||||
}
|
||||
|
||||
INTEGRATION_VERSION = "0.4.3"
|
||||
INTEGRATION_VERSION = "0.4.4"
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION = "0.3.16+b6153"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Defines the various LLM Backend Agents"""
|
||||
from __future__ import annotations
|
||||
from typing import Literal
|
||||
from typing import Literal, List, Tuple, Any
|
||||
import logging
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, ConversationResult, ConversationEntity
|
||||
@@ -9,11 +9,24 @@ from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.helpers import chat_session
|
||||
from homeassistant.exceptions import TemplateError, HomeAssistantError
|
||||
from homeassistant.helpers import chat_session, intent, llm
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from custom_components.llama_conversation.utils import MalformedToolCallException
|
||||
|
||||
from .entity import LocalLLMEntity, LocalLLMClient, LocalLLMConfigEntry
|
||||
from .const import (
|
||||
CONF_PROMPT,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
@@ -73,4 +86,135 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent, LocalLLMEntit
|
||||
) as session,
|
||||
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||
):
|
||||
return await self.client._async_handle_message(user_input, chat_log, self.runtime_options)
|
||||
raw_prompt = self.runtime_options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
refresh_system_prompt = self.runtime_options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
|
||||
remember_conversation = self.runtime_options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
|
||||
remember_num_interactions = self.runtime_options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)
|
||||
max_tool_call_iterations = self.runtime_options.get(CONF_MAX_TOOL_CALL_ITERATIONS, DEFAULT_MAX_TOOL_CALL_ITERATIONS)
|
||||
llm_api: llm.APIInstance | None = None
|
||||
if self.runtime_options.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = await llm.async_get_api(
|
||||
self.hass,
|
||||
self.runtime_options[CONF_LLM_HASS_API],
|
||||
llm_context=user_input.as_llm_context(DOMAIN)
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error getting LLM API: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Error preparing LLM API: {err}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# ensure this chat log has the LLM API instance
|
||||
chat_log.llm_api = llm_api
|
||||
|
||||
if remember_conversation:
|
||||
message_history = chat_log.content[:]
|
||||
else:
|
||||
message_history = []
|
||||
|
||||
# trim message history before processing if necessary
|
||||
if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1:
|
||||
new_message_history = [message_history[0]] # copy system prompt
|
||||
new_message_history.extend(message_history[1:][-(remember_num_interactions * 2):])
|
||||
|
||||
# re-generate prompt if necessary
|
||||
if len(message_history) == 0 or refresh_system_prompt:
|
||||
try:
|
||||
system_prompt = conversation.SystemContent(content=self.client._generate_system_prompt(raw_prompt, llm_api, self.runtime_options))
|
||||
except TemplateError as err:
|
||||
_LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem with my template: {err}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
if len(message_history) == 0:
|
||||
message_history.append(system_prompt)
|
||||
else:
|
||||
message_history[0] = system_prompt
|
||||
|
||||
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
|
||||
# if max tool calls is 0 then we expect to generate the response & tool call in one go
|
||||
for idx in range(max(1, max_tool_call_iterations)):
|
||||
_LOGGER.debug(f"Generating response for {user_input.text=}, iteration {idx+1}/{max_tool_call_iterations}")
|
||||
generation_result = await self.client._async_generate(message_history, user_input.agent_id, chat_log, self.runtime_options)
|
||||
|
||||
last_generation_had_tool_calls = False
|
||||
while True:
|
||||
try:
|
||||
message = await anext(generation_result)
|
||||
message_history.append(message)
|
||||
_LOGGER.debug("Added message to history: %s", message)
|
||||
if message.role == "assistant":
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
last_generation_had_tool_calls = True
|
||||
else:
|
||||
last_generation_had_tool_calls = False
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
except MalformedToolCallException as err:
|
||||
message_history.extend(err.as_tool_messages())
|
||||
last_generation_had_tool_calls = True
|
||||
_LOGGER.debug("Malformed tool call produced", exc_info=err)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("There was a problem talking to the backend")
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
f"Sorry, there was a problem talking to the backend: {repr(err)}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# If not multi-turn, break after first tool call
|
||||
# also break if no tool calls were made
|
||||
if not last_generation_had_tool_calls:
|
||||
break
|
||||
|
||||
# return an error if we run out of attempt without succeeding
|
||||
if idx == max_tool_call_iterations - 1 and max_tool_call_iterations > 0:
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
f"Sorry, I ran out of attempts to handle your request",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# generate intent response to Home Assistant
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
if len(tool_calls) > 0:
|
||||
str_tools = [f"{input.tool_name}({', '.join(str(x) for x in input.tool_args.values())})" for input, response in tool_calls]
|
||||
tools_str = '\n'.join(str_tools)
|
||||
intent_response.async_set_card(
|
||||
title="Changes",
|
||||
content=f"Ran the following tools:\n{tools_str}"
|
||||
)
|
||||
|
||||
has_speech = False
|
||||
for i in range(1, len(message_history)):
|
||||
cur_msg = message_history[-1 * i]
|
||||
if isinstance(cur_msg, conversation.AssistantContent) and cur_msg.content:
|
||||
intent_response.async_set_speech(cur_msg.content)
|
||||
has_speech = True
|
||||
break
|
||||
|
||||
if not has_speech:
|
||||
intent_response.async_set_speech("I don't have anything to say right now")
|
||||
_LOGGER.debug(message_history)
|
||||
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
@@ -8,45 +8,33 @@ import random
|
||||
from typing import Literal, Any, List, Dict, Optional, Tuple, AsyncIterator, Generator, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, ConversationResult
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.const import MATCH_ALL, CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import TemplateError, HomeAssistantError
|
||||
from homeassistant.helpers import intent, template, entity_registry as er, llm, \
|
||||
from homeassistant.helpers import template, entity_registry as er, llm, \
|
||||
area_registry as ar, device_registry as dr, entity
|
||||
from homeassistant.util import color
|
||||
|
||||
from .utils import closest_color, parse_raw_tool_call, flatten_vol_schema, MalformedToolCallException
|
||||
from .utils import closest_color, parse_raw_tool_call, flatten_vol_schema
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_SELECTED_LANGUAGE,
|
||||
CONF_PROMPT,
|
||||
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
||||
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS,
|
||||
CONF_THINKING_PREFIX,
|
||||
CONF_THINKING_SUFFIX,
|
||||
CONF_TOOL_CALL_PREFIX,
|
||||
CONF_TOOL_CALL_SUFFIX,
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
||||
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
||||
DOMAIN,
|
||||
DEFAULT_THINKING_PREFIX,
|
||||
DEFAULT_THINKING_SUFFIX,
|
||||
@@ -145,27 +133,31 @@ class LocalLLMClient:
|
||||
await self.hass.async_add_executor_job(
|
||||
self._unload_model, entity_options
|
||||
)
|
||||
|
||||
def _supports_vision(self, entity_options: dict[str, Any]) -> bool:
|
||||
"""Determine if the backend supports vision inputs. Implemented by sub-classes"""
|
||||
return False
|
||||
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput, entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
"""Async generator for streaming responses. Subclasses should implement."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _generate(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput, entity_options: dict[str, Any]) -> TextGenerationResult:
|
||||
async def _generate(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, agent_id: str, entity_options: dict[str, Any]) -> TextGenerationResult:
|
||||
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _async_generate(self, conv: List[conversation.Content], user_input: ConversationInput, chat_log: conversation.chat_log.ChatLog, entity_options: dict[str, Any]):
|
||||
async def _async_generate(self, conv: List[conversation.Content], agent_id: str, chat_log: conversation.chat_log.ChatLog, entity_options: dict[str, Any]):
|
||||
"""Default implementation: if streaming is supported, consume the async generator and return the full result."""
|
||||
if hasattr(self, '_generate_stream'):
|
||||
# Try to stream and collect the full response
|
||||
return await self._transform_result_stream(self._generate_stream(conv, chat_log.llm_api, user_input, entity_options), user_input, chat_log)
|
||||
return await self._transform_result_stream(self._generate_stream(conv, chat_log.llm_api, agent_id, entity_options), agent_id, chat_log)
|
||||
|
||||
# Fallback to "blocking" generate
|
||||
blocking_result = await self._generate(conv, chat_log.llm_api, user_input, entity_options)
|
||||
blocking_result = await self._generate(conv, chat_log.llm_api, agent_id, entity_options)
|
||||
|
||||
return chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
agent_id=agent_id,
|
||||
content=blocking_result.response,
|
||||
tool_calls=blocking_result.tool_calls
|
||||
)
|
||||
@@ -184,7 +176,7 @@ class LocalLLMClient:
|
||||
async def _transform_result_stream(
|
||||
self,
|
||||
result: AsyncIterator[TextGenerationResult],
|
||||
user_input: ConversationInput,
|
||||
agent_id: str,
|
||||
chat_log: conversation.chat_log.ChatLog
|
||||
):
|
||||
async def async_iterator():
|
||||
@@ -205,152 +197,11 @@ class LocalLLMClient:
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
return chat_log.async_add_delta_content_stream(user_input.agent_id, stream=async_iterator())
|
||||
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
chat_log: conversation.ChatLog,
|
||||
entity_options: Dict[str, Any],
|
||||
) -> conversation.ConversationResult:
|
||||
|
||||
raw_prompt = entity_options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
refresh_system_prompt = entity_options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
|
||||
remember_conversation = entity_options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
|
||||
remember_num_interactions = entity_options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)
|
||||
max_tool_call_iterations = entity_options.get(CONF_MAX_TOOL_CALL_ITERATIONS, DEFAULT_MAX_TOOL_CALL_ITERATIONS)
|
||||
|
||||
llm_api: llm.APIInstance | None = None
|
||||
if entity_options.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = await llm.async_get_api(
|
||||
self.hass,
|
||||
entity_options[CONF_LLM_HASS_API],
|
||||
llm_context=user_input.as_llm_context(DOMAIN)
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error getting LLM API: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Error preparing LLM API: {err}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# ensure this chat log has the LLM API instance
|
||||
chat_log.llm_api = llm_api
|
||||
|
||||
if remember_conversation:
|
||||
message_history = chat_log.content[:]
|
||||
else:
|
||||
message_history = []
|
||||
|
||||
# trim message history before processing if necessary
|
||||
if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1:
|
||||
new_message_history = [message_history[0]] # copy system prompt
|
||||
new_message_history.extend(message_history[1:][-(remember_num_interactions * 2):])
|
||||
|
||||
# re-generate prompt if necessary
|
||||
if len(message_history) == 0 or refresh_system_prompt:
|
||||
try:
|
||||
system_prompt = conversation.SystemContent(content=self._generate_system_prompt(raw_prompt, llm_api, entity_options))
|
||||
except TemplateError as err:
|
||||
_LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem with my template: {err}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
if len(message_history) == 0:
|
||||
message_history.append(system_prompt)
|
||||
else:
|
||||
message_history[0] = system_prompt
|
||||
|
||||
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
|
||||
# if max tool calls is 0 then we expect to generate the response & tool call in one go
|
||||
for idx in range(max(1, max_tool_call_iterations)):
|
||||
_LOGGER.debug(f"Generating response for {user_input.text=}, iteration {idx+1}/{max_tool_call_iterations}")
|
||||
generation_result = await self._async_generate(message_history, user_input, chat_log, entity_options)
|
||||
|
||||
last_generation_had_tool_calls = False
|
||||
while True:
|
||||
try:
|
||||
message = await anext(generation_result)
|
||||
message_history.append(message)
|
||||
_LOGGER.debug("Added message to history: %s", message)
|
||||
if message.role == "assistant":
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
last_generation_had_tool_calls = True
|
||||
else:
|
||||
last_generation_had_tool_calls = False
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
except MalformedToolCallException as err:
|
||||
message_history.extend(err.as_tool_messages())
|
||||
last_generation_had_tool_calls = True
|
||||
_LOGGER.debug("Malformed tool call produced", exc_info=err)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("There was a problem talking to the backend")
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
f"Sorry, there was a problem talking to the backend: {repr(err)}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# If not multi-turn, break after first tool call
|
||||
# also break if no tool calls were made
|
||||
if not last_generation_had_tool_calls:
|
||||
break
|
||||
|
||||
# return an error if we run out of attempt without succeeding
|
||||
if idx == max_tool_call_iterations - 1 and max_tool_call_iterations > 0:
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
f"Sorry, I ran out of attempts to handle your request",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# generate intent response to Home Assistant
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
if len(tool_calls) > 0:
|
||||
str_tools = [f"{input.tool_name}({', '.join(str(x) for x in input.tool_args.values())})" for input, response in tool_calls]
|
||||
tools_str = '\n'.join(str_tools)
|
||||
intent_response.async_set_card(
|
||||
title="Changes",
|
||||
content=f"Ran the following tools:\n{tools_str}"
|
||||
)
|
||||
|
||||
has_speech = False
|
||||
for i in range(1, len(message_history)):
|
||||
cur_msg = message_history[-1 * i]
|
||||
if isinstance(cur_msg, conversation.AssistantContent) and cur_msg.content:
|
||||
intent_response.async_set_speech(cur_msg.content)
|
||||
has_speech = True
|
||||
break
|
||||
|
||||
if not has_speech:
|
||||
intent_response.async_set_speech("I don't have anything to say right now")
|
||||
_LOGGER.debug(message_history)
|
||||
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
return chat_log.async_add_delta_content_stream(agent_id, stream=async_iterator())
|
||||
|
||||
async def _async_parse_completion(
|
||||
self, llm_api: llm.APIInstance | None,
|
||||
user_input: ConversationInput,
|
||||
agent_id: str,
|
||||
entity_options: Dict[str, Any],
|
||||
next_token: Optional[Generator[Tuple[Optional[str], Optional[List]]]] = None,
|
||||
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[List]]]] = None,
|
||||
@@ -442,9 +293,9 @@ class LocalLLMClient:
|
||||
parsed_tool_calls.append(raw_tool_call)
|
||||
else:
|
||||
if isinstance(raw_tool_call, str):
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, user_input)
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call, llm_api, agent_id)
|
||||
else:
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, user_input)
|
||||
tool_call, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, agent_id)
|
||||
|
||||
if tool_call:
|
||||
_LOGGER.debug("Tool call parsed: %s", tool_call)
|
||||
@@ -667,6 +518,10 @@ class LocalLLMClient:
|
||||
message = "No tools were provided. If the user requests you interact with a device, tell them you are unable to do so."
|
||||
render_variables["tools"] = [message]
|
||||
render_variables["formatted_tools"] = message
|
||||
else:
|
||||
# Tools are passed via the API not the prompt
|
||||
render_variables["tools"] = []
|
||||
render_variables["formatted_tools"] = ""
|
||||
|
||||
# only pass examples if there are loaded examples + an API was exposed
|
||||
if self.in_context_examples and llm_api:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"domain": "llama_conversation",
|
||||
"name": "Local LLMs",
|
||||
"version": "0.4.3",
|
||||
"version": "0.4.4",
|
||||
"codeowners": ["@acon96"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
|
||||
@@ -35,10 +35,10 @@
|
||||
"config_subentries": {
|
||||
"conversation": {
|
||||
"initiate_flow": {
|
||||
"user": "Add conversation agent",
|
||||
"reconfigure": "Reconfigure conversation agent"
|
||||
"user": "Add Conversation Agent",
|
||||
"reconfigure": "Reconfigure Conversation Agent"
|
||||
},
|
||||
"entry_type": "Conversation agent",
|
||||
"entry_type": "Conversation Agent",
|
||||
"error": {
|
||||
"download_failed": "The download failed to complete: {exception}",
|
||||
"missing_quantization": "The GGUF quantization level {missing} does not exist in the provided HuggingFace repo. The following quantization levels were found: {available}",
|
||||
@@ -176,6 +176,145 @@
|
||||
"title": "Configure the selected model"
|
||||
}
|
||||
}
|
||||
},
|
||||
"ai_task_data": {
|
||||
"initiate_flow": {
|
||||
"user": "Add AI Task Handler",
|
||||
"reconfigure": "Reconfigure AI Task Handler"
|
||||
},
|
||||
"entry_type": "AI Task Handler",
|
||||
"error": {
|
||||
"download_failed": "The download failed to complete: {exception}",
|
||||
"missing_quantization": "The GGUF quantization level {missing} does not exist in the provided HuggingFace repo. The following quantization levels were found: {available}",
|
||||
"no_supported_ggufs": "The provided HuggingFace repo does not contain any compatible GGUF files!",
|
||||
"missing_model_api": "The selected model is not provided by this API. The available models have been populated in the dropdown.",
|
||||
"missing_model_file": "The provided file does not exist.",
|
||||
"other_existing_local": "Another model is already loaded locally. Please unload it or configure a remote model.",
|
||||
"unknown": "Unexpected error",
|
||||
"sys_refresh_caching_enabled": "System prompt refresh must be enabled for prompt caching to work!",
|
||||
"missing_gbnf_file": "The GBNF file was not found: {filename}",
|
||||
"missing_icl_file": "The in context learning example CSV file was not found: {filename}"
|
||||
},
|
||||
"progress": {
|
||||
"download": "Please wait while the model is being downloaded from HuggingFace. This can take a few minutes."
|
||||
},
|
||||
"abort": {
|
||||
"reconfigure_successful": "Successfully updated model options."
|
||||
},
|
||||
"step": {
|
||||
"pick_model": {
|
||||
"data": {
|
||||
"huggingface_model": "Model Name",
|
||||
"downloaded_model_file": "Local file name",
|
||||
"downloaded_model_quantization": "Downloaded model quantization"
|
||||
},
|
||||
"description": "Select a model to use. \n\n**Models supported out of the box:**\n1. [Home LLM](https://huggingface.co/collections/acon96/home-llm-6618762669211da33bb22c5a): Home 3B & Home 1B\n2. Mistral: [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) or [Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)\n3. Llama 3: [8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and [70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)",
|
||||
"title": "Pick Model"
|
||||
},
|
||||
"model_parameters": {
|
||||
"data": {
|
||||
"max_new_tokens": "Maximum tokens to return in response",
|
||||
"prompt": "System Prompt",
|
||||
"temperature": "Temperature",
|
||||
"top_k": "Top K",
|
||||
"top_p": "Top P",
|
||||
"min_p": "Min P",
|
||||
"typical_p": "Typical P",
|
||||
"request_timeout": "Remote Request Timeout (seconds)",
|
||||
"ollama_keep_alive": "(ollama) Keep Alive/Inactivity Timeout (minutes)",
|
||||
"ollama_json_mode": "(ollama) JSON Output Mode",
|
||||
"extra_attributes_to_expose": "Additional attribute to expose in the context",
|
||||
"enable_flash_attention": "Enable Flash Attention",
|
||||
"gbnf_grammar": "Enable GBNF Grammar",
|
||||
"gbnf_grammar_file": "GBNF Grammar Filename",
|
||||
"openai_api_key": "API Key",
|
||||
"text_generation_webui_admin_key": "(text-generation-webui) Admin Key",
|
||||
"service_call_regex": "Service Call Regex",
|
||||
"in_context_examples": "Enable in context learning (ICL) examples",
|
||||
"in_context_examples_file": "In context learning examples CSV filename",
|
||||
"num_in_context_examples": "Number of ICL examples to generate",
|
||||
"text_generation_webui_preset": "(text-generation-webui) Generation Preset/Character Name",
|
||||
"text_generation_webui_chat_mode": "(text-generation-webui) Chat Mode",
|
||||
"prompt_caching": "Enable Prompt Caching",
|
||||
"prompt_caching_interval": "Prompt Caching fastest refresh interval (sec)",
|
||||
"context_length": "Context Length",
|
||||
"batch_size": "(llama.cpp) Batch Size",
|
||||
"n_threads": "(llama.cpp) Thread Count",
|
||||
"n_batch_threads": "(llama.cpp) Batch Thread Count",
|
||||
"thinking_prefix": "Reasoning Content Prefix",
|
||||
"thinking_suffix": "Reasoning Content Suffix",
|
||||
"tool_call_prefix": "Tool Call Prefix",
|
||||
"tool_call_suffix": "Tool Call Suffix",
|
||||
"enable_legacy_tool_calling": "Enable Legacy Tool Calling",
|
||||
"max_tool_call_iterations": "Maximum Tool Call Attempts"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
|
||||
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
|
||||
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
|
||||
"gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.",
|
||||
"prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below.",
|
||||
"enable_legacy_tool_calling": "Prefer to process tool calls locally rather than relying on the backend to handle the tool calling format. Can be more reliable, however it requires properly setting the tool call prefix and suffix.",
|
||||
"max_tool_call_iterations": "Set to 0 to generate the response and tool call in one attempt, without looping (use this for Home models v1-v3)."
|
||||
},
|
||||
"description": "Please configure the model according to how it should be prompted. There are many different options and selecting the correct ones for your model is essential to getting optimal performance. See [here](https://github.com/acon96/home-llm/blob/develop/docs/Backend%20Configuration.md) for more information about the options on this page.\n\n**Some defaults may have been chosen for you based on the name of the selected model name or filename.** If you renamed a file or are using a fine-tuning of a supported model, then the defaults may not have been detected.",
|
||||
"title": "Configure the selected model"
|
||||
},
|
||||
"reconfigure": {
|
||||
"data": {
|
||||
"max_new_tokens": "Maximum tokens to return in response",
|
||||
"llm_hass_api": "Selected LLM API",
|
||||
"prompt": "System Prompt",
|
||||
"temperature": "Temperature",
|
||||
"top_k": "Top K",
|
||||
"top_p": "Top P",
|
||||
"min_p": "Min P",
|
||||
"typical_p": "Typical P",
|
||||
"request_timeout": "Remote Request Timeout (seconds)",
|
||||
"ollama_keep_alive": "(ollama) Keep Alive/Inactivity Timeout (minutes)",
|
||||
"ollama_json_mode": "(ollama) JSON Output Mode",
|
||||
"extra_attributes_to_expose": "Additional attribute to expose in the context",
|
||||
"enable_flash_attention": "Enable Flash Attention",
|
||||
"gbnf_grammar": "Enable GBNF Grammar",
|
||||
"gbnf_grammar_file": "GBNF Grammar Filename",
|
||||
"openai_api_key": "API Key",
|
||||
"text_generation_webui_admin_key": "(text-generation-webui) Admin Key",
|
||||
"service_call_regex": "Service Call Regex",
|
||||
"refresh_prompt_per_turn": "Refresh System Prompt Every Turn",
|
||||
"remember_conversation": "Remember conversation",
|
||||
"remember_num_interactions": "Number of past interactions to remember",
|
||||
"in_context_examples": "Enable in context learning (ICL) examples",
|
||||
"in_context_examples_file": "In context learning examples CSV filename",
|
||||
"num_in_context_examples": "Number of ICL examples to generate",
|
||||
"text_generation_webui_preset": "(text-generation-webui) Generation Preset/Character Name",
|
||||
"text_generation_webui_chat_mode": "(text-generation-webui) Chat Mode",
|
||||
"prompt_caching": "Enable Prompt Caching",
|
||||
"prompt_caching_interval": "Prompt Caching fastest refresh interval (sec)",
|
||||
"context_length": "Context Length",
|
||||
"batch_size": "(llama.cpp) Batch Size",
|
||||
"n_threads": "(llama.cpp) Thread Count",
|
||||
"n_batch_threads": "(llama.cpp) Batch Thread Count",
|
||||
"thinking_prefix": "Reasoning Content Prefix",
|
||||
"thinking_suffix": "Reasoning Content Suffix",
|
||||
"tool_call_prefix": "Tool Call Prefix",
|
||||
"tool_call_suffix": "Tool Call Suffix",
|
||||
"enable_legacy_tool_calling": "Enable Legacy Tool Calling",
|
||||
"max_tool_call_iterations": "Maximum Tool Call Attempts"
|
||||
},
|
||||
"data_description": {
|
||||
"llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'",
|
||||
"prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.",
|
||||
"in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this",
|
||||
"extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.",
|
||||
"gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.",
|
||||
"prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below.",
|
||||
"enable_legacy_tool_calling": "Prefer to process tool calls locally rather than relying on the backend to handle the tool calling format. Can be more reliable, however it requires properly setting the tool call prefix and suffix.",
|
||||
"max_tool_call_iterations": "Set to 0 to generate the response and tool call in one attempt, without looping (use this for Home models v1-v3)."
|
||||
},
|
||||
"description": "Please configure the model according to how it should be prompted. There are many different options and selecting the correct ones for your model is essential to getting optimal performance. See [here](https://github.com/acon96/home-llm/blob/develop/docs/Backend%20Configuration.md) for more information about the options on this page.\n\n**Some defaults may have been chosen for you based on the name of the selected model name or filename.** If you renamed a file or are using a fine-tuning of a supported model, then the defaults may not have been detected.",
|
||||
"title": "Configure the selected model"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
|
||||
@@ -9,11 +9,14 @@ import multiprocessing
|
||||
import voluptuous as vol
|
||||
import webcolors
|
||||
import json
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Sequence, Tuple, cast
|
||||
from webcolors import CSS3
|
||||
from importlib.metadata import version
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers import intent, llm, aiohttp_client
|
||||
@@ -24,12 +27,12 @@ from homeassistant.util.package import install_package, is_installed
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -296,48 +299,64 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis
|
||||
return result
|
||||
|
||||
def get_oai_formatted_messages(conversation: Sequence[conversation.Content], user_content_as_list: bool = False, tool_args_to_str: bool = True) -> List[ChatCompletionRequestMessage]:
|
||||
messages: List[ChatCompletionRequestMessage] = []
|
||||
for message in conversation:
|
||||
if message.role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": message.content
|
||||
})
|
||||
elif message.role == "user":
|
||||
if user_content_as_list:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [{ "type": "text", "text": message.content }]
|
||||
})
|
||||
else:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": message.content
|
||||
})
|
||||
elif message.role == "assistant":
|
||||
if message.tool_calls:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": str(message.content),
|
||||
"tool_calls": [
|
||||
{
|
||||
"type" : "function",
|
||||
"id": t.id,
|
||||
"function": {
|
||||
"arguments": cast(str, json.dumps(t.tool_args) if tool_args_to_str else t.tool_args),
|
||||
"name": t.tool_name,
|
||||
}
|
||||
} for t in message.tool_calls
|
||||
]
|
||||
})
|
||||
elif message.role == "tool_result":
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": json.dumps(message.tool_result),
|
||||
"tool_call_id": message.tool_call_id
|
||||
})
|
||||
messages: List[ChatCompletionRequestMessage] = []
|
||||
for message in conversation:
|
||||
if message.role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": message.content
|
||||
})
|
||||
elif message.role == "user":
|
||||
images: list[str] = []
|
||||
for attachment in message.attachments or ():
|
||||
if not attachment.mime_type.startswith("image/"):
|
||||
raise HomeAssistantError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="unsupported_attachment_type",
|
||||
)
|
||||
images.append(get_file_contents_base64(attachment.path))
|
||||
|
||||
return messages
|
||||
if user_content_as_list:
|
||||
content = [{ "type": "text", "text": message.content }]
|
||||
for image in images:
|
||||
content.append({ "type": "image_url", "image_url": {"url": image } })
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
})
|
||||
else:
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": message.content
|
||||
}
|
||||
if images:
|
||||
message["images"] = images
|
||||
messages.append(message)
|
||||
elif message.role == "assistant":
|
||||
if message.tool_calls:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": str(message.content),
|
||||
"tool_calls": [
|
||||
{
|
||||
"type" : "function",
|
||||
"id": t.id,
|
||||
"function": {
|
||||
"arguments": cast(str, json.dumps(t.tool_args) if tool_args_to_str else t.tool_args),
|
||||
"name": t.tool_name,
|
||||
}
|
||||
} for t in message.tool_calls
|
||||
]
|
||||
})
|
||||
elif message.role == "tool_result":
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": json.dumps(message.tool_result),
|
||||
"tool_call_id": message.tool_call_id
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dict[str, Any]]:
|
||||
service_dict = llm_api.api.hass.services.async_services()
|
||||
@@ -377,7 +396,7 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic
|
||||
|
||||
return tools
|
||||
|
||||
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, user_input: conversation.ConversationInput) -> tuple[llm.ToolInput | None, str | None]:
|
||||
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, agent_id: str) -> tuple[llm.ToolInput | None, str | None]:
|
||||
if isinstance(raw_block, dict):
|
||||
parsed_tool_call = raw_block
|
||||
else:
|
||||
@@ -407,7 +426,7 @@ def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, user_in
|
||||
schema_to_validate(parsed_tool_call)
|
||||
except vol.Error as ex:
|
||||
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
|
||||
raise MalformedToolCallException(user_input.agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
|
||||
raise MalformedToolCallException(agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
|
||||
|
||||
# try to fix certain arguments
|
||||
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
|
||||
@@ -420,7 +439,7 @@ def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, user_in
|
||||
try:
|
||||
args_dict = json.loads(args_dict)
|
||||
except json.JSONDecodeError:
|
||||
raise MalformedToolCallException(user_input.agent_id, "", tool_name, str(args_dict), "Tool arguments were not properly formatted JSON")
|
||||
raise MalformedToolCallException(agent_id, "", tool_name, str(args_dict), "Tool arguments were not properly formatted JSON")
|
||||
|
||||
# make sure brightness is 0-255 and not a percentage
|
||||
if "brightness" in args_dict and 0.0 < args_dict["brightness"] <= 1.0:
|
||||
@@ -477,4 +496,13 @@ def is_valid_hostname(host: str) -> bool:
|
||||
|
||||
domain_pattern = re.compile(r"^[a-z0-9]([a-z0-9\-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9\-]{0,61}[a-z0-9])?)*\.[a-z]{2,}$")
|
||||
|
||||
return bool(domain_pattern.match(host))
|
||||
return bool(domain_pattern.match(host))
|
||||
|
||||
|
||||
def get_file_contents_base64(file_path: Path) -> str:
|
||||
"""Reads a file and returns its contents encoded in base64."""
|
||||
with open(file_path, "rb") as f:
|
||||
encoded_bytes = base64.b64encode(f.read())
|
||||
encoded_str = encoded_bytes.decode('utf-8')
|
||||
|
||||
return encoded_str
|
||||
@@ -1,5 +1,5 @@
|
||||
# types from Home Assistant
|
||||
homeassistant>=2024.8.3
|
||||
homeassistant>=2025.8.3
|
||||
hassil
|
||||
home-assistant-intents
|
||||
|
||||
|
||||
Reference in New Issue
Block a user