From 7ad8d03dd0d9a4f6ffd3f29f7d1ef0bcd92fe3f7 Mon Sep 17 00:00:00 2001 From: Simon Redman Date: Sat, 31 May 2025 16:30:48 -0400 Subject: [PATCH] Offload responsibility for _extract_response to child classes --- .../llama_conversation/conversation.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/custom_components/llama_conversation/conversation.py b/custom_components/llama_conversation/conversation.py index d67cc30..78232b8 100644 --- a/custom_components/llama_conversation/conversation.py +++ b/custom_components/llama_conversation/conversation.py @@ -1241,16 +1241,6 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent): self.api_key = entry.data.get(CONF_OPENAI_API_KEY) self.model_name = entry.data.get(CONF_CHAT_MODEL) - def _extract_response(self, response_json: dict) -> str: - choices = response_json["choices"] - if choices[0]["finish_reason"] != "stop": - _LOGGER.warning("Model response did not end on a stop token (unfinished sentence)") - - if response_json["object"] in ["chat.completion", "chat.completion.chunk"]: - return choices[0]["message"]["content"] - else: - return choices[0]["text"] - async def _async_generate_with_parameters(self, conversation: dict, endpoint: str, additional_params: dict) -> str: """Generate a response using the OpenAI-compatible API""" @@ -1295,6 +1285,9 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent): return self._extract_response(result) + def _extract_response(self, response_json: dict) -> str: + raise NotImplementedError("Subclasses must implement _extract_response()") + class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent): """Implements the OpenAPI-compatible text completion and chat completion API backends.""" @@ -1316,6 +1309,16 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent): return endpoint, request_params + def _extract_response(self, response_json: dict) -> str: + choices = response_json["choices"] + if choices[0]["finish_reason"] != "stop": + _LOGGER.warning("Model response did not end on a stop token (unfinished sentence)") + + if response_json["object"] in ["chat.completion", "chat.completion.chunk"]: + return choices[0]["message"]["content"] + else: + return choices[0]["text"] + async def _async_generate(self, conversation: dict) -> str: use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)