actually working llamacpp agent

This commit is contained in:
Alex O'Connell
2025-09-15 22:10:25 -04:00
parent 69f5464bf7
commit 61d52ae4d1
6 changed files with 198 additions and 168 deletions

View File

@@ -14,7 +14,7 @@ from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers import llm
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools, parse_raw_tool_call
from custom_components.llama_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@@ -33,7 +33,7 @@ from custom_components.llama_conversation.const import (
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
DEFAULT_GENERIC_OPENAI_PATH,
)
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult, parse_raw_tool_call
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
_LOGGER = logging.getLogger(__name__)
@@ -154,7 +154,7 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
request_params["messages"] = get_oai_formatted_messages(conversation)
if llm_api:
request_params["tools"] = get_oai_formatted_tools(llm_api)
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
return self._async_generate_with_parameters(endpoint, True, request_params, llm_api, user_input)

View File

@@ -21,7 +21,7 @@ from homeassistant.helpers.event import async_track_state_change, async_call_lat
from llama_cpp import CreateChatCompletionStreamResponse
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation, get_oai_formatted_messages, get_oai_formatted_tools
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation, get_oai_formatted_messages, get_oai_formatted_tools, parse_raw_tool_call
from custom_components.llama_conversation.const import (
CONF_THINKING_PREFIX,
CONF_THINKING_SUFFIX,
@@ -222,13 +222,13 @@ class LlamaCppAgent(LocalLLMAgent):
else:
self._set_prompt_caching(enabled=False)
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
def _async_get_exposed_entities(self) -> dict[str, str]:
"""Takes the super class function results and sorts the entities with the recently updated at the end"""
entities, domains = LocalLLMAgent._async_get_exposed_entities(self)
entities = LocalLLMAgent._async_get_exposed_entities(self)
# ignore sorting if prompt caching is disabled
if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
return entities, domains
return entities
entity_order = { name: None for name in entities.keys() }
entity_order.update(self.last_updated_entities)
@@ -250,7 +250,7 @@ class LlamaCppAgent(LocalLLMAgent):
for item_name, _ in sorted_items:
sorted_entities[item_name] = entities[item_name]
return sorted_entities, domains
return sorted_entities
def _set_prompt_caching(self, *, enabled=True):
if enabled and not self.remove_prompt_caching_listener:
@@ -361,7 +361,7 @@ class LlamaCppAgent(LocalLLMAgent):
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
async def _async_generate_completion(self, chat_completion: Iterator[CreateChatCompletionStreamResponse]) -> AsyncGenerator[TextGenerationResult, None]:
async def _async_generate_completion(self, llm_api: llm.APIInstance | None, chat_completion: Iterator[CreateChatCompletionStreamResponse]) -> AsyncGenerator[TextGenerationResult, None]:
think_prefix = self.entry.options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
think_suffix = self.entry.options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
tool_prefix = self.entry.options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
@@ -377,6 +377,7 @@ class LlamaCppAgent(LocalLLMAgent):
in_thinking = False
in_tool_call = False
tool_content = ""
last_5_tokens = []
while chunk := await self.hass.async_add_executor_job(next_token):
content = chunk["choices"][0]["delta"].get("content")
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
@@ -389,37 +390,44 @@ class LlamaCppAgent(LocalLLMAgent):
tool_calls=None
)
if content:
if think_prefix in content and not in_thinking:
in_thinking = True
elif think_suffix in content and in_thinking:
in_thinking = False
content = content.replace(think_suffix, "").strip()
elif tool_prefix in content and not in_tool_call:
in_tool_call = True
elif tool_suffix in content and in_tool_call:
in_tool_call = False
tool_call = json.loads(tool_content.strip().removeprefix(tool_prefix).removesuffix(tool_suffix))
result.tool_calls = [
llm.ToolInput(
tool_name=tool_call["name"],
tool_args=tool_call["arguments"]
)
]
content = None
last_5_tokens.append(content)
if len(last_5_tokens) > 5:
last_5_tokens.pop(0)
result.response = content
potential_block = "".join(last_5_tokens)
if in_tool_call:
tool_content += content
if think_prefix in potential_block and not in_thinking:
in_thinking = True
last_5_tokens.clear()
elif think_suffix in potential_block and in_thinking:
in_thinking = False
content = content.replace(think_suffix, "").strip()
elif tool_prefix in potential_block and not in_tool_call:
in_tool_call = True
last_5_tokens.clear()
elif tool_suffix in potential_block and in_tool_call:
in_tool_call = False
_LOGGER.debug("Tool content: %s", tool_content)
tool_call, to_say = parse_raw_tool_call(tool_content.strip().removeprefix(tool_prefix).removesuffix(tool_suffix), llm_api)
if tool_call:
result.tool_calls = [tool_call]
if to_say:
content = to_say
else:
content = None
result.response = content
if tool_calls:
result.tool_calls = [
llm.ToolInput(
tool_name=str(tool_calls[0]["function"]["name"]),
tool_args=json.loads(tool_calls[0]["function"]["arguments"])
)
]
result.tool_calls = [llm.ToolInput(
tool_name=str(tool_call["function"]["name"]),
tool_args=json.loads(tool_call["function"]["arguments"])
) for tool_call in tool_calls ]
if not in_thinking and not in_tool_call:
yield result
@@ -436,7 +444,6 @@ class LlamaCppAgent(LocalLLMAgent):
_LOGGER.debug(f"Options: {self.entry.options}")
# TODO: re-enable the context length check
# with self.model_lock:
# # FIXME: use the high level API so we can use the built-in prompt formatting
# input_tokens = self.llm.tokenize(
# prompt.encode(), add_bos=False
@@ -444,7 +451,7 @@ class LlamaCppAgent(LocalLLMAgent):
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
# if len(input_tokens) >= context_len:
# num_entities = len(self._async_get_exposed_entities()[0])
# num_entities = len(self._async_get_exposed_entities())
# context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
# self._warn_context_size()
# raise Exception(f"The model failed to produce a result because too many devices are exposed ({num_entities} devices) for the context size ({context_size} tokens)!")
@@ -456,11 +463,11 @@ class LlamaCppAgent(LocalLLMAgent):
messages = get_oai_formatted_messages(conversation)
tools = None
if llm_api:
tools = get_oai_formatted_tools(llm_api)
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
_LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...")
return self._async_generate_completion(self.llm.create_chat_completion(
return self._async_generate_completion(llm_api, self.llm.create_chat_completion(
messages,
tools=tools,
temperature=temperature,

View File

@@ -67,6 +67,7 @@ from .const import (
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
CONF_MAX_TOOL_CALL_ITERATIONS,
CONF_PROMPT_CACHING_ENABLED,
CONF_PROMPT_CACHING_INTERVAL,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
@@ -113,6 +114,7 @@ from .const import (
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_MAX_TOOL_CALL_ITERATIONS,
DEFAULT_PROMPT_CACHING_ENABLED,
DEFAULT_PROMPT_CACHING_INTERVAL,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
@@ -856,6 +858,11 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT
description={"suggested_value": options.get(CONF_PROMPT)},
default=options[CONF_PROMPT],
): TemplateSelector(),
vol.Required(
CONF_MAX_TOOL_CALL_ITERATIONS,
description={"suggested_value": options.get(CONF_MAX_TOOL_CALL_ITERATIONS)},
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
): int,
vol.Required(
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},

View File

@@ -161,6 +161,8 @@ CONF_REMEMBER_NUM_INTERACTIONS = "remember_num_interactions"
DEFAULT_REMEMBER_NUM_INTERACTIONS = 5
CONF_REMEMBER_CONVERSATION_TIME_MINUTES = "remember_conversation_time_minutes"
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES = 2
CONF_MAX_TOOL_CALL_ITERATIONS = "max_tool_call_iterations"
DEFAULT_MAX_TOOL_CALL_ITERATIONS = 3
CONF_PROMPT_CACHING_ENABLED = "prompt_caching"
DEFAULT_PROMPT_CACHING_ENABLED = False
CONF_PROMPT_CACHING_INTERVAL = "prompt_caching_interval"
@@ -226,42 +228,49 @@ OPTIONS_OVERRIDES = {
CONF_TOOL_CALL_PREFIX: "```homeassistant",
CONF_TOOL_CALL_SUFFIX: "```",
CONF_CONTEXT_LENGTH: 131072,
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
},
"home-3b-v3": {
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
CONF_TOOL_CALL_PREFIX: "```homeassistant",
CONF_TOOL_CALL_SUFFIX: "```",
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
},
"home-3b-v2": {
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
CONF_TOOL_CALL_PREFIX: "```homeassistant",
CONF_TOOL_CALL_SUFFIX: "```",
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
},
"home-3b-v1": {
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
CONF_TOOL_CALL_PREFIX: "```homeassistant",
CONF_TOOL_CALL_SUFFIX: "```",
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
},
"home-1b-v3": {
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
CONF_TOOL_CALL_PREFIX: "```homeassistant",
CONF_TOOL_CALL_SUFFIX: "```",
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
},
"home-1b-v2": {
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
CONF_TOOL_CALL_PREFIX: "```homeassistant",
CONF_TOOL_CALL_SUFFIX: "```",
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
},
"home-1b-v1": {
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
CONF_TOOL_CALL_PREFIX: "```homeassistant",
CONF_TOOL_CALL_SUFFIX: "```",
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
},
"mistral": {
CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_NO_SYSTEM_PROMPT_EXTRAS,

View File

@@ -40,6 +40,7 @@ from .const import (
CONF_REFRESH_SYSTEM_PROMPT,
CONF_REMEMBER_CONVERSATION,
CONF_REMEMBER_NUM_INTERACTIONS,
CONF_MAX_TOOL_CALL_ITERATIONS,
CONF_CONTEXT_LENGTH,
DEFAULT_PROMPT,
DEFAULT_BACKEND_TYPE,
@@ -50,13 +51,9 @@ from .const import (
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_REMEMBER_NUM_INTERACTIONS,
DEFAULT_MAX_TOOL_CALL_ITERATIONS,
DEFAULT_CONTEXT_LENGTH,
DOMAIN,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
ALLOWED_SERVICE_CALL_ARGUMENTS,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
CONF_BACKEND_TYPE,
DEFAULT_BACKEND_TYPE,
)
@@ -96,61 +93,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_e
return True
def parse_raw_tool_call(raw_block: str, tool_name: str, tool_call_id: str, llm_api: llm.APIInstance, user_input: ConversationInput) -> tuple[bool, llm.ToolInput | None, str | None]:
parsed_tool_call: dict = json.loads(raw_block)
if llm_api.api.id == HOME_LLM_API_ID:
schema_to_validate = vol.Schema({
vol.Required('service'): str,
vol.Required('target_device'): str,
vol.Optional('rgb_color'): str,
vol.Optional('brightness'): vol.Coerce(float),
vol.Optional('temperature'): vol.Coerce(float),
vol.Optional('humidity'): vol.Coerce(float),
vol.Optional('fan_mode'): str,
vol.Optional('hvac_mode'): str,
vol.Optional('preset_mode'): str,
vol.Optional('duration'): str,
vol.Optional('item'): str,
})
else:
schema_to_validate = vol.Schema({
vol.Required("name"): str,
vol.Required("arguments"): dict,
})
try:
schema_to_validate(parsed_tool_call)
except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
return False, None, f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
# try to fix certain arguments
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
# make sure brightness is 0-255 and not a percentage
if "brightness" in args_dict and 0.0 < args_dict["brightness"] <= 1.0:
args_dict["brightness"] = int(args_dict["brightness"] * 255)
# convert string "tuple" to a list for RGB colors
if "rgb_color" in args_dict and isinstance(args_dict["rgb_color"], str):
args_dict["rgb_color"] = [ int(x) for x in args_dict["rgb_color"][1:-1].split(",") ]
if llm_api.api.id == HOME_LLM_API_ID:
to_say = parsed_tool_call.pop("to_say", "")
tool_input = llm.ToolInput(
tool_name=SERVICE_TOOL_NAME,
tool_args=parsed_tool_call,
)
else:
to_say = ""
tool_input = llm.ToolInput(
tool_name=parsed_tool_call["name"],
tool_args=parsed_tool_call["arguments"],
)
return True, tool_input, to_say
class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
"""Base Local LLM conversation agent."""
@@ -271,7 +213,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
)
def _warn_context_size(self):
num_entities = len(self._async_get_exposed_entities()[0])
num_entities = len(self._async_get_exposed_entities())
context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
_LOGGER.error("There were too many entities exposed when attempting to generate a response for " +
f"{self.entry.data[CONF_CHAT_MODEL]} and it exceeded the context size for the model. " +
@@ -315,6 +257,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
refresh_system_prompt = self.entry.options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
remember_conversation = self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)
max_tool_call_iterations = self.entry.options.get(CONF_MAX_TOOL_CALL_ITERATIONS, DEFAULT_MAX_TOOL_CALL_ITERATIONS)
llm_api: llm.APIInstance | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
@@ -368,9 +311,8 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
else:
message_history[0] = system_prompt
MAX_TOOL_CALL_ITERATIONS = 3 # FIXME: move to config option
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
for _ in range(MAX_TOOL_CALL_ITERATIONS):
for _ in range(max_tool_call_iterations):
try:
_LOGGER.debug(message_history)
generation_result = await self._async_generate(message_history, user_input, chat_log)
@@ -409,16 +351,28 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
content=f"Ran the following tools:\n{tools_str}"
)
intent_response.async_set_speech(message_history[-1].content)
for i in range(1, len(message_history)):
cur_msg = message_history[-1 * i]
if isinstance(cur_msg, conversation.AssistantContent) and cur_msg.content:
intent_response.async_set_speech(cur_msg.content)
break
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
def _async_get_all_exposed_domains(self) -> list[str]:
"""Gather all exposed domains"""
domains = set()
for state in self.hass.states.async_all():
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
domains.add(state.domain)
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
return list(domains)
def _async_get_exposed_entities(self) -> dict[str, str]:
"""Gather exposed entity states"""
entity_states = {}
domains = set()
entity_registry = er.async_get(self.hass)
device_registry = dr.async_get(self.hass)
area_registry = ar.async_get(self.hass)
@@ -456,9 +410,8 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
attributes["area_name"] = area.name
entity_states[state.entity_id] = attributes
domains.add(state.domain)
return entity_states, list(domains)
return entity_states
def _generate_icl_examples(self, num_examples, entity_names):
entity_names = entity_names[:]
@@ -529,7 +482,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
def _generate_system_prompt(self, prompt_template: str, llm_api: llm.APIInstance | None) -> str:
"""Generate the system prompt with current entity states"""
entities_to_expose, domains = self._async_get_exposed_entities()
entities_to_expose = self._async_get_exposed_entities()
extra_attributes_to_expose = self.entry.options \
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
@@ -596,62 +549,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
"is_alias": True
})
# if llm_api:
# if llm_api.api.id == HOME_LLM_API_ID:
# service_dict = self.hass.services.async_services()
# all_services = []
# scripts_added = False
# for domain in domains:
# if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
# continue
# # scripts show up as individual services
# if domain == "script" and not scripts_added:
# all_services.extend([
# ("script.reload", vol.Schema({}), ""),
# ("script.turn_on", vol.Schema({}), ""),
# ("script.turn_off", vol.Schema({}), ""),
# ("script.toggle", vol.Schema({}), ""),
# ])
# scripts_added = True
# continue
# for name, service in service_dict.get(domain, {}).items():
# if name not in SERVICE_TOOL_ALLOWED_SERVICES:
# continue
# args = flatten_vol_schema(service.schema)
# args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
# service_schema = vol.Schema({
# vol.Optional(arg): str for arg in args_to_expose
# })
# all_services.append((f"{domain}.{name}", service_schema, ""))
# tools = [
# self._format_tool(*tool)
# for tool in all_services
# ]
# else:
# tools = [
# self._format_tool(tool.name, tool.parameters, tool.description)
# for tool in llm_api.tools
# ]
# if self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT) == TOOL_FORMAT_MINIMAL:
# formatted_tools = ", ".join(tools)
# else:
# formatted_tools = json.dumps(tools)
# else:
# tools = ["No tools were provided. If the user requests you interact with a device, tell them you are unable to do so."]
# formatted_tools = tools[0]
render_variables = {
"devices": devices,
"formatted_devices": formatted_devices,
# "tools": tools,
# "formatted_tools": formatted_tools,
"response_examples": []
}

View File

@@ -24,6 +24,11 @@ from voluptuous_openapi import convert
from .const import (
INTEGRATION_VERSION,
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
ALLOWED_SERVICE_CALL_ARGUMENTS,
SERVICE_TOOL_ALLOWED_SERVICES,
SERVICE_TOOL_ALLOWED_DOMAINS,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME
)
from typing import TYPE_CHECKING
@@ -252,15 +257,25 @@ def install_llama_cpp_python(config_dir: str):
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
def get_oai_formatted_tools(llm_api: llm.APIInstance) -> List[ChatCompletionTool]:
result: List[ChatCompletionTool] = [ {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
}
} for tool in llm_api.tools ]
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
if llm_api.api.id == HOME_LLM_API_ID:
result: List[ChatCompletionTool] = [ {
"type": "function",
"function": {
"name": tool["name"],
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
}
} for tool in get_home_llm_tools(llm_api, domains) ]
else:
result: List[ChatCompletionTool] = [ {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
}
} for tool in llm_api.tools ]
return result
@@ -307,3 +322,95 @@ def get_oai_formatted_messages(conversation: Sequence[conversation.Content], use
})
return messages
def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dict[str, Any]]:
service_dict = llm_api.api.hass.services.async_services()
all_services = []
scripts_added = False
for domain in domains:
if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
continue
# scripts show up as individual services
if domain == "script" and not scripts_added:
all_services.extend([
("script.reload", vol.Schema({})),
("script.turn_on", vol.Schema({})),
("script.turn_off", vol.Schema({})),
("script.toggle", vol.Schema({})),
])
scripts_added = True
continue
for name, service in service_dict.get(domain, {}).items():
if name not in SERVICE_TOOL_ALLOWED_SERVICES:
continue
args = flatten_vol_schema(service.schema)
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
service_schema = vol.Schema({
vol.Optional(arg): str for arg in args_to_expose
})
all_services.append((f"{domain}.{name}", service_schema))
tools: List[Dict[str, Any]] = [
{ "name": service[0], "arguments": service[1] } for service in all_services
]
return tools
def parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance) -> tuple[llm.ToolInput | None, str | None]:
parsed_tool_call: dict = json.loads(raw_block)
if llm_api.api.id == HOME_LLM_API_ID:
schema_to_validate = vol.Schema({
vol.Required('service'): str,
vol.Required('target_device'): str,
vol.Optional('rgb_color'): str,
vol.Optional('brightness'): vol.Coerce(float),
vol.Optional('temperature'): vol.Coerce(float),
vol.Optional('humidity'): vol.Coerce(float),
vol.Optional('fan_mode'): str,
vol.Optional('hvac_mode'): str,
vol.Optional('preset_mode'): str,
vol.Optional('duration'): str,
vol.Optional('item'): str,
})
else:
schema_to_validate = vol.Schema({
vol.Required("name"): str,
vol.Required("arguments"): dict,
})
try:
schema_to_validate(parsed_tool_call)
except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
raise # re-raise exception for now to force the LLM to try again
# try to fix certain arguments
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
# make sure brightness is 0-255 and not a percentage
if "brightness" in args_dict and 0.0 < args_dict["brightness"] <= 1.0:
args_dict["brightness"] = int(args_dict["brightness"] * 255)
# convert string "tuple" to a list for RGB colors
if "rgb_color" in args_dict and isinstance(args_dict["rgb_color"], str):
args_dict["rgb_color"] = [ int(x) for x in args_dict["rgb_color"][1:-1].split(",") ]
if llm_api.api.id == HOME_LLM_API_ID:
to_say = parsed_tool_call.pop("to_say", "")
tool_input = llm.ToolInput(
tool_name=SERVICE_TOOL_NAME,
tool_args=parsed_tool_call,
)
else:
to_say = ""
tool_input = llm.ToolInput(
tool_name=parsed_tool_call["name"],
tool_args=parsed_tool_call["arguments"],
)
return tool_input, to_say