mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
actually working llamacpp agent
This commit is contained in:
@@ -14,7 +14,7 @@ from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
|
||||
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools, parse_raw_tool_call
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
@@ -33,7 +33,7 @@ from custom_components.llama_conversation.const import (
|
||||
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
DEFAULT_GENERIC_OPENAI_PATH,
|
||||
)
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult, parse_raw_tool_call
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -154,7 +154,7 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
request_params["messages"] = get_oai_formatted_messages(conversation)
|
||||
|
||||
if llm_api:
|
||||
request_params["tools"] = get_oai_formatted_tools(llm_api)
|
||||
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
|
||||
return self._async_generate_with_parameters(endpoint, True, request_params, llm_api, user_input)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from homeassistant.helpers.event import async_track_state_change, async_call_lat
|
||||
|
||||
from llama_cpp import CreateChatCompletionStreamResponse
|
||||
|
||||
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation, get_oai_formatted_messages, get_oai_formatted_tools
|
||||
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation, get_oai_formatted_messages, get_oai_formatted_tools, parse_raw_tool_call
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_THINKING_PREFIX,
|
||||
CONF_THINKING_SUFFIX,
|
||||
@@ -222,13 +222,13 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
else:
|
||||
self._set_prompt_caching(enabled=False)
|
||||
|
||||
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
||||
def _async_get_exposed_entities(self) -> dict[str, str]:
|
||||
"""Takes the super class function results and sorts the entities with the recently updated at the end"""
|
||||
entities, domains = LocalLLMAgent._async_get_exposed_entities(self)
|
||||
entities = LocalLLMAgent._async_get_exposed_entities(self)
|
||||
|
||||
# ignore sorting if prompt caching is disabled
|
||||
if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
|
||||
return entities, domains
|
||||
return entities
|
||||
|
||||
entity_order = { name: None for name in entities.keys() }
|
||||
entity_order.update(self.last_updated_entities)
|
||||
@@ -250,7 +250,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
for item_name, _ in sorted_items:
|
||||
sorted_entities[item_name] = entities[item_name]
|
||||
|
||||
return sorted_entities, domains
|
||||
return sorted_entities
|
||||
|
||||
def _set_prompt_caching(self, *, enabled=True):
|
||||
if enabled and not self.remove_prompt_caching_listener:
|
||||
@@ -361,7 +361,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
||||
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
|
||||
|
||||
async def _async_generate_completion(self, chat_completion: Iterator[CreateChatCompletionStreamResponse]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
async def _async_generate_completion(self, llm_api: llm.APIInstance | None, chat_completion: Iterator[CreateChatCompletionStreamResponse]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
think_prefix = self.entry.options.get(CONF_THINKING_PREFIX, DEFAULT_THINKING_PREFIX)
|
||||
think_suffix = self.entry.options.get(CONF_THINKING_SUFFIX, DEFAULT_THINKING_SUFFIX)
|
||||
tool_prefix = self.entry.options.get(CONF_TOOL_CALL_PREFIX, DEFAULT_TOOL_CALL_PREFIX)
|
||||
@@ -377,6 +377,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
in_thinking = False
|
||||
in_tool_call = False
|
||||
tool_content = ""
|
||||
last_5_tokens = []
|
||||
while chunk := await self.hass.async_add_executor_job(next_token):
|
||||
content = chunk["choices"][0]["delta"].get("content")
|
||||
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
|
||||
@@ -389,37 +390,44 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
tool_calls=None
|
||||
)
|
||||
if content:
|
||||
if think_prefix in content and not in_thinking:
|
||||
in_thinking = True
|
||||
elif think_suffix in content and in_thinking:
|
||||
in_thinking = False
|
||||
content = content.replace(think_suffix, "").strip()
|
||||
elif tool_prefix in content and not in_tool_call:
|
||||
in_tool_call = True
|
||||
elif tool_suffix in content and in_tool_call:
|
||||
in_tool_call = False
|
||||
tool_call = json.loads(tool_content.strip().removeprefix(tool_prefix).removesuffix(tool_suffix))
|
||||
result.tool_calls = [
|
||||
llm.ToolInput(
|
||||
tool_name=tool_call["name"],
|
||||
tool_args=tool_call["arguments"]
|
||||
)
|
||||
]
|
||||
|
||||
content = None
|
||||
last_5_tokens.append(content)
|
||||
if len(last_5_tokens) > 5:
|
||||
last_5_tokens.pop(0)
|
||||
|
||||
result.response = content
|
||||
potential_block = "".join(last_5_tokens)
|
||||
|
||||
if in_tool_call:
|
||||
tool_content += content
|
||||
|
||||
if think_prefix in potential_block and not in_thinking:
|
||||
in_thinking = True
|
||||
last_5_tokens.clear()
|
||||
elif think_suffix in potential_block and in_thinking:
|
||||
in_thinking = False
|
||||
content = content.replace(think_suffix, "").strip()
|
||||
elif tool_prefix in potential_block and not in_tool_call:
|
||||
in_tool_call = True
|
||||
last_5_tokens.clear()
|
||||
elif tool_suffix in potential_block and in_tool_call:
|
||||
in_tool_call = False
|
||||
_LOGGER.debug("Tool content: %s", tool_content)
|
||||
tool_call, to_say = parse_raw_tool_call(tool_content.strip().removeprefix(tool_prefix).removesuffix(tool_suffix), llm_api)
|
||||
|
||||
if tool_call:
|
||||
result.tool_calls = [tool_call]
|
||||
if to_say:
|
||||
content = to_say
|
||||
else:
|
||||
content = None
|
||||
|
||||
result.response = content
|
||||
|
||||
if tool_calls:
|
||||
result.tool_calls = [
|
||||
llm.ToolInput(
|
||||
tool_name=str(tool_calls[0]["function"]["name"]),
|
||||
tool_args=json.loads(tool_calls[0]["function"]["arguments"])
|
||||
)
|
||||
]
|
||||
result.tool_calls = [llm.ToolInput(
|
||||
tool_name=str(tool_call["function"]["name"]),
|
||||
tool_args=json.loads(tool_call["function"]["arguments"])
|
||||
) for tool_call in tool_calls ]
|
||||
|
||||
if not in_thinking and not in_tool_call:
|
||||
yield result
|
||||
@@ -436,7 +444,6 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
_LOGGER.debug(f"Options: {self.entry.options}")
|
||||
|
||||
# TODO: re-enable the context length check
|
||||
# with self.model_lock:
|
||||
# # FIXME: use the high level API so we can use the built-in prompt formatting
|
||||
# input_tokens = self.llm.tokenize(
|
||||
# prompt.encode(), add_bos=False
|
||||
@@ -444,7 +451,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
|
||||
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
# if len(input_tokens) >= context_len:
|
||||
# num_entities = len(self._async_get_exposed_entities()[0])
|
||||
# num_entities = len(self._async_get_exposed_entities())
|
||||
# context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
# self._warn_context_size()
|
||||
# raise Exception(f"The model failed to produce a result because too many devices are exposed ({num_entities} devices) for the context size ({context_size} tokens)!")
|
||||
@@ -456,11 +463,11 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
messages = get_oai_formatted_messages(conversation)
|
||||
tools = None
|
||||
if llm_api:
|
||||
tools = get_oai_formatted_tools(llm_api)
|
||||
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
|
||||
_LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...")
|
||||
|
||||
return self._async_generate_completion(self.llm.create_chat_completion(
|
||||
return self._async_generate_completion(llm_api, self.llm.create_chat_completion(
|
||||
messages,
|
||||
tools=tools,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -67,6 +67,7 @@ from .const import (
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
@@ -113,6 +114,7 @@ from .const import (
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
@@ -856,6 +858,11 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT
|
||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
||||
default=options[CONF_PROMPT],
|
||||
): TemplateSelector(),
|
||||
vol.Required(
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOOL_CALL_ITERATIONS)},
|
||||
default=DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
||||
): int,
|
||||
vol.Required(
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
|
||||
|
||||
@@ -161,6 +161,8 @@ CONF_REMEMBER_NUM_INTERACTIONS = "remember_num_interactions"
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS = 5
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES = "remember_conversation_time_minutes"
|
||||
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES = 2
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS = "max_tool_call_iterations"
|
||||
DEFAULT_MAX_TOOL_CALL_ITERATIONS = 3
|
||||
CONF_PROMPT_CACHING_ENABLED = "prompt_caching"
|
||||
DEFAULT_PROMPT_CACHING_ENABLED = False
|
||||
CONF_PROMPT_CACHING_INTERVAL = "prompt_caching_interval"
|
||||
@@ -226,42 +228,49 @@ OPTIONS_OVERRIDES = {
|
||||
CONF_TOOL_CALL_PREFIX: "```homeassistant",
|
||||
CONF_TOOL_CALL_SUFFIX: "```",
|
||||
CONF_CONTEXT_LENGTH: 131072,
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
|
||||
},
|
||||
"home-3b-v3": {
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
|
||||
CONF_TOOL_CALL_PREFIX: "```homeassistant",
|
||||
CONF_TOOL_CALL_SUFFIX: "```",
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
|
||||
},
|
||||
"home-3b-v2": {
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
|
||||
CONF_TOOL_CALL_PREFIX: "```homeassistant",
|
||||
CONF_TOOL_CALL_SUFFIX: "```",
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
|
||||
},
|
||||
"home-3b-v1": {
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
|
||||
CONF_TOOL_CALL_PREFIX: "```homeassistant",
|
||||
CONF_TOOL_CALL_SUFFIX: "```",
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
|
||||
},
|
||||
"home-1b-v3": {
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
|
||||
CONF_TOOL_CALL_PREFIX: "```homeassistant",
|
||||
CONF_TOOL_CALL_SUFFIX: "```",
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
|
||||
},
|
||||
"home-1b-v2": {
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
|
||||
CONF_TOOL_CALL_PREFIX: "```homeassistant",
|
||||
CONF_TOOL_CALL_SUFFIX: "```",
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
|
||||
},
|
||||
"home-1b-v1": {
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False,
|
||||
CONF_TOOL_CALL_PREFIX: "```homeassistant",
|
||||
CONF_TOOL_CALL_SUFFIX: "```",
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS: 1,
|
||||
},
|
||||
"mistral": {
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_NO_SYSTEM_PROMPT_EXTRAS,
|
||||
|
||||
@@ -40,6 +40,7 @@ from .const import (
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_MAX_TOOL_CALL_ITERATIONS,
|
||||
CONF_CONTEXT_LENGTH,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
@@ -50,13 +51,9 @@ from .const import (
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_MAX_TOOL_CALL_ITERATIONS,
|
||||
DEFAULT_CONTEXT_LENGTH,
|
||||
DOMAIN,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME,
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
CONF_BACKEND_TYPE,
|
||||
DEFAULT_BACKEND_TYPE,
|
||||
)
|
||||
@@ -96,61 +93,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_e
|
||||
|
||||
return True
|
||||
|
||||
def parse_raw_tool_call(raw_block: str, tool_name: str, tool_call_id: str, llm_api: llm.APIInstance, user_input: ConversationInput) -> tuple[bool, llm.ToolInput | None, str | None]:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
schema_to_validate = vol.Schema({
|
||||
vol.Required('service'): str,
|
||||
vol.Required('target_device'): str,
|
||||
vol.Optional('rgb_color'): str,
|
||||
vol.Optional('brightness'): vol.Coerce(float),
|
||||
vol.Optional('temperature'): vol.Coerce(float),
|
||||
vol.Optional('humidity'): vol.Coerce(float),
|
||||
vol.Optional('fan_mode'): str,
|
||||
vol.Optional('hvac_mode'): str,
|
||||
vol.Optional('preset_mode'): str,
|
||||
vol.Optional('duration'): str,
|
||||
vol.Optional('item'): str,
|
||||
})
|
||||
else:
|
||||
schema_to_validate = vol.Schema({
|
||||
vol.Required("name"): str,
|
||||
vol.Required("arguments"): dict,
|
||||
})
|
||||
|
||||
try:
|
||||
schema_to_validate(parsed_tool_call)
|
||||
except vol.Error as ex:
|
||||
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
|
||||
return False, None, f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
|
||||
|
||||
# try to fix certain arguments
|
||||
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
|
||||
|
||||
# make sure brightness is 0-255 and not a percentage
|
||||
if "brightness" in args_dict and 0.0 < args_dict["brightness"] <= 1.0:
|
||||
args_dict["brightness"] = int(args_dict["brightness"] * 255)
|
||||
|
||||
# convert string "tuple" to a list for RGB colors
|
||||
if "rgb_color" in args_dict and isinstance(args_dict["rgb_color"], str):
|
||||
args_dict["rgb_color"] = [ int(x) for x in args_dict["rgb_color"][1:-1].split(",") ]
|
||||
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
to_say = parsed_tool_call.pop("to_say", "")
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=SERVICE_TOOL_NAME,
|
||||
tool_args=parsed_tool_call,
|
||||
)
|
||||
else:
|
||||
to_say = ""
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=parsed_tool_call["name"],
|
||||
tool_args=parsed_tool_call["arguments"],
|
||||
)
|
||||
|
||||
return True, tool_input, to_say
|
||||
|
||||
class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
"""Base Local LLM conversation agent."""
|
||||
|
||||
@@ -271,7 +213,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
)
|
||||
|
||||
def _warn_context_size(self):
|
||||
num_entities = len(self._async_get_exposed_entities()[0])
|
||||
num_entities = len(self._async_get_exposed_entities())
|
||||
context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
_LOGGER.error("There were too many entities exposed when attempting to generate a response for " +
|
||||
f"{self.entry.data[CONF_CHAT_MODEL]} and it exceeded the context size for the model. " +
|
||||
@@ -315,6 +257,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
refresh_system_prompt = self.entry.options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
|
||||
remember_conversation = self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
|
||||
remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)
|
||||
max_tool_call_iterations = self.entry.options.get(CONF_MAX_TOOL_CALL_ITERATIONS, DEFAULT_MAX_TOOL_CALL_ITERATIONS)
|
||||
|
||||
llm_api: llm.APIInstance | None = None
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
@@ -368,9 +311,8 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
else:
|
||||
message_history[0] = system_prompt
|
||||
|
||||
MAX_TOOL_CALL_ITERATIONS = 3 # FIXME: move to config option
|
||||
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
|
||||
for _ in range(MAX_TOOL_CALL_ITERATIONS):
|
||||
for _ in range(max_tool_call_iterations):
|
||||
try:
|
||||
_LOGGER.debug(message_history)
|
||||
generation_result = await self._async_generate(message_history, user_input, chat_log)
|
||||
@@ -409,16 +351,28 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
content=f"Ran the following tools:\n{tools_str}"
|
||||
)
|
||||
|
||||
intent_response.async_set_speech(message_history[-1].content)
|
||||
for i in range(1, len(message_history)):
|
||||
cur_msg = message_history[-1 * i]
|
||||
if isinstance(cur_msg, conversation.AssistantContent) and cur_msg.content:
|
||||
intent_response.async_set_speech(cur_msg.content)
|
||||
break
|
||||
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
def _async_get_all_exposed_domains(self) -> list[str]:
|
||||
"""Gather all exposed domains"""
|
||||
domains = set()
|
||||
for state in self.hass.states.async_all():
|
||||
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
|
||||
domains.add(state.domain)
|
||||
|
||||
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
||||
return list(domains)
|
||||
|
||||
def _async_get_exposed_entities(self) -> dict[str, str]:
|
||||
"""Gather exposed entity states"""
|
||||
entity_states = {}
|
||||
domains = set()
|
||||
entity_registry = er.async_get(self.hass)
|
||||
device_registry = dr.async_get(self.hass)
|
||||
area_registry = ar.async_get(self.hass)
|
||||
@@ -456,9 +410,8 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
attributes["area_name"] = area.name
|
||||
|
||||
entity_states[state.entity_id] = attributes
|
||||
domains.add(state.domain)
|
||||
|
||||
return entity_states, list(domains)
|
||||
return entity_states
|
||||
|
||||
def _generate_icl_examples(self, num_examples, entity_names):
|
||||
entity_names = entity_names[:]
|
||||
@@ -529,7 +482,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
def _generate_system_prompt(self, prompt_template: str, llm_api: llm.APIInstance | None) -> str:
|
||||
"""Generate the system prompt with current entity states"""
|
||||
entities_to_expose, domains = self._async_get_exposed_entities()
|
||||
entities_to_expose = self._async_get_exposed_entities()
|
||||
|
||||
extra_attributes_to_expose = self.entry.options \
|
||||
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
|
||||
@@ -596,62 +549,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
"is_alias": True
|
||||
})
|
||||
|
||||
# if llm_api:
|
||||
# if llm_api.api.id == HOME_LLM_API_ID:
|
||||
# service_dict = self.hass.services.async_services()
|
||||
# all_services = []
|
||||
# scripts_added = False
|
||||
# for domain in domains:
|
||||
# if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
|
||||
# continue
|
||||
|
||||
# # scripts show up as individual services
|
||||
# if domain == "script" and not scripts_added:
|
||||
# all_services.extend([
|
||||
# ("script.reload", vol.Schema({}), ""),
|
||||
# ("script.turn_on", vol.Schema({}), ""),
|
||||
# ("script.turn_off", vol.Schema({}), ""),
|
||||
# ("script.toggle", vol.Schema({}), ""),
|
||||
# ])
|
||||
# scripts_added = True
|
||||
# continue
|
||||
|
||||
# for name, service in service_dict.get(domain, {}).items():
|
||||
# if name not in SERVICE_TOOL_ALLOWED_SERVICES:
|
||||
# continue
|
||||
|
||||
# args = flatten_vol_schema(service.schema)
|
||||
# args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
|
||||
# service_schema = vol.Schema({
|
||||
# vol.Optional(arg): str for arg in args_to_expose
|
||||
# })
|
||||
|
||||
# all_services.append((f"{domain}.{name}", service_schema, ""))
|
||||
|
||||
# tools = [
|
||||
# self._format_tool(*tool)
|
||||
# for tool in all_services
|
||||
# ]
|
||||
|
||||
# else:
|
||||
# tools = [
|
||||
# self._format_tool(tool.name, tool.parameters, tool.description)
|
||||
# for tool in llm_api.tools
|
||||
# ]
|
||||
|
||||
# if self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT) == TOOL_FORMAT_MINIMAL:
|
||||
# formatted_tools = ", ".join(tools)
|
||||
# else:
|
||||
# formatted_tools = json.dumps(tools)
|
||||
# else:
|
||||
# tools = ["No tools were provided. If the user requests you interact with a device, tell them you are unable to do so."]
|
||||
# formatted_tools = tools[0]
|
||||
|
||||
render_variables = {
|
||||
"devices": devices,
|
||||
"formatted_devices": formatted_devices,
|
||||
# "tools": tools,
|
||||
# "formatted_tools": formatted_tools,
|
||||
"response_examples": []
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,11 @@ from voluptuous_openapi import convert
|
||||
from .const import (
|
||||
INTEGRATION_VERSION,
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
|
||||
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
||||
SERVICE_TOOL_ALLOWED_SERVICES,
|
||||
SERVICE_TOOL_ALLOWED_DOMAINS,
|
||||
HOME_LLM_API_ID,
|
||||
SERVICE_TOOL_NAME
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -252,15 +257,25 @@ def install_llama_cpp_python(config_dir: str):
|
||||
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
|
||||
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
|
||||
|
||||
def get_oai_formatted_tools(llm_api: llm.APIInstance) -> List[ChatCompletionTool]:
|
||||
result: List[ChatCompletionTool] = [ {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in llm_api.tools ]
|
||||
def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[ChatCompletionTool]:
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
result: List[ChatCompletionTool] = [ {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"parameters": convert(tool["arguments"], custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in get_home_llm_tools(llm_api, domains) ]
|
||||
|
||||
else:
|
||||
result: List[ChatCompletionTool] = [ {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": convert(tool.parameters, custom_serializer=llm_api.custom_serializer)
|
||||
}
|
||||
} for tool in llm_api.tools ]
|
||||
|
||||
return result
|
||||
|
||||
@@ -307,3 +322,95 @@ def get_oai_formatted_messages(conversation: Sequence[conversation.Content], use
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dict[str, Any]]:
|
||||
service_dict = llm_api.api.hass.services.async_services()
|
||||
all_services = []
|
||||
scripts_added = False
|
||||
for domain in domains:
|
||||
if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
|
||||
continue
|
||||
|
||||
# scripts show up as individual services
|
||||
if domain == "script" and not scripts_added:
|
||||
all_services.extend([
|
||||
("script.reload", vol.Schema({})),
|
||||
("script.turn_on", vol.Schema({})),
|
||||
("script.turn_off", vol.Schema({})),
|
||||
("script.toggle", vol.Schema({})),
|
||||
])
|
||||
scripts_added = True
|
||||
continue
|
||||
|
||||
for name, service in service_dict.get(domain, {}).items():
|
||||
if name not in SERVICE_TOOL_ALLOWED_SERVICES:
|
||||
continue
|
||||
|
||||
args = flatten_vol_schema(service.schema)
|
||||
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
|
||||
service_schema = vol.Schema({
|
||||
vol.Optional(arg): str for arg in args_to_expose
|
||||
})
|
||||
|
||||
all_services.append((f"{domain}.{name}", service_schema))
|
||||
|
||||
tools: List[Dict[str, Any]] = [
|
||||
{ "name": service[0], "arguments": service[1] } for service in all_services
|
||||
]
|
||||
|
||||
return tools
|
||||
|
||||
def parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance) -> tuple[llm.ToolInput | None, str | None]:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
schema_to_validate = vol.Schema({
|
||||
vol.Required('service'): str,
|
||||
vol.Required('target_device'): str,
|
||||
vol.Optional('rgb_color'): str,
|
||||
vol.Optional('brightness'): vol.Coerce(float),
|
||||
vol.Optional('temperature'): vol.Coerce(float),
|
||||
vol.Optional('humidity'): vol.Coerce(float),
|
||||
vol.Optional('fan_mode'): str,
|
||||
vol.Optional('hvac_mode'): str,
|
||||
vol.Optional('preset_mode'): str,
|
||||
vol.Optional('duration'): str,
|
||||
vol.Optional('item'): str,
|
||||
})
|
||||
else:
|
||||
schema_to_validate = vol.Schema({
|
||||
vol.Required("name"): str,
|
||||
vol.Required("arguments"): dict,
|
||||
})
|
||||
|
||||
try:
|
||||
schema_to_validate(parsed_tool_call)
|
||||
except vol.Error as ex:
|
||||
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
|
||||
raise # re-raise exception for now to force the LLM to try again
|
||||
|
||||
# try to fix certain arguments
|
||||
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
|
||||
|
||||
# make sure brightness is 0-255 and not a percentage
|
||||
if "brightness" in args_dict and 0.0 < args_dict["brightness"] <= 1.0:
|
||||
args_dict["brightness"] = int(args_dict["brightness"] * 255)
|
||||
|
||||
# convert string "tuple" to a list for RGB colors
|
||||
if "rgb_color" in args_dict and isinstance(args_dict["rgb_color"], str):
|
||||
args_dict["rgb_color"] = [ int(x) for x in args_dict["rgb_color"][1:-1].split(",") ]
|
||||
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
to_say = parsed_tool_call.pop("to_say", "")
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=SERVICE_TOOL_NAME,
|
||||
tool_args=parsed_tool_call,
|
||||
)
|
||||
else:
|
||||
to_say = ""
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=parsed_tool_call["name"],
|
||||
tool_args=parsed_tool_call["arguments"],
|
||||
)
|
||||
|
||||
return tool_input, to_say
|
||||
Reference in New Issue
Block a user