mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
mostly working implementation. still needs debugging
This commit is contained in:
@@ -34,7 +34,7 @@ from custom_components.llama_conversation.const import (
|
||||
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
DEFAULT_GENERIC_OPENAI_PATH,
|
||||
)
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult
|
||||
from custom_components.llama_conversation.conversation import LocalLLMAgent, TextGenerationResult, parse_raw_tool_call
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,7 +54,7 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY, "")
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL, "")
|
||||
|
||||
async def _async_generate_with_parameters(self, endpoint: str, stream: bool, additional_params: dict):
|
||||
async def _async_generate_with_parameters(self, endpoint: str, stream: bool, additional_params: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput):
|
||||
"""Generate a response using the OpenAI-compatible API"""
|
||||
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
@@ -89,10 +89,10 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
if stream:
|
||||
async for line_bytes in response.content:
|
||||
chunk = line_bytes.decode("utf-8").strip()
|
||||
yield self._extract_response(json.loads(chunk))
|
||||
yield self._extract_response(json.loads(chunk), llm_api, user_input)
|
||||
else:
|
||||
response_json = await response.json()
|
||||
yield self._extract_response(response_json)
|
||||
yield self._extract_response(response_json, llm_api, user_input)
|
||||
except asyncio.TimeoutError:
|
||||
yield TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
|
||||
except aiohttp.ClientError as err:
|
||||
@@ -101,19 +101,34 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
_LOGGER.debug(f"Result was: {response}")
|
||||
yield TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
|
||||
|
||||
def _extract_response(self, response_json: dict) -> TextGenerationResult:
|
||||
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> TextGenerationResult:
|
||||
raise NotImplementedError("Subclasses must implement _extract_response()")
|
||||
|
||||
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
|
||||
|
||||
def _extract_response(self, response_json: dict) -> TextGenerationResult:
|
||||
def _extract_response(self, response_json: dict, llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> TextGenerationResult:
|
||||
choice = response_json["choices"][0]
|
||||
tool_calls = None
|
||||
if response_json["object"] == "chat.completion":
|
||||
response_text = choice["message"]["content"]
|
||||
streamed = False
|
||||
elif response_json["object"] == "chat.completion.chunk":
|
||||
response_text = choice["message"]["content"]
|
||||
response_text = choice["delta"].get("content", "")
|
||||
if "tool_calls" in choice["delta"]:
|
||||
tool_calls = []
|
||||
for call in choice["delta"]["tool_calls"]:
|
||||
success, tool_args, to_say = parse_raw_tool_call(
|
||||
call["function"]["arguments"],
|
||||
call["function"]["name"],
|
||||
call["id"],
|
||||
llm_api, user_input)
|
||||
|
||||
if success and tool_args:
|
||||
tool_calls.append(tool_args)
|
||||
|
||||
if to_say:
|
||||
response_text += to_say
|
||||
streamed = True
|
||||
else:
|
||||
response_text = choice["text"]
|
||||
@@ -127,9 +142,10 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
response=response_text,
|
||||
stop_reason=choice["finish_reason"],
|
||||
response_streamed=streamed,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
@@ -139,7 +155,7 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
if llm_api:
|
||||
request_params["tools"] = get_oai_formatted_tools(llm_api)
|
||||
|
||||
return self._async_generate_with_parameters(endpoint, True, request_params)
|
||||
return self._async_generate_with_parameters(endpoint, True, request_params, llm_api, user_input)
|
||||
|
||||
class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible Responses API backend."""
|
||||
@@ -244,4 +260,4 @@ class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Generate a response using the OpenAI-compatible Responses API"""
|
||||
|
||||
endpoint, additional_params = self._responses_params(conv)
|
||||
return self._async_generate_with_parameters(endpoint, additional_params)
|
||||
return self._async_generate_with_parameters(endpoint, False, additional_params)
|
||||
@@ -107,8 +107,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_e
|
||||
async_add_entities([entry.runtime_data])
|
||||
|
||||
return True
|
||||
|
||||
def _parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance, user_input: ConversationInput) -> tuple[bool, ConversationResult | llm.ToolInput, str | None]:
|
||||
|
||||
def parse_raw_tool_call(raw_block: str, tool_name: str, tool_call_id: str, llm_api: llm.APIInstance, user_input: ConversationInput) -> tuple[bool, llm.ToolInput | None, str | None]:
|
||||
parsed_tool_call: dict = json.loads(raw_block)
|
||||
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
@@ -135,15 +135,7 @@ def _parse_raw_tool_call(raw_block: str, llm_api: llm.APIInstance, user_input: C
|
||||
schema_to_validate(parsed_tool_call)
|
||||
except vol.Error as ex:
|
||||
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
|
||||
f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
|
||||
)
|
||||
return False, ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
), ""
|
||||
return False, None, f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
|
||||
|
||||
# try to fix certain arguments
|
||||
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
|
||||
@@ -267,11 +259,11 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
self._load_model, entry
|
||||
)
|
||||
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
def _generate_stream(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> AsyncGenerator[TextGenerationResult, None]:
|
||||
"""Async generator for streaming responses. Subclasses should implement."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _generate(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None) -> TextGenerationResult:
|
||||
async def _generate(self, conversation: List[conversation.Content], llm_api: llm.APIInstance | None, user_input: conversation.ConversationInput) -> TextGenerationResult:
|
||||
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -279,10 +271,10 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
"""Default implementation: if streaming is supported, consume the async generator and return the full result."""
|
||||
if hasattr(self, '_generate_stream'):
|
||||
# Try to stream and collect the full response
|
||||
return await self._transform_result_stream(self._generate_stream(conv, chat_log.llm_api), user_input, chat_log)
|
||||
return await self._transform_result_stream(self._generate_stream(conv, chat_log.llm_api, user_input), user_input, chat_log)
|
||||
|
||||
# Fallback to "blocking" generate
|
||||
blocking_result = await self._generate(conv, chat_log.llm_api)
|
||||
blocking_result = await self._generate(conv, chat_log.llm_api, user_input)
|
||||
|
||||
return chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
|
||||
@@ -295,4 +295,4 @@ def get_oai_formatted_messages(conversation: Sequence[conversation.Content]) ->
|
||||
"tool_call_id": message.tool_call_id
|
||||
})
|
||||
|
||||
return messages
|
||||
return messages
|
||||
|
||||
Reference in New Issue
Block a user