mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
backends should all work now
This commit is contained in:
3
TODO.md
3
TODO.md
@@ -1,7 +1,7 @@
|
||||
# TODO
|
||||
- [x] proper tool calling support
|
||||
- [ ] fix old GGUFs to support tool calling
|
||||
- [ ] home assistant component text streaming support
|
||||
- [x] home assistant component text streaming support
|
||||
- [ ] new model based on qwen3 0.6b
|
||||
- [ ] new model based on gemma3 270m
|
||||
- [ ] support AI task API
|
||||
@@ -59,6 +59,7 @@
|
||||
- [x] generic openai responses backend
|
||||
- [ ] fix and re-upload all compatible old models (+ upload all original safetensors)
|
||||
- [x] config entry migration function
|
||||
- [ ] re-write setup guide
|
||||
|
||||
## more complicated ideas
|
||||
- [ ] "context requests"
|
||||
|
||||
@@ -36,6 +36,7 @@ from custom_components.llama_conversation.const import (
|
||||
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
DEFAULT_GENERIC_OPENAI_PATH,
|
||||
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
|
||||
RECOMMENDED_CHAT_MODELS,
|
||||
)
|
||||
from custom_components.llama_conversation.entity import TextGenerationResult, LocalLLMClient
|
||||
|
||||
@@ -100,14 +101,18 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
async with session.get(
|
||||
f"{self.api_host}/models",
|
||||
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
models_result = await response.json()
|
||||
try:
|
||||
session = async_get_clientsession(self.hass)
|
||||
async with session.get(
|
||||
f"{self.api_host}/models",
|
||||
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
models_result = await response.json()
|
||||
except:
|
||||
_LOGGER.exception("Failed to get available models")
|
||||
return RECOMMENDED_CHAT_MODELS
|
||||
|
||||
return [ model["id"] for model in models_result["data"] ]
|
||||
|
||||
@@ -152,7 +157,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List]], None]:
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
|
||||
response = None
|
||||
chunk = None
|
||||
try:
|
||||
@@ -172,20 +177,22 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
break
|
||||
|
||||
if chunk and chunk.strip():
|
||||
yield self._extract_response(json.loads(chunk), llm_api)
|
||||
to_say, tool_calls = self._extract_response(json.loads(chunk), llm_api, user_input)
|
||||
if to_say or tool_calls:
|
||||
yield to_say, tool_calls
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HomeAssistantError("The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.") from err
|
||||
except aiohttp.ClientError as err:
|
||||
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
|
||||
|
||||
return self._async_parse_completion(llm_api, entity_options, anext_token=anext_token())
|
||||
return self._async_parse_completion(llm_api, user_input, entity_options, anext_token=anext_token())
|
||||
|
||||
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
endpoint = "/chat/completions"
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None) -> Tuple[Optional[str], Optional[List]]:
|
||||
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
|
||||
if len(response_json["choices"]) == 0: # finished
|
||||
return None, None
|
||||
|
||||
@@ -200,7 +207,7 @@ class GenericOpenAIAPIClient(LocalLLMClient):
|
||||
tool_calls = []
|
||||
for call in choice["delta"]["tool_calls"]:
|
||||
tool_args, to_say = parse_raw_tool_call(
|
||||
call["function"], llm_api)
|
||||
call["function"], llm_api, user_input)
|
||||
|
||||
if tool_args:
|
||||
tool_calls.append(tool_args)
|
||||
|
||||
@@ -18,7 +18,7 @@ from homeassistant.exceptions import ConfigEntryError, HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
||||
|
||||
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation, get_oai_formatted_messages, get_oai_formatted_tools, parse_raw_tool_call
|
||||
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation, get_oai_formatted_messages, get_oai_formatted_tools
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_INSTALLED_LLAMACPP_VERSION,
|
||||
CONF_CHAT_MODEL,
|
||||
@@ -475,5 +475,5 @@ class LlamaCppClient(LocalLLMClient):
|
||||
tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
|
||||
yield content, tool_calls
|
||||
|
||||
return self._async_parse_completion(llm_api, entity_options, next_token=next_token())
|
||||
return self._async_parse_completion(llm_api, user_input, entity_options, next_token=next_token())
|
||||
|
||||
|
||||
@@ -9,10 +9,9 @@ import logging
|
||||
from typing import Optional, Tuple, Dict, List, Any, AsyncGenerator
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
@@ -113,7 +112,7 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
|
||||
return [x["name"] for x in models_result["models"]]
|
||||
|
||||
def _extract_response(self, response_json: Dict) -> TextGenerationResult:
|
||||
def _extract_response(self, response_json: Dict) -> Tuple[Optional[str], Optional[List[llm.ToolInput]]]:
|
||||
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
|
||||
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
@@ -128,39 +127,16 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
else:
|
||||
response = response_json["message"]["content"]
|
||||
tool_calls = response_json["message"].get("tool_calls")
|
||||
raw_tool_calls = response_json["message"].get("tool_calls")
|
||||
tool_calls = [ llm.ToolInput(tool_name=x["function"]["name"], tool_args=x["function"]["arguments"]) for x in raw_tool_calls] if raw_tool_calls else None
|
||||
stop_reason = response_json.get("done_reason")
|
||||
|
||||
return TextGenerationResult(
|
||||
response=response, tool_calls=tool_calls, stop_reason=stop_reason, response_streamed=True
|
||||
)
|
||||
|
||||
async def _async_generate_with_parameters(self, endpoint: str, request_params: dict[str, Any], headers: dict[str, Any], timeout: int) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
session = async_get_clientsession(self.hass)
|
||||
response = None
|
||||
_LOGGER.debug(f"{response=} {tool_calls=}")
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
while True:
|
||||
chunk = await response.content.readline()
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
yield self._extract_response(json.loads(chunk))
|
||||
except asyncio.TimeoutError:
|
||||
yield TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.debug(f"Err was: {err}")
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {response}")
|
||||
yield TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
|
||||
return response, tool_calls
|
||||
# return TextGenerationResult(
|
||||
# response=response, tool_calls=tool_calls, stop_reason=stop_reason, response_streamed=True
|
||||
# )
|
||||
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput, entity_options: Dict[str, Any]) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
model_name = entity_options.get(CONF_CHAT_MODEL, "")
|
||||
@@ -195,10 +171,35 @@ class OllamaAPIClient(LocalLLMClient):
|
||||
request_params["tools"] = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
||||
|
||||
endpoint = "/api/chat"
|
||||
request_params["messages"] = get_oai_formatted_messages(conversation)
|
||||
request_params["messages"] = get_oai_formatted_messages(conversation, tool_args_to_str=False)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
return self._async_generate_with_parameters(endpoint, request_params, headers, timeout)
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[llm.ToolInput]]], None]:
|
||||
response = None
|
||||
chunk = None
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
while True:
|
||||
chunk = await response.content.readline()
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
yield self._extract_response(json.loads(chunk))
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HomeAssistantError("The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.") from err
|
||||
except aiohttp.ClientError as err:
|
||||
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
|
||||
|
||||
return self._async_parse_completion(llm_api, user_input, entity_options, anext_token=anext_token())
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
@@ -553,7 +552,12 @@ def local_llama_config_option_schema(
|
||||
CONF_TOOL_CALL_SUFFIX,
|
||||
description={"suggested_value": options.get(CONF_TOOL_CALL_SUFFIX)},
|
||||
default=DEFAULT_TOOL_CALL_SUFFIX,
|
||||
): str
|
||||
): str,
|
||||
vol.Required(
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)},
|
||||
default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING
|
||||
): bool,
|
||||
}
|
||||
|
||||
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
||||
@@ -593,7 +597,7 @@ def local_llama_config_option_schema(
|
||||
CONF_CONTEXT_LENGTH,
|
||||
description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
|
||||
default=DEFAULT_CONTEXT_LENGTH,
|
||||
): NumberSelector(NumberSelectorConfig(min=512, max=32768, step=1)),
|
||||
): NumberSelector(NumberSelectorConfig(min=512, max=1_048_576, step=512)),
|
||||
vol.Required(
|
||||
CONF_LLAMACPP_BATCH_SIZE,
|
||||
description={"suggested_value": options.get(CONF_LLAMACPP_BATCH_SIZE)},
|
||||
@@ -631,7 +635,7 @@ def local_llama_config_option_schema(
|
||||
CONF_CONTEXT_LENGTH,
|
||||
description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
|
||||
default=DEFAULT_CONTEXT_LENGTH,
|
||||
): NumberSelector(NumberSelectorConfig(min=512, max=32768, step=1)),
|
||||
): NumberSelector(NumberSelectorConfig(min=512, max=1_048_576, step=512)),
|
||||
vol.Required(
|
||||
CONF_TOP_K,
|
||||
description={"suggested_value": options.get(CONF_TOP_K)},
|
||||
@@ -684,11 +688,6 @@ def local_llama_config_option_schema(
|
||||
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
||||
default=DEFAULT_REQUEST_TIMEOUT,
|
||||
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
|
||||
vol.Required(
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)},
|
||||
default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING
|
||||
): bool,
|
||||
})
|
||||
elif backend_type in BACKEND_TYPE_GENERIC_OPENAI_RESPONSES:
|
||||
del result[CONF_REMEMBER_NUM_INTERACTIONS]
|
||||
@@ -741,14 +740,14 @@ def local_llama_config_option_schema(
|
||||
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
||||
default=DEFAULT_REQUEST_TIMEOUT,
|
||||
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
|
||||
vol.Required(
|
||||
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
||||
description={"suggested_value": options.get(CONF_ENABLE_LEGACY_TOOL_CALLING)},
|
||||
default=DEFAULT_ENABLE_LEGACY_TOOL_CALLING
|
||||
): bool,
|
||||
})
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
result.update({
|
||||
vol.Required(
|
||||
CONF_CONTEXT_LENGTH,
|
||||
description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
|
||||
default=DEFAULT_CONTEXT_LENGTH,
|
||||
): NumberSelector(NumberSelectorConfig(min=512, max=1_048_576, step=512)),
|
||||
vol.Required(
|
||||
CONF_TOP_K,
|
||||
description={"suggested_value": options.get(CONF_TOP_K)},
|
||||
@@ -909,7 +908,7 @@ class LocalLLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
description_placeholders = {}
|
||||
entry = self._get_entry()
|
||||
|
||||
backend_type = entry.options.get(CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE)
|
||||
backend_type = entry.data[CONF_BACKEND_TYPE]
|
||||
if backend_type == BACKEND_TYPE_LLAMA_CPP:
|
||||
schema = STEP_LOCAL_MODEL_SELECTION_DATA_SCHEMA()
|
||||
else:
|
||||
|
||||
@@ -20,7 +20,7 @@ from homeassistant.helpers import intent, template, entity_registry as er, llm,
|
||||
area_registry as ar, device_registry as dr, entity
|
||||
from homeassistant.util import color
|
||||
|
||||
from .utils import closest_color, parse_raw_tool_call, flatten_vol_schema
|
||||
from .utils import closest_color, parse_raw_tool_call, flatten_vol_schema, MalformedToolCallException
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_SELECTED_LANGUAGE,
|
||||
@@ -261,35 +261,51 @@ class LocalLLMClient:
|
||||
message_history[0] = system_prompt
|
||||
|
||||
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
|
||||
for _ in range(max_tool_call_iterations):
|
||||
try:
|
||||
_LOGGER.debug(message_history)
|
||||
generation_result = await self._async_generate(message_history, user_input, chat_log, entity_options)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("There was a problem talking to the backend")
|
||||
for idx in range(max_tool_call_iterations):
|
||||
generation_result = await self._async_generate(message_history, user_input, chat_log, entity_options)
|
||||
|
||||
last_generation_had_tool_calls = False
|
||||
while True:
|
||||
try:
|
||||
message = await anext(generation_result)
|
||||
message_history.append(message)
|
||||
if message.role == "assistant":
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
last_generation_had_tool_calls = True
|
||||
else:
|
||||
last_generation_had_tool_calls = False
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
except MalformedToolCallException as err:
|
||||
message_history.extend(err.as_tool_messages())
|
||||
last_generation_had_tool_calls = True
|
||||
except Exception as err:
|
||||
_LOGGER.exception("There was a problem talking to the backend")
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
f"Sorry, there was a problem talking to the backend: {repr(err)}",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# If not multi-turn, break after first tool call
|
||||
# also break if no tool calls were made
|
||||
if not last_generation_had_tool_calls:
|
||||
break
|
||||
|
||||
# return an error if we run out of attempt without succeeding
|
||||
if idx == max_tool_call_iterations - 1:
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
f"Sorry, there was a problem talking to the backend: {repr(err)}",
|
||||
f"Sorry, I ran out of attempts to handle your request",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
last_message_had_tool_calls = False
|
||||
async for message in generation_result:
|
||||
message_history.append(message)
|
||||
if message.role == "assistant":
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
last_message_had_tool_calls = True
|
||||
else:
|
||||
last_message_had_tool_calls = False
|
||||
|
||||
# If not multi-turn, break after first tool call
|
||||
# also break if no tool calls were made
|
||||
if not last_message_had_tool_calls:
|
||||
break
|
||||
|
||||
|
||||
# generate intent response to Home Assistant
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
if len(tool_calls) > 0:
|
||||
@@ -300,18 +316,25 @@ class LocalLLMClient:
|
||||
content=f"Ran the following tools:\n{tools_str}"
|
||||
)
|
||||
|
||||
has_speech = False
|
||||
for i in range(1, len(message_history)):
|
||||
cur_msg = message_history[-1 * i]
|
||||
if isinstance(cur_msg, conversation.AssistantContent) and cur_msg.content:
|
||||
intent_response.async_set_speech(cur_msg.content)
|
||||
has_speech = True
|
||||
break
|
||||
|
||||
if not has_speech:
|
||||
intent_response.async_set_speech("I don't have anything to say right now")
|
||||
_LOGGER.debug(message_history)
|
||||
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
async def _async_parse_completion(
|
||||
self, llm_api: llm.APIInstance | None,
|
||||
user_input: ConversationInput,
|
||||
entity_options: Dict[str, Any],
|
||||
next_token: Optional[Generator[Tuple[Optional[str], Optional[List]]]] = None,
|
||||
anext_token: Optional[AsyncGenerator[Tuple[Optional[str], Optional[List]]]] = None,
|
||||
@@ -378,7 +401,7 @@ class LocalLLMClient:
|
||||
if not llm_api:
|
||||
_LOGGER.warning("Model attempted to call a tool but no LLM API was provided, ignoring tool calls")
|
||||
else:
|
||||
tool_call, to_say = parse_raw_tool_call(tool_content.strip().removeprefix(tool_prefix).removesuffix(tool_suffix), llm_api)
|
||||
tool_call, to_say = parse_raw_tool_call(tool_content.strip().removeprefix(tool_prefix).removesuffix(tool_suffix), llm_api, user_input)
|
||||
_LOGGER.debug("Tool call parsed: %s", tool_call)
|
||||
|
||||
if tool_call:
|
||||
@@ -396,11 +419,14 @@ class LocalLLMClient:
|
||||
else:
|
||||
result.tool_calls = []
|
||||
for raw_tool_call in tool_calls:
|
||||
tool_input, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api)
|
||||
if tool_input:
|
||||
result.tool_calls.append(tool_input)
|
||||
if to_say:
|
||||
result.response = to_say
|
||||
if isinstance(raw_tool_call, llm.ToolInput):
|
||||
result.tool_calls.append(raw_tool_call)
|
||||
else:
|
||||
tool_input, to_say = parse_raw_tool_call(raw_tool_call["function"], llm_api, user_input)
|
||||
if tool_input:
|
||||
result.tool_calls.append(tool_input)
|
||||
if to_say:
|
||||
result.response = to_say
|
||||
|
||||
if not in_thinking and not in_tool_call:
|
||||
yield result
|
||||
|
||||
@@ -51,6 +51,24 @@ class MissingQuantizationException(Exception):
|
||||
self.missing_quant = missing_quant
|
||||
self.available_quants = available_quants
|
||||
|
||||
class MalformedToolCallException(Exception):
|
||||
def __init__(self, agent_id: str, tool_call_id: str, tool_name: str, tool_args: str, error_msg: str):
|
||||
self.agent_id = agent_id
|
||||
self.tool_call_id = tool_call_id
|
||||
self.tool_name = tool_name
|
||||
self.tool_args = tool_args
|
||||
self.error_msg = error_msg
|
||||
|
||||
def as_tool_messages(self) -> Sequence[conversation.Content]:
|
||||
return [
|
||||
conversation.AssistantContent(
|
||||
self.agent_id, tool_calls=[llm.ToolInput(self.tool_name, {})]
|
||||
),
|
||||
conversation.ToolResultContent(
|
||||
self.agent_id, self.tool_call_id, self.tool_name,
|
||||
{"error": f"Error occurred calling tool with args='{self.tool_args}': {self.error_msg}" }
|
||||
)]
|
||||
|
||||
def closest_color(requested_color):
|
||||
min_colors = {}
|
||||
|
||||
@@ -280,7 +298,7 @@ def get_oai_formatted_tools(llm_api: llm.APIInstance, domains: list[str]) -> Lis
|
||||
|
||||
return result
|
||||
|
||||
def get_oai_formatted_messages(conversation: Sequence[conversation.Content], user_content_as_list: bool = False) -> List[ChatCompletionRequestMessage]:
|
||||
def get_oai_formatted_messages(conversation: Sequence[conversation.Content], user_content_as_list: bool = False, tool_args_to_str: bool = True) -> List[ChatCompletionRequestMessage]:
|
||||
messages: List[ChatCompletionRequestMessage] = []
|
||||
for message in conversation:
|
||||
if message.role == "system":
|
||||
@@ -309,7 +327,7 @@ def get_oai_formatted_messages(conversation: Sequence[conversation.Content], use
|
||||
"type" : "function",
|
||||
"id": t.id,
|
||||
"function": {
|
||||
"arguments": json.dumps(t.tool_args),
|
||||
"arguments": cast(str, json.dumps(t.tool_args) if tool_args_to_str else t.tool_args),
|
||||
"name": t.tool_name,
|
||||
}
|
||||
} for t in message.tool_calls
|
||||
@@ -362,7 +380,7 @@ def get_home_llm_tools(llm_api: llm.APIInstance, domains: list[str]) -> List[Dic
|
||||
|
||||
return tools
|
||||
|
||||
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance) -> tuple[llm.ToolInput | None, str | None]:
|
||||
def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance, user_input: conversation.ConversationInput) -> tuple[llm.ToolInput | None, str | None]:
|
||||
if isinstance(raw_block, dict):
|
||||
parsed_tool_call = raw_block
|
||||
else:
|
||||
@@ -385,20 +403,27 @@ def parse_raw_tool_call(raw_block: str | dict, llm_api: llm.APIInstance) -> tupl
|
||||
else:
|
||||
schema_to_validate = vol.Schema({
|
||||
vol.Required("name"): str,
|
||||
vol.Required("arguments"): str | dict,
|
||||
vol.Required("arguments"): vol.Union(str, dict),
|
||||
})
|
||||
|
||||
try:
|
||||
schema_to_validate(parsed_tool_call)
|
||||
except vol.Error as ex:
|
||||
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
|
||||
raise ex # re-raise exception for now to force the LLM to try again
|
||||
raise MalformedToolCallException(user_input.agent_id, "", "unknown", str(raw_block), "Tool call was not properly formatted")
|
||||
|
||||
# try to fix certain arguments
|
||||
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
|
||||
tool_name = parsed_tool_call.get("name", parsed_tool_call.get("service", ""))
|
||||
|
||||
if isinstance(args_dict, str):
|
||||
args_dict = json.loads(args_dict)
|
||||
if not args_dict.strip():
|
||||
args_dict = {} # don't attempt to parse empty arguments
|
||||
else:
|
||||
try:
|
||||
args_dict = json.loads(args_dict)
|
||||
except json.JSONDecodeError:
|
||||
raise MalformedToolCallException(user_input.agent_id, "", tool_name, str(args_dict), "Tool arguments were not properly formatted JSON")
|
||||
|
||||
# make sure brightness is 0-255 and not a percentage
|
||||
if "brightness" in args_dict and 0.0 < args_dict["brightness"] <= 1.0:
|
||||
|
||||
@@ -7,9 +7,6 @@ There are multiple backends to choose for running the model that the Home Assist
|
||||
|-----------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------|
|
||||
| LLM API | This is the set of tools that are provided to the LLM. Use Assist for the built-in API. If you are using Home-LLM v1, v2, or v3, then select the dedicated API | |
|
||||
| System Prompt | [see here](./Model%20Prompting.md) | |
|
||||
| Prompt Format | The format for the context of the model | |
|
||||
| Tool Format | The format of the tools that are provided to the model. Full, Reduced, or Minimal | |
|
||||
| Multi-Turn Tool Use | Enable this if the model you are using expects to receive the result from the tool call before responding to the user | |
|
||||
| Maximum tokens to return in response | Limits the number of tokens that can be produced by each model response | 512 |
|
||||
| Additional attribute to expose in the context | Extra attributes that will be exposed to the model via the `{{ devices }}` template variable | |
|
||||
| Arguments allowed to be pass to service calls | Any arguments not listed here will be filtered out of service calls. Used to restrict the model from modifying certain parts of your home. | |
|
||||
@@ -65,7 +62,6 @@ For details about the sampling parameters, see here: https://github.com/oobaboog
|
||||
| Option Name | Description | Suggested Value |
|
||||
|----------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------|
|
||||
| Request Timeout | The maximum time in seconds that the integration will wait for a response from the remote server | 90 (higher if running on low resource hardware) |
|
||||
| Use chat completions endpoint | If set, tells text-generation-webui to format the prompt instead of this extension. Prompt Format set here will not apply if this is enabled | |
|
||||
| Generation Preset/Character Name | The preset or character name to pass to the backend. If none is provided then the settings that are currently selected in the UI will be applied | |
|
||||
| Chat Mode | [see here](https://github.com/oobabooga/text-generation-webui/wiki/01-%E2%80%90-Chat-Tab#mode) | Instruct |
|
||||
| Top K | Sampling parameter; see above link | 40 |
|
||||
@@ -80,7 +76,6 @@ For details about the sampling parameters, see here: https://github.com/oobaboog
|
||||
|-------------------------------|--------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------|
|
||||
| Request Timeout | The maximum time in seconds that the integration will wait for a response from the remote server | 90 (higher if running on low resource hardware) |
|
||||
| Keep Alive/Inactivity Timeout | The duration in minutes to keep the model loaded after each request. Set to a negative value to keep loaded forever | 30m |
|
||||
| Use chat completions endpoint | If set, tells Ollama to format the prompt instead of this extension. Prompt Format set here will not apply if this is enabled | |
|
||||
| JSON Mode | Restricts the model to only ouput valid JSON objects. Enable this if you are using ICL and are getting invalid JSON responses. | True |
|
||||
| Top K | Sampling parameter; see above link | 40 |
|
||||
| Top P | Sampling parameter; see above link | 1.0 |
|
||||
@@ -92,6 +87,5 @@ For details about the sampling parameters, see here: https://github.com/oobaboog
|
||||
| Option Name | Description | Suggested Value |
|
||||
|-------------------------------|--------------------------------------------------------------------------------------------------|-------------------------------------------------|
|
||||
| Request Timeout | The maximum time in seconds that the integration will wait for a response from the remote server | 90 (higher if running on low resource hardware) |
|
||||
| Use chat completions endpoint | Flag to use `/v1/chat/completions` as the remote endpoint instead of `/v1/completions` | Backend Dependent |
|
||||
| Top P | Sampling parameter; see above link | 1.0 |
|
||||
| Temperature | Sampling parameter; see above link | 0.1 |
|
||||
|
||||
@@ -133,19 +133,3 @@ Vous êtes « Al », un assistant IA utile qui contrôle les appareils d'une m
|
||||
Eres 'Al', un útil asistente de IA que controla los dispositivos de una casa. Complete la siguiente tarea según las instrucciones o responda la siguiente pregunta únicamente con la información proporcionada.
|
||||
```
|
||||
-->
|
||||
|
||||
## Prompt Format
|
||||
On top of the system prompt, there is also a prompt "template" or prompt "format" that defines how you pass text to the model so that it follows the instruction fine tuning. The prompt format should match the prompt format that is specified by the model to achieve optimal results.
|
||||
|
||||
Currently supported prompt formats are:
|
||||
1. ChatML
|
||||
2. Vicuna
|
||||
3. Alpaca
|
||||
4. Mistral
|
||||
5. Zephyr w/ eos token `<|endoftext|>`
|
||||
6. Zephyr w/ eos token `</s>`
|
||||
7. Zephyr w/ eos token `<|end|>`
|
||||
8. Llama 3
|
||||
9. Command-R
|
||||
10. None (useful for foundation models)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user