mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 21:28:05 -05:00
1428 lines
60 KiB
Python
1428 lines
60 KiB
Python
"""Defines the various LLM Backend Agents"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
import importlib
|
|
from typing import Literal, Any, Callable
|
|
import voluptuous as vol
|
|
|
|
import requests
|
|
import re
|
|
import os
|
|
import json
|
|
import csv
|
|
import random
|
|
import time
|
|
|
|
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent
|
|
import homeassistant.components.conversation as ha_conversation
|
|
from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN
|
|
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL, CONF_LLM_HASS_API
|
|
from homeassistant.core import HomeAssistant, callback
|
|
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
|
|
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm, area_registry as ar
|
|
from homeassistant.helpers.event import async_track_state_change, async_call_later
|
|
from homeassistant.util import ulid, color
|
|
|
|
import voluptuous_serialize
|
|
|
|
from .utils import closest_color, flatten_vol_schema, custom_custom_serializer, install_llama_cpp_python, \
|
|
validate_llama_cpp_python_installation, format_url
|
|
from .const import (
|
|
CONF_CHAT_MODEL,
|
|
CONF_MAX_TOKENS,
|
|
CONF_PROMPT,
|
|
CONF_TEMPERATURE,
|
|
CONF_TOP_K,
|
|
CONF_TOP_P,
|
|
CONF_TYPICAL_P,
|
|
CONF_MIN_P,
|
|
CONF_REQUEST_TIMEOUT,
|
|
CONF_BACKEND_TYPE,
|
|
CONF_DOWNLOADED_MODEL_FILE,
|
|
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
CONF_PROMPT_TEMPLATE,
|
|
CONF_TOOL_FORMAT,
|
|
CONF_TOOL_MULTI_TURN_CHAT,
|
|
CONF_ENABLE_FLASH_ATTENTION,
|
|
CONF_USE_GBNF_GRAMMAR,
|
|
CONF_GBNF_GRAMMAR_FILE,
|
|
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
CONF_IN_CONTEXT_EXAMPLES_FILE,
|
|
CONF_NUM_IN_CONTEXT_EXAMPLES,
|
|
CONF_TEXT_GEN_WEBUI_PRESET,
|
|
CONF_OPENAI_API_KEY,
|
|
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
|
|
CONF_REFRESH_SYSTEM_PROMPT,
|
|
CONF_REMEMBER_CONVERSATION,
|
|
CONF_REMEMBER_NUM_INTERACTIONS,
|
|
CONF_PROMPT_CACHING_ENABLED,
|
|
CONF_PROMPT_CACHING_INTERVAL,
|
|
CONF_SERVICE_CALL_REGEX,
|
|
CONF_REMOTE_USE_CHAT_ENDPOINT,
|
|
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
|
|
CONF_OLLAMA_KEEP_ALIVE_MIN,
|
|
CONF_OLLAMA_JSON_MODE,
|
|
CONF_CONTEXT_LENGTH,
|
|
CONF_BATCH_SIZE,
|
|
CONF_THREAD_COUNT,
|
|
CONF_BATCH_THREAD_COUNT,
|
|
DEFAULT_MAX_TOKENS,
|
|
DEFAULT_PROMPT,
|
|
DEFAULT_TEMPERATURE,
|
|
DEFAULT_TOP_K,
|
|
DEFAULT_TOP_P,
|
|
DEFAULT_MIN_P,
|
|
DEFAULT_TYPICAL_P,
|
|
DEFAULT_BACKEND_TYPE,
|
|
DEFAULT_REQUEST_TIMEOUT,
|
|
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
|
|
DEFAULT_PROMPT_TEMPLATE,
|
|
DEFAULT_TOOL_FORMAT,
|
|
DEFAULT_TOOL_MULTI_TURN_CHAT,
|
|
DEFAULT_ENABLE_FLASH_ATTENTION,
|
|
DEFAULT_USE_GBNF_GRAMMAR,
|
|
DEFAULT_GBNF_GRAMMAR_FILE,
|
|
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
|
|
DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
|
|
DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
|
|
DEFAULT_REFRESH_SYSTEM_PROMPT,
|
|
DEFAULT_REMEMBER_CONVERSATION,
|
|
DEFAULT_REMEMBER_NUM_INTERACTIONS,
|
|
DEFAULT_PROMPT_CACHING_ENABLED,
|
|
DEFAULT_PROMPT_CACHING_INTERVAL,
|
|
DEFAULT_SERVICE_CALL_REGEX,
|
|
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
|
|
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
|
|
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
|
|
DEFAULT_OLLAMA_JSON_MODE,
|
|
DEFAULT_CONTEXT_LENGTH,
|
|
DEFAULT_BATCH_SIZE,
|
|
DEFAULT_THREAD_COUNT,
|
|
DEFAULT_BATCH_THREAD_COUNT,
|
|
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
|
|
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
|
|
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
|
|
DOMAIN,
|
|
HOME_LLM_API_ID,
|
|
SERVICE_TOOL_NAME,
|
|
PROMPT_TEMPLATE_DESCRIPTIONS,
|
|
TOOL_FORMAT_FULL,
|
|
TOOL_FORMAT_REDUCED,
|
|
TOOL_FORMAT_MINIMAL,
|
|
ALLOWED_SERVICE_CALL_ARGUMENTS,
|
|
)
|
|
|
|
# make type checking work for llama-cpp-python without importing it directly at runtime
|
|
from typing import TYPE_CHECKING
|
|
if TYPE_CHECKING:
|
|
from llama_cpp import Llama as LlamaType
|
|
else:
|
|
LlamaType = Any
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
|
|
|
class LocalLLMAgent(AbstractConversationAgent):
|
|
"""Base Local LLM conversation agent."""
|
|
|
|
hass: HomeAssistant
|
|
entry_id: str
|
|
history: dict[str, list[dict]]
|
|
in_context_examples: list[dict]
|
|
|
|
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
|
"""Initialize the agent."""
|
|
self.hass = hass
|
|
self.entry_id = entry.entry_id
|
|
self.history = {}
|
|
|
|
self.backend_type = entry.data.get(
|
|
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
|
|
)
|
|
|
|
self.in_context_examples = None
|
|
if entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
|
self._load_icl_examples(entry.options.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE))
|
|
|
|
self._load_model(entry)
|
|
|
|
def _load_icl_examples(self, filename: str):
|
|
"""Load info used for generating in context learning examples"""
|
|
try:
|
|
icl_filename = os.path.join(os.path.dirname(__file__), filename)
|
|
|
|
with open(icl_filename, encoding="utf-8-sig") as f:
|
|
self.in_context_examples = list(csv.DictReader(f))
|
|
|
|
if set(self.in_context_examples[0].keys()) != set(["type", "request", "tool", "response" ]):
|
|
raise Exception("ICL csv file did not have 2 columns: service & response")
|
|
|
|
if len(self.in_context_examples) == 0:
|
|
_LOGGER.warning(f"There were no in context learning examples found in the file '{filename}'!")
|
|
self.in_context_examples = None
|
|
else:
|
|
_LOGGER.debug(f"Loaded {len(self.in_context_examples)} examples for ICL")
|
|
except Exception:
|
|
_LOGGER.exception("Failed to load in context learning examples!")
|
|
self.in_context_examples = None
|
|
|
|
def _update_options(self):
|
|
if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
|
|
self._load_icl_examples(self.entry.options.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE))
|
|
else:
|
|
self.in_context_examples = None
|
|
|
|
@property
|
|
def entry(self) -> ConfigEntry:
|
|
try:
|
|
return self.hass.data[DOMAIN][self.entry_id]
|
|
except KeyError as ex:
|
|
raise Exception("Attempted to use self.entry during startup.") from ex
|
|
|
|
@property
|
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
|
"""Return a list of supported languages."""
|
|
return MATCH_ALL
|
|
|
|
def _load_model(self, entry: ConfigEntry) -> None:
|
|
"""Load the model on the backend. Implemented by sub-classes"""
|
|
raise NotImplementedError()
|
|
|
|
def _generate(self, conversation: dict) -> str:
|
|
"""Call the backend to generate a response from the conversation. Implemented by sub-classes"""
|
|
raise NotImplementedError()
|
|
|
|
async def _async_generate(self, conversation: dict) -> str:
|
|
"""Async wrapper for _generate()"""
|
|
return await self.hass.async_add_executor_job(
|
|
self._generate, conversation
|
|
)
|
|
|
|
def _warn_context_size(self):
|
|
num_entities = len(self._async_get_exposed_entities()[0])
|
|
context_size = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
_LOGGER.error("There were too many entities exposed when attempting to generate a response for " +
|
|
f"{self.entry.data[CONF_CHAT_MODEL]} and it exceeded the context size for the model. " +
|
|
f"Please reduce the number of entities exposed ({num_entities}) or increase the model's context size ({int(context_size)})")
|
|
|
|
async def async_process(
|
|
self, user_input: ConversationInput
|
|
) -> ConversationResult:
|
|
"""Process a sentence."""
|
|
|
|
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
|
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
|
|
template_desc = PROMPT_TEMPLATE_DESCRIPTIONS[prompt_template]
|
|
refresh_system_prompt = self.entry.options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
|
|
remember_conversation = self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
|
|
remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS)
|
|
service_call_regex = self.entry.options.get(CONF_SERVICE_CALL_REGEX, DEFAULT_SERVICE_CALL_REGEX)
|
|
|
|
try:
|
|
service_call_pattern = re.compile(service_call_regex)
|
|
except Exception as err:
|
|
_LOGGER.exception("There was a problem compiling the service call regex")
|
|
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
f"Sorry, there was a problem compiling the service call regex: {err}",
|
|
)
|
|
return ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
llm_api: llm.APIInstance | None = None
|
|
if self.entry.options.get(CONF_LLM_HASS_API):
|
|
try:
|
|
llm_api = await llm.async_get_api(
|
|
self.hass,
|
|
self.entry.options[CONF_LLM_HASS_API],
|
|
llm_context=llm.LLMContext(
|
|
platform=DOMAIN,
|
|
context=user_input.context,
|
|
user_prompt=user_input.text,
|
|
language=user_input.language,
|
|
assistant=ha_conversation.DOMAIN,
|
|
device_id=user_input.device_id,
|
|
)
|
|
)
|
|
except HomeAssistantError as err:
|
|
_LOGGER.error("Error getting LLM API: %s", err)
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
f"Error preparing LLM API: {err}",
|
|
)
|
|
return ConversationResult(
|
|
response=intent_response, conversation_id=user_input.conversation_id
|
|
)
|
|
|
|
if user_input.conversation_id in self.history:
|
|
conversation_id = user_input.conversation_id
|
|
conversation = self.history[conversation_id] if remember_conversation else [self.history[conversation_id][0]]
|
|
else:
|
|
conversation_id = ulid.ulid()
|
|
conversation = []
|
|
|
|
if len(conversation) == 0 or refresh_system_prompt:
|
|
try:
|
|
message = self._generate_system_prompt(raw_prompt, llm_api)
|
|
except TemplateError as err:
|
|
_LOGGER.error("Error rendering prompt: %s", err)
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
f"Sorry, I had a problem with my template: {err}",
|
|
)
|
|
return ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
system_prompt = { "role": "system", "message": message }
|
|
|
|
if len(conversation) == 0:
|
|
conversation.append(system_prompt)
|
|
if not remember_conversation:
|
|
self.history[conversation_id] = conversation
|
|
else:
|
|
conversation[0] = system_prompt
|
|
|
|
conversation.append({"role": "user", "message": user_input.text})
|
|
|
|
# generate a response
|
|
try:
|
|
_LOGGER.debug(conversation)
|
|
response = await self._async_generate(conversation)
|
|
_LOGGER.debug(response)
|
|
|
|
except Exception as err:
|
|
_LOGGER.exception("There was a problem talking to the backend")
|
|
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
|
f"Sorry, there was a problem talking to the backend: {repr(err)}",
|
|
)
|
|
return ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
conversation.append({"role": "assistant", "message": response})
|
|
if remember_conversation:
|
|
if remember_num_interactions and len(conversation) > (remember_num_interactions * 2) + 1:
|
|
for i in range(0,2):
|
|
conversation.pop(1)
|
|
self.history[conversation_id] = conversation
|
|
|
|
if llm_api is None:
|
|
# return the output without messing with it if there is no API exposed to the model
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_speech(response.strip())
|
|
return ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
# parse response
|
|
to_say = service_call_pattern.sub("", response.strip())
|
|
for block in service_call_pattern.findall(response.strip()):
|
|
parsed_tool_call: dict = json.loads(block)
|
|
|
|
if llm_api.api.id == HOME_LLM_API_ID:
|
|
schema_to_validate = vol.Schema({
|
|
vol.Required('service'): str,
|
|
vol.Required('target_device'): str,
|
|
vol.Optional('rgb_color'): str,
|
|
vol.Optional('brightness'): float,
|
|
vol.Optional('temperature'): float,
|
|
vol.Optional('humidity'): float,
|
|
vol.Optional('fan_mode'): str,
|
|
vol.Optional('hvac_mode'): str,
|
|
vol.Optional('preset_mode'): str,
|
|
vol.Optional('duration'): str,
|
|
vol.Optional('item'): str,
|
|
})
|
|
else:
|
|
schema_to_validate = vol.Schema({
|
|
vol.Required("name"): str,
|
|
vol.Required("arguments"): dict,
|
|
})
|
|
|
|
try:
|
|
schema_to_validate(parsed_tool_call)
|
|
except vol.Error as ex:
|
|
_LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}")
|
|
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
|
|
f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.",
|
|
)
|
|
return ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
_LOGGER.info(f"calling tool: {block}")
|
|
|
|
# try to fix certain arguments
|
|
args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"]
|
|
|
|
# make sure brightness is 0-255 and not a percentage
|
|
if "brightness" in args_dict and 0.0 < args_dict["brightness"] <= 1.0:
|
|
args_dict["brightness"] = int(args_dict["brightness"] * 255)
|
|
|
|
# convert string "tuple" to a list for RGB colors
|
|
if "rgb_color" in args_dict and isinstance(args_dict["rgb_color"], str):
|
|
args_dict["rgb_color"] = [ int(x) for x in args_dict["rgb_color"][1:-1].split(",") ]
|
|
|
|
if llm_api.api.id == HOME_LLM_API_ID:
|
|
to_say = to_say + parsed_tool_call.pop("to_say", "")
|
|
tool_input = llm.ToolInput(
|
|
tool_name=SERVICE_TOOL_NAME,
|
|
tool_args=parsed_tool_call,
|
|
)
|
|
else:
|
|
tool_input = llm.ToolInput(
|
|
tool_name=parsed_tool_call["name"],
|
|
tool_args=parsed_tool_call["arguments"],
|
|
)
|
|
|
|
try:
|
|
tool_response = await llm_api.async_call_tool(tool_input)
|
|
except (HomeAssistantError, vol.Invalid) as e:
|
|
tool_response = {"error": type(e).__name__}
|
|
if str(e):
|
|
tool_response["error_text"] = str(e)
|
|
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
|
|
f"I'm sorry! I encountered an error calling the tool. See the logs for more info.",
|
|
)
|
|
return ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
_LOGGER.debug("Tool response: %s", tool_response)
|
|
|
|
# handle models that generate a function call and wait for the result before providing a response
|
|
if self.entry.options.get(CONF_TOOL_MULTI_TURN_CHAT, DEFAULT_TOOL_MULTI_TURN_CHAT):
|
|
conversation.append({"role": "tool", "message": json.dumps(tool_response)})
|
|
|
|
# generate a response based on the tool result
|
|
try:
|
|
_LOGGER.debug(conversation)
|
|
to_say = await self._async_generate(conversation)
|
|
_LOGGER.debug(to_say)
|
|
|
|
except Exception as err:
|
|
_LOGGER.exception("There was a problem talking to the backend")
|
|
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
|
f"Sorry, there was a problem talking to the backend: {repr(err)}",
|
|
)
|
|
return ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
conversation.append({"role": "assistant", "message": response})
|
|
|
|
# generate intent response to Home Assistant
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_speech(to_say)
|
|
return ConversationResult(
|
|
response=intent_response, conversation_id=conversation_id
|
|
)
|
|
|
|
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
|
"""Gather exposed entity states"""
|
|
entity_states = {}
|
|
domains = set()
|
|
entity_registry = er.async_get(self.hass)
|
|
area_registry = ar.async_get(self.hass)
|
|
|
|
for state in self.hass.states.async_all():
|
|
if not async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
|
|
continue
|
|
|
|
entity = entity_registry.async_get(state.entity_id)
|
|
|
|
attributes = dict(state.attributes)
|
|
attributes["state"] = state.state
|
|
if entity and entity.aliases:
|
|
attributes["aliases"] = entity.aliases
|
|
|
|
if entity and entity.area_id:
|
|
area = area_registry.async_get_area(entity.area_id)
|
|
attributes["area_id"] = area.id
|
|
attributes["area_name"] = area.name
|
|
|
|
entity_states[state.entity_id] = attributes
|
|
domains.add(state.domain)
|
|
|
|
return entity_states, list(domains)
|
|
|
|
def _format_prompt(
|
|
self, prompt: list[dict], include_generation_prompt: bool = True
|
|
) -> str:
|
|
"""Format a conversation into a raw text completion using the model's prompt template"""
|
|
formatted_prompt = ""
|
|
|
|
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
|
|
template_desc = PROMPT_TEMPLATE_DESCRIPTIONS[prompt_template]
|
|
|
|
# handle models without a system prompt
|
|
if prompt[0]["role"] == "system" and "system" not in template_desc:
|
|
system_prompt = prompt.pop(0)
|
|
prompt[0]["message"] = system_prompt["message"] + prompt[0]["message"]
|
|
|
|
for message in prompt:
|
|
role = message["role"]
|
|
message = message["message"]
|
|
# fall back to the "user" role for unknown roles
|
|
role_desc = template_desc.get(role, template_desc["user"])
|
|
formatted_prompt = (
|
|
formatted_prompt + f"{role_desc['prefix']}{message}{role_desc['suffix']}\n"
|
|
)
|
|
|
|
if include_generation_prompt:
|
|
formatted_prompt = formatted_prompt + template_desc["generation_prompt"]
|
|
|
|
_LOGGER.debug(formatted_prompt)
|
|
return formatted_prompt
|
|
|
|
def _format_tool(self, name: str, parameters: vol.Schema, description: str):
|
|
style = self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT)
|
|
|
|
if style == TOOL_FORMAT_MINIMAL:
|
|
result = f"{name}({','.join(flatten_vol_schema(parameters))})"
|
|
if description:
|
|
result = result + f" - {description}"
|
|
return result
|
|
|
|
raw_parameters: list = voluptuous_serialize.convert(
|
|
parameters, custom_serializer=custom_custom_serializer)
|
|
|
|
# handle vol.Any in the key side of things
|
|
processed_parameters = []
|
|
for param in raw_parameters:
|
|
if isinstance(param["name"], vol.Any):
|
|
for possible_name in param["name"].validators:
|
|
actual_param = param.copy()
|
|
actual_param["name"] = possible_name
|
|
actual_param["required"] = False
|
|
processed_parameters.append(actual_param)
|
|
else:
|
|
processed_parameters.append(param)
|
|
|
|
if style == TOOL_FORMAT_REDUCED:
|
|
return {
|
|
"name": name,
|
|
"description": description,
|
|
"parameters": {
|
|
"properties": {
|
|
x["name"]: x.get("type", "string") for x in processed_parameters
|
|
},
|
|
"required": [
|
|
x["name"] for x in processed_parameters if x.get("required")
|
|
]
|
|
}
|
|
}
|
|
elif style == TOOL_FORMAT_FULL:
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": name,
|
|
"description": description,
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
x["name"]: {
|
|
"type": x.get("type", "string"),
|
|
"description": x.get("description", ""),
|
|
} for x in processed_parameters
|
|
},
|
|
"required": [
|
|
x["name"] for x in processed_parameters if x.get("required")
|
|
]
|
|
}
|
|
}
|
|
}
|
|
|
|
raise Exception(f"Unknown tool format {style}")
|
|
|
|
def _generate_icl_examples(self, num_examples, entity_names):
|
|
entity_names = entity_names[:]
|
|
entity_domains = set([x.split(".")[0] for x in entity_names])
|
|
|
|
area_registry = ar.async_get(self.hass)
|
|
all_areas = list(area_registry.async_list_areas())
|
|
|
|
in_context_examples = [
|
|
x for x in self.in_context_examples
|
|
if x["type"] in entity_domains
|
|
]
|
|
|
|
random.shuffle(in_context_examples)
|
|
random.shuffle(entity_names)
|
|
|
|
num_examples_to_generate = min(num_examples, len(in_context_examples))
|
|
if num_examples_to_generate < num_examples:
|
|
_LOGGER.warning(f"Attempted to generate {num_examples} ICL examples for conversation, but only {len(in_context_examples)} are available!")
|
|
|
|
examples = []
|
|
for _ in range(num_examples_to_generate):
|
|
chosen_example = in_context_examples.pop()
|
|
request = chosen_example["request"]
|
|
response = chosen_example["response"]
|
|
|
|
random_device = [ x for x in entity_names if x.split(".")[0] == chosen_example["type"] ][0]
|
|
random_area = random.choice(all_areas).name
|
|
random_brightness = round(random.random(), 2)
|
|
random_color = random.choice(list(color.COLORS.keys()))
|
|
|
|
tool_arguments = {}
|
|
|
|
if "<area>" in request:
|
|
request = request.replace("<area>", random_area)
|
|
response = response.replace("<area>", random_area)
|
|
tool_arguments["area"] = random_area
|
|
|
|
if "<name>" in request:
|
|
request = request.replace("<name>", random_device)
|
|
response = response.replace("<name>", random_device)
|
|
tool_arguments["name"] = random_device
|
|
|
|
if "<brightness>" in request:
|
|
request = request.replace("<brightness>", str(random_brightness))
|
|
response = response.replace("<brightness>", str(random_brightness))
|
|
tool_arguments["brightness"] = random_brightness
|
|
|
|
if "<color>" in request:
|
|
request = request.replace("<color>", random_color)
|
|
response = response.replace("<color>", random_color)
|
|
tool_arguments["color"] = random_color
|
|
|
|
examples.append({
|
|
"request": request,
|
|
"response": response,
|
|
"tool": {
|
|
"name": chosen_example["tool"],
|
|
"arguments": tool_arguments
|
|
}
|
|
})
|
|
|
|
return examples
|
|
|
|
def _generate_system_prompt(self, prompt_template: str, llm_api: llm.APIInstance) -> str:
|
|
"""Generate the system prompt with current entity states"""
|
|
entities_to_expose, domains = self._async_get_exposed_entities()
|
|
|
|
extra_attributes_to_expose = self.entry.options \
|
|
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
|
|
|
|
def expose_attributes(attributes) -> list[str]:
|
|
result = []
|
|
for attribute_name in extra_attributes_to_expose:
|
|
if attribute_name not in attributes:
|
|
continue
|
|
|
|
_LOGGER.debug(f"{attribute_name} = {attributes[attribute_name]}")
|
|
|
|
value = attributes[attribute_name]
|
|
if value is not None:
|
|
if attribute_name == "temperature":
|
|
value = int(value)
|
|
if value > 50:
|
|
value = f"{value}F"
|
|
else:
|
|
value = f"{value}C"
|
|
elif attribute_name == "rgb_color":
|
|
value = F"{closest_color(value)} {value}"
|
|
elif attribute_name == "volume_level":
|
|
value = f"vol={int(value*100)}"
|
|
elif attribute_name == "brightness":
|
|
value = f"{int(value/255*100)}%"
|
|
elif attribute_name == "humidity":
|
|
value = f"{value}%"
|
|
|
|
result.append(str(value))
|
|
return result
|
|
|
|
devices = []
|
|
formatted_devices = ""
|
|
|
|
# expose devices and their alias as well
|
|
for name, attributes in entities_to_expose.items():
|
|
state = attributes["state"]
|
|
exposed_attributes = expose_attributes(attributes)
|
|
str_attributes = ";".join([state] + exposed_attributes)
|
|
|
|
formatted_devices = formatted_devices + f"{name} '{attributes.get('friendly_name')}' = {str_attributes}\n"
|
|
devices.append({
|
|
"entity_id": name,
|
|
"name": attributes.get('friendly_name'),
|
|
"state": state,
|
|
"attributes": exposed_attributes,
|
|
"area_name": attributes.get("area_name"),
|
|
"area_id": attributes.get("area_id"),
|
|
"is_alias": False
|
|
})
|
|
if "aliases" in attributes:
|
|
for alias in attributes["aliases"]:
|
|
formatted_devices = formatted_devices + f"{name} '{alias}' = {str_attributes}\n"
|
|
devices.append({
|
|
"entity_id": name,
|
|
"name": alias,
|
|
"state": state,
|
|
"attributes": exposed_attributes,
|
|
"area_name": attributes.get("area_name"),
|
|
"area_id": attributes.get("area_id"),
|
|
"is_alias": True
|
|
})
|
|
|
|
if llm_api:
|
|
if llm_api.api.id == HOME_LLM_API_ID:
|
|
service_dict = self.hass.services.async_services()
|
|
all_services = []
|
|
scripts_added = False
|
|
for domain in domains:
|
|
# scripts show up as individual services
|
|
if domain == "script" and not scripts_added:
|
|
all_services.extend([
|
|
("script.reload", vol.Schema({}), ""),
|
|
("script.turn_on", vol.Schema({}), ""),
|
|
("script.turn_off", vol.Schema({}), ""),
|
|
("script.toggle", vol.Schema({}), ""),
|
|
])
|
|
scripts_added = True
|
|
continue
|
|
|
|
for name, service in service_dict.get(domain, {}).items():
|
|
args = flatten_vol_schema(service.schema)
|
|
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
|
|
service_schema = vol.Schema({
|
|
vol.Optional(arg): str for arg in args_to_expose
|
|
})
|
|
|
|
all_services.append((f"{domain}.{name}", service_schema, ""))
|
|
|
|
tools = [
|
|
self._format_tool(*tool)
|
|
for tool in all_services
|
|
]
|
|
|
|
else:
|
|
tools = [
|
|
self._format_tool(tool.name, tool.parameters, tool.description)
|
|
for tool in llm_api.tools
|
|
]
|
|
|
|
if self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT) == TOOL_FORMAT_MINIMAL:
|
|
formatted_tools = ", ".join(tools)
|
|
else:
|
|
formatted_tools = json.dumps(tools)
|
|
else:
|
|
tools = ["No tools were provided. If the user requests you interact with a device, tell them you are unable to do so."]
|
|
formatted_tools = tools[0]
|
|
|
|
render_variables = {
|
|
"devices": devices,
|
|
"formatted_devices": formatted_devices,
|
|
"tools": tools,
|
|
"formatted_tools": formatted_tools,
|
|
"response_examples": []
|
|
}
|
|
|
|
# only pass examples if there are loaded examples + an API was exposed
|
|
if self.in_context_examples and llm_api:
|
|
num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES))
|
|
render_variables["response_examples"] = self._generate_icl_examples(num_examples, list(entities_to_expose.keys()))
|
|
|
|
return template.Template(prompt_template, self.hass).async_render(
|
|
render_variables,
|
|
parse_result=False,
|
|
)
|
|
|
|
class LlamaCppAgent(LocalLLMAgent):
|
|
model_path: str
|
|
llm: LlamaType
|
|
grammar: Any
|
|
llama_cpp_module: Any
|
|
remove_prompt_caching_listener: Callable
|
|
model_lock: threading.Lock
|
|
last_cache_prime: float
|
|
last_updated_entities: dict[str, float]
|
|
cache_refresh_after_cooldown: bool
|
|
loaded_model_settings: dict[str, Any]
|
|
|
|
def _load_model(self, entry: ConfigEntry) -> None:
|
|
self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
|
|
|
|
_LOGGER.info(
|
|
"Using model file '%s'", self.model_path
|
|
)
|
|
|
|
if not self.model_path:
|
|
raise Exception(f"Model was not found at '{self.model_path}'!")
|
|
|
|
validate_llama_cpp_python_installation()
|
|
|
|
# don't import it until now because the wheel is installed by config_flow.py
|
|
try:
|
|
self.llama_cpp_module = importlib.import_module("llama_cpp")
|
|
except ModuleNotFoundError:
|
|
# attempt to re-install llama-cpp-python if it was uninstalled for some reason
|
|
install_result = install_llama_cpp_python(self.hass.config.config_dir)
|
|
if not install_result == True:
|
|
raise ConfigEntryError("llama-cpp-python was not installed on startup and re-installing it led to an error!")
|
|
|
|
validate_llama_cpp_python_installation()
|
|
self.llama_cpp_module = importlib.import_module("llama_cpp")
|
|
|
|
Llama = getattr(self.llama_cpp_module, "Llama")
|
|
|
|
_LOGGER.debug(f"Loading model '{self.model_path}'...")
|
|
self.loaded_model_settings = {}
|
|
self.loaded_model_settings[CONF_CONTEXT_LENGTH] = entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
self.loaded_model_settings[CONF_BATCH_SIZE] = entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE)
|
|
self.loaded_model_settings[CONF_THREAD_COUNT] = entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT)
|
|
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT)
|
|
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION)
|
|
|
|
self.llm = Llama(
|
|
model_path=self.model_path,
|
|
n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]),
|
|
n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]),
|
|
n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]),
|
|
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]),
|
|
flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION],
|
|
)
|
|
_LOGGER.debug("Model loaded")
|
|
|
|
self.grammar = None
|
|
if entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
|
self._load_grammar(entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE))
|
|
|
|
|
|
# TODO: check about disk caching
|
|
# self.llm.set_cache(self.llama_cpp_module.LlamaDiskCache(
|
|
# capacity_bytes=(512 * 10e8),
|
|
# cache_dir=os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "kv_cache")
|
|
# ))
|
|
|
|
self.remove_prompt_caching_listener = None
|
|
self.last_cache_prime = None
|
|
self.last_updated_entities = {}
|
|
self.cache_refresh_after_cooldown = False
|
|
self.model_lock = threading.Lock()
|
|
|
|
self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED)
|
|
if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED]:
|
|
@callback
|
|
async def enable_caching_after_startup(_now) -> None:
|
|
self._set_prompt_caching(enabled=True)
|
|
await self._async_cache_prompt(None, None, None)
|
|
async_call_later(self.hass, 5.0, enable_caching_after_startup)
|
|
|
|
def _load_grammar(self, filename: str):
|
|
LlamaGrammar = getattr(self.llama_cpp_module, "LlamaGrammar")
|
|
_LOGGER.debug(f"Loading grammar {filename}...")
|
|
try:
|
|
with open(os.path.join(os.path.dirname(__file__), filename)) as f:
|
|
grammar_str = "".join(f.readlines())
|
|
self.grammar = LlamaGrammar.from_string(grammar_str)
|
|
self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] = filename
|
|
_LOGGER.debug("Loaded grammar")
|
|
except Exception:
|
|
_LOGGER.exception("Failed to load grammar!")
|
|
self.grammar = None
|
|
|
|
def _update_options(self):
|
|
LocalLLMAgent._update_options(self)
|
|
|
|
model_reloaded = False
|
|
if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \
|
|
self.loaded_model_settings[CONF_BATCH_SIZE] != self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) or \
|
|
self.loaded_model_settings[CONF_THREAD_COUNT] != self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) or \
|
|
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] != self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT) or \
|
|
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] != self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION):
|
|
|
|
_LOGGER.debug(f"Reloading model '{self.model_path}'...")
|
|
self.loaded_model_settings[CONF_CONTEXT_LENGTH] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
self.loaded_model_settings[CONF_BATCH_SIZE] = self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE)
|
|
self.loaded_model_settings[CONF_THREAD_COUNT] = self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT)
|
|
self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT)
|
|
self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION)
|
|
|
|
Llama = getattr(self.llama_cpp_module, "Llama")
|
|
self.llm = Llama(
|
|
model_path=self.model_path,
|
|
n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]),
|
|
n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]),
|
|
n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]),
|
|
n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]),
|
|
flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION],
|
|
)
|
|
_LOGGER.debug("Model loaded")
|
|
model_reloaded = True
|
|
|
|
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
|
current_grammar = self.entry.options.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
|
|
if not self.grammar or self.loaded_model_settings[CONF_GBNF_GRAMMAR_FILE] != current_grammar:
|
|
self._load_grammar(current_grammar)
|
|
else:
|
|
self.grammar = None
|
|
|
|
if self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
|
|
self._set_prompt_caching(enabled=True)
|
|
|
|
if self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] != self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED) or \
|
|
model_reloaded:
|
|
self.loaded_model_settings[CONF_PROMPT_CACHING_ENABLED] = self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED)
|
|
|
|
async def cache_current_prompt(_now):
|
|
await self._async_cache_prompt(None, None, None)
|
|
async_call_later(self.hass, 1.0, cache_current_prompt)
|
|
else:
|
|
self._set_prompt_caching(enabled=False)
|
|
|
|
def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
|
|
"""Takes the super class function results and sorts the entities with the recently updated at the end"""
|
|
entities, domains = LocalLLMAgent._async_get_exposed_entities(self)
|
|
|
|
# ignore sorting if prompt caching is disabled
|
|
if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED):
|
|
return entities, domains
|
|
|
|
entity_order = { name: None for name in entities.keys() }
|
|
entity_order.update(self.last_updated_entities)
|
|
|
|
def sort_key(item):
|
|
item_name, last_updated = item
|
|
# Handle cases where last updated is None
|
|
if last_updated is None:
|
|
return (False, '', item_name)
|
|
else:
|
|
return (True, last_updated, '')
|
|
|
|
# Sort the items based on the sort_key function
|
|
sorted_items = sorted(list(entity_order.items()), key=sort_key)
|
|
|
|
_LOGGER.debug(f"sorted_items: {sorted_items}")
|
|
|
|
sorted_entities = {}
|
|
for item_name, _ in sorted_items:
|
|
sorted_entities[item_name] = entities[item_name]
|
|
|
|
return sorted_entities, domains
|
|
|
|
def _set_prompt_caching(self, *, enabled=True):
|
|
if enabled and not self.remove_prompt_caching_listener:
|
|
_LOGGER.info("enabling prompt caching...")
|
|
|
|
entity_ids = [
|
|
state.entity_id for state in self.hass.states.async_all() \
|
|
if async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id)
|
|
]
|
|
|
|
_LOGGER.debug(f"watching entities: {entity_ids}")
|
|
|
|
self.remove_prompt_caching_listener = async_track_state_change(self.hass, entity_ids, self._async_cache_prompt)
|
|
|
|
elif not enabled and self.remove_prompt_caching_listener:
|
|
_LOGGER.info("disabling prompt caching...")
|
|
self.remove_prompt_caching_listener()
|
|
|
|
@callback
|
|
async def _async_cache_prompt(self, entity, old_state, new_state):
|
|
refresh_start = time.time()
|
|
|
|
# track last update time so we can sort the context efficiently
|
|
if entity:
|
|
self.last_updated_entities[entity] = refresh_start
|
|
|
|
llm_api: llm.APIInstance | None = None
|
|
if self.entry.options.get(CONF_LLM_HASS_API):
|
|
try:
|
|
llm_api = await llm.async_get_api(
|
|
self.hass, self.entry.options[CONF_LLM_HASS_API]
|
|
)
|
|
except HomeAssistantError:
|
|
_LOGGER.exception("Failed to get LLM API when caching prompt!")
|
|
return
|
|
|
|
_LOGGER.debug(f"refreshing cached prompt because {entity} changed...")
|
|
await self.hass.async_add_executor_job(self._cache_prompt, llm_api)
|
|
|
|
refresh_end = time.time()
|
|
_LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec")
|
|
|
|
def _cache_prompt(self, llm_api: llm.API) -> None:
|
|
# if a refresh is already scheduled then exit
|
|
if self.cache_refresh_after_cooldown:
|
|
return
|
|
|
|
# if we are inside the cooldown period, request a refresh and exit
|
|
current_time = time.time()
|
|
fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
|
if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval:
|
|
self.cache_refresh_after_cooldown = True
|
|
return
|
|
|
|
# try to acquire the lock, if we are still running for some reason, request a refresh and exit
|
|
lock_acquired = self.model_lock.acquire(False)
|
|
if not lock_acquired:
|
|
self.cache_refresh_after_cooldown = True
|
|
return
|
|
|
|
try:
|
|
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
|
prompt = self._format_prompt([
|
|
{ "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)},
|
|
{ "role": "user", "message": "" }
|
|
], include_generation_prompt=False)
|
|
|
|
input_tokens = self.llm.tokenize(
|
|
prompt.encode(), add_bos=False
|
|
)
|
|
|
|
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
|
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
|
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
|
grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None
|
|
|
|
_LOGGER.debug(f"Options: {self.entry.options}")
|
|
|
|
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
|
|
|
# grab just one token. should prime the kv cache with the system prompt
|
|
next(self.llm.generate(
|
|
input_tokens,
|
|
temp=temperature,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
grammar=grammar
|
|
))
|
|
|
|
self.last_cache_prime = time.time()
|
|
finally:
|
|
self.model_lock.release()
|
|
|
|
|
|
# schedule a refresh using async_call_later
|
|
# if the flag is set after the delay then we do another refresh
|
|
|
|
@callback
|
|
async def refresh_if_requested(_now):
|
|
if self.cache_refresh_after_cooldown:
|
|
self.cache_refresh_after_cooldown = False
|
|
|
|
refresh_start = time.time()
|
|
_LOGGER.debug(f"refreshing cached prompt after cooldown...")
|
|
await self.hass.async_add_executor_job(self._cache_prompt)
|
|
|
|
refresh_end = time.time()
|
|
_LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec")
|
|
|
|
refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL)
|
|
async_call_later(self.hass, float(refresh_delay), refresh_if_requested)
|
|
|
|
|
|
def _generate(self, conversation: dict) -> str:
|
|
prompt = self._format_prompt(conversation)
|
|
|
|
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
|
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
|
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
|
min_p = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
|
|
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
|
|
|
_LOGGER.debug(f"Options: {self.entry.options}")
|
|
|
|
with self.model_lock:
|
|
input_tokens = self.llm.tokenize(
|
|
prompt.encode(), add_bos=False
|
|
)
|
|
|
|
context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
if len(input_tokens) + max_tokens >= context_len:
|
|
self._warn_context_size()
|
|
|
|
_LOGGER.debug(f"Processing {len(input_tokens)} input tokens...")
|
|
output_tokens = self.llm.generate(
|
|
input_tokens,
|
|
temp=temperature,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
min_p=min_p,
|
|
typical_p=typical_p,
|
|
grammar=self.grammar
|
|
)
|
|
|
|
result_tokens = []
|
|
for token in output_tokens:
|
|
if token == self.llm.token_eos():
|
|
break
|
|
|
|
result_tokens.append(token)
|
|
|
|
if len(result_tokens) >= max_tokens:
|
|
break
|
|
|
|
result = self.llm.detokenize(result_tokens).decode()
|
|
|
|
return result
|
|
|
|
class GenericOpenAIAPIAgent(LocalLLMAgent):
|
|
api_host: str
|
|
api_key: str
|
|
model_name: str
|
|
|
|
def _load_model(self, entry: ConfigEntry) -> None:
|
|
self.api_host = format_url(
|
|
hostname=entry.data[CONF_HOST],
|
|
port=entry.data[CONF_PORT],
|
|
ssl=entry.data[CONF_SSL],
|
|
path=""
|
|
)
|
|
|
|
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
|
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
|
|
|
|
|
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
|
request_params = {}
|
|
|
|
endpoint = "/v1/chat/completions"
|
|
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
|
|
|
return endpoint, request_params
|
|
|
|
def _completion_params(self, conversation: dict) -> (str, dict):
|
|
request_params = {}
|
|
|
|
endpoint = "/v1/completions"
|
|
request_params["prompt"] = self._format_prompt(conversation)
|
|
|
|
return endpoint, request_params
|
|
|
|
def _extract_response(self, response_json: dict) -> str:
|
|
choices = response_json["choices"]
|
|
if choices[0]["finish_reason"] != "stop":
|
|
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
|
|
|
if response_json["object"] in ["chat.completion", "chat.completion.chunk"]:
|
|
return choices[0]["message"]["content"]
|
|
else:
|
|
return choices[0]["text"]
|
|
|
|
def _generate(self, conversation: dict) -> str:
|
|
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
|
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
|
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
|
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
|
|
|
|
|
request_params = {
|
|
"model": self.model_name,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
}
|
|
|
|
if use_chat_api:
|
|
endpoint, additional_params = self._chat_completion_params(conversation)
|
|
else:
|
|
endpoint, additional_params = self._completion_params(conversation)
|
|
|
|
request_params.update(additional_params)
|
|
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
try:
|
|
result = requests.post(
|
|
f"{self.api_host}{endpoint}",
|
|
json=request_params,
|
|
timeout=timeout,
|
|
headers=headers,
|
|
)
|
|
|
|
result.raise_for_status()
|
|
except requests.exceptions.Timeout:
|
|
return f"The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
|
|
except requests.RequestException as err:
|
|
_LOGGER.debug(f"Err was: {err}")
|
|
_LOGGER.debug(f"Request was: {request_params}")
|
|
_LOGGER.debug(f"Result was: {result.text}")
|
|
return f"Failed to communicate with the API! {err}"
|
|
|
|
_LOGGER.debug(result.json())
|
|
|
|
return self._extract_response(result.json())
|
|
|
|
class TextGenerationWebuiAgent(GenericOpenAIAPIAgent):
|
|
admin_key: str
|
|
|
|
def _load_model(self, entry: ConfigEntry) -> None:
|
|
super()._load_model(entry)
|
|
self.admin_key = entry.data.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, self.api_key)
|
|
|
|
try:
|
|
headers = {}
|
|
if self.admin_key:
|
|
headers["Authorization"] = f"Bearer {self.admin_key}"
|
|
|
|
currently_loaded_result = requests.get(
|
|
f"{self.api_host}/v1/internal/model/info",
|
|
headers=headers,
|
|
)
|
|
currently_loaded_result.raise_for_status()
|
|
|
|
loaded_model = currently_loaded_result.json()["model_name"]
|
|
if loaded_model == self.model_name:
|
|
_LOGGER.info(f"Model {self.model_name} is already loaded on the remote backend.")
|
|
return
|
|
else:
|
|
_LOGGER.info(f"Model is not {self.model_name} loaded on the remote backend. Loading it now...")
|
|
|
|
load_result = requests.post(
|
|
f"{self.api_host}/v1/internal/model/load",
|
|
json={
|
|
"model_name": self.model_name,
|
|
# TODO: expose arguments to the user in home assistant UI
|
|
# "args": {},
|
|
},
|
|
headers=headers
|
|
)
|
|
load_result.raise_for_status()
|
|
|
|
except Exception as ex:
|
|
_LOGGER.debug("Connection error was: %s", repr(ex))
|
|
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
|
|
|
|
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
|
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
|
|
chat_mode = self.entry.options.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE, DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE)
|
|
|
|
endpoint, request_params = super()._chat_completion_params(conversation)
|
|
|
|
request_params["mode"] = chat_mode
|
|
if chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_CHAT or chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT:
|
|
if preset:
|
|
request_params["character"] = preset
|
|
elif chat_mode == TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT:
|
|
request_params["instruction_template"] = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
|
|
|
|
request_params["truncation_length"] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
request_params["top_k"] = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
|
request_params["min_p"] = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
|
|
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
|
|
|
return endpoint, request_params
|
|
|
|
def _completion_params(self, conversation: dict) -> (str, dict):
|
|
preset = self.entry.options.get(CONF_TEXT_GEN_WEBUI_PRESET)
|
|
|
|
endpoint, request_params = super()._completion_params(conversation)
|
|
|
|
if preset:
|
|
request_params["preset"] = preset
|
|
|
|
request_params["truncation_length"] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
request_params["top_k"] = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
|
request_params["min_p"] = self.entry.options.get(CONF_MIN_P, DEFAULT_MIN_P)
|
|
request_params["typical_p"] = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
|
|
|
return endpoint, request_params
|
|
|
|
def _extract_response(self, response_json: dict) -> str:
|
|
choices = response_json["choices"]
|
|
if choices[0]["finish_reason"] != "stop":
|
|
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
|
|
|
context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
if response_json["usage"]["prompt_tokens"] + max_tokens > context_len:
|
|
self._warn_context_size()
|
|
|
|
# text-gen-webui has a typo where it is 'chat.completions' not 'chat.completion'
|
|
if response_json["object"] == "chat.completions":
|
|
return choices[0]["message"]["content"]
|
|
else:
|
|
return choices[0]["text"]
|
|
|
|
class LlamaCppPythonAPIAgent(GenericOpenAIAPIAgent):
|
|
"""https://llama-cpp-python.readthedocs.io/en/latest/server/"""
|
|
grammar: str
|
|
|
|
def _load_model(self, entry: ConfigEntry):
|
|
super()._load_model(entry)
|
|
|
|
with open(os.path.join(os.path.dirname(__file__), DEFAULT_GBNF_GRAMMAR_FILE)) as f:
|
|
self.grammar = "".join(f.readlines())
|
|
|
|
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
|
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
|
endpoint, request_params = super()._chat_completion_params(conversation)
|
|
|
|
request_params["top_k"] = top_k
|
|
|
|
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
|
request_params["grammar"] = self.grammar
|
|
|
|
return endpoint, request_params
|
|
|
|
def _completion_params(self, conversation: dict) -> (str, dict):
|
|
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
|
|
endpoint, request_params = super()._completion_params(conversation)
|
|
|
|
request_params["top_k"] = top_k
|
|
|
|
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
|
|
request_params["grammar"] = self.grammar
|
|
|
|
return endpoint, request_params
|
|
|
|
class OllamaAPIAgent(LocalLLMAgent):
|
|
api_host: str
|
|
api_key: str
|
|
model_name: str
|
|
|
|
def _load_model(self, entry: ConfigEntry) -> None:
|
|
self.api_host = format_url(
|
|
hostname=entry.data[CONF_HOST],
|
|
port=entry.data[CONF_PORT],
|
|
ssl=entry.data[CONF_SSL],
|
|
path=""
|
|
)
|
|
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
|
|
self.model_name = entry.data.get(CONF_CHAT_MODEL)
|
|
|
|
# ollama handles loading for us so just make sure the model is available
|
|
try:
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
currently_downloaded_result = requests.get(
|
|
f"{self.api_host}/api/tags",
|
|
headers=headers,
|
|
)
|
|
currently_downloaded_result.raise_for_status()
|
|
|
|
except Exception as ex:
|
|
_LOGGER.debug("Connection error was: %s", repr(ex))
|
|
raise ConfigEntryNotReady("There was a problem connecting to the remote server") from ex
|
|
|
|
model_names = [ x["name"] for x in currently_downloaded_result.json()["models"] ]
|
|
if ":" in self.model_name:
|
|
if not any([ name == self.model_name for name in model_names]):
|
|
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
|
|
elif not any([ name.split(":")[0] == self.model_name for name in model_names ]):
|
|
raise ConfigEntryNotReady(f"Ollama server does not have the provided model: {self.model_name}")
|
|
|
|
def _chat_completion_params(self, conversation: dict) -> (str, dict):
|
|
request_params = {}
|
|
|
|
endpoint = "/api/chat"
|
|
request_params["messages"] = [ { "role": x["role"], "content": x["message"] } for x in conversation ]
|
|
|
|
return endpoint, request_params
|
|
|
|
def _completion_params(self, conversation: dict) -> (str, dict):
|
|
request_params = {}
|
|
|
|
endpoint = "/api/generate"
|
|
request_params["prompt"] = self._format_prompt(conversation)
|
|
request_params["raw"] = True # ignore prompt template
|
|
|
|
return endpoint, request_params
|
|
|
|
def _extract_response(self, response_json: dict) -> str:
|
|
if response_json["done"] not in ["true", True]:
|
|
_LOGGER.warning("Model response did not end on a stop token (unfinished sentence)")
|
|
|
|
# TODO: this doesn't work because ollama caches prompts and doesn't always return the full prompt length
|
|
# context_len = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
# max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
# if response_json["prompt_eval_count"] + max_tokens > context_len:
|
|
# self._warn_context_size()
|
|
|
|
if "response" in response_json:
|
|
return response_json["response"]
|
|
else:
|
|
return response_json["message"]["content"]
|
|
|
|
def _generate(self, conversation: dict) -> str:
|
|
context_length = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH)
|
|
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
|
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
|
top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)
|
|
typical_p = self.entry.options.get(CONF_TYPICAL_P, DEFAULT_TYPICAL_P)
|
|
timeout = self.entry.options.get(CONF_REQUEST_TIMEOUT, DEFAULT_REQUEST_TIMEOUT)
|
|
keep_alive = self.entry.options.get(CONF_OLLAMA_KEEP_ALIVE_MIN, DEFAULT_OLLAMA_KEEP_ALIVE_MIN)
|
|
use_chat_api = self.entry.options.get(CONF_REMOTE_USE_CHAT_ENDPOINT, DEFAULT_REMOTE_USE_CHAT_ENDPOINT)
|
|
json_mode = self.entry.options.get(CONF_OLLAMA_JSON_MODE, DEFAULT_OLLAMA_JSON_MODE)
|
|
|
|
request_params = {
|
|
"model": self.model_name,
|
|
"stream": False,
|
|
"keep_alive": f"{keep_alive}m", # prevent ollama from unloading the model
|
|
"options": {
|
|
"num_ctx": context_length,
|
|
"top_p": top_p,
|
|
"top_k": top_k,
|
|
"typical_p": typical_p,
|
|
"temperature": temperature,
|
|
"num_predict": max_tokens,
|
|
}
|
|
}
|
|
|
|
if json_mode:
|
|
request_params["format"] = "json"
|
|
|
|
if use_chat_api:
|
|
endpoint, additional_params = self._chat_completion_params(conversation)
|
|
else:
|
|
endpoint, additional_params = self._completion_params(conversation)
|
|
|
|
request_params.update(additional_params)
|
|
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
try:
|
|
result = requests.post(
|
|
f"{self.api_host}{endpoint}",
|
|
json=request_params,
|
|
timeout=timeout,
|
|
headers=headers,
|
|
)
|
|
|
|
result.raise_for_status()
|
|
except requests.exceptions.Timeout:
|
|
return f"The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities."
|
|
except requests.RequestException as err:
|
|
_LOGGER.debug(f"Err was: {err}")
|
|
_LOGGER.debug(f"Request was: {request_params}")
|
|
_LOGGER.debug(f"Result was: {result.text}")
|
|
return f"Failed to communicate with the API! {err}"
|
|
|
|
_LOGGER.debug(result.json())
|
|
|
|
return self._extract_response(result.json())
|