mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
Merge pull request #279 from sredman/work/sredman/openai-responses
Implement basic OpenAI Responses API
This commit is contained in:
@@ -26,11 +26,12 @@ from .const import (
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
)
|
||||
from .conversation import LlamaCppAgent, GenericOpenAIAPIAgent, TextGenerationWebuiAgent, \
|
||||
LlamaCppPythonAPIAgent, OllamaAPIAgent, LocalLLMAgent
|
||||
from .conversation import LlamaCppAgent, GenericOpenAIAPIAgent, GenericOpenAIResponsesAPIAgent, \
|
||||
TextGenerationWebuiAgent, LlamaCppPythonAPIAgent, OllamaAPIAgent, LocalLLMAgent
|
||||
|
||||
type LocalLLMConfigEntry = ConfigEntry[LocalLLMAgent]
|
||||
|
||||
@@ -56,13 +57,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) ->
|
||||
agent_cls = LlamaCppAgent
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
agent_cls = GenericOpenAIAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI_RESPONSES:
|
||||
agent_cls = GenericOpenAIResponsesAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
agent_cls = TextGenerationWebuiAgent
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
agent_cls = LlamaCppPythonAPIAgent
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
agent_cls = OllamaAPIAgent
|
||||
|
||||
|
||||
return agent_cls(hass, entry)
|
||||
|
||||
# create the agent in an executor job because the constructor calls `open()`
|
||||
@@ -74,7 +77,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: LocalLLMConfigEntry) ->
|
||||
|
||||
# forward setup to platform to register the entity
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -130,12 +133,12 @@ class HassServiceTool(llm.Tool):
|
||||
domain, service = tuple(tool_input.tool_args["service"].split("."))
|
||||
except ValueError:
|
||||
return { "result": "unknown service" }
|
||||
|
||||
|
||||
target_device = tool_input.tool_args["target_device"]
|
||||
|
||||
if domain not in self.ALLOWED_DOMAINS or service not in self.ALLOWED_SERVICES:
|
||||
return { "result": "unknown service" }
|
||||
|
||||
|
||||
if domain == "script" and service not in ["reload", "turn_on", "turn_off", "toggle"]:
|
||||
return { "result": "unknown service" }
|
||||
|
||||
@@ -153,12 +156,12 @@ class HassServiceTool(llm.Tool):
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to execute service for model")
|
||||
return { "result": "failed" }
|
||||
|
||||
|
||||
return { "result": "success" }
|
||||
|
||||
class HomeLLMAPI(llm.API):
|
||||
"""
|
||||
An API that allows calling Home Assistant services to maintain compatibility
|
||||
An API that allows calling Home Assistant services to maintain compatibility
|
||||
with the older (v3 and older) Home LLM models
|
||||
"""
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ from .const import (
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
||||
@@ -133,6 +134,7 @@ from .const import (
|
||||
BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA,
|
||||
PROMPT_TEMPLATE_DESCRIPTIONS,
|
||||
@@ -164,10 +166,11 @@ def STEP_INIT_DATA_SCHEMA(backend_type=None):
|
||||
CONF_BACKEND_TYPE,
|
||||
default=backend_type if backend_type else DEFAULT_BACKEND_TYPE
|
||||
): SelectSelector(SelectSelectorConfig(
|
||||
options=[
|
||||
options=[
|
||||
BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING,
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI,
|
||||
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES,
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
|
||||
BACKEND_TYPE_OLLAMA
|
||||
],
|
||||
@@ -215,13 +218,13 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ss
|
||||
extra1, extra2 = ({}, {})
|
||||
default_port = DEFAULT_PORT
|
||||
|
||||
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))
|
||||
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
|
||||
extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
default_port = "8000"
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
default_port = "11434"
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
elif backend_type in [BACKEND_TYPE_GENERIC_OPENAI, BACKEND_TYPE_GENERIC_OPENAI_RESPONSES]:
|
||||
default_port = ""
|
||||
extra2[vol.Required(
|
||||
CONF_GENERIC_OPENAI_PATH,
|
||||
@@ -257,7 +260,7 @@ def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ss
|
||||
|
||||
|
||||
class BaseLlamaConversationConfigFlow(FlowHandler, ABC):
|
||||
"""Represent the base config flow for Z-Wave JS."""
|
||||
"""Represent the base config flow for Local LLM."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -333,7 +336,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
"""Handle the initial step."""
|
||||
self.model_config = {}
|
||||
self.options = {}
|
||||
|
||||
|
||||
# make sure the API is registered
|
||||
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(self.hass)]):
|
||||
llm.async_register_api(self.hass, HomeLLMAPI(self.hass))
|
||||
@@ -382,7 +385,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
step_id="install_local_wheels",
|
||||
progress_action="install_local_wheels",
|
||||
)
|
||||
|
||||
|
||||
if self.install_wheel_task and not self.install_wheel_task.done():
|
||||
return self.async_show_progress(
|
||||
progress_task=self.install_wheel_task,
|
||||
@@ -489,7 +492,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
step_id="download",
|
||||
progress_action="download",
|
||||
)
|
||||
|
||||
|
||||
if self.download_task and not self.download_task.done():
|
||||
return self.async_show_progress(
|
||||
progress_task=self.download_task,
|
||||
@@ -508,7 +511,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
|
||||
self.download_task = None
|
||||
return self.async_show_progress_done(next_step_id=next_step)
|
||||
|
||||
|
||||
async def _async_validate_generic_openai(self, user_input: dict) -> tuple:
|
||||
"""
|
||||
Validates a connection to an OpenAI compatible API server and that the model exists on the remote server
|
||||
@@ -585,7 +588,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
except Exception as ex:
|
||||
_LOGGER.info("Connection error was: %s", repr(ex))
|
||||
return "failed_to_connect", ex, []
|
||||
|
||||
|
||||
async def _async_validate_ollama(self, user_input: dict) -> tuple:
|
||||
"""
|
||||
Validates a connection to ollama and that the model exists on the remote server
|
||||
@@ -617,7 +620,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
model_name = self.model_config[CONF_CHAT_MODEL]
|
||||
if model["name"] == model_name:
|
||||
return (None, None, [])
|
||||
|
||||
|
||||
return "missing_model_api", None, [x["name"] for x in models_result["models"]]
|
||||
|
||||
except Exception as ex:
|
||||
@@ -644,9 +647,9 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
error_message, ex, possible_models = await self._async_validate_text_generation_webui(user_input)
|
||||
elif backend_type == BACKEND_TYPE_OLLAMA:
|
||||
error_message, ex, possible_models = await self._async_validate_ollama(user_input)
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI and \
|
||||
elif backend_type in [BACKEND_TYPE_GENERIC_OPENAI, BACKEND_TYPE_GENERIC_OPENAI_RESPONSES] and \
|
||||
user_input.get(CONF_GENERIC_OPENAI_VALIDATE_MODEL, DEFAULT_GENERIC_OPENAI_VALIDATE_MODEL):
|
||||
error_message, ex, possible_models = await self._async_validate_generic_openai(user_input)
|
||||
error_message, ex, possible_models = await self._async_validate_generic_openai(user_input)
|
||||
else:
|
||||
possible_models = []
|
||||
|
||||
@@ -674,7 +677,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
return self.async_show_form(
|
||||
step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False
|
||||
)
|
||||
|
||||
|
||||
async def async_step_model_parameters(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
@@ -704,7 +707,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<tools>", tools)
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<area>", area)
|
||||
selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<user_instruction>", user_instruction)
|
||||
|
||||
|
||||
schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type))
|
||||
|
||||
if user_input:
|
||||
@@ -716,7 +719,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_gbnf_file"
|
||||
description_placeholders["filename"] = filename
|
||||
|
||||
|
||||
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
@@ -725,7 +728,7 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
|
||||
|
||||
if user_input[CONF_LLM_HASS_API] == "none":
|
||||
user_input.pop(CONF_LLM_HASS_API)
|
||||
|
||||
|
||||
if len(errors) == 0:
|
||||
try:
|
||||
# validate input
|
||||
@@ -792,7 +795,7 @@ class OptionsFlow(config_entries.OptionsFlow):
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
errors["base"] = "missing_gbnf_file"
|
||||
description_placeholders["filename"] = filename
|
||||
|
||||
|
||||
if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
|
||||
if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
|
||||
@@ -804,7 +807,7 @@ class OptionsFlow(config_entries.OptionsFlow):
|
||||
|
||||
if len(errors) == 0:
|
||||
return self.async_create_entry(title="Local LLM Conversation", data=user_input)
|
||||
|
||||
|
||||
schema = local_llama_config_option_schema(
|
||||
self.hass,
|
||||
self.config_entry.options,
|
||||
@@ -1063,7 +1066,7 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
)),
|
||||
})
|
||||
elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
|
||||
elif backend_type in BACKEND_TYPE_GENERIC_OPENAI:
|
||||
result = insert_after_key(result, CONF_MAX_TOKENS, {
|
||||
vol.Required(
|
||||
CONF_TEMPERATURE,
|
||||
@@ -1086,6 +1089,32 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT
|
||||
default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
||||
): BooleanSelector(BooleanSelectorConfig()),
|
||||
})
|
||||
elif backend_type in BACKEND_TYPE_GENERIC_OPENAI_RESPONSES:
|
||||
del result[CONF_REMEMBER_NUM_INTERACTIONS]
|
||||
result = insert_after_key(result, CONF_REMEMBER_CONVERSATION, {
|
||||
vol.Required(
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES)},
|
||||
default=DEFAULT_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=180, step=0.5, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
|
||||
})
|
||||
result = insert_after_key(result, CONF_MAX_TOKENS, {
|
||||
vol.Required(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||
default=DEFAULT_TEMPERATURE,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=3, step=0.05)),
|
||||
vol.Required(
|
||||
CONF_TOP_P,
|
||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||
default=DEFAULT_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Required(
|
||||
CONF_REQUEST_TIMEOUT,
|
||||
description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
|
||||
default=DEFAULT_REQUEST_TIMEOUT,
|
||||
): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
|
||||
})
|
||||
elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
|
||||
result = insert_after_key(result, CONF_MAX_TOKENS, {
|
||||
vol.Required(
|
||||
|
||||
@@ -110,6 +110,7 @@ BACKEND_TYPE_LLAMA_HF = "llama_cpp_hf"
|
||||
BACKEND_TYPE_LLAMA_EXISTING = "llama_cpp_existing"
|
||||
BACKEND_TYPE_TEXT_GEN_WEBUI = "text-generation-webui_api"
|
||||
BACKEND_TYPE_GENERIC_OPENAI = "generic_openai"
|
||||
BACKEND_TYPE_GENERIC_OPENAI_RESPONSES = "generic_openai_responses"
|
||||
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER = "llama_cpp_python_server"
|
||||
BACKEND_TYPE_OLLAMA = "ollama"
|
||||
DEFAULT_BACKEND_TYPE = BACKEND_TYPE_LLAMA_HF
|
||||
@@ -118,8 +119,8 @@ CONF_SELECTED_LANGUAGE_OPTIONS = [ "en", "de", "fr", "es", "pl"]
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION = "downloaded_model_quantization"
|
||||
CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = [
|
||||
"Q4_0", "Q4_1", "Q5_0", "Q5_1", "IQ2_XXS", "IQ2_XS", "IQ2_S", "IQ2_M", "IQ1_S", "IQ1_M",
|
||||
"Q2_K", "Q2_K_S", "IQ3_XXS", "IQ3_S", "IQ3_M", "Q3_K", "IQ3_XS", "Q3_K_S", "Q3_K_M", "Q3_K_L",
|
||||
"IQ4_NL", "IQ4_XS", "Q4_K", "Q4_K_S", "Q4_K_M", "Q5_K", "Q5_K_S", "Q5_K_M", "Q6_K", "Q8_0",
|
||||
"Q2_K", "Q2_K_S", "IQ3_XXS", "IQ3_S", "IQ3_M", "Q3_K", "IQ3_XS", "Q3_K_S", "Q3_K_M", "Q3_K_L",
|
||||
"IQ4_NL", "IQ4_XS", "Q4_K", "Q4_K_S", "Q4_K_M", "Q5_K", "Q5_K_S", "Q5_K_M", "Q6_K", "Q8_0",
|
||||
"F16", "BF16", "F32"
|
||||
]
|
||||
DEFAULT_DOWNLOADED_MODEL_QUANTIZATION = "Q4_K_M"
|
||||
@@ -242,6 +243,8 @@ CONF_REMEMBER_CONVERSATION = "remember_conversation"
|
||||
DEFAULT_REMEMBER_CONVERSATION = True
|
||||
CONF_REMEMBER_NUM_INTERACTIONS = "remember_num_interactions"
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS = 5
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES = "remember_conversation_time_minutes"
|
||||
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES = 2
|
||||
CONF_PROMPT_CACHING_ENABLED = "prompt_caching"
|
||||
DEFAULT_PROMPT_CACHING_ENABLED = False
|
||||
CONF_PROMPT_CACHING_INTERVAL = "prompt_caching_interval"
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import csv
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
@@ -64,6 +65,7 @@ from .const import (
|
||||
CONF_REFRESH_SYSTEM_PROMPT,
|
||||
CONF_REMEMBER_CONVERSATION,
|
||||
CONF_REMEMBER_NUM_INTERACTIONS,
|
||||
CONF_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
CONF_PROMPT_CACHING_ENABLED,
|
||||
CONF_PROMPT_CACHING_INTERVAL,
|
||||
CONF_SERVICE_CALL_REGEX,
|
||||
@@ -98,6 +100,7 @@ from .const import (
|
||||
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
||||
DEFAULT_REMEMBER_CONVERSATION,
|
||||
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
||||
DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES,
|
||||
DEFAULT_PROMPT_CACHING_ENABLED,
|
||||
DEFAULT_PROMPT_CACHING_INTERVAL,
|
||||
DEFAULT_SERVICE_CALL_REGEX,
|
||||
@@ -141,7 +144,7 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Handle options update."""
|
||||
hass.data[DOMAIN][entry.entry_id] = entry
|
||||
|
||||
|
||||
# call update handler
|
||||
agent: LocalLLMAgent = entry.runtime_data
|
||||
await hass.async_add_executor_job(agent._update_options)
|
||||
@@ -244,7 +247,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
if set(self.in_context_examples[0].keys()) != set(["type", "request", "tool", "response" ]):
|
||||
raise Exception("ICL csv file did not have 2 columns: service & response")
|
||||
|
||||
|
||||
if len(self.in_context_examples) == 0:
|
||||
_LOGGER.warning(f"There were no in context learning examples found in the file '{filename}'!")
|
||||
self.in_context_examples = None
|
||||
@@ -259,7 +262,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
self._attr_supported_features = (
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
|
||||
if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
||||
self._load_icl_examples(self.entry.options.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE))
|
||||
else:
|
||||
@@ -276,7 +279,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
"""Return a list of supported languages."""
|
||||
return MATCH_ALL
|
||||
|
||||
|
||||
def _load_model(self, entry: ConfigEntry) -> None:
|
||||
"""Load the model on the backend. Implemented by sub-classes"""
|
||||
raise NotImplementedError()
|
||||
@@ -286,7 +289,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._load_model, entry
|
||||
)
|
||||
|
||||
|
||||
def _generate(self, conversation: dict) -> str:
|
||||
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
|
||||
raise NotImplementedError()
|
||||
@@ -296,7 +299,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
return await self.hass.async_add_executor_job(
|
||||
self._generate, conversation
|
||||
)
|
||||
|
||||
|
||||
def _warn_context_size(self):
|
||||
num_entities = len(self._async_get_exposed_entities()[0])
|
||||
context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
@@ -334,17 +337,17 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
service_call_pattern = re.compile(service_call_regex, flags=re.MULTILINE)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("There was a problem compiling the service call regex")
|
||||
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, there was a problem compiling the service call regex: {err}",
|
||||
)
|
||||
|
||||
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
|
||||
llm_api: llm.APIInstance | None = None
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
@@ -372,7 +375,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
)
|
||||
|
||||
message_history = [ _convert_content(content) for content in chat_log.content ]
|
||||
|
||||
|
||||
# re-generate prompt if necessary
|
||||
if len(message_history) == 0 or refresh_system_prompt:
|
||||
try:
|
||||
@@ -387,9 +390,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
|
||||
system_prompt = { "role": "system", "message": message }
|
||||
|
||||
|
||||
if len(message_history) == 0:
|
||||
message_history.append(system_prompt)
|
||||
else:
|
||||
@@ -403,7 +406,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.exception("There was a problem talking to the backend")
|
||||
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
@@ -412,13 +415,13 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
|
||||
|
||||
# remove end of text token if it was returned
|
||||
response = response.replace(template_desc["assistant"]["suffix"], "")
|
||||
|
||||
# remove think blocks
|
||||
# remove think blocks
|
||||
response = re.sub(rf"^.*?{template_desc["chain_of_thought"]["suffix"]}", "", response, flags=re.DOTALL)
|
||||
|
||||
|
||||
message_history.append({"role": "assistant", "message": response})
|
||||
if remember_conversation:
|
||||
if remember_num_interactions and len(message_history) > (remember_num_interactions * 2) + 1:
|
||||
@@ -460,7 +463,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
vol.Required("name"): str,
|
||||
vol.Required("arguments"): dict,
|
||||
})
|
||||
|
||||
|
||||
try:
|
||||
schema_to_validate(parsed_tool_call)
|
||||
except vol.Error as ex:
|
||||
@@ -487,7 +490,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
# convert string "tuple" to a list for RGB colors
|
||||
if "rgb_color" in args_dict and isinstance(args_dict["rgb_color"], str):
|
||||
args_dict["rgb_color"] = [ int(x) for x in args_dict["rgb_color"][1:-1].split(",") ]
|
||||
|
||||
|
||||
if llm_api.api.id == HOME_LLM_API_ID:
|
||||
to_say = to_say + parsed_tool_call.pop("to_say", "")
|
||||
tool_input = llm.ToolInput(
|
||||
@@ -534,7 +537,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.exception("There was a problem talking to the backend")
|
||||
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
@@ -546,7 +549,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
message_history.append({"role": "assistant", "message": response})
|
||||
message_history.append({"role": "assistant", "message": to_say})
|
||||
|
||||
|
||||
# generate intent response to Home Assistant
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(to_say.strip())
|
||||
@@ -577,7 +580,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
if entity:
|
||||
if entity.aliases:
|
||||
attributes["aliases"] = entity.aliases
|
||||
|
||||
|
||||
if entity.unit_of_measurement:
|
||||
attributes["state"] = attributes["state"] + " " + entity.unit_of_measurement
|
||||
|
||||
@@ -587,13 +590,13 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
area_id = device.area_id
|
||||
if entity and entity.area_id:
|
||||
area_id = entity.area_id
|
||||
|
||||
|
||||
if area_id:
|
||||
area = area_registry.async_get_area(entity.area_id)
|
||||
if area:
|
||||
attributes["area_id"] = area.id
|
||||
attributes["area_name"] = area.name
|
||||
|
||||
|
||||
entity_states[state.entity_id] = attributes
|
||||
domains.add(state.domain)
|
||||
|
||||
@@ -627,7 +630,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
|
||||
_LOGGER.debug(formatted_prompt)
|
||||
return formatted_prompt
|
||||
|
||||
|
||||
def _format_tool(self, name: str, parameters: vol.Schema, description: str):
|
||||
style = self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT)
|
||||
|
||||
@@ -636,10 +639,10 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
if description:
|
||||
result = result + f" - {description}"
|
||||
return result
|
||||
|
||||
|
||||
raw_parameters: list = voluptuous_serialize.convert(
|
||||
parameters, custom_serializer=custom_custom_serializer)
|
||||
|
||||
|
||||
# handle vol.Any in the key side of things
|
||||
processed_parameters = []
|
||||
for param in raw_parameters:
|
||||
@@ -685,9 +688,9 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
raise Exception(f"Unknown tool format {style}")
|
||||
|
||||
|
||||
def _generate_icl_examples(self, num_examples, entity_names):
|
||||
entity_names = entity_names[:]
|
||||
entity_domains = set([x.split(".")[0] for x in entity_names])
|
||||
@@ -699,14 +702,14 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
x for x in self.in_context_examples
|
||||
if x["type"] in entity_domains
|
||||
]
|
||||
|
||||
|
||||
random.shuffle(in_context_examples)
|
||||
random.shuffle(entity_names)
|
||||
|
||||
num_examples_to_generate = min(num_examples, len(in_context_examples))
|
||||
if num_examples_to_generate < num_examples:
|
||||
_LOGGER.warning(f"Attempted to generate {num_examples} ICL examples for conversation, but only {len(in_context_examples)} are available!")
|
||||
|
||||
|
||||
examples = []
|
||||
for _ in range(num_examples_to_generate):
|
||||
chosen_example = in_context_examples.pop()
|
||||
@@ -748,7 +751,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
"arguments": tool_arguments
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
return examples
|
||||
|
||||
def _generate_system_prompt(self, prompt_template: str, llm_api: llm.APIInstance | None) -> str:
|
||||
@@ -796,7 +799,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
state = attributes["state"]
|
||||
exposed_attributes = expose_attributes(attributes)
|
||||
str_attributes = ";".join([state] + exposed_attributes)
|
||||
|
||||
|
||||
formatted_devices = formatted_devices + f"{name} '{attributes.get('friendly_name')}' = {str_attributes}\n"
|
||||
devices.append({
|
||||
"entity_id": name,
|
||||
@@ -828,7 +831,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
for domain in domains:
|
||||
if domain not in SERVICE_TOOL_ALLOWED_DOMAINS:
|
||||
continue
|
||||
|
||||
|
||||
# scripts show up as individual services
|
||||
if domain == "script" and not scripts_added:
|
||||
all_services.extend([
|
||||
@@ -839,7 +842,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
])
|
||||
scripts_added = True
|
||||
continue
|
||||
|
||||
|
||||
for name, service in service_dict.get(domain, {}).items():
|
||||
if name not in SERVICE_TOOL_ALLOWED_SERVICES:
|
||||
continue
|
||||
@@ -856,13 +859,13 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
self._format_tool(*tool)
|
||||
for tool in all_services
|
||||
]
|
||||
|
||||
|
||||
else:
|
||||
tools = [
|
||||
self._format_tool(tool.name, tool.parameters, tool.description)
|
||||
for tool in llm_api.tools
|
||||
]
|
||||
|
||||
|
||||
if self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT) == TOOL_FORMAT_MINIMAL:
|
||||
formatted_tools = ", ".join(tools)
|
||||
else:
|
||||
@@ -883,7 +886,7 @@ class LocalLLMAgent(ConversationEntity, AbstractConversationAgent):
|
||||
if self.in_context_examples and llm_api:
|
||||
num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES))
|
||||
render_variables["response_examples"] = self._generate_icl_examples(num_examples, list(entities_to_expose.keys()))
|
||||
|
||||
|
||||
return template.Template(prompt_template, self.hass).async_render(
|
||||
render_variables,
|
||||
parse_result=False,
|
||||
@@ -910,7 +913,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
|
||||
if not self.model_path:
|
||||
raise Exception(f"Model was not found at '{self.model_path}'!")
|
||||
|
||||
|
||||
validate_llama_cpp_python_installation()
|
||||
|
||||
# don't import it until now because the wheel is installed by config_flow.py
|
||||
@@ -921,10 +924,10 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
install_result = install_llama_cpp_python(self.hass.config.config_dir)
|
||||
if not install_result == True:
|
||||
raise ConfigEntryError("llama-cpp-python was not installed on startup and re-installing it led to an error!")
|
||||
|
||||
|
||||
validate_llama_cpp_python_installation()
|
||||
self.llama_cpp_module = importlib.import_module("llama_cpp")
|
||||
|
||||
|
||||
Llama = getattr(self.llama_cpp_module, "Llama")
|
||||
|
||||
_LOGGER.debug(f"Loading model '{self.model_path}'...")
|
||||
@@ -948,14 +951,14 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
self.grammar = None
|
||||
if entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
self._load_grammar(entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE))
|
||||
|
||||
|
||||
|
||||
# TODO: check about disk caching
|
||||
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
|
||||
# capacity_bytes=(512 * 10e8),
|
||||
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
||||
# ))
|
||||
|
||||
|
||||
self.remove_prompt_caching_listener = None
|
||||
self.last_cache_prime = None
|
||||
self.last_updated_entities = {}
|
||||
@@ -1025,7 +1028,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] != self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED) or \
|
||||
model_reloaded:
|
||||
self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED)
|
||||
|
||||
|
||||
async def cache_current_prompt(_now):
|
||||
await self._async_cache_prompt(None, None, None)
|
||||
async_call_later(self.hass, 1.0, cache_current_prompt)
|
||||
@@ -1039,7 +1042,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
# ignore sorting if prompt caching is disabled
|
||||
if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
|
||||
return entities, domains
|
||||
|
||||
|
||||
entity_order = { name: None for name in entities.keys() }
|
||||
entity_order.update(self.last_updated_entities)
|
||||
|
||||
@@ -1050,7 +1053,7 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
return (False, '', item_name)
|
||||
else:
|
||||
return (True, last_updated, '')
|
||||
|
||||
|
||||
# Sort the items based on the sort_key function
|
||||
sorted_items = sorted(list(entity_order.items()), key=sort_key)
|
||||
|
||||
@@ -1066,9 +1069,9 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
if enabled and not self.remove_prompt_caching_listener:
|
||||
_LOGGER.info("enabling prompt caching...")
|
||||
|
||||
entity_ids = [
|
||||
entity_ids = [
|
||||
state.entity_id for state in self.hass.states.async_all() \
|
||||
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
|
||||
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
|
||||
]
|
||||
|
||||
_LOGGER.debug(f"watching entities: {entity_ids}")
|
||||
@@ -1107,27 +1110,27 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
# if a refresh is already scheduled then exit
|
||||
if self.cache_refresh_after_cooldown:
|
||||
return
|
||||
|
||||
|
||||
# if we are inside the cooldown period, request a refresh and exit
|
||||
current_time = time.time()
|
||||
fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
||||
if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval:
|
||||
self.cache_refresh_after_cooldown = True
|
||||
return
|
||||
|
||||
|
||||
# try to acquire the lock, if we are still running for some reason, request a refresh and exit
|
||||
lock_acquired = self.model_lock.acquire(False)
|
||||
if not lock_acquired:
|
||||
self.cache_refresh_after_cooldown = True
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
prompt = self._format_prompt([
|
||||
{ "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)},
|
||||
{ "role": "user", "message": "" }
|
||||
], include_generation_prompt=False)
|
||||
|
||||
|
||||
input_tokens = self.llm.tokenize(
|
||||
prompt.encode(), add_bos=False
|
||||
)
|
||||
@@ -1154,10 +1157,10 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
finally:
|
||||
self.model_lock.release()
|
||||
|
||||
|
||||
|
||||
# schedule a refresh using async_call_later
|
||||
# if the flag is set after the delay then we do another refresh
|
||||
|
||||
|
||||
@callback
|
||||
async def refresh_if_requested(_now):
|
||||
if self.cache_refresh_after_cooldown:
|
||||
@@ -1172,8 +1175,8 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
|
||||
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
||||
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
|
||||
|
||||
|
||||
|
||||
|
||||
def _generate(self, conversation: dict) -> str:
|
||||
prompt = self._format_prompt(conversation)
|
||||
|
||||
@@ -1224,8 +1227,8 @@ class LlamaCppAgent(LocalLLMAgent):
|
||||
result = self.llm.detokenize(result_tokens).decode()
|
||||
|
||||
return result
|
||||
|
||||
class GenericOpenAIAPIAgent(LocalLLMAgent):
|
||||
|
||||
class BaseOpenAICompatibleAPIAgent(LocalLLMAgent):
|
||||
api_host: str
|
||||
api_key: str
|
||||
model_name: str
|
||||
@@ -1237,46 +1240,17 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
|
||||
ssl=entry.data[CONF_SSL],
|
||||
path=""
|
||||
)
|
||||
|
||||
|
||||
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
||||
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
||||
|
||||
async def _async_generate_with_parameters(self, conversation: dict, endpoint: str, additional_params: dict) -> str:
|
||||
"""Generate a response using the OpenAI-compatible API"""
|
||||
|
||||
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/chat/completions"
|
||||
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: dict) -> (str, dict):
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/completions"
|
||||
request_params["prompt"] = self._format_prompt(conversation)
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
if response_json["object"] in ["chat.completion", "chat.completion.chunk"]:
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
async def _async_generate(self, conversation: dict) -> str:
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
|
||||
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
@@ -1284,12 +1258,7 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
if use_chat_api:
|
||||
endpoint, additional_params = self._chat_completion_params(conversation)
|
||||
else:
|
||||
endpoint, additional_params = self._completion_params(conversation)
|
||||
|
||||
|
||||
request_params.update(additional_params)
|
||||
|
||||
headers = {}
|
||||
@@ -1318,7 +1287,161 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
|
||||
_LOGGER.debug(result)
|
||||
|
||||
return self._extract_response(result)
|
||||
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
raise NotImplementedError("Subclasses must implement _extract_response()")
|
||||
|
||||
class GenericOpenAIAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible text completion and chat completion API backends."""
|
||||
|
||||
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/chat/completions"
|
||||
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _completion_params(self, conversation: dict) -> (str, dict):
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/completions"
|
||||
request_params["prompt"] = self._format_prompt(conversation)
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
if response_json["object"] in ["chat.completion", "chat.completion.chunk"]:
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
async def _async_generate(self, conversation: dict) -> str:
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
|
||||
if use_chat_api:
|
||||
endpoint, additional_params = self._chat_completion_params(conversation)
|
||||
else:
|
||||
endpoint, additional_params = self._completion_params(conversation)
|
||||
|
||||
result = await self._async_generate_with_parameters(conversation, endpoint, additional_params)
|
||||
|
||||
return result
|
||||
|
||||
class GenericOpenAIResponsesAPIAgent(BaseOpenAICompatibleAPIAgent):
|
||||
"""Implements the OpenAPI-compatible Responses API backend."""
|
||||
|
||||
_last_response_id: str | None = None
|
||||
_last_response_id_time: datetime.datetime = None
|
||||
|
||||
def _responses_params(self, conversation: dict) -> (str, dict):
|
||||
request_params = {}
|
||||
api_base_path = self.entry.data.get(CONF_GENERIC_OPENAI_PATH, DEFAULT_GENERIC_OPENAI_PATH)
|
||||
|
||||
endpoint = f"/{api_base_path}/responses"
|
||||
request_params["input"] = conversation[-1]["message"] # last message in the conversation is the user input
|
||||
|
||||
# Assign previous_response_id if relevant
|
||||
if self._last_response_id and self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION):
|
||||
# If the last response was generated recently, use it as a context
|
||||
configured_memory_time: datetime.timedelta = datetime.timedelta(minutes=self.entry.options.get(CONF_REMEMBER_CONVERSATION_TIME_MINUTES, DEFAULT_REMEMBER_CONVERSATION_TIME_MINUTES))
|
||||
last_conversation_age: datetime.timedelta = datetime.datetime.now() - self._last_response_id_time
|
||||
_LOGGER.debug(f"Conversation ID age: {last_conversation_age}")
|
||||
if last_conversation_age < configured_memory_time:
|
||||
_LOGGER.debug(f"Using previous response ID {self._last_response_id} for context")
|
||||
request_params["previous_response_id"] = self._last_response_id
|
||||
else:
|
||||
_LOGGER.debug(f"Previous response ID {self._last_response_id} is too old, not using it for context")
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _validate_response_payload(self, response_json: dict) -> bool:
|
||||
"""
|
||||
Validate that the payload given matches the expected structure for the Responses API.
|
||||
|
||||
API ref: https://platform.openai.com/docs/api-reference/responses/object
|
||||
|
||||
Returns True or raises an error
|
||||
"""
|
||||
required_response_keys = ["object", "output", "status", "id"]
|
||||
missing_keys = [key for key in required_response_keys if key not in response_json]
|
||||
if missing_keys:
|
||||
raise ValueError(f"Response JSON is missing required keys: {', '.join(missing_keys)}")
|
||||
|
||||
if response_json["object"] != "response":
|
||||
raise ValueError(f"Response JSON object is not 'response', got {response_json['object']}")
|
||||
|
||||
if "error" in response_json and response_json["error"] is not None:
|
||||
error = response_json["error"]
|
||||
_LOGGER.error(f"Response received error payload.")
|
||||
if "message" not in error:
|
||||
raise ValueError("Response JSON error is missing 'message' key")
|
||||
raise ValueError(f"Response JSON error: {error['message']}")
|
||||
|
||||
return True
|
||||
|
||||
def _check_response_status(self, response_json: dict) -> None:
|
||||
"""
|
||||
Check the status of the response and logs a message if it is not 'completed'.
|
||||
|
||||
API ref: https://platform.openai.com/docs/api-reference/responses/object#responses_object-status
|
||||
"""
|
||||
if response_json["status"] != "completed":
|
||||
_LOGGER.warning(f"Response status is not 'completed', got {response_json['status']}. Details: {response_json.get('incomplete_details', 'No details provided')}")
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
self._validate_response_payload(response_json)
|
||||
self._check_response_status(response_json)
|
||||
|
||||
outputs = response_json["output"]
|
||||
|
||||
if len(outputs) > 1:
|
||||
_LOGGER.warning("Received multiple outputs from the Responses API, returning the first one.")
|
||||
|
||||
output = outputs[0]
|
||||
|
||||
if not output["type"] == "message":
|
||||
raise NotImplementedError(f"Response output type is not 'message', got {output['type']}")
|
||||
|
||||
if len(output["content"]) > 1:
|
||||
_LOGGER.warning("Received multiple content items in the response output, returning the first one.")
|
||||
|
||||
content = output["content"][0]
|
||||
|
||||
output_type = content["type"]
|
||||
|
||||
to_return: str | None = None
|
||||
|
||||
if output_type == "refusal":
|
||||
_LOGGER.info("Received a refusal from the Responses API.")
|
||||
to_return = content["refusal"]
|
||||
elif output_type == "output_text":
|
||||
to_return = content["text"]
|
||||
else:
|
||||
raise ValueError(f"Response output content type is not expected, got {output_type}")
|
||||
|
||||
# Save the response_id and return the successful response.
|
||||
response_id = response_json["id"]
|
||||
self._last_response_id = response_id
|
||||
self._last_response_id_time = datetime.datetime.now()
|
||||
|
||||
return to_return
|
||||
|
||||
async def _async_generate(self, conversation: dict) -> str:
|
||||
"""Generate a response using the OpenAI-compatible Responses API"""
|
||||
|
||||
endpoint, additional_params = self._responses_params(conversation)
|
||||
|
||||
result = await self._async_generate_with_parameters(conversation, endpoint, additional_params)
|
||||
|
||||
return result
|
||||
|
||||
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
admin_key: str
|
||||
|
||||
@@ -1332,21 +1455,21 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
|
||||
if self.admin_key:
|
||||
headers["Authorization"] = f"Bearer {self.admin_key}"
|
||||
|
||||
|
||||
async with session.get(
|
||||
f"{self.api_host}/v1/internal/model/info",
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
currently_loaded_result = await response.json()
|
||||
|
||||
|
||||
loaded_model = currently_loaded_result["model_name"]
|
||||
if loaded_model == self.model_name:
|
||||
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
|
||||
return
|
||||
else:
|
||||
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
|
||||
|
||||
|
||||
async with session.post(
|
||||
f"{self.api_host}/v1/internal/model/load",
|
||||
json={
|
||||
@@ -1381,7 +1504,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
|
||||
def _completion_params(self, conversation: dict) -> (str, dict):
|
||||
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
|
||||
|
||||
@@ -1396,7 +1519,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
choices = response_json["choices"]
|
||||
if choices[0]["finish_reason"] != "stop":
|
||||
@@ -1412,7 +1535,7 @@ class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
||||
return choices[0]["message"]["content"]
|
||||
else:
|
||||
return choices[0]["text"]
|
||||
|
||||
|
||||
class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
|
||||
"""https://llama-cpp-python.readthedocs.io/en/latest/server/"""
|
||||
grammar: str
|
||||
@@ -1438,13 +1561,13 @@ class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
|
||||
request_params["grammar"] = self.grammar
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
|
||||
def _completion_params(self, conversation: dict) -> (str, dict):
|
||||
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
||||
endpoint, request_params = super()._completion_params(conversation)
|
||||
|
||||
request_params["top_k"] = top_k
|
||||
|
||||
|
||||
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
||||
request_params["grammar"] = self.grammar
|
||||
|
||||
@@ -1471,14 +1594,14 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
session = async_get_clientsession(self.hass)
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
|
||||
async with session.get(
|
||||
f"{self.api_host}/api/tags",
|
||||
headers=headers,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
currently_downloaded_result = await response.json()
|
||||
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.debug("Connection error was: %s", repr(ex))
|
||||
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
|
||||
@@ -1486,7 +1609,7 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
model_names = [ x["name"] for x in currently_downloaded_result["models"] ]
|
||||
if ":" in self.model_name:
|
||||
if not any([ name == self.model_name for name in model_names]):
|
||||
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
|
||||
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
|
||||
elif not any([ name.split(":")[0] == self.model_name for name in model_names ]):
|
||||
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
|
||||
|
||||
@@ -1506,11 +1629,11 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
request_params["raw"] = True # ignore prompt template
|
||||
|
||||
return endpoint, request_params
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
|
||||
def _extract_response(self, response_json: dict) -> str:
|
||||
if response_json["done"] not in ["true", True]:
|
||||
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
||||
|
||||
|
||||
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
|
||||
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
@@ -1521,7 +1644,7 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
return response_json["response"]
|
||||
else:
|
||||
return response_json["message"]["content"]
|
||||
|
||||
|
||||
async def _async_generate(self, conversation: dict) -> str:
|
||||
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
@@ -1533,7 +1656,7 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
|
||||
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
||||
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
|
||||
|
||||
|
||||
request_params = {
|
||||
"model": self.model_name,
|
||||
"stream": False,
|
||||
@@ -1550,18 +1673,18 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
|
||||
if json_mode:
|
||||
request_params["format"] = "json"
|
||||
|
||||
|
||||
if use_chat_api:
|
||||
endpoint, additional_params = self._chat_completion_params(conversation)
|
||||
else:
|
||||
endpoint, additional_params = self._completion_params(conversation)
|
||||
|
||||
|
||||
request_params.update(additional_params)
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
response = None
|
||||
try:
|
||||
@@ -1580,7 +1703,7 @@ class OllamaAPIAgent(LocalLLMAgent):
|
||||
_LOGGER.debug(f"Request was: {request_params}")
|
||||
_LOGGER.debug(f"Result was: {response}")
|
||||
return f"Failed to communicate with the API! {err}"
|
||||
|
||||
|
||||
_LOGGER.debug(result)
|
||||
|
||||
return self._extract_response(result)
|
||||
|
||||
@@ -194,6 +194,7 @@
|
||||
"llama_cpp_existing": "Llama.cpp (existing model)",
|
||||
"text-generation-webui_api": "text-generation-webui API",
|
||||
"generic_openai": "Generic OpenAI Compatible API",
|
||||
"generic_openai_responses": "Generic OpenAPI Compatible Responses API",
|
||||
"llama_cpp_python_server": "llama-cpp-python Server",
|
||||
"ollama": "Ollama API"
|
||||
|
||||
|
||||
@@ -194,6 +194,7 @@
|
||||
"llama_cpp_existing": "Llama.cpp (istniejący model)",
|
||||
"text-generation-webui_api": "text-generation-webui API",
|
||||
"generic_openai": "Ogólne API kompatybilne z OpenAI",
|
||||
"generic_openai_responses": "Ogólne API odpowiedzi kompatybilne z OpenAPI",
|
||||
"llama_cpp_python_server": "llama-cpp-python Server",
|
||||
"ollama": "Ollama API"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user