mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
remove intermediate dict format and pass around home assistant model object
This commit is contained in:
committed by
Alex O'Connell
parent
53052af641
commit
da0a0e4dbc
@@ -1,18 +1,20 @@
|
||||
"""Defines the OpenAI API compatible agents"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from typing import List, Dict, Tuple, AsyncGenerator, Any
|
||||
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from custom_components.llama_conversation.utils import format_url
|
||||
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
@@ -30,7 +32,6 @@ from custom_components.llama_conversation.const import (
|
||||
DEFAULT_REQUEST_TIMEOUT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
DEFAULT_GENERIC_OPENAI_PATH,
|
||||
)
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
|
||||
@@ -50,10 +51,10 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
path=""
|
||||
)
|
||||
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY, "")
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL, "")
|
||||
|
||||
async def _async_generate_with_parameters(self, endpoint: str, additional_params: dict) -> TextGenerationResult:
|
||||
async def _async_generate_with_parameters(self, endpoint: str, stream: bool, additional_params: dict):
|
||||
"""Generate a response using the OpenAI-compatible API"""
|
||||
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
@@ -63,6 +64,7 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"stream": stream,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
@@ -84,18 +86,20 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
if stream:
|
||||
async for line_bytes in response.content:
|
||||
chunk = line_bytes.decode("utf-8").strip()
|
||||
yield self._extract_response(json.loads(chunk))
|
||||
else:
|
||||
response_json = await response.json()
|
||||
yield self._extract_response(response_json)
|
||||
except asyncio.TimeoutError:
|
||||
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
|
||||
yield TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.debug(f"Err was: {err}")
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {response}")
|
||||
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
|
||||
|
||||
_LOGGER.debug(result)
|
||||
|
||||
return self._extract_response(result)
|
||||
yield TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
|
||||
|
||||
def _extract_response(self, response_json: dict) -> TextGenerationResult:
|
||||
raise NotImplementedError("Subclasses must implement _extract_response()")
|
||||
@@ -103,24 +107,6 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
|
||||
|
||||
def _chat_completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/chat/completions"
|
||||
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/completions"
|
||||
request_params["prompt"] = self._format_prompt(conversation)
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> TextGenerationResult:
|
||||
choice = response_json["choices"][0]
|
||||
if response_json["object"] == "chat.completion":
|
||||
@@ -140,26 +126,26 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
return TextGenerationResult(
|
||||
response=response_text,
|
||||
stop_reason=choice["finish_reason"],
|
||||
response_streamed=streamed
|
||||
response_streamed=streamed,
|
||||
)
|
||||
|
||||
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
if use_chat_api:
|
||||
endpoint, additional_params = self._chat_completion_params(conversation)
|
||||
else:
|
||||
endpoint, additional_params = self._completion_params(conversation)
|
||||
endpoint = f"/{api_base_path}/chat/completions"
|
||||
request_params["messages"] = get_oai_formatted_messages(conversation)
|
||||
|
||||
result = await self._async_generate_with_parameters(endpoint, additional_params)
|
||||
if llm_api:
|
||||
request_params["tools"] = get_oai_formatted_tools(llm_api)
|
||||
|
||||
return result
|
||||
return self._async_generate_with_parameters(endpoint, True, request_params)
|
||||
|
||||
class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible Responses API backend."""
|
||||
|
||||
_last_response_id: str | None = None
|
||||
_last_response_id_time: datetime.datetime = None
|
||||
_last_response_id_time: datetime.datetime | None = None
|
||||
|
||||
def _responses_params(self, conversation: List[Dict[str, str]]) -> Tuple[str, Dict[str, Any]]:
|
||||
request_params = {}
|
||||
@@ -254,8 +240,8 @@ class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
|
||||
return to_return
|
||||
|
||||
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
|
||||
async def _async_generate(self, conv: List[Dict[str, str]], user_input: conversation.ConversationInput, chat_log: conversation.chat_log.ChatLog):
|
||||
"""Generate a response using the OpenAI-compatible Responses API"""
|
||||
|
||||
endpoint, additional_params = self._responses_params(conversation)
|
||||
return await self._async_generate_with_parameters(endpoint, additional_params)
|
||||
endpoint, additional_params = self._responses_params(conv)
|
||||
return self._async_generate_with_parameters(endpoint, additional_params)
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Generator, Optional, List, Dict, AsyncIterable
|
||||
from typing import Any, Callable, Generator, Optional, List, Dict, AsyncIterable, AsyncGenerator, Sequence
|
||||
|
||||
from homeassistant.components import conversation as conversation
|
||||
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
||||
@@ -18,7 +18,7 @@ from homeassistant.exceptions import ConfigEntryError, HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
||||
|
||||
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation
|
||||
from custom_components.llama_conversation.utils import install_llama_cpp_python, validate_llama_cpp_python_installation, get_oai_formatted_messages, get_oai_formatted_tools
|
||||
from custom_components.llama_conversation.const import (
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
@@ -65,6 +65,7 @@ else:
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCppAgent(LocalLLMAgent):
|
||||
model_path: str
|
||||
llm: LlamaType
|
||||
@@ -351,10 +352,20 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
||||
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
|
||||
|
||||
async def _async_generate_completion(self, chat_completion) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
for token in chat_completion:
|
||||
if isinstance(token, str):
|
||||
yield TextGenerationResult(
|
||||
response=token,
|
||||
response_streamed=True
|
||||
)
|
||||
else:
|
||||
token["choices"][0]["delta"].get("tool_calls")
|
||||
|
||||
def _generate_stream(self, conversation: List[Dict[str, str]], user_input: conversation.ConversationInput, chat_log: conversation.ChatLog) -> TextGenerationResult:
|
||||
prompt = self._format_prompt(conversation)
|
||||
|
||||
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
"""Async generator that yields TextGenerationResult as tokens are produced."""
|
||||
# prompt = self._format_prompt(conversation)
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
@@ -364,49 +375,37 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
|
||||
_LOGGER.debug(f"Options: {self.entry.options}")
|
||||
|
||||
with self.model_lock:
|
||||
input_tokens = self.llm.tokenize(
|
||||
prompt.encode(), add_bos=False
|
||||
)
|
||||
# with self.model_lock:
|
||||
# # FIXME: use the high level API so we can use the built-in prompt formatting
|
||||
# input_tokens = self.llm.tokenize(
|
||||
# prompt.encode(), add_bos=False
|
||||
# )
|
||||
|
||||
context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
if len(input_tokens) >= context_len:
|
||||
num_entities = len(self._async_get_exposed_entities()[0])
|
||||
context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
self._warn_context_size()
|
||||
raise Exception(f"The model failed to produce a result because too many devices are exposed ({num_entities} devices) for the context size ({context_size} tokens)!")
|
||||
if len(input_tokens) + max_tokens >= context_len:
|
||||
self._warn_context_size()
|
||||
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
# if len(input_tokens) >= context_len:
|
||||
# num_entities = len(self._async_get_exposed_entities()[0])
|
||||
# context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
# self._warn_context_size()
|
||||
# raise Exception(f"The model failed to produce a result because too many devices are exposed ({num_entities} devices) for the context size ({context_size} tokens)!")
|
||||
# if len(input_tokens) + max_tokens >= context_len:
|
||||
# self._warn_context_size()
|
||||
|
||||
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
||||
output_tokens = self.llm.generate(
|
||||
input_tokens,
|
||||
temp=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
grammar=self.grammar
|
||||
)
|
||||
# _LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
||||
|
||||
result_tokens = []
|
||||
async def async_iterator():
|
||||
num_tokens = 0
|
||||
for token in output_tokens:
|
||||
result_tokens.append(token)
|
||||
yield TextGenerationResult(response=self.llm.detokenize([token]).decode(), response_streamed=True)
|
||||
messages = get_oai_formatted_messages(conversation)
|
||||
tools = None
|
||||
if llm_api:
|
||||
tools = get_oai_formatted_tools(llm_api)
|
||||
|
||||
if token == self.llm.token_eos():
|
||||
break
|
||||
return self._async_generate_completion(self.llm.create_chat_completion(
|
||||
messages,
|
||||
tools=tools,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
max_tokens=max_tokens,
|
||||
grammar=self.grammar
|
||||
))
|
||||
|
||||
if len(result_tokens) >= max_tokens:
|
||||
break
|
||||
|
||||
num_tokens += 1
|
||||
|
||||
self._transform_result_stream(async_iterator(), user_input=user_input, chat_log=chat_log)
|
||||
|
||||
response = TextGenerationResult(
|
||||
response=self.llm.detokenize(result_tokens).decode(),
|
||||
response_streamed=True,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Defines the ollama compatible agent"""
|
||||
from __future__ import annotations
|
||||
from warnings import deprecated
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
@@ -43,6 +44,7 @@ from custom_components.llama_conversation.conversation import LocalLLMAgent, Tex
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@deprecated("Use the built-in Ollama integration instead")
|
||||
class OllamaAPIAgent(LocalLLMAgent):
|
||||
api_host: str
|
||||
api_key: Optional[str]
|
||||
|
||||
@@ -107,37 +107,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_e
|
||||
async_add_entities([entry.runtime_data])
|
||||
|
||||
return True
|
||||
|
||||
def _convert_content(
|
||||
chat_content: conversation.Content
|
||||
) -> dict[str, str]:
|
||||
"""Create tool response content."""
|
||||
role_name = None
|
||||
if isinstance(chat_content, conversation.ToolResultContent):
|
||||
role_name = "tool"
|
||||
elif isinstance(chat_content, conversation.AssistantContent):
|
||||
role_name = "assistant"
|
||||
elif isinstance(chat_content, conversation.UserContent):
|
||||
role_name = "user"
|
||||
elif isinstance(chat_content, conversation.SystemContent):
|
||||
role_name = "system"
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(chat_content)}")
|
||||
|
||||
return { "role": role_name, "message": chat_content.content }
|
||||
|
||||
def _convert_content_back(
|
||||
agent_id: str,
|
||||
message_history_entry: dict[str, str]
|
||||
) -> Optional[conversation.Content]:
|
||||
if message_history_entry["role"] == "tool":
|
||||
return conversation.ToolResultContent(agent_id=agent_id, content=message_history_entry["message"])
|
||||
if message_history_entry["role"] == "assistant":
|
||||
return conversation.AssistantContent(agent_id=agent_id, content=message_history_entry["message"])
|
||||
if message_history_entry["role"] == "user":
|
||||
return conversation.UserContent(content=message_history_entry["message"])
|
||||
if message_history_entry["role"] == "system":
|
||||
return conversation.SystemContent(content=message_history_entry["message"])
|
||||
|
||||
def _parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance, user_input: ConversationInput) -> tuple[bool, ConversationResult | llm.ToolInput, str | None]:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
@@ -298,21 +267,29 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
self._load_model, entry
|
||||
)
|
||||
|
||||
def _generate_stream(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
"""Async generator for streaming responses. Subclasses should implement."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
|
||||
async def _generate(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> TextGenerationResult:
|
||||
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _async_generate(self, conversation: List[Dict[str, str]]) -> TextGenerationResult:
|
||||
"""Default implementation is to call _generate() which probably does blocking stuff"""
|
||||
async def _async_generate(self, conv: List[conversation.Content], user_input: ConversationInput, chat_log: conversation.chat_log.ChatLog):
|
||||
"""Default implementation: if streaming is supported, consume the async generator and return the full result."""
|
||||
if hasattr(self, '_generate_stream'):
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._generate_stream, conversation
|
||||
# Try to stream and collect the full response
|
||||
return await self._transform_result_stream(self._generate_stream(conv, chat_log.llm_api), user_input, chat_log)
|
||||
|
||||
# Fallback to "blocking" generate
|
||||
blocking_result = await self._generate(conv, chat_log.llm_api)
|
||||
|
||||
return chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=blocking_result.response,
|
||||
tool_calls=blocking_result.tool_calls
|
||||
)
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._generate, conversation
|
||||
)
|
||||
|
||||
def _warn_context_size(self):
|
||||
@@ -322,7 +299,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
f"{self.entry.data[CONF_CHAT_MODEL]} and it exceeded the context size for the model. " +
|
||||
f"Please reduce the number of entities exposed ({num_entities}) or increase the model's context size ({int(context_size)})")
|
||||
|
||||
def _transform_result_stream(
|
||||
async def _transform_result_stream(
|
||||
self,
|
||||
result: AsyncIterator[TextGenerationResult],
|
||||
user_input: ConversationInput,
|
||||
@@ -335,8 +312,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
tool_calls=input_chunk.tool_calls
|
||||
)
|
||||
|
||||
|
||||
chat_log.async_add_delta_content_stream(user_input.agent_id, stream=async_iterator())
|
||||
return chat_log.async_add_delta_content_stream(user_input.agent_id, stream=async_iterator())
|
||||
|
||||
async def async_process(
|
||||
self, user_input: ConversationInput
|
||||
@@ -405,7 +381,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
)
|
||||
|
||||
if remember_conversation:
|
||||
message_history = [ _convert_content(content) for content in chat_log.content ]
|
||||
message_history = chat_log.content[:]
|
||||
else:
|
||||
message_history = []
|
||||
|
||||
@@ -436,16 +412,14 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
multi_turn_enabled = self.entry.options.get(CONF_TOOL_MULTI_TURN_CHAT, DEFAULT_TOOL_MULTI_TURN_CHAT)
|
||||
MAX_TOOL_CALL_ITERATIONS = 3 if multi_turn_enabled else 1
|
||||
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
|
||||
for _ in range(MAX_TOOL_CALL_ITERATIONS):
|
||||
# generate a response
|
||||
try:
|
||||
_LOGGER.debug(message_history)
|
||||
generation_result = await self._async_generate(message_history)
|
||||
generation_result = await self._async_generate(message_history, user_input, chat_log)
|
||||
_LOGGER.debug(generation_result)
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.exception("There was a problem talking to the backend")
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
@@ -455,67 +429,28 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
response = generation_result.response or ""
|
||||
last_message_had_tool_calls = False
|
||||
async for message in generation_result:
|
||||
message_history.append(message)
|
||||
if message.role == "assistant":
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
last_message_had_tool_calls = True
|
||||
else:
|
||||
last_message_had_tool_calls = False
|
||||
|
||||
# remove think blocks
|
||||
response = re.sub(rf"^.*?{template_desc["chain_of_thought"]["suffix"]}", "", response, flags=re.DOTALL)
|
||||
|
||||
message_history.append({"role": "assistant", "message": response})
|
||||
if llm_api is None or (generation_result.tool_calls and len(generation_result.tool_calls) == 0):
|
||||
if not generation_result.response_streamed:
|
||||
chat_log.async_add_assistant_content_without_tools(conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=response,
|
||||
))
|
||||
|
||||
# return the output without messing with it if there is no API exposed to the model
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
# execute the tool calls
|
||||
tool_calls: List[Tuple[llm.ToolInput, Any]] = []
|
||||
for tool_input in generation_result.tool_calls or []:
|
||||
tool_response = None
|
||||
try:
|
||||
tool_response = llm_api.async_call_tool(tool_input)
|
||||
_LOGGER.debug("Tool response: %s", tool_response)
|
||||
|
||||
tool_calls.append((tool_input, tool_response))
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
tool_response["error_text"] = str(e)
|
||||
_LOGGER.debug("Tool response: %s", tool_response)
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
|
||||
f"I'm sorry! I encountered an error calling the tool. See the logs for more info.",
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
if tool_response and multi_turn_enabled:
|
||||
async for tool_result in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=response,
|
||||
tool_calls=generation_result.tool_calls,
|
||||
),
|
||||
tool_call_tasks={ x[0].tool_name: x[1] for x in tool_calls}
|
||||
):
|
||||
message_history.append({"role": "tool", "content": json.dumps(tool_result.tool_result) })
|
||||
# If not multi-turn, break after first tool call
|
||||
# also break if no tool calls were made
|
||||
if not multi_turn_enabled or not last_message_had_tool_calls:
|
||||
break
|
||||
|
||||
# generate intent response to Home Assistant
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
if len(tool_calls) > 0:
|
||||
str_tools = [f"{input.tool_name}({', '.join(input.tool_args.values())})" for input, response in tool_calls]
|
||||
str_tools = [f"{input.tool_name}({', '.join(str(x) for x in input.tool_args.values())})" for input, response in tool_calls]
|
||||
tools_str = '\n'.join(str_tools)
|
||||
intent_response.async_set_card(
|
||||
title="Changes",
|
||||
content=f"Ran the following tools:\n{'\n'.join(str_tools)}"
|
||||
content=f"Ran the following tools:\n{tools_str}"
|
||||
)
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
|
||||
@@ -7,11 +7,13 @@ import logging
|
||||
import multiprocessing
|
||||
import voluptuous as vol
|
||||
import webcolors
|
||||
from typing import Any, Dict, List, Sequence
|
||||
from webcolors import CSS3
|
||||
from importlib.metadata import version
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers import intent, llm
|
||||
from homeassistant.requirements import pip_kwargs
|
||||
from homeassistant.util import color
|
||||
from homeassistant.util.package import install_package, is_installed
|
||||
@@ -21,6 +23,13 @@ from .const import (
|
||||
EMBEDDED_LLAMA_CPP_PYTHON_VERSION,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp.llama_types import ChatCompletionRequestMessage, ChatCompletionTool
|
||||
else:
|
||||
ChatCompletionRequestMessage = Any
|
||||
ChatCompletionTool = Any
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CSS3_NAME_TO_RGB = {
|
||||
@@ -238,4 +247,52 @@ def install_llama_cpp_python(config_dir: str):
|
||||
return True
|
||||
|
||||
def format_url(*, hostname: str, port: str, ssl: bool, path: str):
|
||||
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
|
||||
return f"{'https' if ssl else 'http'}://{hostname}{ ':' + port if port else ''}{path}"
|
||||
|
||||
def get_oai_formatted_tools(llm_api: llm.APIInstance) -> List[ChatCompletionTool]:
|
||||
return [ {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": tool.parameters.schema
|
||||
}
|
||||
} for tool in llm_api.tools ]
|
||||
|
||||
def get_oai_formatted_messages(conversation: Sequence[conversation.Content]) -> List[ChatCompletionRequestMessage]:
|
||||
messages = []
|
||||
for message in conversation:
|
||||
if message.role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": message.content
|
||||
})
|
||||
elif message.role == "user":
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [{ "type": "text", "text": message.content }]
|
||||
})
|
||||
elif message.role == "assistant":
|
||||
if message.tool_calls:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type" : "function",
|
||||
"id": t.id,
|
||||
"function": {
|
||||
"arguments": t.tool_args,
|
||||
"name": t.tool_name,
|
||||
}
|
||||
} for t in message.tool_calls
|
||||
]
|
||||
})
|
||||
elif message.role == "tool_result":
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": message.tool_result,
|
||||
"tool_call_id": message.tool_call_id
|
||||
})
|
||||
|
||||
return messages
|
||||
2
train.py
2
train.py
@@ -675,8 +675,6 @@ def do_training_run(training_run_args: TrainingRunArguments):
|
||||
if input("Something bad happened! Try and save it? (Y/n) ").lower().startswith("y"):
|
||||
trainer._save_checkpoint(model, None)
|
||||
print("Saved Checkpoint!")
|
||||
|
||||
exit(-1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser([TrainingRunArguments])
|
||||
|
||||
Reference in New Issue
Block a user