Extract most of _async_generate to helper method

This commit is contained in:
Simon Redman
2025-05-31 14:57:55 -04:00
parent 82621674e4
commit 8a059fab29

View File

@@ -1251,6 +1251,50 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
else:
return choices[0]["text"]
async def _async_generate_with_parameters(self, conversation: dict, endpoint: str, additional_params: dict) -> str:
"""Generate a response using the OpenAI-compatible API"""
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
request_params = {
"model": self.model_name,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
request_params.update(additional_params)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
session = async_get_clientsession(self.hass)
response = None
try:
async with session.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers
) as response:
response.raise_for_status()
result = await response.json()
except asyncio.TimeoutError:
return "The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
except aiohttp.ClientError as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {response}")
return f"Failed to communicate with the API! {err}"
_LOGGER.debug(result)
return self._extract_response(result)
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
@@ -1273,53 +1317,16 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
return endpoint, request_params
async def _async_generate(self, conversation: dict) -> str:
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
request_params = {
"model": self.model_name,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
if use_chat_api:
endpoint, additional_params = self._chat_completion_params(conversation)
else:
endpoint, additional_params = self._completion_params(conversation)
request_params.update(additional_params)
result = await self._async_generate_with_parameters(conversation, endpoint, additional_params)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
session = async_get_clientsession(self.hass)
response = None
try:
async with session.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers
) as response:
response.raise_for_status()
result = await response.json()
except asyncio.TimeoutError:
return "The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
except aiohttp.ClientError as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {response}")
return f"Failed to communicate with the API! {err}"
_LOGGER.debug(result)
return self._extract_response(result)
return result
class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
"""Implements the OpenAPI-compatible Responses API backend."""
@@ -1334,50 +1341,13 @@ class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
return endpoint, request_params
async def _async_generate(self, conversation: dict) -> str:
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
request_params = {
"model": self.model_name,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
"""Generate a response using the OpenAI-compatible Responses API"""
endpoint, additional_params = self._responses_params(conversation)
request_params.update(additional_params)
result = await self._async_generate_with_parameters(conversation, endpoint, additional_params)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
session = async_get_clientsession(self.hass)
response = None
try:
async with session.post(
f"{self.api_host}{endpoint}",
json=request_params,
timeout=timeout,
headers=headers
) as response:
response.raise_for_status()
result = await response.json()
except asyncio.TimeoutError:
return "The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
except aiohttp.ClientError as err:
_LOGGER.debug(f"Err was: {err}")
_LOGGER.debug(f"Request was: {request_params}")
_LOGGER.debug(f"Result was: {response}")
return f"Failed to communicate with the API! {err}"
_LOGGER.debug(result)
return self._extract_response(result)
return result
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
admin_key: str