mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
Extract most of _async_generate to helper method
This commit is contained in:
@@ -1251,6 +1251,50 @@ class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
async def _async_generate_with_parameters(self, conversation: dict, endpoint: str, additional_params: dict) -> str:
|
||||
"""Generate a response using the OpenAI-compatible API"""
|
||||
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
request_params.update(additional_params)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
response = None
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
except asyncio.TimeoutError:
|
||||
return "The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.debug(f"Err was: {err}")
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {response}")
|
||||
return f"Failed to communicate with the API! {err}"
|
||||
|
||||
_LOGGER.debug(result)
|
||||
|
||||
return self._extract_response(result)
|
||||
|
||||
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
|
||||
|
||||
@@ -1273,53 +1317,16 @@ class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
return endpoint, request_params
|
||||
|
||||
async def _async_generate(self, conversation: dict) -> str:
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
|
||||
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
if use_chat_api:
|
||||
endpoint, additional_params = self._chat_completion_params(conversation)
|
||||
else:
|
||||
endpoint, additional_params = self._completion_params(conversation)
|
||||
|
||||
request_params.update(additional_params)
|
||||
result = await self._async_generate_with_parameters(conversation, endpoint, additional_params)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
response = None
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
except asyncio.TimeoutError:
|
||||
return "The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.debug(f"Err was: {err}")
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {response}")
|
||||
return f"Failed to communicate with the API! {err}"
|
||||
|
||||
_LOGGER.debug(result)
|
||||
|
||||
return self._extract_response(result)
|
||||
return result
|
||||
|
||||
class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible Responses API backend."""
|
||||
@@ -1334,50 +1341,13 @@ class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
return endpoint, request_params
|
||||
|
||||
async def _async_generate(self, conversation: dict) -> str:
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
|
||||
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
"""Generate a response using the OpenAI-compatible Responses API"""
|
||||
|
||||
endpoint, additional_params = self._responses_params(conversation)
|
||||
|
||||
request_params.update(additional_params)
|
||||
result = await self._async_generate_with_parameters(conversation, endpoint, additional_params)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
response = None
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.api_host}{endpoint}",
|
||||
json=request_params,
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
except asyncio.TimeoutError:
|
||||
return "The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.debug(f"Err was: {err}")
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {response}")
|
||||
return f"Failed to communicate with the API! {err}"
|
||||
|
||||
_LOGGER.debug(result)
|
||||
|
||||
return self._extract_response(result)
|
||||
return result
|
||||
|
||||
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
admin_key: str
|
||||
|
||||
Reference in New Issue
Block a user