mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
Offload responsibility for _extract_response to child classes
This commit is contained in:
@@ -1241,16 +1241,6 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
if response_json["object"] in ["chat.completion", "chat.completion.chunk"]:
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
async def _async_generate_with_parameters(self, conversation: dict, endpoint: str, additional_params: dict) -> str:
|
||||
"""Generate a response using the OpenAI-compatible API"""
|
||||
|
||||
@@ -1295,6 +1285,9 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
|
||||
return self._extract_response(result)
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
raise NotImplementedError("Subclasses must implement _extract_response()")
|
||||
|
||||
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
|
||||
|
||||
@@ -1316,6 +1309,16 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
if response_json["object"] in ["chat.completion", "chat.completion.chunk"]:
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
async def _async_generate(self, conversation: dict) -> str:
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user