mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
429 lines
18 KiB
Python
429 lines
18 KiB
Python
"""Defines the OpenAI API compatible agents"""
|
|
from __future__ import annotations
|
|
import json
|
|
|
|
import aiohttp
|
|
import asyncio
|
|
import datetime
|
|
import logging
|
|
from typing import List, Dict, Tuple, AsyncGenerator, Any, Optional
|
|
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.components import conversation
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL
|
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
|
from homeassistant.helpers import llm
|
|
|
|
from custom_components.llama_conversation.utils import format_url, get_oai_formatted_messages, get_oai_formatted_tools, parse_raw_tool_call
|
|
from custom_components.llama_conversation.const import (
|
|
CONF_CHAT_MODEL,
|
|
CONF_MAX_TOKENS,
|
|
CONF_TEMPERATURE,
|
|
CONF_TOP_P,
|
|
CONF_REQUEST_TIMEOUT,
|
|
CONF_OPENAI_API_KEY,
|
|
CONF_REMEMBER_CONVERSATION,
|
|
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
|
CONF_GENERIC_OPENAI_PATH,
|
|
CONF_ENABLE_LEGACY_TOOL_CALLING,
|
|
CONF_RESPONSE_JSON_SCHEMA,
|
|
DEFAULT_MAX_TOKENS,
|
|
DEFAULT_TEMPERATURE,
|
|
DEFAULT_TOP_P,
|
|
DEFAULT_REQUEST_TIMEOUT,
|
|
DEFAULT_REMEMBER_CONVERSATION,
|
|
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
|
|
DEFAULT_GENERIC_OPENAI_PATH,
|
|
DEFAULT_ENABLE_LEGACY_TOOL_CALLING,
|
|
RECOMMENDED_CHAT_MODELS,
|
|
)
|
|
from custom_components.llama_conversation.entity import TextGenerationResult, LocalLLMClient
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
class GenericOpenAIAPIClient(LocalLLMClient):
|
|
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
|
|
|
|
api_host: str
|
|
api_key: str
|
|
|
|
_attr_supports_streaming = True
|
|
|
|
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
|
|
super().__init__(hass, client_options)
|
|
self.api_host = format_url(
|
|
hostname=client_options[CONF_HOST],
|
|
port=client_options[CONF_PORT],
|
|
ssl=client_options[CONF_SSL],
|
|
path="/" + client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
|
)
|
|
|
|
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
|
|
|
|
@staticmethod
|
|
def get_name(client_options: dict[str, Any]):
|
|
host = client_options[CONF_HOST]
|
|
port = client_options[CONF_PORT]
|
|
ssl = client_options[CONF_SSL]
|
|
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
|
|
return f"Generic OpenAI at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
|
|
|
|
@staticmethod
|
|
async def async_validate_connection(hass: HomeAssistant, user_input: Dict[str, Any]) -> str | None:
|
|
headers = {}
|
|
api_key = user_input.get(CONF_OPENAI_API_KEY)
|
|
api_base_path = user_input.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
try:
|
|
session = async_get_clientsession(hass)
|
|
async with session.get(
|
|
format_url(
|
|
hostname=user_input[CONF_HOST],
|
|
port=user_input[CONF_PORT],
|
|
ssl=user_input[CONF_SSL],
|
|
path=f"/{api_base_path}/models"
|
|
),
|
|
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
|
headers=headers
|
|
) as response:
|
|
if response.ok:
|
|
return None
|
|
else:
|
|
return f"HTTP Status {response.status}"
|
|
except Exception as ex:
|
|
return str(ex)
|
|
|
|
async def async_get_available_models(self) -> List[str]:
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
try:
|
|
session = async_get_clientsession(self.hass)
|
|
async with session.get(
|
|
f"{self.api_host}/models",
|
|
timeout=aiohttp.ClientTimeout(total=5), # quick timeout
|
|
headers=headers
|
|
) as response:
|
|
response.raise_for_status()
|
|
models_result = await response.json()
|
|
except (asyncio.TimeoutError, aiohttp.ClientResponseError):
|
|
_LOGGER.exception("Failed to get available models")
|
|
return RECOMMENDED_CHAT_MODELS
|
|
|
|
return [ model["id"] for model in models_result["data"] ]
|
|
|
|
def _generate_stream(self,
|
|
conversation: List[conversation.Content],
|
|
llm_api: llm.APIInstance | None,
|
|
agent_id: str,
|
|
entity_options: dict[str, Any],
|
|
) -> AsyncGenerator[TextGenerationResult, None]:
|
|
model_name = entity_options[CONF_CHAT_MODEL]
|
|
temperature = entity_options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
|
top_p = entity_options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
|
timeout = entity_options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
|
enable_legacy_tool_calling = entity_options.get(CONF_ENABLE_LEGACY_TOOL_CALLING, DEFAULT_ENABLE_LEGACY_TOOL_CALLING)
|
|
|
|
endpoint, additional_params = self._chat_completion_params(entity_options)
|
|
messages = get_oai_formatted_messages(conversation, user_content_as_list=True)
|
|
|
|
request_params = {
|
|
"model": model_name,
|
|
"stream": True,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"messages": messages
|
|
}
|
|
|
|
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
|
|
if response_json_schema:
|
|
request_params["response_format"] = {
|
|
"type": "json_schema",
|
|
"json_schema": {
|
|
"name": "ha_task",
|
|
"schema": response_json_schema,
|
|
"strict": True,
|
|
},
|
|
}
|
|
|
|
tools = None
|
|
# "legacy" tool calling passes the tools directly as part of the system prompt instead of as "tools"
|
|
# most local backends absolutely butcher any sort of prompt formatting when using tool calling
|
|
if llm_api and not enable_legacy_tool_calling:
|
|
tools = get_oai_formatted_tools(llm_api, self._async_get_all_exposed_domains())
|
|
request_params["tools"] = tools
|
|
|
|
request_params.update(additional_params)
|
|
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
_LOGGER.debug(f"Generating completion with {len(messages)} messages and {len(tools) if tools else 0} tools...")
|
|
|
|
session = async_get_clientsession(self.hass)
|
|
|
|
async def anext_token() -> AsyncGenerator[Tuple[Optional[str], Optional[List[dict]]], None]:
|
|
response = None
|
|
chunk = None
|
|
try:
|
|
async with session.post(
|
|
f"{self.api_host}{endpoint}",
|
|
json=request_params,
|
|
timeout=timeout,
|
|
headers=headers
|
|
) as response:
|
|
response.raise_for_status()
|
|
async for line_bytes in response.content:
|
|
raw_line = line_bytes.decode("utf-8").strip()
|
|
if raw_line.startswith("error: "):
|
|
raise Exception(f"Error from server: {raw_line}")
|
|
chunk = raw_line.removeprefix("data: ")
|
|
if "[DONE]" in chunk:
|
|
break
|
|
|
|
if chunk and chunk.strip():
|
|
to_say, tool_calls = self._extract_response(json.loads(chunk))
|
|
if to_say or tool_calls:
|
|
yield to_say, tool_calls
|
|
except asyncio.TimeoutError as err:
|
|
raise HomeAssistantError("The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.") from err
|
|
except aiohttp.ClientError as err:
|
|
raise HomeAssistantError(f"Failed to communicate with the API! {err}") from err
|
|
|
|
return self._async_stream_parse_completion(llm_api, agent_id, entity_options, anext_token=anext_token())
|
|
|
|
def _chat_completion_params(self, entity_options: dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
|
request_params = {}
|
|
endpoint = "/chat/completions"
|
|
return endpoint, request_params
|
|
|
|
def _extract_response(self, response_json: dict) -> Tuple[Optional[str], Optional[List[dict]]]:
|
|
if "choices" not in response_json or len(response_json["choices"]) == 0: # finished
|
|
_LOGGER.warning("Response missing or empty 'choices'. Keys present: %s. Full response: %s",
|
|
list(response_json.keys()), response_json)
|
|
return None, None
|
|
|
|
choice = response_json["choices"][0]
|
|
tool_calls = None
|
|
if response_json["object"] == "chat.completion":
|
|
response_text = choice["message"]["content"]
|
|
streamed = False
|
|
elif response_json["object"] == "chat.completion.chunk":
|
|
response_text = choice["delta"].get("content", "")
|
|
if "tool_calls" in choice["delta"] and choice["delta"]["tool_calls"] is not None:
|
|
tool_calls = [call["function"] for call in choice["delta"]["tool_calls"]]
|
|
streamed = True
|
|
else:
|
|
response_text = choice["text"]
|
|
streamed = False
|
|
|
|
if not streamed or (streamed and choice.get("finish_reason")):
|
|
finish_reason = choice.get("finish_reason")
|
|
if finish_reason in ("length", "content_filter"):
|
|
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
|
|
|
return response_text, tool_calls
|
|
|
|
|
|
class GenericOpenAIResponsesAPIClient(LocalLLMClient):
|
|
"""Implements the OpenAPI-compatible Responses API backend."""
|
|
|
|
api_host: str
|
|
api_key: str
|
|
|
|
_attr_supports_streaming = False
|
|
|
|
_last_response_id: str | None = None
|
|
_last_response_id_time: datetime.datetime | None = None
|
|
|
|
def __init__(self, hass: HomeAssistant, client_options: dict[str, Any]) -> None:
|
|
super().__init__(hass, client_options)
|
|
self.api_host = format_url(
|
|
hostname=client_options[CONF_HOST],
|
|
port=client_options[CONF_PORT],
|
|
ssl=client_options[CONF_SSL],
|
|
path="/" + client_options.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
|
)
|
|
|
|
self.api_key = client_options.get(CONF_OPENAI_API_KEY, "")
|
|
|
|
@staticmethod
|
|
def get_name(client_options: dict[str, Any]):
|
|
host = client_options[CONF_HOST]
|
|
port = client_options[CONF_PORT]
|
|
ssl = client_options[CONF_SSL]
|
|
path = "/" + client_options[CONF_GENERIC_OPENAI_PATH]
|
|
return f"Generic OpenAI at '{format_url(hostname=host, port=port, ssl=ssl, path=path)}'"
|
|
|
|
def _responses_params(self, conversation: List[conversation.Content], entity_options: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
|
request_params = {}
|
|
|
|
endpoint = "/responses"
|
|
# Find the last user message in the conversation and use its content as the input
|
|
input_text: str | None = None
|
|
for msg in reversed(conversation):
|
|
try:
|
|
if msg.role == "user":
|
|
input_text = msg.content
|
|
except Exception:
|
|
continue
|
|
|
|
if input_text is None:
|
|
# fallback to the last message content
|
|
input_text = getattr(conversation[-1], "content", "")
|
|
|
|
request_params["input"] = input_text
|
|
|
|
# Assign previous_response_id if relevant
|
|
if self._last_response_id and self._last_response_id_time and entity_options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION):
|
|
# If the last response was generated recently, use it as a context
|
|
configured_memory_time: datetime.timedelta = datetime.timedelta(minutes=entity_options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES))
|
|
last_conversation_age: datetime.timedelta = datetime.datetime.now() - self._last_response_id_time
|
|
_LOGGER.debug(f"Conversation ID age: {last_conversation_age}")
|
|
if last_conversation_age < configured_memory_time:
|
|
_LOGGER.debug(f"Using previous response ID {self._last_response_id} for context")
|
|
request_params["previous_response_id"] = self._last_response_id
|
|
else:
|
|
_LOGGER.debug(f"Previous response ID {self._last_response_id} is too old, not using it for context")
|
|
|
|
return endpoint, request_params
|
|
|
|
def _validate_response_payload(self, response_json: dict) -> bool:
|
|
"""
|
|
Validate that the payload given matches the expected structure for the Responses API.
|
|
|
|
API ref: https://platform.openai.com/docs/api-reference/responses/object
|
|
|
|
Returns True or raises an error
|
|
"""
|
|
required_response_keys = ["object", "output", "status", "id"]
|
|
missing_keys = [key for key in required_response_keys if key not in response_json]
|
|
if missing_keys:
|
|
raise ValueError(f"Response JSON is missing required keys: {', '.join(missing_keys)}")
|
|
|
|
if response_json["object"] != "response":
|
|
raise ValueError(f"Response JSON object is not 'response', got {response_json['object']}")
|
|
|
|
if "error" in response_json and response_json["error"] is not None:
|
|
error = response_json["error"]
|
|
_LOGGER.error(f"Response received error payload.")
|
|
if "message" not in error:
|
|
raise ValueError("Response JSON error is missing 'message' key")
|
|
raise ValueError(f"Response JSON error: {error['message']}")
|
|
|
|
return True
|
|
|
|
def _check_response_status(self, response_json: dict) -> None:
|
|
"""
|
|
Check the status of the response and logs a message if it is not 'completed'.
|
|
|
|
API ref: https://platform.openai.com/docs/api-reference/responses/object#responses_object-status
|
|
"""
|
|
if response_json["status"] != "completed":
|
|
_LOGGER.warning(f"Response status is not 'completed', got {response_json['status']}. Details: {response_json.get('incomplete_details', 'No details provided')}")
|
|
|
|
def _extract_response(self, response_json: dict) -> str | None:
|
|
self._validate_response_payload(response_json)
|
|
self._check_response_status(response_json)
|
|
|
|
outputs = response_json["output"]
|
|
|
|
if len(outputs) > 1:
|
|
_LOGGER.warning("Received multiple outputs from the Responses API, returning the first one.")
|
|
|
|
output = outputs[0]
|
|
|
|
if not output["type"] == "message":
|
|
raise NotImplementedError(f"Response output type is not 'message', got {output['type']}")
|
|
|
|
if len(output["content"]) > 1:
|
|
_LOGGER.warning("Received multiple content items in the response output, returning the first one.")
|
|
|
|
content = output["content"][0]
|
|
|
|
output_type = content["type"]
|
|
|
|
to_return: str | None = None
|
|
|
|
if output_type == "refusal":
|
|
_LOGGER.info("Received a refusal from the Responses API.")
|
|
to_return = content["refusal"]
|
|
elif output_type == "output_text":
|
|
to_return = content["text"]
|
|
else:
|
|
raise ValueError(f"Response output content type is not expected, got {output_type}")
|
|
|
|
# Save the response_id and return the successful response.
|
|
response_id = response_json["id"]
|
|
self._last_response_id = response_id
|
|
self._last_response_id_time = datetime.datetime.now()
|
|
|
|
return to_return
|
|
|
|
async def _generate(self,
|
|
conversation: List[conversation.Content],
|
|
llm_api: llm.APIInstance | None,
|
|
agent_id: str,
|
|
entity_options: dict[str, Any],
|
|
) -> TextGenerationResult:
|
|
"""Generate a response using the OpenAI-compatible Responses API (non-streaming endpoint wrapped as a single-chunk stream)."""
|
|
|
|
model_name = entity_options.get(CONF_CHAT_MODEL)
|
|
timeout = entity_options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
|
endpoint, additional_params = self._responses_params(conversation, entity_options)
|
|
|
|
request_params: Dict[str, Any] = {
|
|
"model": model_name,
|
|
}
|
|
response_json_schema = entity_options.get(CONF_RESPONSE_JSON_SCHEMA)
|
|
if response_json_schema:
|
|
request_params["response_format"] = {
|
|
"type": "json_schema",
|
|
"json_schema": {
|
|
"name": "ha_task",
|
|
"schema": response_json_schema,
|
|
"strict": True,
|
|
},
|
|
}
|
|
request_params.update(additional_params)
|
|
|
|
headers: Dict[str, Any] = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
session = async_get_clientsession(self.hass)
|
|
response = None
|
|
|
|
try:
|
|
async with session.post(
|
|
f"{self.api_host}{endpoint}",
|
|
json=request_params,
|
|
timeout=aiohttp.ClientTimeout(total=timeout),
|
|
headers=headers,
|
|
) as response:
|
|
response.raise_for_status()
|
|
response_json = await response.json()
|
|
|
|
try:
|
|
text = self._extract_response(response_json)
|
|
if not text:
|
|
return TextGenerationResult(raise_error=True, error_msg="The Responses API returned an empty response.")
|
|
# return await self._async_parse_completion(llm_api, agent_id, entity_options, text)
|
|
return TextGenerationResult(response=text) # Currently we don't extract any info from the response besides the raw model output
|
|
except Exception as err:
|
|
_LOGGER.exception("Failed to parse Responses API payload: %s", err)
|
|
return TextGenerationResult(raise_error=True, error_msg=f"Failed to parse Responses API payload: {err}")
|
|
except asyncio.TimeoutError:
|
|
return TextGenerationResult(raise_error=True, error_msg="The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities.")
|
|
except aiohttp.ClientError as err:
|
|
_LOGGER.debug(f"Err was: {err}")
|
|
_LOGGER.debug(f"Request was: {request_params}")
|
|
_LOGGER.debug(f"Result was: {response}")
|
|
return TextGenerationResult(raise_error=True, error_msg=f"Failed to communicate with the API! {err}")
|