mostly working implementation. still needs debugging

This commit is contained in:
Alex O'Connell
2025-09-15 22:10:09 -04:00
parent da0a0e4dbc
commit 843f99d64a
3 changed files with 34 additions and 26 deletions

View File

@@ -34,7 +34,7 @@ from custom_components.llama_conversation.const import (
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
DEFAULT_GENERIC_OPENAI_PATH,
)
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult, parse_raw_tool_call
_LOGGER = logging.getLogger(__name__)
@@ -54,7 +54,7 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
self.api_key = entry.data.get(CONF_OPENAI_API_KEY, "")
self.model_name = entry.data.get(CONF_CHAT_MODEL, "")
async def _async_generate_with_parameters(self, endpoint: str, stream: bool, additional_params: dict):
async def _async_generate_with_parameters(self, endpoint: str, stream: bool, additional_params: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput):
"""Generate a response using the OpenAI-compatible API"""
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
@@ -89,10 +89,10 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
if stream:
async for line_bytes in response.content:
chunk = line_bytes.decode("utf-8").strip()
yield self._extract_response(json.loads(chunk))
yield self._extract_response(json.loads(chunk), llm_api, user_input)
else:
response_json = await response.json()
yield self._extract_response(response_json)
yield self._extract_response(response_json, llm_api, user_input)
except asyncio.TimeoutError:
yield TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
except aiohttp.ClientError as err:
@@ -101,19 +101,34 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
_LOGGER.debug(f"Result was: {response}")
yield TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
def _extract_response(self, response_json: dict) -> TextGenerationResult:
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> TextGenerationResult:
raise NotImplementedError("Subclasses must implement _extract_response()")
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
def _extract_response(self, response_json: dict) -> TextGenerationResult:
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> TextGenerationResult:
choice = response_json["choices"][0]
tool_calls = None
if response_json["object"] == "chat.completion":
response_text = choice["message"]["content"]
streamed = False
elif response_json["object"] == "chat.completion.chunk":
response_text = choice["message"]["content"]
response_text = choice["delta"].get("content", "")
if "tool_calls" in choice["delta"]:
tool_calls = []
for call in choice["delta"]["tool_calls"]:
success, tool_args, to_say = parse_raw_tool_call(
call["function"]["arguments"],
call["function"]["name"],
call["id"],
llm_api, user_input)
if success and tool_args:
tool_calls.append(tool_args)
if to_say:
response_text += to_say
streamed = True
else:
response_text = choice["text"]
@@ -127,9 +142,10 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
response=response_text,
stop_reason=choice["finish_reason"],
response_streamed=streamed,
tool_calls=tool_calls
)
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> AsyncGenerator[TextGenerationResult, None]:
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> AsyncGenerator[TextGenerationResult, None]:
request_params = {}
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
@@ -139,7 +155,7 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
if llm_api:
request_params["tools"] = get_oai_formatted_tools(llm_api)
return self._async_generate_with_parameters(endpoint, True, request_params)
return self._async_generate_with_parameters(endpoint, True, request_params, llm_api, user_input)
class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
"""Implements the OpenAPI-compatible Responses API backend."""
@@ -244,4 +260,4 @@ class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
"""Generate a response using the OpenAI-compatible Responses API"""
endpoint, additional_params = self._responses_params(conv)
return self._async_generate_with_parameters(endpoint, additional_params)
return self._async_generate_with_parameters(endpoint, False, additional_params)

View File

@@ -107,8 +107,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_e
async_add_entities([entry.runtime_data])
return True
def _parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance, user_input: ConversationInput) -> tuple[bool, ConversationResult | llm.ToolInput, str | None]:
def parse_raw_tool_call(raw_block: str, tool_name: str, tool_call_id: str, llm_api: llm.APIInstance, user_input: ConversationInput) -> tuple[bool, llm.ToolInput | None, str | None]:
parsed_tool_call: dict = json.loads(raw_block)
if llm_api.api.id == HOME_LLM_API_ID:
@@ -135,15 +135,7 @@ def _parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance, user_input: C
schema_to_validate(parsed_tool_call)
except vol.Error as ex:
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
)
return False, ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
), ""
return False, None, f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
# try to fix certain arguments
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
@@ -267,11 +259,11 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
self._load_model, entry
)
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> AsyncGenerator[TextGenerationResult, None]:
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> AsyncGenerator[TextGenerationResult, None]:
"""Async generator for streaming responses. Subclasses should implement."""
raise NotImplementedError()
async def _generate(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> TextGenerationResult:
async def _generate(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> TextGenerationResult:
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
raise NotImplementedError()
@@ -279,10 +271,10 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
"""Default implementation: if streaming is supported, consume the async generator and return the full result."""
if hasattr(self, '_generate_stream'):
# Try to stream and collect the full response
return await self._transform_result_stream(self._generate_stream(conv, chat_log.llm_api), user_input, chat_log)
return await self._transform_result_stream(self._generate_stream(conv, chat_log.llm_api, user_input), user_input, chat_log)
# Fallback to "blocking" generate
blocking_result = await self._generate(conv, chat_log.llm_api)
blocking_result = await self._generate(conv, chat_log.llm_api, user_input)
return chat_log.async_add_assistant_content(
conversation.AssistantContent(

View File

@@ -295,4 +295,4 @@ def get_oai_formatted_messages(conversation: Sequence[conversation.Content]) ->
"tool_call_id": message.tool_call_id
})
return messages
return messages