Offload responsibility for _extract_response to child classes

This commit is contained in:
Simon Redman
2025-05-31 16:30:48 -04:00
parent 92042f629d
commit 7ad8d03dd0

View File

@@ -1241,16 +1241,6 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)
def _extract_response(self, response_json: dict) -> str:
choices = response_json["choices"]
if choices[0]["finish_reason"] != "stop":
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
if response_json["object"] in ["chat.completion", "chat.completion.chunk"]:
return choices[0]["message"]["content"]
else:
return choices[0]["text"]
async def _async_generate_with_parameters(self, conversation: dict, endpoint: str, additional_params: dict) -> str:
"""Generate a response using the OpenAI-compatible API"""
@@ -1295,6 +1285,9 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
return self._extract_response(result)
def _extract_response(self, response_json: dict) -> str:
raise NotImplementedError("Subclasses must implement _extract_response()")
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
@@ -1316,6 +1309,16 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
return endpoint, request_params
def _extract_response(self, response_json: dict) -> str:
choices = response_json["choices"]
if choices[0]["finish_reason"] != "stop":
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
if response_json["object"] in ["chat.completion", "chat.completion.chunk"]:
return choices[0]["message"]["content"]
else:
return choices[0]["text"]
async def _async_generate(self, conversation: dict) -> str:
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)